Add media queries

This commit is contained in:
Alessio 2021-08-01 15:52:04 -07:00
parent 35cf1f1296
commit 083f5f72e0
8 changed files with 168 additions and 17 deletions

View File

@ -1,15 +1,55 @@
package persistence
import (
"fmt"
"database/sql"
"offline_twitter/scraper"
)
/**
* Save an Image. If it's a new Image (no rowid), does an insert; otherwise, does an update.
*
* args:
* - img: the Image to save
*
* returns:
* - the rowid
*/
func (p Profile) SaveImage(img scraper.Image) (sql.Result, error) {
if img.ID == 0 {
// New image
return p.DB.Exec("insert into images (tweet_id, filename) values (?, ?)", img.TweetID, img.Filename)
} else {
// Updating an existing image
return p.DB.Exec("update images set filename=?, is_downloaded=? where rowid=?", img.Filename, img.IsDownloaded, img.ID)
}
}
/**
* Save a Video. If it's a new Video (no rowid), does an insert; otherwise, does an update.
*
* args:
* - img: the Video to save
*
* returns:
* - the rowid
*/
func (p Profile) SaveVideo(vid scraper.Video) (sql.Result, error) {
if vid.ID == 0 {
// New image
return p.DB.Exec("insert into videos (tweet_id, filename) values (?, ?)", vid.TweetID, vid.Filename)
} else {
// Updating an existing image
return p.DB.Exec("update videos set filename=?, is_downloaded=? where rowid=?", vid.Filename, vid.IsDownloaded, vid.ID)
}
}
/**
* Get the list of images for a tweet
*/
func (p Profile) GetImagesForTweet(t scraper.Tweet) (imgs []scraper.Image, err error) {
stmt, err := p.DB.Prepare("select filename, is_downloaded from images where tweet_id=?")
stmt, err := p.DB.Prepare("select rowid, filename, is_downloaded from images where tweet_id=?")
if err != nil {
return
}
@ -21,7 +61,7 @@ func (p Profile) GetImagesForTweet(t scraper.Tweet) (imgs []scraper.Image, err e
var img scraper.Image
for rows.Next() {
err = rows.Scan(&img.Filename, &img.IsDownloaded)
err = rows.Scan(&img.ID, &img.Filename, &img.IsDownloaded)
if err != nil {
return
}
@ -36,7 +76,7 @@ func (p Profile) GetImagesForTweet(t scraper.Tweet) (imgs []scraper.Image, err e
* Get the list of videos for a tweet
*/
func (p Profile) GetVideosForTweet(t scraper.Tweet) (vids []scraper.Video, err error) {
stmt, err := p.DB.Prepare("select filename, is_downloaded from videos where tweet_id=?")
stmt, err := p.DB.Prepare("select rowid, filename, is_downloaded from videos where tweet_id=?")
if err != nil {
return
}
@ -47,7 +87,7 @@ func (p Profile) GetVideosForTweet(t scraper.Tweet) (vids []scraper.Video, err e
}
var vid scraper.Video
for rows.Next() {
err = rows.Scan(&vid.Filename, &vid.IsDownloaded)
err = rows.Scan(&vid.ID, &vid.Filename, &vid.IsDownloaded)
if err != nil {
return
}

View File

@ -0,0 +1,104 @@
package persistence_test
import (
"testing"
"math/rand"
"fmt"
"time"
"github.com/go-test/deep"
"offline_twitter/scraper"
)
/**
* Create an Image, save it, reload it, and make sure it comes back the same
*/
func TestSaveAndLoadImage(t *testing.T) {
profile_path := "test_profiles/TestMediaQueries"
profile := create_or_load_profile(profile_path)
tweet := create_stable_tweet()
// Create a fresh Image to test on
rand.Seed(time.Now().UnixNano())
filename := fmt.Sprint(rand.Int())
img := scraper.Image{TweetID: tweet.ID, Filename: filename, IsDownloaded: false}
// Save the Image
result, err := profile.SaveImage(img)
if err != nil {
t.Fatalf("Failed to save the image: %s", err.Error())
}
last_insert, err := result.LastInsertId()
if err != nil {
t.Fatalf("last insert??? %s", err.Error())
}
img.ID = scraper.ImageID(last_insert)
// Reload the Image
imgs, err := profile.GetImagesForTweet(tweet)
if err != nil {
t.Fatalf("Could not load images: %s", err.Error())
}
var new_img scraper.Image
for index := range imgs {
if imgs[index].ID == img.ID {
new_img = imgs[index]
}
}
if new_img.ID != img.ID {
t.Fatalf("Could not find image for some reason: %d, %d; %+v", new_img.ID, img.ID, imgs)
}
if diff := deep.Equal(img, new_img); diff != nil {
t.Error(diff)
}
}
/**
* Create an Video, save it, reload it, and make sure it comes back the same
*/
func TestSaveAndLoadVideo(t *testing.T) {
profile_path := "test_profiles/TestMediaQueries"
profile := create_or_load_profile(profile_path)
tweet := create_stable_tweet()
// Create a fresh Video to test on
rand.Seed(time.Now().UnixNano())
filename := fmt.Sprint(rand.Int())
vid := scraper.Video{TweetID: tweet.ID, Filename: filename, IsDownloaded: false}
// Save the Video
result, err := profile.SaveVideo(vid)
if err != nil {
t.Fatalf("Failed to save the video: %s", err.Error())
}
last_insert, err := result.LastInsertId()
if err != nil {
t.Fatalf("last insert??? %s", err.Error())
}
vid.ID = scraper.VideoID(last_insert)
// Reload the Video
vids, err := profile.GetVideosForTweet(tweet)
if err != nil {
t.Fatalf("Could not load videos: %s", err.Error())
}
var new_vid scraper.Video
for index := range vids {
if vids[index].ID == vid.ID {
new_vid = vids[index]
}
}
if new_vid.ID != vid.ID {
t.Fatalf("Could not find video for some reason: %d, %d; %+v", new_vid.ID, vid.ID, vids)
}
if diff := deep.Equal(vid, new_vid); diff != nil {
t.Error(diff)
}
}

View File

@ -60,19 +60,17 @@ create table urls (rowid integer primary key,
create table images (rowid integer primary key,
tweet_id integer not null,
filename text not null,
filename text not null unique,
is_downloaded boolean default 0,
unique (tweet_id, filename)
foreign key(tweet_id) references tweets(id)
);
create table videos (rowid integer primary key,
tweet_id integer not null,
filename text not null,
filename text not null unique,
is_downloaded boolean default 0,
unique (tweet_id, filename)
foreign key(tweet_id) references tweets(id)
);

View File

@ -39,13 +39,13 @@ func (p Profile) SaveTweet(t scraper.Tweet) error {
}
}
for _, image := range t.Images {
_, err := db.Exec("insert into images (tweet_id, filename) values (?, ?) on conflict do nothing", t.ID, image.Filename)
_, err := p.SaveImage(image)
if err != nil {
return err
}
}
for _, video := range t.Videos {
_, err := db.Exec("insert into videos (tweet_id, filename) values (?, ?) on conflict do nothing", t.ID, video.Filename)
_, err := p.SaveVideo(video)
if err != nil {
return err
}

View File

@ -15,8 +15,6 @@ func TestSaveAndLoadTweet(t *testing.T) {
profile := create_or_load_profile(profile_path)
tweet := create_dummy_tweet()
user := create_stable_user()
tweet.UserID = user.ID
// Save the tweet
err := profile.SaveTweet(tweet)
@ -30,6 +28,15 @@ func TestSaveAndLoadTweet(t *testing.T) {
t.Fatalf("Failed to load the tweet: %s", err.Error())
}
// Spoof the image and video IDs
// TODO: This feels clumsy-- possible bad design
for i := range tweet.Images {
tweet.Images[i].ID = new_tweet.Images[i].ID
}
for i := range tweet.Videos {
tweet.Videos[i].ID = new_tweet.Videos[i].ID
}
if diff := deep.Equal(tweet, new_tweet); diff != nil {
t.Error(diff)
}
@ -43,8 +50,6 @@ func TestIsTweetInDatabase(t *testing.T) {
profile := create_or_load_profile(profile_path)
tweet := create_dummy_tweet()
user := create_stable_user()
tweet.UserID = user.ID
exists := profile.IsTweetInDatabase(tweet.ID)
if exists {
@ -68,8 +73,6 @@ func TestLoadUserForTweet(t *testing.T) {
profile := create_or_load_profile(profile_path)
tweet := create_dummy_tweet()
user := create_stable_user()
tweet.UserID = user.ID
// Save the tweet
err := profile.SaveTweet(tweet)

View File

@ -117,7 +117,7 @@ func create_dummy_tweet() scraper.Tweet {
return scraper.Tweet{
ID: tweet_id,
UserID: "user",
UserID: scraper.UserID("-1"),
Text: "text",
PostedAt: time.Now().Truncate(1e9), // Round to nearest second
NumLikes: 1,

View File

@ -1,6 +1,9 @@
package scraper
type ImageID int
type Image struct {
ID ImageID
TweetID TweetID
Filename string
IsDownloaded bool

View File

@ -1,6 +1,9 @@
package scraper
type VideoID int
type Video struct {
ID VideoID
TweetID TweetID
Filename string
IsDownloaded bool