diff --git a/persistence/media_queries.go b/persistence/media_queries.go index b826f23..9ffbd59 100644 --- a/persistence/media_queries.go +++ b/persistence/media_queries.go @@ -1,15 +1,55 @@ package persistence import ( + "fmt" + "database/sql" "offline_twitter/scraper" ) +/** + * Save an Image. If it's a new Image (no rowid), does an insert; otherwise, does an update. + * + * args: + * - img: the Image to save + * + * returns: + * - the rowid + */ +func (p Profile) SaveImage(img scraper.Image) (sql.Result, error) { + if img.ID == 0 { + // New image + return p.DB.Exec("insert into images (tweet_id, filename) values (?, ?)", img.TweetID, img.Filename) + } else { + // Updating an existing image + return p.DB.Exec("update images set filename=?, is_downloaded=? where rowid=?", img.Filename, img.IsDownloaded, img.ID) + } +} + +/** + * Save a Video. If it's a new Video (no rowid), does an insert; otherwise, does an update. + * + * args: + * - img: the Video to save + * + * returns: + * - the rowid + */ +func (p Profile) SaveVideo(vid scraper.Video) (sql.Result, error) { + if vid.ID == 0 { + // New image + return p.DB.Exec("insert into videos (tweet_id, filename) values (?, ?)", vid.TweetID, vid.Filename) + } else { + // Updating an existing image + return p.DB.Exec("update videos set filename=?, is_downloaded=? where rowid=?", vid.Filename, vid.IsDownloaded, vid.ID) + } +} + /** * Get the list of images for a tweet */ func (p Profile) GetImagesForTweet(t scraper.Tweet) (imgs []scraper.Image, err error) { - stmt, err := p.DB.Prepare("select filename, is_downloaded from images where tweet_id=?") + stmt, err := p.DB.Prepare("select rowid, filename, is_downloaded from images where tweet_id=?") if err != nil { return } @@ -21,7 +61,7 @@ func (p Profile) GetImagesForTweet(t scraper.Tweet) (imgs []scraper.Image, err e var img scraper.Image for rows.Next() { - err = rows.Scan(&img.Filename, &img.IsDownloaded) + err = rows.Scan(&img.ID, &img.Filename, &img.IsDownloaded) if err != nil { return } @@ -36,7 +76,7 @@ func (p Profile) GetImagesForTweet(t scraper.Tweet) (imgs []scraper.Image, err e * Get the list of videos for a tweet */ func (p Profile) GetVideosForTweet(t scraper.Tweet) (vids []scraper.Video, err error) { - stmt, err := p.DB.Prepare("select filename, is_downloaded from videos where tweet_id=?") + stmt, err := p.DB.Prepare("select rowid, filename, is_downloaded from videos where tweet_id=?") if err != nil { return } @@ -47,7 +87,7 @@ func (p Profile) GetVideosForTweet(t scraper.Tweet) (vids []scraper.Video, err e } var vid scraper.Video for rows.Next() { - err = rows.Scan(&vid.Filename, &vid.IsDownloaded) + err = rows.Scan(&vid.ID, &vid.Filename, &vid.IsDownloaded) if err != nil { return } diff --git a/persistence/media_queries_test.go b/persistence/media_queries_test.go new file mode 100644 index 0000000..55592ab --- /dev/null +++ b/persistence/media_queries_test.go @@ -0,0 +1,104 @@ +package persistence_test + +import ( + "testing" + "math/rand" + "fmt" + "time" + + "github.com/go-test/deep" + + "offline_twitter/scraper" +) + + +/** + * Create an Image, save it, reload it, and make sure it comes back the same + */ +func TestSaveAndLoadImage(t *testing.T) { + profile_path := "test_profiles/TestMediaQueries" + profile := create_or_load_profile(profile_path) + + tweet := create_stable_tweet() + + // Create a fresh Image to test on + rand.Seed(time.Now().UnixNano()) + filename := fmt.Sprint(rand.Int()) + img := scraper.Image{TweetID: tweet.ID, Filename: filename, IsDownloaded: false} + + // Save the Image + result, err := profile.SaveImage(img) + if err != nil { + t.Fatalf("Failed to save the image: %s", err.Error()) + } + last_insert, err := result.LastInsertId() + if err != nil { + t.Fatalf("last insert??? %s", err.Error()) + } + img.ID = scraper.ImageID(last_insert) + + // Reload the Image + imgs, err := profile.GetImagesForTweet(tweet) + if err != nil { + t.Fatalf("Could not load images: %s", err.Error()) + } + + var new_img scraper.Image + for index := range imgs { + if imgs[index].ID == img.ID { + new_img = imgs[index] + } + } + if new_img.ID != img.ID { + t.Fatalf("Could not find image for some reason: %d, %d; %+v", new_img.ID, img.ID, imgs) + } + if diff := deep.Equal(img, new_img); diff != nil { + t.Error(diff) + } +} + + +/** + * Create an Video, save it, reload it, and make sure it comes back the same + */ +func TestSaveAndLoadVideo(t *testing.T) { + profile_path := "test_profiles/TestMediaQueries" + profile := create_or_load_profile(profile_path) + + tweet := create_stable_tweet() + + // Create a fresh Video to test on + rand.Seed(time.Now().UnixNano()) + filename := fmt.Sprint(rand.Int()) + vid := scraper.Video{TweetID: tweet.ID, Filename: filename, IsDownloaded: false} + + // Save the Video + result, err := profile.SaveVideo(vid) + if err != nil { + t.Fatalf("Failed to save the video: %s", err.Error()) + } + last_insert, err := result.LastInsertId() + if err != nil { + t.Fatalf("last insert??? %s", err.Error()) + } + vid.ID = scraper.VideoID(last_insert) + + // Reload the Video + vids, err := profile.GetVideosForTweet(tweet) + if err != nil { + t.Fatalf("Could not load videos: %s", err.Error()) + } + + var new_vid scraper.Video + for index := range vids { + if vids[index].ID == vid.ID { + new_vid = vids[index] + } + } + if new_vid.ID != vid.ID { + t.Fatalf("Could not find video for some reason: %d, %d; %+v", new_vid.ID, vid.ID, vids) + } + if diff := deep.Equal(vid, new_vid); diff != nil { + t.Error(diff) + } +} diff --git a/persistence/schema.sql b/persistence/schema.sql index 8b2ad9c..0179b62 100644 --- a/persistence/schema.sql +++ b/persistence/schema.sql @@ -60,19 +60,17 @@ create table urls (rowid integer primary key, create table images (rowid integer primary key, tweet_id integer not null, - filename text not null, + filename text not null unique, is_downloaded boolean default 0, - unique (tweet_id, filename) foreign key(tweet_id) references tweets(id) ); create table videos (rowid integer primary key, tweet_id integer not null, - filename text not null, + filename text not null unique, is_downloaded boolean default 0, - unique (tweet_id, filename) foreign key(tweet_id) references tweets(id) ); diff --git a/persistence/tweet_queries.go b/persistence/tweet_queries.go index 725a990..fb34323 100644 --- a/persistence/tweet_queries.go +++ b/persistence/tweet_queries.go @@ -39,13 +39,13 @@ func (p Profile) SaveTweet(t scraper.Tweet) error { } } for _, image := range t.Images { - _, err := db.Exec("insert into images (tweet_id, filename) values (?, ?) on conflict do nothing", t.ID, image.Filename) + _, err := p.SaveImage(image) if err != nil { return err } } for _, video := range t.Videos { - _, err := db.Exec("insert into videos (tweet_id, filename) values (?, ?) on conflict do nothing", t.ID, video.Filename) + _, err := p.SaveVideo(video) if err != nil { return err } diff --git a/persistence/tweet_queries_test.go b/persistence/tweet_queries_test.go index 25d69f1..cfa8efa 100644 --- a/persistence/tweet_queries_test.go +++ b/persistence/tweet_queries_test.go @@ -15,8 +15,6 @@ func TestSaveAndLoadTweet(t *testing.T) { profile := create_or_load_profile(profile_path) tweet := create_dummy_tweet() - user := create_stable_user() - tweet.UserID = user.ID // Save the tweet err := profile.SaveTweet(tweet) @@ -30,6 +28,15 @@ func TestSaveAndLoadTweet(t *testing.T) { t.Fatalf("Failed to load the tweet: %s", err.Error()) } + // Spoof the image and video IDs + // TODO: This feels clumsy-- possible bad design + for i := range tweet.Images { + tweet.Images[i].ID = new_tweet.Images[i].ID + } + for i := range tweet.Videos { + tweet.Videos[i].ID = new_tweet.Videos[i].ID + } + if diff := deep.Equal(tweet, new_tweet); diff != nil { t.Error(diff) } @@ -43,8 +50,6 @@ func TestIsTweetInDatabase(t *testing.T) { profile := create_or_load_profile(profile_path) tweet := create_dummy_tweet() - user := create_stable_user() - tweet.UserID = user.ID exists := profile.IsTweetInDatabase(tweet.ID) if exists { @@ -68,8 +73,6 @@ func TestLoadUserForTweet(t *testing.T) { profile := create_or_load_profile(profile_path) tweet := create_dummy_tweet() - user := create_stable_user() - tweet.UserID = user.ID // Save the tweet err := profile.SaveTweet(tweet) diff --git a/persistence/utils_test.go b/persistence/utils_test.go index 04b35c9..316f96e 100644 --- a/persistence/utils_test.go +++ b/persistence/utils_test.go @@ -117,7 +117,7 @@ func create_dummy_tweet() scraper.Tweet { return scraper.Tweet{ ID: tweet_id, - UserID: "user", + UserID: scraper.UserID("-1"), Text: "text", PostedAt: time.Now().Truncate(1e9), // Round to nearest second NumLikes: 1, diff --git a/scraper/image.go b/scraper/image.go index 71fd963..40d910c 100644 --- a/scraper/image.go +++ b/scraper/image.go @@ -1,6 +1,9 @@ package scraper +type ImageID int + type Image struct { + ID ImageID TweetID TweetID Filename string IsDownloaded bool diff --git a/scraper/video.go b/scraper/video.go index a847efd..c437cf5 100644 --- a/scraper/video.go +++ b/scraper/video.go @@ -1,6 +1,9 @@ package scraper +type VideoID int + type Video struct { + ID VideoID TweetID TweetID Filename string IsDownloaded bool