diff --git a/internal/webserver/server_test.go b/internal/webserver/server_test.go index 89d3a17..bd8a5a1 100644 --- a/internal/webserver/server_test.go +++ b/internal/webserver/server_test.go @@ -570,7 +570,8 @@ func TestLists(t *testing.T) { // Messages // -------- -func TestMessages(t *testing.T) { +// Loading the index page should work if you're logged in +func TestMessagesIndexPage(t *testing.T) { assert := assert.New(t) require := require.New(t) @@ -587,12 +588,23 @@ func TestMessages(t *testing.T) { require.NoError(err) assert.Len(cascadia.QueryAll(root, selector(".chat-list .chat")), 2) assert.Len(cascadia.QueryAll(root, selector(".chat-view .dm-message-and-reacts-container")), 0) // No messages until you click on one +} + +// Open a chat room +func TestMessagesRoom(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + // Boilerplate for setting an active user + app := webserver.NewApp(profile) + app.IsScrapingDisabled = true + app.ActiveUser = scraper.User{ID: 1488963321701171204, Handle: "Offline_Twatter"} // Simulate a login // Chat detail - recorder = httptest.NewRecorder() + recorder := httptest.NewRecorder() app.ServeHTTP(recorder, httptest.NewRequest("GET", "/messages/1488963321701171204-1178839081222115328", nil)) - resp = recorder.Result() - root, err = html.Parse(resp.Body) + resp := recorder.Result() + root, err := html.Parse(resp.Body) require.NoError(err) assert.Len(cascadia.QueryAll(root, selector(".chat-list .chat")), 2) // Chat list still renders assert.Len(cascadia.QueryAll(root, selector("#chat-view .dm-message-and-reacts-container")), 5) diff --git a/pkg/persistence/dm_queries.go b/pkg/persistence/dm_queries.go index 00fc8cc..3a3f6aa 100644 --- a/pkg/persistence/dm_queries.go +++ b/pkg/persistence/dm_queries.go @@ -36,7 +36,7 @@ func (p Profile) SaveChatRoom(r DMChatRoom) error { last_read_event_id, is_chat_settings_valid, is_notifications_disabled, - is_mention_notifications_disabled, + is_mention_notifications_disabled, is_read_only, is_trusted, is_muted, @@ -47,12 +47,12 @@ func (p Profile) SaveChatRoom(r DMChatRoom) error { :last_read_event_id, :is_chat_settings_valid, :is_notifications_disabled, - :is_mention_notifications_disabled, + :is_mention_notifications_disabled, :is_read_only, :is_trusted, :is_muted, :status) - on conflict do update + on conflict do update set last_read_event_id=:last_read_event_id, is_chat_settings_valid=:is_chat_settings_valid, is_notifications_disabled=:is_notifications_disabled, @@ -73,10 +73,10 @@ func (p Profile) SaveChatRoom(r DMChatRoom) error { func (p Profile) GetChatRoom(id DMChatRoomID) (ret DMChatRoom, err error) { err = p.DB.Get(&ret, ` - select id, type, last_messaged_at, is_nsfw, created_at, created_by_user_id, name, avatar_image_remote_url, avatar_image_local_path - from chat_rooms - where id = ? - `, id) + select id, type, last_messaged_at, is_nsfw, created_at, created_by_user_id, name, avatar_image_remote_url, avatar_image_local_path + from chat_rooms + where id = ? + `, id) if err != nil { return ret, fmt.Errorf("Error getting chat room (%s):\n %w", id, err) } @@ -187,9 +187,9 @@ func (p Profile) GetChatRoomsPreview(id UserID) DMChatView { var msg DMMessage q, args, err := sqlx.Named(` select id, chat_room_id, sender_id, sent_at, request_id, text, in_reply_to_id, embedded_tweet_id - from chat_messages - where chat_room_id = :room_id - and sent_at = (select max(sent_at) from chat_messages where chat_room_id = :room_id) + from chat_messages + where chat_room_id = :room_id + and sent_at = (select max(sent_at) from chat_messages where chat_room_id = :room_id) `, struct { ID DMChatRoomID `db:"room_id"` }{ID: room.ID}) @@ -213,8 +213,8 @@ func (p Profile) GetChatRoomsPreview(id UserID) DMChatView { err = p.DB.Select(&participants, ` select chat_room_id, user_id, last_read_event_id, is_chat_settings_valid, is_notifications_disabled, is_mention_notifications_disabled, is_read_only, is_trusted, is_muted, status, `+USERS_ALL_SQL_FIELDS+` - from chat_room_participants join users on chat_room_participants.user_id = users.id - where chat_room_id = ? + from chat_room_participants join users on chat_room_participants.user_id = users.id + where chat_room_id = ? `, room.ID) if err != nil { panic(err) @@ -240,8 +240,8 @@ func (p Profile) GetChatRoomContents(id DMChatRoomID) DMChatView { err := p.DB.Get(&room, ` select id, type, last_messaged_at, is_nsfw, created_at, created_by_user_id, name, avatar_image_remote_url, avatar_image_local_path - from chat_rooms - where id = ? + from chat_rooms + where id = ? `, id) if err != nil { panic(err) @@ -256,8 +256,8 @@ func (p Profile) GetChatRoomContents(id DMChatRoomID) DMChatView { err = p.DB.Select(&participants, ` select chat_room_id, user_id, last_read_event_id, is_chat_settings_valid, is_notifications_disabled, is_mention_notifications_disabled, is_read_only, is_trusted, is_muted, status, `+USERS_ALL_SQL_FIELDS+` - from chat_room_participants join users on chat_room_participants.user_id = users.id - where chat_room_id = ? + from chat_room_participants join users on chat_room_participants.user_id = users.id + where chat_room_id = ? `, room.ID) if err != nil { panic(err) @@ -272,10 +272,10 @@ func (p Profile) GetChatRoomContents(id DMChatRoomID) DMChatView { var msgs []DMMessage err = p.DB.Select(&msgs, ` select id, chat_room_id, sender_id, sent_at, request_id, text, in_reply_to_id, embedded_tweet_id - from chat_messages - where chat_room_id = :room_id - order by sent_at desc - limit 50 + from chat_messages + where chat_room_id = ? + order by sent_at desc + limit 50 `, room.ID) if err != nil { panic(err) @@ -288,56 +288,62 @@ func (p Profile) GetChatRoomContents(id DMChatRoomID) DMChatView { } // Set last message ID on chat room - room.LastMessageID = ret.MessageIDs[len(ret.MessageIDs)-1] + if len(ret.MessageIDs) > 0 { + // If there's no messages, it should be OK to have LastMessageID = 0, since this is only used + // to generate previews + room.LastMessageID = ret.MessageIDs[len(ret.MessageIDs)-1] + } // Put the room in the Trove ret.Rooms[room.ID] = room - // Fetch all reaccs - var reaccs []DMReaction - message_ids_copy := make([]interface{}, len(ret.MessageIDs)) - for i, id := range ret.MessageIDs { - message_ids_copy[i] = id - } - err = p.DB.Select(&reaccs, ` - select id, message_id, sender_id, sent_at, emoji - from chat_message_reactions - where message_id in (`+strings.Repeat("?,", len(ret.MessageIDs)-1)+`?) - `, message_ids_copy...) - if err != nil { - panic(err) - } - for _, reacc := range reaccs { - msg := ret.Messages[reacc.DMMessageID] - msg.Reactions[reacc.SenderID] = reacc - ret.Messages[reacc.DMMessageID] = msg - } - - // Fetch all embedded tweets - embedded_tweet_ids := []interface{}{} - for _, m := range ret.Messages { - if m.EmbeddedTweetID != 0 { - embedded_tweet_ids = append(embedded_tweet_ids, m.EmbeddedTweetID) + if len(ret.MessageIDs) > 0 { + // Fetch all reaccs + var reaccs []DMReaction + message_ids_copy := make([]interface{}, len(ret.MessageIDs)) + for i, id := range ret.MessageIDs { + message_ids_copy[i] = id } - } - if len(embedded_tweet_ids) > 0 { - var embedded_tweets []Tweet - err = p.DB.Select(&embedded_tweets, ` - select `+TWEETS_ALL_SQL_FIELDS+` - from tweets - left join tombstone_types on tweets.tombstone_type = tombstone_types.rowid - left join likes on tweets.id = likes.tweet_id and likes.user_id = ? - where id in (`+strings.Repeat("?,", len(embedded_tweet_ids)-1)+`?)`, - append([]interface{}{UserID(0)}, embedded_tweet_ids...)...) + err = p.DB.Select(&reaccs, ` + select id, message_id, sender_id, sent_at, emoji + from chat_message_reactions + where message_id in (`+strings.Repeat("?,", len(ret.MessageIDs)-1)+`?) + `, message_ids_copy...) if err != nil { panic(err) } - for _, t := range embedded_tweets { - ret.Tweets[t.ID] = t + for _, reacc := range reaccs { + msg := ret.Messages[reacc.DMMessageID] + msg.Reactions[reacc.SenderID] = reacc + ret.Messages[reacc.DMMessageID] = msg } - } - p.fill_content(&ret.DMTrove.TweetTrove, UserID(0)) + // Fetch all embedded tweets + embedded_tweet_ids := []interface{}{} + for _, m := range ret.Messages { + if m.EmbeddedTweetID != 0 { + embedded_tweet_ids = append(embedded_tweet_ids, m.EmbeddedTweetID) + } + } + if len(embedded_tweet_ids) > 0 { + var embedded_tweets []Tweet + err = p.DB.Select(&embedded_tweets, ` + select `+TWEETS_ALL_SQL_FIELDS+` + from tweets + left join tombstone_types on tweets.tombstone_type = tombstone_types.rowid + left join likes on tweets.id = likes.tweet_id and likes.user_id = ? + where id in (`+strings.Repeat("?,", len(embedded_tweet_ids)-1)+`?)`, + append([]interface{}{UserID(0)}, embedded_tweet_ids...)...) + if err != nil { + panic(err) + } + for _, t := range embedded_tweets { + ret.Tweets[t.ID] = t + } + } + + p.fill_content(&ret.DMTrove.TweetTrove, UserID(0)) + } return ret }