diff --git a/pkg/persistence/compound_queries.go b/pkg/persistence/compound_queries.go index 7639025..0f23ed0 100644 --- a/pkg/persistence/compound_queries.go +++ b/pkg/persistence/compound_queries.go @@ -197,6 +197,7 @@ type TweetDetailView struct { TweetTrove ParentIDs []TweetID MainTweetID TweetID + ThreadIDs []TweetID ReplyChains [][]TweetID } @@ -247,6 +248,39 @@ func (p Profile) GetTweetDetail(id TweetID, current_user_id UserID) (TweetDetail } } + // Threaded replies + stmt, err = p.DB.Preparex(` + with recursive thread_replies(id) as ( + values(?) + union all + select tweets.id from tweets + join thread_replies on tweets.in_reply_to_id = thread_replies.id + where tweets.user_id = ? + ) + + select ` + TWEETS_ALL_SQL_FIELDS + ` + from tweets + left join tombstone_types on tweets.tombstone_type = tombstone_types.rowid + left join likes on tweets.id = likes.tweet_id and likes.user_id = ? + inner join thread_replies on tweets.id = thread_replies.id + order by id asc`) + + if err != nil { + panic(err) + } + defer stmt.Close() + var reply_thread []Tweet + err = stmt.Select(&reply_thread, id, ret.Tweets[ret.MainTweetID].UserID, current_user_id) + if err != nil { + panic(err) + } + for _, tweet := range reply_thread { + ret.Tweets[tweet.ID] = tweet + if tweet.ID != ret.MainTweetID { + ret.ThreadIDs = append(ret.ThreadIDs, tweet.ID) + } + } + var replies []Tweet stmt, err = p.DB.Preparex( `select ` + TWEETS_ALL_SQL_FIELDS + ` @@ -254,16 +288,22 @@ func (p Profile) GetTweetDetail(id TweetID, current_user_id UserID) (TweetDetail left join tombstone_types on tweets.tombstone_type = tombstone_types.rowid left join likes on tweets.id = likes.tweet_id and likes.user_id = ? where in_reply_to_id = ? + and id != ? -- skip the main Thread if there is one order by num_likes desc limit 50`) if err != nil { panic(err) } defer stmt.Close() - err = stmt.Select(&replies, current_user_id, id) + thread_top_id := TweetID(0) + if len(ret.ThreadIDs) > 0 { + thread_top_id = ret.ThreadIDs[0] + } + err = stmt.Select(&replies, current_user_id, id, thread_top_id) if err != nil { panic(err) } + if len(replies) > 0 { reply_1_ids := []interface{}{} for _, r := range replies { diff --git a/pkg/persistence/compound_queries_test.go b/pkg/persistence/compound_queries_test.go index a83c8b3..f24f2bc 100644 --- a/pkg/persistence/compound_queries_test.go +++ b/pkg/persistence/compound_queries_test.go @@ -256,3 +256,39 @@ func TestTweetDetailWithParents(t *testing.T) { require.Len(tweet_detail.ReplyChains, 0) } + +func TestTweetDetailWithThread(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + profile, err := persistence.LoadProfile("../../sample_data/profile") + require.NoError(err) + + tweet_detail, err := profile.GetTweetDetail(TweetID(1698762403163304110), UserID(0)) + require.NoError(err) + + assert.Len(tweet_detail.Retweets, 0) + + assert.Len(tweet_detail.Tweets, 11) + + expected_thread := []TweetID{ + 1698762405268902217, 1698762406929781161, 1698762408410390772, 1698762409974857832, + 1698762411853971851, 1698762413393236329, 1698762414957666416, + } + + assert.Equal(expected_thread, tweet_detail.ThreadIDs) + + for _, id := range expected_thread { + _, is_ok := tweet_detail.Tweets[id] + assert.True(is_ok) + } + + assert.Len(tweet_detail.Users, 2) + _, is_ok := tweet_detail.Users[1458284524761075714] + assert.True(is_ok) + _, is_ok = tweet_detail.Users[534463724] + assert.True(is_ok) + + require.Len(tweet_detail.ReplyChains, 1) // Should not include the Thread replies + assert.Equal(tweet_detail.ReplyChains[0][0], TweetID(1698792233619562866)) +}