Add way to scrape unread count; add db query to count unread notifs since a given sort index

This commit is contained in:
Alessio 2024-09-02 15:02:27 -07:00
parent 665e6a31dd
commit 9c7174a84a
6 changed files with 46 additions and 5 deletions

View File

@ -599,7 +599,7 @@ func send_dm_reacc(room_id string, in_reply_to_id int, reacc string) {
} }
func get_notifications(how_many int) { func get_notifications(how_many int) {
trove, err := api.GetNotifications(how_many) trove, _, err := api.GetNotifications(how_many)
if err != nil && !errors.Is(err, scraper.END_OF_FEED) { if err != nil && !errors.Is(err, scraper.END_OF_FEED) {
panic(err) panic(err)
} }

View File

@ -115,3 +115,12 @@ func (p Profile) CheckNotificationScrapesNeeded(trove TweetTrove) []Notification
} }
return ret return ret
} }
func (p Profile) GetUnreadNotificationsCount(since_sort_index int64) int {
var ret int
err := p.DB.Get(&ret, `select count(*) from notifications where sort_index > ?`, since_sort_index)
if err != nil {
panic(err)
}
return ret
}

View File

@ -4,6 +4,10 @@ import (
"testing" "testing"
"github.com/go-test/deep" "github.com/go-test/deep"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/offline-twitter/twitter_offline_engine/pkg/persistence"
) )
func TestSaveAndLoadNotification(t *testing.T) { func TestSaveAndLoadNotification(t *testing.T) {
@ -20,3 +24,14 @@ func TestSaveAndLoadNotification(t *testing.T) {
t.Error(diff) t.Error(diff)
} }
} }
func TestGetUnreadNotificationsCount(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
profile, err := persistence.LoadProfile("../../sample_data/profile")
require.NoError(err)
unread_notifs_count := profile.GetUnreadNotificationsCount(1724372973735)
assert.Equal(2, unread_notifs_count)
}

View File

@ -399,6 +399,9 @@ type TweetResponse struct {
ReplaceEntry struct { ReplaceEntry struct {
Entry Entry Entry Entry
} `json:"replaceEntry"` } `json:"replaceEntry"`
MarkEntriesUnreadGreaterThanSortIndex struct {
SortIndex int64 `json:"sortIndex,string"`
} `json:"markEntriesUnreadGreaterThanSortIndex"`
} `json:"instructions"` } `json:"instructions"`
} `json:"timeline"` } `json:"timeline"`
} }

View File

@ -29,10 +29,11 @@ func (api *API) GetNotificationsPage(cursor string) (TweetResponse, error) {
return result, err return result, err
} }
func (api *API) GetNotifications(how_many int) (TweetTrove, error) { // Second return value is last unread notification sort-index. `0` will be returned if there is none.
func (api *API) GetNotifications(how_many int) (TweetTrove, int64, error) {
resp, err := api.GetNotificationsPage("") resp, err := api.GetNotificationsPage("")
if err != nil { if err != nil {
return TweetTrove{}, err return TweetTrove{}, 0, err
} }
trove, err := resp.ToTweetTroveAsNotifications(api.UserID) trove, err := resp.ToTweetTroveAsNotifications(api.UserID)
if err != nil { if err != nil {
@ -45,7 +46,7 @@ func (api *API) GetNotifications(how_many int) (TweetTrove, error) {
log.Warnf("Rate limited!") log.Warnf("Rate limited!")
break break
} else if err != nil { } else if err != nil {
return TweetTrove{}, err return TweetTrove{}, 0, err
} }
if resp.IsEndOfFeed() { if resp.IsEndOfFeed() {
log.Infof("End of feed!") log.Infof("End of feed!")
@ -59,7 +60,17 @@ func (api *API) GetNotifications(how_many int) (TweetTrove, error) {
trove.MergeWith(new_trove) trove.MergeWith(new_trove)
} }
return trove, nil return trove, resp.CheckUnreadNotifications(), nil
}
// Check a Notifications result for unread notifications. Returns `0` if there are none.
func (t TweetResponse) CheckUnreadNotifications() int64 {
for _, instr := range t.Timeline.Instructions {
if instr.MarkEntriesUnreadGreaterThanSortIndex.SortIndex != 0 {
return instr.MarkEntriesUnreadGreaterThanSortIndex.SortIndex
}
}
return 0
} }
func (api *API) GetNotificationDetailForAll(trove TweetTrove, to_scrape []NotificationID) (TweetTrove, error) { func (api *API) GetNotificationDetailForAll(trove TweetTrove, to_scrape []NotificationID) (TweetTrove, error) {

View File

@ -152,6 +152,9 @@ func TestParseNotificationsPage(t *testing.T) {
assert.True(is_ok) assert.True(is_ok)
} }
// Test unread notifs
assert.Equal(int64(1724566381021), resp.CheckUnreadNotifications())
// Test cursor-bottom // Test cursor-bottom
bottom_cursor := resp.GetCursor() bottom_cursor := resp.GetCursor()
assert.Equal("DAACDAABCgABFKncQJGVgAQIAAIAAAABCAADSQ3bEQgABIsN6BEACwACAAAAC0FaRkxRSXFNLTJJAAA", bottom_cursor) assert.Equal("DAACDAABCgABFKncQJGVgAQIAAIAAAABCAADSQ3bEQgABIsN6BEACwACAAAAC0FaRkxRSXFNLTJJAAA", bottom_cursor)