From 0bc5a995ade2051fc0bd905658c180d7d98287f4 Mon Sep 17 00:00:00 2001 From: Alessio Date: Sun, 24 Mar 2024 20:02:40 -0700 Subject: [PATCH] REFACTOR: remove some redundant code in dm_queries.go --- pkg/persistence/dm_queries.go | 441 +++++++++++++---------------- pkg/persistence/dm_queries_test.go | 9 +- 2 files changed, 202 insertions(+), 248 deletions(-) diff --git a/pkg/persistence/dm_queries.go b/pkg/persistence/dm_queries.go index 2ec1f1c..b641e09 100644 --- a/pkg/persistence/dm_queries.go +++ b/pkg/persistence/dm_queries.go @@ -11,6 +11,12 @@ import ( . "gitlab.com/offline-twitter/twitter_offline_engine/pkg/scraper" ) +const ( + CHAT_MESSAGES_ALL_SQL_FIELDS = "id, chat_room_id, sender_id, sent_at, request_id, text, in_reply_to_id, embedded_tweet_id" + CHAT_ROOMS_ALL_SQL_FIELDS = "id, type, last_messaged_at, is_nsfw, created_at, created_by_user_id, name, avatar_image_remote_url, avatar_image_local_path" + CHAT_ROOM_PARTICIPANTS_ALL_SQL_FIELDS = "chat_room_id, user_id, last_read_event_id, is_chat_settings_valid, is_notifications_disabled, is_mention_notifications_disabled, is_read_only, is_trusted, is_muted, status" +) + func (p Profile) SaveChatRoom(r DMChatRoom) error { _, err := p.DB.NamedExec(` insert into chat_rooms (id, type, last_messaged_at, is_nsfw, created_at, created_by_user_id, name, @@ -70,32 +76,14 @@ func (p Profile) SaveChatRoom(r DMChatRoom) error { return nil } -func (p Profile) GetChatRoom(id DMChatRoomID) (ret DMChatRoom, err error) { - err = p.DB.Get(&ret, ` - select id, type, last_messaged_at, is_nsfw, created_at, created_by_user_id, name, avatar_image_remote_url, avatar_image_local_path - from chat_rooms - where id = ? - `, id) - if err != nil { - return ret, fmt.Errorf("Error getting chat room (%s):\n %w", id, err) - } - - participants := []DMChatParticipant{} - err = p.DB.Select(&participants, ` - select chat_room_id, user_id, last_read_event_id, is_chat_settings_valid, is_notifications_disabled, - is_mention_notifications_disabled, is_read_only, is_trusted, is_muted, status - from chat_room_participants - where chat_room_id = ? - `, id, - ) - if err != nil { - return ret, fmt.Errorf("Error getting chat room participants (%s):\n %w", id, err) - } - ret.Participants = make(map[UserID]DMChatParticipant) - for _, p := range participants { - ret.Participants[p.UserID] = p - } - return ret, nil +// Get a chat room with participants. +// +// Since this function is only used for tests and to confirm the room exists in `fetch_dm` and +// `send_dm` command-line subcommands, it doesn't need to be super efficient. So just reuse the +// full DMChatRoom fetch function and throw away the stuff we don't need +func (p Profile) GetChatRoom(id DMChatRoomID) (room DMChatRoom, err error) { + chat_view := p.GetChatRoomContents(id, 0) + return chat_view.Rooms[id], nil } func (p Profile) SaveChatMessage(m DMMessage) error { @@ -179,9 +167,12 @@ func (p Profile) SaveChatMessage(m DMMessage) error { return nil } +// Get a single chat message, filling its attachment contents. +// +// This function is only used in tests. func (p Profile) GetChatMessage(id DMMessageID) (ret DMMessage, err error) { err = p.DB.Get(&ret, ` - select id, chat_room_id, sender_id, sent_at, request_id, text, in_reply_to_id, embedded_tweet_id + select `+CHAT_MESSAGES_ALL_SQL_FIELDS+` from chat_messages where id = ? `, id, @@ -189,56 +180,14 @@ func (p Profile) GetChatMessage(id DMMessageID) (ret DMMessage, err error) { if err != nil { return ret, fmt.Errorf("Error getting chat message %d:\n %w", id, err) } - - // Reactions - reaccs := []DMReaction{} - err = p.DB.Select(&reaccs, ` - select id, message_id, sender_id, sent_at, emoji - from chat_message_reactions - where message_id = ? - `, id, - ) - if err != nil { - return ret, fmt.Errorf("Error getting reactions to chat message %d:\n %w", id, err) - } ret.Reactions = make(map[UserID]DMReaction) - for _, r := range reaccs { - ret.Reactions[r.SenderID] = r - } - // Images - err = p.DB.Select(&ret.Images, ` - select id, chat_message_id, width, height, remote_url, local_filename, is_downloaded - from chat_message_images - where chat_message_id = ? - `, ret.ID) - if err != nil { - return ret, fmt.Errorf("Error getting images for chat messsage %d:\n %w", id, err) - } + // This is a bit circuitous, but it doesn't matter because this function is only used in tests + trove := NewDMTrove() + trove.Messages[ret.ID] = ret + p.fill_dm_contents(&trove) - // Videos - err = p.DB.Select(&ret.Videos, ` - select id, chat_message_id, width, height, remote_url, local_filename, thumbnail_remote_url, thumbnail_local_filename, - duration, view_count, is_downloaded, is_blocked_by_dmca, is_gif - from chat_message_videos - where chat_message_id = ? - `, ret.ID) - if err != nil { - return ret, fmt.Errorf("Error getting videos for chat messsage %d:\n %w", id, err) - } - - // Urls - err = p.DB.Select(&ret.Urls, ` - select chat_message_id, domain, text, short_text, title, description, creator_id, site_id, thumbnail_width, thumbnail_height, - thumbnail_remote_url, thumbnail_local_path, has_card, has_thumbnail, is_content_downloaded - from chat_message_urls - where chat_message_id = ? - `, ret.ID) - if err != nil { - return ret, fmt.Errorf("Error getting urls for chat messsage %d:\n %w", id, err) - } - - return ret, nil + return trove.Messages[ret.ID], nil } type DMChatView struct { @@ -256,13 +205,14 @@ func NewDMChatView() DMChatView { } } +// Get the list of chat rooms the given user is in, including participants and latest message preview func (p Profile) GetChatRoomsPreview(id UserID) DMChatView { ret := NewDMChatView() + // Get the list of rooms var rooms []DMChatRoom err := p.DB.Select(&rooms, ` - select id, type, last_messaged_at, is_nsfw, created_at, created_by_user_id, name, - avatar_image_remote_url, avatar_image_local_path + select `+CHAT_ROOMS_ALL_SQL_FIELDS+` from chat_rooms where exists (select 1 from chat_room_participants where chat_room_id = chat_rooms.id and user_id = ?) order by last_messaged_at desc @@ -270,11 +220,13 @@ func (p Profile) GetChatRoomsPreview(id UserID) DMChatView { if err != nil { panic(err) } + + // Fill data for the rooms for _, room := range rooms { // Fetch the latest message var msg DMMessage q, args, err := sqlx.Named(` - select id, chat_room_id, sender_id, sent_at, request_id, text, in_reply_to_id, embedded_tweet_id + select `+CHAT_MESSAGES_ALL_SQL_FIELDS+` from chat_messages where chat_room_id = :room_id and sent_at = (select max(sent_at) from chat_messages where chat_room_id = :room_id) @@ -293,25 +245,7 @@ func (p Profile) GetChatRoomsPreview(id UserID) DMChatView { } // Fetch the participants - // DUPE chat-room-participants-SQL - var participants []struct { - DMChatParticipant - User - } - err = p.DB.Select(&participants, ` - select chat_room_id, user_id, last_read_event_id, is_chat_settings_valid, is_notifications_disabled, - is_mention_notifications_disabled, is_read_only, is_trusted, is_muted, status, `+USERS_ALL_SQL_FIELDS+` - from chat_room_participants join users on chat_room_participants.user_id = users.id - where chat_room_id = ? - `, room.ID) - if err != nil { - panic(err) - } - room.Participants = make(map[UserID]DMChatParticipant) - for _, participant := range participants { - room.Participants[participant.User.ID] = participant.DMChatParticipant - ret.Users[participant.User.ID] = participant.User - } + p.fill_chat_room_participants(&room, &ret.DMTrove) // Add everything to the Trove room.LastMessageID = msg.ID @@ -322,12 +256,12 @@ func (p Profile) GetChatRoomsPreview(id UserID) DMChatView { return ret } +// Get chat room detail, including participants and messages func (p Profile) GetChatRoomContents(id DMChatRoomID, latest_timestamp int) DMChatView { ret := NewDMChatView() var room DMChatRoom err := p.DB.Get(&room, ` - select id, type, last_messaged_at, is_nsfw, created_at, created_by_user_id, name, - avatar_image_remote_url, avatar_image_local_path + select `+CHAT_ROOMS_ALL_SQL_FIELDS+` from chat_rooms where id = ? `, id) @@ -335,31 +269,10 @@ func (p Profile) GetChatRoomContents(id DMChatRoomID, latest_timestamp int) DMCh panic(err) } - // Fetch the participants - // DUPE chat-room-participants-SQL - var participants []struct { - DMChatParticipant - User - } - err = p.DB.Select(&participants, ` - select chat_room_id, user_id, last_read_event_id, is_chat_settings_valid, is_notifications_disabled, - is_mention_notifications_disabled, is_read_only, is_trusted, is_muted, status, `+USERS_ALL_SQL_FIELDS+` - from chat_room_participants join users on chat_room_participants.user_id = users.id - where chat_room_id = ? - `, room.ID) - if err != nil { - panic(err) - } - room.Participants = make(map[UserID]DMChatParticipant) - for _, participant := range participants { - room.Participants[participant.User.ID] = participant.DMChatParticipant - ret.Users[participant.User.ID] = participant.User - } - // Fetch all messages var msgs []DMMessage err = p.DB.Select(&msgs, ` - select id, chat_room_id, sender_id, sent_at, request_id, text, in_reply_to_id, embedded_tweet_id + select `+CHAT_MESSAGES_ALL_SQL_FIELDS+` from chat_messages where chat_room_id = ? and sent_at > ? @@ -371,11 +284,14 @@ func (p Profile) GetChatRoomContents(id DMChatRoomID, latest_timestamp int) DMCh } ret.MessageIDs = make([]DMMessageID, len(msgs)) for i, msg := range msgs { - ret.MessageIDs[len(ret.MessageIDs)-i-1] = msg.ID + ret.MessageIDs[len(ret.MessageIDs)-i-1] = msg.ID // Reverse order msg.Reactions = make(map[UserID]DMReaction) ret.Messages[msg.ID] = msg } + // Fetch the participants + p.fill_chat_room_participants(&room, &ret.DMTrove) + // Set last message ID on chat room if len(ret.MessageIDs) > 0 { // If there's no messages, it should be OK to have LastMessageID = 0, since this is only used @@ -386,130 +302,163 @@ func (p Profile) GetChatRoomContents(id DMChatRoomID, latest_timestamp int) DMCh // Put the room in the Trove ret.Rooms[room.ID] = room - if len(ret.MessageIDs) > 0 { - // Fetch all reaccs - var reaccs []DMReaction - message_ids_copy := make([]interface{}, len(ret.MessageIDs)) - for i, id := range ret.MessageIDs { - message_ids_copy[i] = id - } - err = p.DB.Select(&reaccs, ` - select id, message_id, sender_id, sent_at, emoji - from chat_message_reactions - where message_id in (`+strings.Repeat("?,", len(ret.MessageIDs)-1)+`?) - `, message_ids_copy...) - if err != nil { - panic(err) - } - for _, reacc := range reaccs { - msg := ret.Messages[reacc.DMMessageID] - msg.Reactions[reacc.SenderID] = reacc - ret.Messages[reacc.DMMessageID] = msg - } - - // Images - var images []Image - err = p.DB.Select(&images, ` - select id, chat_message_id, width, height, remote_url, local_filename, is_downloaded - from chat_message_images - where chat_message_id in (`+strings.Repeat("?,", len(ret.MessageIDs)-1)+`?) - `, message_ids_copy...) - if err != nil { - panic(err) - } - for _, img := range images { - msg := ret.Messages[img.DMMessageID] - msg.Images = []Image{img} - ret.Messages[msg.ID] = msg - } - - // Videos - var videos []Video - err = p.DB.Select(&videos, ` - select id, chat_message_id, width, height, remote_url, local_filename, thumbnail_remote_url, thumbnail_local_filename, - duration, view_count, is_downloaded, is_blocked_by_dmca, is_gif - from chat_message_videos - where chat_message_id in (`+strings.Repeat("?,", len(ret.MessageIDs)-1)+`?) - `, message_ids_copy...) - if err != nil { - panic(err) - } - for _, vid := range videos { - println("asdfasfasdf") - msg := ret.Messages[vid.DMMessageID] - msg.Videos = []Video{vid} - ret.Messages[msg.ID] = msg - } - - // Urls - var urls []Url - err = p.DB.Select(&urls, ` - select chat_message_id, domain, text, short_text, title, description, creator_id, site_id, thumbnail_width, thumbnail_height, - thumbnail_remote_url, thumbnail_local_path, has_card, has_thumbnail, is_content_downloaded - from chat_message_urls - where chat_message_id in (`+strings.Repeat("?,", len(ret.MessageIDs)-1)+`?) - `, message_ids_copy...) - if err != nil { - panic(err) - } - for _, url := range urls { - msg := ret.Messages[url.DMMessageID] - msg.Urls = []Url{url} - ret.Messages[msg.ID] = msg - } - - // Fetch all embedded tweets - embedded_tweet_ids := []interface{}{} - for _, m := range ret.Messages { - if m.EmbeddedTweetID != 0 { - embedded_tweet_ids = append(embedded_tweet_ids, m.EmbeddedTweetID) - } - } - if len(embedded_tweet_ids) > 0 { - var embedded_tweets []Tweet - err = p.DB.Select(&embedded_tweets, ` - select `+TWEETS_ALL_SQL_FIELDS+` - from tweets - left join tombstone_types on tweets.tombstone_type = tombstone_types.rowid - left join likes on tweets.id = likes.tweet_id and likes.user_id = ? - where id in (`+strings.Repeat("?,", len(embedded_tweet_ids)-1)+`?)`, - append([]interface{}{UserID(0)}, embedded_tweet_ids...)...) - if err != nil { - panic(err) - } - for _, t := range embedded_tweets { - ret.Tweets[t.ID] = t - } - } - - // Fetch replied-to message previews - replied_message_ids := []interface{}{} - for _, m := range ret.Messages { - if m.InReplyToID != 0 { - // Don't clobber if it's already been fetched - if _, is_ok := ret.Messages[m.InReplyToID]; !is_ok { - replied_message_ids = append(replied_message_ids, m.InReplyToID) - } - } - } - if len(replied_message_ids) > 0 { - var replied_msgs []DMMessage - err = p.DB.Select(&replied_msgs, ` - select id, chat_room_id, sender_id, sent_at, request_id, text, in_reply_to_id, embedded_tweet_id - from chat_messages - where id in (`+strings.Repeat("?,", len(replied_message_ids)-1)+`?)`, - replied_message_ids...) - if err != nil { - panic(err) - } - for _, msg := range replied_msgs { - msg.Reactions = make(map[UserID]DMReaction) - ret.Messages[msg.ID] = msg - } - } - - p.fill_content(&ret.DMTrove.TweetTrove, UserID(0)) - } - + p.fill_dm_contents(&ret.DMTrove) return ret } + +// Fetch the chat participants and insert it into the DMChatRoom. Inserts user information +// into the DMTrove. +func (p Profile) fill_chat_room_participants(room *DMChatRoom, trove *DMTrove) { + var participants []struct { + DMChatParticipant + User + } + err := p.DB.Select(&participants, ` + select ` + CHAT_ROOM_PARTICIPANTS_ALL_SQL_FIELDS + `, `+USERS_ALL_SQL_FIELDS+` + from chat_room_participants join users on chat_room_participants.user_id = users.id + where chat_room_id = ? + `, room.ID) + if err != nil { + panic(err) + } + room.Participants = make(map[UserID]DMChatParticipant) + for _, p := range participants { + room.Participants[p.User.ID] = p.DMChatParticipant + trove.Users[p.User.ID] = p.User + } +} + +// Fetch reaccs, attachments/embeds and replied-to messages and add them to the DMTrove +func (p Profile) fill_dm_contents(trove *DMTrove) { + // Skip processing if there's no messages whomst'd've contents to fetch + if len(trove.Messages) == 0 { + return + } + + // Fetch all reaccs + var reaccs []DMReaction + message_ids := []interface{}{} + for _, msg := range trove.Messages { + message_ids = append(message_ids, msg.ID) + } + err := p.DB.Select(&reaccs, ` + select id, message_id, sender_id, sent_at, emoji + from chat_message_reactions + where message_id in (`+strings.Repeat("?,", len(trove.Messages)-1)+`?) + `, message_ids...) + if err != nil { + panic(err) + } + for _, reacc := range reaccs { + msg := trove.Messages[reacc.DMMessageID] + msg.Reactions[reacc.SenderID] = reacc + trove.Messages[reacc.DMMessageID] = msg + } + + // Images + var images []Image + err = p.DB.Select(&images, ` + select id, chat_message_id, width, height, remote_url, local_filename, is_downloaded + from chat_message_images + where chat_message_id in (`+strings.Repeat("?,", len(trove.Messages)-1)+`?) + `, message_ids...) + if err != nil { + panic(err) + } + for _, img := range images { + msg := trove.Messages[img.DMMessageID] + msg.Images = []Image{img} + trove.Messages[msg.ID] = msg + } + + // Videos + var videos []Video + err = p.DB.Select(&videos, ` + select id, chat_message_id, width, height, remote_url, local_filename, thumbnail_remote_url, thumbnail_local_filename, + duration, view_count, is_downloaded, is_blocked_by_dmca, is_gif + from chat_message_videos + where chat_message_id in (`+strings.Repeat("?,", len(trove.Messages)-1)+`?) + `, message_ids...) + if err != nil { + panic(err) + } + for _, vid := range videos { + println("asdfasfasdf") + msg := trove.Messages[vid.DMMessageID] + msg.Videos = []Video{vid} + trove.Messages[msg.ID] = msg + } + + // Urls + var urls []Url + err = p.DB.Select(&urls, ` + select chat_message_id, domain, text, short_text, title, description, creator_id, site_id, thumbnail_width, thumbnail_height, + thumbnail_remote_url, thumbnail_local_path, has_card, has_thumbnail, is_content_downloaded + from chat_message_urls + where chat_message_id in (`+strings.Repeat("?,", len(trove.Messages)-1)+`?) + `, message_ids...) + if err != nil { + panic(err) + } + for _, url := range urls { + msg := trove.Messages[url.DMMessageID] + msg.Urls = []Url{url} + trove.Messages[msg.ID] = msg + } + + // Fetch all embedded tweets + embedded_tweet_ids := []interface{}{} + for _, m := range trove.Messages { + if m.EmbeddedTweetID != 0 { + embedded_tweet_ids = append(embedded_tweet_ids, m.EmbeddedTweetID) + } + } + if len(embedded_tweet_ids) > 0 { + var embedded_tweets []Tweet + err = p.DB.Select(&embedded_tweets, ` + select `+TWEETS_ALL_SQL_FIELDS+` + from tweets + left join tombstone_types on tweets.tombstone_type = tombstone_types.rowid + left join likes on tweets.id = likes.tweet_id and likes.user_id = ? + where id in (`+strings.Repeat("?,", len(embedded_tweet_ids)-1)+`?)`, + append([]interface{}{UserID(0)}, embedded_tweet_ids...)...) + if err != nil { + panic(err) + } + for _, t := range embedded_tweets { + trove.Tweets[t.ID] = t + } + } + + // Fetch replied-to message previews + replied_message_ids := []interface{}{} + for _, m := range trove.Messages { + if m.InReplyToID != 0 { + // Don't clobber if it's already been fetched + if _, is_ok := trove.Messages[m.InReplyToID]; !is_ok { + replied_message_ids = append(replied_message_ids, m.InReplyToID) + } + } + } + if len(replied_message_ids) > 0 { + var replied_msgs []DMMessage + err = p.DB.Select(&replied_msgs, ` + select `+CHAT_MESSAGES_ALL_SQL_FIELDS+` + from chat_messages + where id in (`+strings.Repeat("?,", len(replied_message_ids)-1)+`?)`, + replied_message_ids...) + if err != nil { + panic(err) + } + for _, msg := range replied_msgs { + msg.Reactions = make(map[UserID]DMReaction) + trove.Messages[msg.ID] = msg + } + } + + p.fill_content(&trove.TweetTrove, UserID(0)) +} + + + } +} diff --git a/pkg/persistence/dm_queries_test.go b/pkg/persistence/dm_queries_test.go index f68ac9d..1b170ce 100644 --- a/pkg/persistence/dm_queries_test.go +++ b/pkg/persistence/dm_queries_test.go @@ -173,13 +173,16 @@ func TestGetChatRoomsPreview(t *testing.T) { require.True(is_ok) assert.Equal(msg.Text, "This looks pretty good huh") + // Participants require.Len(room.Participants, 2) for _, user_id := range []UserID{1458284524761075714, 1488963321701171204} { participant, is_ok := room.Participants[user_id] require.True(is_ok) assert.Equal(participant.IsChatSettingsValid, participant.UserID == 1488963321701171204) - _, is_ok = chat_view.Users[user_id] + u, is_ok := chat_view.Users[user_id] require.True(is_ok) + assert.Equal(u.ID, user_id) + assert.NotEqual(u.Handle, "") // Make sure it's filled out } } @@ -202,8 +205,10 @@ func TestGetChatRoomContents(t *testing.T) { participant, is_ok := room.Participants[user_id] require.True(is_ok) assert.Equal(participant.IsChatSettingsValid, participant.UserID == 1488963321701171204) - _, is_ok = chat_view.Users[user_id] + u, is_ok := chat_view.Users[user_id] require.True(is_ok) + assert.Equal(u.ID, user_id) + assert.NotEqual(u.Handle, "") // Make sure it's filled out } // Messages