diff --git a/persistence/user_queries.go b/persistence/user_queries.go new file mode 100644 index 0000000..5ab8003 --- /dev/null +++ b/persistence/user_queries.go @@ -0,0 +1,159 @@ +package persistence + +import ( + "fmt" + "database/sql" + "time" + "offline_twitter/scraper" +) + +/** + * Save the given User to the database. + * If the User is already in the database, it will update most of its attributes (follower count, etc) + * + * args: + * - u: the User + */ +func (p Profile) SaveUser(u scraper.User) error { + db := p.DB + + tx, err := db.Begin() + if err != nil { + return err + } + _, err = db.Exec(` + insert into users (id, display_name, handle, bio, following_count, followers_count, location, website, join_date, is_private, is_verified, profile_image_url, banner_image_url, pinned_tweet_id) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + on conflict do update + set bio=?, + following_count=?, + followers_count=?, + location=?, + website=?, + is_private=?, + is_verified=?, + profile_image_url=?, + banner_image_url=?, + pinned_tweet_id=? + `, + u.ID, u.DisplayName, u.Handle, u.Bio, u.FollowingCount, u.FollowersCount, u.Location, u.Website, u.JoinDate.Unix(), u.IsPrivate, u.IsVerified, u.ProfileImageUrl, u.BannerImageUrl, u.PinnedTweetID, u.Bio, u.FollowingCount, u.FollowersCount, u.Location, u.Website, u.IsPrivate, u.IsVerified, u.ProfileImageUrl, u.BannerImageUrl, u.PinnedTweetID, + ) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + + +/** + * Check if the database has a User with the given user handle. + * + * args: + * - handle: the user handle to search for + * + * returns: + * - true if there is such a User in the database, false otherwise + */ +func (p Profile) UserExists(handle scraper.UserHandle) bool { + db := p.DB + + var dummy string + err := db.QueryRow("select 1 from users where lower(handle) = lower(?)", handle).Scan(&dummy) + if err != nil { + if err != sql.ErrNoRows { + // A real error + panic(err) + } + return false + } + return true +} + +/** + * Helper function. Create a User from a Row. + */ +func parse_user_from_row(row *sql.Row) (scraper.User, error) { + var u scraper.User + var joinDate int64 + var user_id int64 + var pinned_tweet_id int64 + + err := row.Scan(&user_id, &u.DisplayName, &u.Handle, &u.Bio, &u.FollowingCount, &u.FollowersCount, &u.Location, &u.Website, &joinDate, &u.IsPrivate, &u.IsVerified, &u.ProfileImageUrl, &u.BannerImageUrl, &pinned_tweet_id) + if err != nil { + return u, err + } + + u.ID = scraper.UserID(fmt.Sprint(user_id)) + u.JoinDate = time.Unix(joinDate, 0) + u.PinnedTweetID = scraper.TweetID(fmt.Sprint(pinned_tweet_id)) + + return u, nil +} + + +/** + * Retrieve a User from the database, by handle. + * + * args: + * - handle: the user handle to search for + * + * returns: + * - the User, if it exists + */ +func (p Profile) GetUserByHandle(handle scraper.UserHandle) (scraper.User, error) { + db := p.DB + + stmt, err := db.Prepare(` + select id, display_name, handle, bio, following_count, followers_count, location, website, join_date, is_private, is_verified, profile_image_url, banner_image_url, pinned_tweet_id + from users + where handle = ? + `) + if err != nil { + return scraper.User{}, err + } + defer stmt.Close() + + row := stmt.QueryRow(handle) + ret, err := parse_user_from_row(row) + if err == sql.ErrNoRows { + return ret, ErrNotInDatabase{"User", handle} + } + return ret, nil +} + + +/** + * Retrieve a User from the database, by user ID. + * + * args: + * - id: the user ID to search for + * + * returns: + * - the User, if it exists + */ +func (p Profile) GetUserByID(id scraper.UserID) (scraper.User, error) { + db := p.DB + + stmt, err := db.Prepare(` + select id, display_name, handle, bio, following_count, followers_count, location, website, join_date, is_private, is_verified, profile_image_url, banner_image_url, pinned_tweet_id + from users + where id = ? + `) + if err != nil { + return scraper.User{}, err + } + defer stmt.Close() + + row := stmt.QueryRow(id) + ret, err := parse_user_from_row(row) + if err == sql.ErrNoRows { + return ret, ErrNotInDatabase{"User", id} + } + return ret, err +} diff --git a/persistence/user_queries_test.go b/persistence/user_queries_test.go new file mode 100644 index 0000000..ea52c07 --- /dev/null +++ b/persistence/user_queries_test.go @@ -0,0 +1,115 @@ +package persistence_test + +import ( + "fmt" + "testing" + "time" + "math/rand" + + "github.com/go-test/deep" + + "offline_twitter/scraper" + "offline_twitter/persistence" +) + +/** + * Helper function + */ +func create_or_load_profile(profile_path string) persistence.Profile { + var profile persistence.Profile + var err error + + if !file_exists(profile_path) { + profile, err = persistence.NewProfile(profile_path) + } else { + profile, err = persistence.LoadProfile(profile_path) + } + if err != nil { + panic(err) + } + return profile +} + + +/** + * Create a user, save it, reload it, and make sure it comes back the same + */ +func TestSaveAndLoadUser(t *testing.T) { + profile_path := "test_profiles/TestUserQueries" + profile := create_or_load_profile(profile_path) + + // Generate a new random user ID + rand.Seed(time.Now().UnixNano()) + userID := fmt.Sprint(rand.Int()) + + fake_user := scraper.User{ + ID: scraper.UserID(userID), + DisplayName: "display name", + Handle: scraper.UserHandle("handle" + userID), + Bio: "bio", + FollowersCount: 0, + FollowingCount: 1000, + Location: "location", + Website:"website", + JoinDate: time.Now().Truncate(1e9), // Round to nearest second + IsVerified: false, + IsPrivate: true, + ProfileImageUrl: "profile image url", + BannerImageUrl: "banner image url", + PinnedTweetID: scraper.TweetID("234"), + } + + // Save the user, then reload it and ensure it's the same + err := profile.SaveUser(fake_user) + if err != nil { + panic(err) + } + new_fake_user, err := profile.GetUserByID(scraper.UserID(userID)) + if err != nil { + panic(err) + } + + if diff := deep.Equal(new_fake_user, fake_user); diff != nil { + t.Error(diff) + } + + // Same thing, but get by handle + new_fake_user2, err := profile.GetUserByHandle(scraper.UserHandle(fake_user.Handle)) + if err != nil { + panic(err) + } + + if diff := deep.Equal(new_fake_user2, fake_user); diff != nil { + t.Error(diff) + } +} + + +/** + * Should correctly report whether the user exists in the database + */ +func TestUserExists(t *testing.T) { + profile_path := "test_profiles/TestUserQueries" + profile := create_or_load_profile(profile_path) + + // Generate a new random user ID + rand.Seed(time.Now().UnixNano()) + userID := fmt.Sprint(rand.Int()) + + user := scraper.User{} + user.ID = scraper.UserID(userID) + user.Handle = scraper.UserHandle("handle" + userID) + + exists := profile.UserExists(scraper.UserHandle(user.Handle)) + if exists { + t.Errorf("It shouldn't exist, but it does: %s", userID) + } + err := profile.SaveUser(user) + if err != nil { + panic(err) + } + exists = profile.UserExists(scraper.UserHandle(user.Handle)) + if !exists { + t.Errorf("It should exist, but it doesn't: %s", userID) + } +}