diff --git a/cmd/twitter/main.go b/cmd/twitter/main.go index 54dbd23..a9a0fc9 100644 --- a/cmd/twitter/main.go +++ b/cmd/twitter/main.go @@ -599,7 +599,7 @@ func send_dm_reacc(room_id string, in_reply_to_id int, reacc string) { } func get_notifications(how_many int) { - trove, err := api.GetNotifications(how_many) + trove, _, err := api.GetNotifications(how_many) if err != nil && !errors.Is(err, scraper.END_OF_FEED) { panic(err) } diff --git a/pkg/persistence/notification_queries.go b/pkg/persistence/notification_queries.go index ea14c2f..2280104 100644 --- a/pkg/persistence/notification_queries.go +++ b/pkg/persistence/notification_queries.go @@ -115,3 +115,12 @@ func (p Profile) CheckNotificationScrapesNeeded(trove TweetTrove) []Notification } return ret } + +func (p Profile) GetUnreadNotificationsCount(since_sort_index int64) int { + var ret int + err := p.DB.Get(&ret, `select count(*) from notifications where sort_index > ?`, since_sort_index) + if err != nil { + panic(err) + } + return ret +} diff --git a/pkg/persistence/notification_queries_test.go b/pkg/persistence/notification_queries_test.go index ffe2af2..235a431 100644 --- a/pkg/persistence/notification_queries_test.go +++ b/pkg/persistence/notification_queries_test.go @@ -4,6 +4,10 @@ import ( "testing" "github.com/go-test/deep" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/offline-twitter/twitter_offline_engine/pkg/persistence" ) func TestSaveAndLoadNotification(t *testing.T) { @@ -20,3 +24,14 @@ func TestSaveAndLoadNotification(t *testing.T) { t.Error(diff) } } + +func TestGetUnreadNotificationsCount(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + profile, err := persistence.LoadProfile("../../sample_data/profile") + require.NoError(err) + + unread_notifs_count := profile.GetUnreadNotificationsCount(1724372973735) + assert.Equal(2, unread_notifs_count) +} diff --git a/pkg/scraper/api_types.go b/pkg/scraper/api_types.go index b5d94a0..28a5bb5 100644 --- a/pkg/scraper/api_types.go +++ b/pkg/scraper/api_types.go @@ -399,6 +399,9 @@ type TweetResponse struct { ReplaceEntry struct { Entry Entry } `json:"replaceEntry"` + MarkEntriesUnreadGreaterThanSortIndex struct { + SortIndex int64 `json:"sortIndex,string"` + } `json:"markEntriesUnreadGreaterThanSortIndex"` } `json:"instructions"` } `json:"timeline"` } diff --git a/pkg/scraper/api_types_notifications.go b/pkg/scraper/api_types_notifications.go index 57c573a..0f035f1 100644 --- a/pkg/scraper/api_types_notifications.go +++ b/pkg/scraper/api_types_notifications.go @@ -29,10 +29,11 @@ func (api *API) GetNotificationsPage(cursor string) (TweetResponse, error) { return result, err } -func (api *API) GetNotifications(how_many int) (TweetTrove, error) { +// Second return value is last unread notification sort-index. `0` will be returned if there is none. +func (api *API) GetNotifications(how_many int) (TweetTrove, int64, error) { resp, err := api.GetNotificationsPage("") if err != nil { - return TweetTrove{}, err + return TweetTrove{}, 0, err } trove, err := resp.ToTweetTroveAsNotifications(api.UserID) if err != nil { @@ -45,7 +46,7 @@ func (api *API) GetNotifications(how_many int) (TweetTrove, error) { log.Warnf("Rate limited!") break } else if err != nil { - return TweetTrove{}, err + return TweetTrove{}, 0, err } if resp.IsEndOfFeed() { log.Infof("End of feed!") @@ -59,7 +60,17 @@ func (api *API) GetNotifications(how_many int) (TweetTrove, error) { trove.MergeWith(new_trove) } - return trove, nil + return trove, resp.CheckUnreadNotifications(), nil +} + +// Check a Notifications result for unread notifications. Returns `0` if there are none. +func (t TweetResponse) CheckUnreadNotifications() int64 { + for _, instr := range t.Timeline.Instructions { + if instr.MarkEntriesUnreadGreaterThanSortIndex.SortIndex != 0 { + return instr.MarkEntriesUnreadGreaterThanSortIndex.SortIndex + } + } + return 0 } func (api *API) GetNotificationDetailForAll(trove TweetTrove, to_scrape []NotificationID) (TweetTrove, error) { diff --git a/pkg/scraper/api_types_notifications_test.go b/pkg/scraper/api_types_notifications_test.go index 33093e7..24ef1fe 100644 --- a/pkg/scraper/api_types_notifications_test.go +++ b/pkg/scraper/api_types_notifications_test.go @@ -152,6 +152,9 @@ func TestParseNotificationsPage(t *testing.T) { assert.True(is_ok) } + // Test unread notifs + assert.Equal(int64(1724566381021), resp.CheckUnreadNotifications()) + // Test cursor-bottom bottom_cursor := resp.GetCursor() assert.Equal("DAACDAABCgABFKncQJGVgAQIAAIAAAABCAADSQ3bEQgABIsN6BEACwACAAAAC0FaRkxRSXFNLTJJAAA", bottom_cursor)