From 1c67a9039d3fa42e25f8d3185f5cf00d189d5b48 Mon Sep 17 00:00:00 2001 From: Alessio Date: Mon, 16 Aug 2021 20:37:35 -0700 Subject: [PATCH] Add retweet queries --- persistence/retweet_queries.go | 49 +++++++++++++++++++++++++++++ persistence/retweet_queries_test.go | 37 ++++++++++++++++++++++ persistence/schema.sql | 2 +- persistence/utils_test.go | 31 ++++++++++++++++++ 4 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 persistence/retweet_queries.go create mode 100644 persistence/retweet_queries_test.go diff --git a/persistence/retweet_queries.go b/persistence/retweet_queries.go new file mode 100644 index 0000000..1799102 --- /dev/null +++ b/persistence/retweet_queries.go @@ -0,0 +1,49 @@ +package persistence + +import ( + "time" + + "offline_twitter/scraper" +) + +/** + * Save a Retweet. Do nothing if it already exists, because none of its parameters are modifiable. + */ +func (p Profile) SaveRetweet(r scraper.Retweet) error { + _, err := p.DB.Exec(` + insert into retweets (retweet_id, tweet_id, retweeted_by, retweeted_at) + values (?, ?, ?, ?) + on conflict do nothing + `, + r.RetweetID, r.TweetID, r.RetweetedByID, r.RetweetedAt.Unix(), + ) + return err +} + + +/** + * Retrieve a Retweet by ID + */ +func (p Profile) GetRetweetById(id scraper.TweetID) (scraper.Retweet, error) { + stmt, err := p.DB.Prepare(` + select retweet_id, tweet_id, retweeted_by, retweeted_at + from retweets + where retweet_id = ? + `) + if err != nil { + return scraper.Retweet{}, err + } + defer stmt.Close() + + var r scraper.Retweet + var retweeted_at int + + row := stmt.QueryRow(id) + err = row.Scan(&r.RetweetID, &r.TweetID, &r.RetweetedByID, &retweeted_at) + if err != nil { + return scraper.Retweet{}, err + } + + r.RetweetedAt = time.Unix(int64(retweeted_at), 0) + return r, nil +} diff --git a/persistence/retweet_queries_test.go b/persistence/retweet_queries_test.go new file mode 100644 index 0000000..6ad04fe --- /dev/null +++ b/persistence/retweet_queries_test.go @@ -0,0 +1,37 @@ +package persistence_test + +import ( + "testing" + + "github.com/go-test/deep" +) + + +func TestSaveAndLoadRetweet(t *testing.T) { + profile_path := "test_profiles/TestRetweetQueries" + profile := create_or_load_profile(profile_path) + + tweet := create_dummy_tweet() + err := profile.SaveTweet(tweet) + if err != nil { + t.Fatalf("Failed to save the tweet: %s", err.Error()) + } + + rt := create_dummy_retweet(tweet.ID) + + // Save the Retweet + err = profile.SaveRetweet(rt) + if err != nil { + t.Fatalf("Failed to save the retweet: %s", err.Error()) + } + + // Reload the Retweet + new_rt, err := profile.GetRetweetById(rt.RetweetID) + if err != nil { + t.Fatalf("Failed to load the retweet: %s", err.Error()) + } + + if diff := deep.Equal(rt, new_rt); diff != nil { + t.Error(diff) + } +} diff --git a/persistence/schema.sql b/persistence/schema.sql index 5d32825..70207d4 100644 --- a/persistence/schema.sql +++ b/persistence/schema.sql @@ -44,7 +44,7 @@ create table tweets (rowid integer primary key, ); create table retweets(rowid integer primary key, - retweet_id integer not null, + retweet_id integer not null unique, tweet_id integer not null, retweeted_by integer not null, retweeted_at integer not null, diff --git a/persistence/utils_test.go b/persistence/utils_test.go index 51a6f46..3bb24af 100644 --- a/persistence/utils_test.go +++ b/persistence/utils_test.go @@ -26,6 +26,10 @@ func create_or_load_profile(profile_path string) persistence.Profile { panic(err) } err = profile.SaveTweet(create_stable_tweet()) + if err != nil { + panic(err) + } + err = profile.SaveRetweet(create_stable_retweet()) } else { profile, err = persistence.LoadProfile(profile_path) } @@ -114,6 +118,18 @@ func create_stable_tweet() scraper.Tweet { } } +/** + * Create a stable retweet with a fixed ID and parameters + */ +func create_stable_retweet() scraper.Retweet { + retweet_id := scraper.TweetID(-1) + return scraper.Retweet{ + RetweetID: retweet_id, + TweetID: -1, + RetweetedByID: -1, + RetweetedAt: time.Unix(20000000, 0), + } +} /** * Create a new user with a random ID and handle @@ -173,3 +189,18 @@ func create_dummy_tweet() scraper.Tweet { Hashtags: []string{"hash1", "hash2"}, } } + +/** + * Create a new retweet with a random ID for a given TweetID + */ +func create_dummy_retweet(tweet_id scraper.TweetID) scraper.Retweet { + rand.Seed(time.Now().UnixNano()) + retweet_id := scraper.TweetID(rand.Int()) + + return scraper.Retweet{ + RetweetID: retweet_id, + TweetID: tweet_id, + RetweetedByID: -1, + RetweetedAt: time.Unix(20000000, 0), + } +}