diff --git a/cmd/tests.sh b/cmd/tests.sh index 2ecb84a..98bbf0e 100755 --- a/cmd/tests.sh +++ b/cmd/tests.sh @@ -184,6 +184,12 @@ test $(find link_preview_images | wc -l) = $initial_link_preview_images_count # tw search "from:michaelmalice constitution" test $(sqlite3 twitter.db "select count(*) from tweets where user_id = 44067298 and text like '%constitution%'") -gt "30" # Not sure exactly how many +tw fetch_tweet 1465534109573390348 +test $(sqlite3 twitter.db "select count(*) from polls where tweet_id = 1465534109573390348") = "1" +test "$(sqlite3 twitter.db "select choice1, choice2, choice3, choice4 from polls where tweet_id = 1465534109573390348")" = "Tribal armband|Marijuana leaf|Butterfly|Maple leaf" +test "$(sqlite3 twitter.db "select choice1_votes, choice2_votes, choice3_votes, choice4_votes from polls where tweet_id = 1465534109573390348")" = "1593|624|778|1138" + + # TODO: Maybe this file should be broken up into multiple test scripts echo -e "\033[32mAll tests passed. Finished successfully.\033[0m" diff --git a/cmd/twitter/main.go b/cmd/twitter/main.go index 8a9bf20..2606150 100644 --- a/cmd/twitter/main.go +++ b/cmd/twitter/main.go @@ -171,7 +171,7 @@ func fetch_tweet_conversation(tweet_identifier string) { for _, t := range tweets { err = profile.SaveTweet(t) if err != nil { - die("Error saving tweet: " + err.Error(), false, 4) + die(fmt.Sprintf("Error saving tweet (id %d): %s", t.ID, err.Error()), false, 4) } err = profile.DownloadTweetContentFor(&t) if err != nil { diff --git a/persistence/media_queries.go b/persistence/media_queries.go index 1224c8c..d40f0b3 100644 --- a/persistence/media_queries.go +++ b/persistence/media_queries.go @@ -65,8 +65,8 @@ func (p Profile) SaveUrl(url scraper.Url) error { */ func (p Profile) SavePoll(poll scraper.Poll) error { _, err := p.DB.Exec(` - insert into polls (tweet_id, num_choices, choice1, choice1_votes, choice2, choice2_votes, choice3, choice3_votes, choice4, choice4_votes, voting_duration, voting_ends_at, last_scraped_at) - values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + insert into polls (id, tweet_id, num_choices, choice1, choice1_votes, choice2, choice2_votes, choice3, choice3_votes, choice4, choice4_votes, voting_duration, voting_ends_at, last_scraped_at) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) on conflict do update set choice1_votes=?, choice2_votes=?, @@ -74,7 +74,7 @@ func (p Profile) SavePoll(poll scraper.Poll) error { choice4_votes=?, last_scraped_at=? `, - poll.TweetID, poll.NumChoices, poll.Choice1, poll.Choice1_Votes, poll.Choice2, poll.Choice2_Votes, poll.Choice3, poll.Choice3_Votes, poll.Choice4, poll.Choice4_Votes, poll.VotingDuration, poll.VotingEndsAt.Unix(), poll.LastUpdatedAt.Unix(), + poll.ID, poll.TweetID, poll.NumChoices, poll.Choice1, poll.Choice1_Votes, poll.Choice2, poll.Choice2_Votes, poll.Choice3, poll.Choice3_Votes, poll.Choice4, poll.Choice4_Votes, poll.VotingDuration, poll.VotingEndsAt.Unix(), poll.LastUpdatedAt.Unix(), poll.Choice1_Votes, poll.Choice2_Votes, poll.Choice3_Votes, poll.Choice4_Votes, poll.LastUpdatedAt.Unix(), ) return err @@ -162,7 +162,7 @@ func (p Profile) GetUrlsForTweet(t scraper.Tweet) (urls []scraper.Url, err error * Get the list of Polls for a Tweet */ func (p Profile) GetPollsForTweet(t scraper.Tweet) (polls []scraper.Poll, err error) { - stmt, err := p.DB.Prepare("select num_choices, choice1, choice1_votes, choice2, choice2_votes, choice3, choice3_votes, choice4, choice4_votes, voting_duration, voting_ends_at, last_scraped_at from polls where tweet_id=?") + stmt, err := p.DB.Prepare("select id, num_choices, choice1, choice1_votes, choice2, choice2_votes, choice3, choice3_votes, choice4, choice4_votes, voting_duration, voting_ends_at, last_scraped_at from polls where tweet_id=?") if err != nil { return } @@ -175,7 +175,7 @@ func (p Profile) GetPollsForTweet(t scraper.Tweet) (polls []scraper.Poll, err er var voting_ends_at int var last_scraped_at int for rows.Next() { - err = rows.Scan(&poll.NumChoices, &poll.Choice1, &poll.Choice1_Votes, &poll.Choice2, &poll.Choice2_Votes, &poll.Choice3, &poll.Choice3_Votes, &poll.Choice4, &poll.Choice4_Votes, &poll.VotingDuration, &voting_ends_at, &last_scraped_at) + err = rows.Scan(&poll.ID, &poll.NumChoices, &poll.Choice1, &poll.Choice1_Votes, &poll.Choice2, &poll.Choice2_Votes, &poll.Choice3, &poll.Choice3_Votes, &poll.Choice4, &poll.Choice4_Votes, &poll.VotingDuration, &voting_ends_at, &last_scraped_at) if err != nil { return } diff --git a/persistence/media_queries_test.go b/persistence/media_queries_test.go index d3610b4..5b28177 100644 --- a/persistence/media_queries_test.go +++ b/persistence/media_queries_test.go @@ -272,14 +272,51 @@ func TestSaveAndLoadPoll(t *testing.T) { var new_poll scraper.Poll for index := range polls { - if polls[index].Choice1 == poll.Choice1 { + if polls[index].ID == poll.ID { new_poll = polls[index] } } - if new_poll.Choice1 != poll.Choice1 { - t.Fatalf("Could not find poll for some reason: %s, %s; %+v", new_poll.Choice1, poll.Choice1, polls) + if new_poll.ID != poll.ID { + t.Fatalf("Could not find poll for some reason: %d, %d; %+v", new_poll.ID, poll.ID, polls) } if diff := deep.Equal(poll, new_poll); diff != nil { t.Error(diff) } } + +/** + * Change an Poll, save the changes, reload it, and check if it comes back the same + */ +func TestModifyPoll(t *testing.T) { + profile_path := "test_profiles/TestMediaQueries" + profile := create_or_load_profile(profile_path) + + tweet := create_stable_tweet() + poll := tweet.Polls[0] + + if poll.Choice1 != "-1" { + t.Fatalf("Got the wrong Poll back: wanted %q, got %q", "-1", poll.Choice1) + } + + poll.Choice1_Votes = 1200 // Increment it by 200 votes + + // Save the changes + err := profile.SavePoll(poll) + if err != nil { + t.Error(err) + } + + // Reload it + polls, err := profile.GetPollsForTweet(tweet) + if err != nil { + t.Fatalf("Could not load polls: %s", err.Error()) + } + new_poll := polls[0] + if new_poll.Choice1 != "-1" { + t.Fatalf("Got the wrong poll back: wanted %s, got %s!", "-1", new_poll.Choice1) + } + + if diff := deep.Equal(poll, new_poll); diff != nil { + t.Error(diff) + } +} diff --git a/persistence/schema.sql b/persistence/schema.sql index a80083d..575a146 100644 --- a/persistence/schema.sql +++ b/persistence/schema.sql @@ -82,6 +82,7 @@ create table urls (rowid integer primary key, ); create table polls (rowid integer primary key, + id integer unique not null check(typeof(id) = 'integer'), tweet_id integer not null, num_choices integer not null, diff --git a/persistence/utils_test.go b/persistence/utils_test.go index f8d73dc..66ef18e 100644 --- a/persistence/utils_test.go +++ b/persistence/utils_test.go @@ -126,6 +126,7 @@ func create_url_from_id(id int) scraper.Url { func create_poll_from_id(id int) scraper.Poll { s := fmt.Sprint(id) return scraper.Poll{ + ID: scraper.PollID(id), TweetID: -1, NumChoices: 2, Choice1: s, diff --git a/scraper/poll.go b/scraper/poll.go index 2e1bd16..dbc8b6a 100644 --- a/scraper/poll.go +++ b/scraper/poll.go @@ -4,10 +4,14 @@ import ( "time" "strings" "strconv" + "net/url" ) +type PollID int64 + type Poll struct { - TweetID TweetID + ID PollID + TweetID TweetID NumChoices int Choice1 string @@ -26,6 +30,12 @@ type Poll struct { } func ParseAPIPoll(apiCard APICard) Poll { + card_url, err := url.Parse(apiCard.ShortenedUrl) + if err != nil { + panic(err) + } + id := int_or_panic(card_url.Hostname()) + voting_ends_at, err := time.Parse(time.RFC3339, apiCard.BindingValues.EndDatetimeUTC.StringValue) if err != nil { panic(err) @@ -36,6 +46,7 @@ func ParseAPIPoll(apiCard APICard) Poll { } ret := Poll{} + ret.ID = PollID(id) ret.NumChoices = parse_num_choices(apiCard.Name) ret.VotingDuration = int_or_panic(apiCard.BindingValues.DurationMinutes.StringValue) * 60 ret.VotingEndsAt = voting_ends_at diff --git a/scraper/poll_test.go b/scraper/poll_test.go index 97a88b8..833ac22 100644 --- a/scraper/poll_test.go +++ b/scraper/poll_test.go @@ -20,6 +20,9 @@ func TestParsePoll2Choices(t *testing.T) { } poll := scraper.ParseAPIPoll(apiCard) + if poll.ID != 1457419248461131776 { + t.Errorf("Expected ID %d, got %d", 1457419248461131776, poll.ID) + } if poll.NumChoices != 2 { t.Errorf("Expected %d choices, got %d", 2, poll.NumChoices) } @@ -61,6 +64,9 @@ func TestParsePoll4Choices(t *testing.T) { } poll := scraper.ParseAPIPoll(apiCard) + if poll.ID != 1455611588854140929 { + t.Errorf("Expected ID %d, got %d", 1455611588854140929, poll.ID) + } if poll.NumChoices != 4 { t.Errorf("Expected %d choices, got %d", 4, poll.NumChoices) }