diff --git a/.golangci.yaml b/.golangci.yaml index fbdc37a..54ea499 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -23,7 +23,6 @@ linters: - unused - varcheck - whitespace - # - wsl # - wrapcheck - lll - godox @@ -546,48 +545,6 @@ linters-settings: # - .Wrapf( # - .WithMessage( - wsl: - # See https://github.com/bombsimon/wsl/blob/master/doc/configuration.md for documentation of available settings. - - # Controls if you're allowed to cuddle multiple variable declarations. - allow-cuddle-declarations: true - - # # Controls if you may cuddle assignments and anything without needing an empty line between them. - # # Default: false - # allow-assign-and-anything: false - # # Controls if you may cuddle assignments and calls without needing an empty line between them. - # # Default: true - # allow-assign-and-call: true - # # Controls if you may cuddle assignments even if they span over multiple lines. - # # Default: true - # allow-multiline-assign: true - # # This option allows whitespace after each comment group that begins a block. - # # Default: false - # allow-separated-leading-comment: false - # # Controls if blocks can end with comments. - # # This is not encouraged sine it's usually code smell but might be useful do improve understanding or learning purposes. - # # To be allowed there must be no whitespace between the comment and the last statement or the comment and the closing brace. - # # Default: false - # allow-trailing-comment: false - # # Can be set to force trailing newlines at the end of case blocks to improve readability. - # # If the number of lines (including comments) in a case block exceeds this number - # # a linter error will be yielded if the case does not end with a newline. - # # Default: 0 - # force-case-trailing-whitespace: 0 - # # Enforces that an `if` statement checking an error variable is cuddled - # # with the line that assigned that error variable. - # # Default: false - # force-err-cuddling: false - # # Enforces that an assignment which is actually a short declaration (using `:=`) - # # is only allowed to cuddle with other short declarations, and not plain assignments, blocks, etc. - # # This rule helps make declarations stand out by themselves, much the same as grouping var statement. - # # Default: false - # force-short-decl-cuddling: false - # # Controls if the checks for slice append should be "strict" - # # in the sense that it will only allow these assignments to be cuddled with variables being appended. - # # Default: true - # strict-append: true - # # The custom section can be used to define linter plugins to be loaded at runtime. # # See README doc for more info. # custom: diff --git a/go.mod b/go.mod index bdcd5ad..a2381b0 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/go-test/deep v1.0.7 github.com/jarcoal/httpmock v1.1.0 + github.com/jmoiron/sqlx v1.3.4 github.com/mattn/go-sqlite3 v1.14.7 github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index df38f49..9579380 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,17 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M= github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= github.com/jarcoal/httpmock v1.1.0 h1:F47ChZj1Y2zFsCXxNkBPwNNKnAyOATcdQibk0qEdVCE= github.com/jarcoal/httpmock v1.1.0/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= +github.com/jmoiron/sqlx v1.3.4 h1:wv+0IJZfL5z0uZoUjlpKgHkgaFSYD+r9CfrXjEXsO7w= +github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= +github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/persistence/media_queries.go b/persistence/media_queries.go index 5f7a965..c72b770 100644 --- a/persistence/media_queries.go +++ b/persistence/media_queries.go @@ -92,25 +92,7 @@ func (p Profile) SavePoll(poll scraper.Poll) error { * Get the list of images for a tweet */ func (p Profile) GetImagesForTweet(t scraper.Tweet) (imgs []scraper.Image, err error) { - stmt, err := p.DB.Prepare("select id, width, height, remote_url, local_filename, is_downloaded from images where tweet_id=?") - if err != nil { - return - } - defer stmt.Close() - rows, err := stmt.Query(t.ID) - if err != nil { - return - } - var img scraper.Image - - for rows.Next() { - err = rows.Scan(&img.ID, &img.Width, &img.Height, &img.RemoteURL, &img.LocalFilename, &img.IsDownloaded) - if err != nil { - return - } - img.TweetID = t.ID - imgs = append(imgs, img) - } + err = p.DB.Select(&imgs, "select id, tweet_id, width, height, remote_url, local_filename, is_downloaded from images where tweet_id=?", t.ID) return } @@ -118,93 +100,38 @@ func (p Profile) GetImagesForTweet(t scraper.Tweet) (imgs []scraper.Image, err e * Get the list of videos for a tweet */ func (p Profile) GetVideosForTweet(t scraper.Tweet) (vids []scraper.Video, err error) { - stmt, err := p.DB.Prepare(` - select id, width, height, remote_url, local_filename, thumbnail_remote_url, thumbnail_local_filename, duration, view_count, + err = p.DB.Select(&vids, ` + select id, tweet_id, width, height, remote_url, local_filename, thumbnail_remote_url, thumbnail_local_filename, duration, view_count, is_downloaded, is_gif from videos where tweet_id = ? - `) - if err != nil { - return - } - defer stmt.Close() - rows, err := stmt.Query(t.ID) - if err != nil { - return - } - var vid scraper.Video - for rows.Next() { - err = rows.Scan(&vid.ID, &vid.Width, &vid.Height, &vid.RemoteURL, &vid.LocalFilename, &vid.ThumbnailRemoteUrl, - &vid.ThumbnailLocalPath, &vid.Duration, &vid.ViewCount, &vid.IsDownloaded, &vid.IsGif) - if err != nil { - return - } - vid.TweetID = t.ID - vids = append(vids, vid) - } - return + `, t.ID) + return } /** * Get the list of Urls for a Tweet */ func (p Profile) GetUrlsForTweet(t scraper.Tweet) (urls []scraper.Url, err error) { - stmt, err := p.DB.Prepare(` - select domain, text, short_text, title, description, creator_id, site_id, thumbnail_width, thumbnail_height, thumbnail_remote_url, + err = p.DB.Select(&urls, ` + select tweet_id, domain, text, short_text, title, description, creator_id, site_id, thumbnail_width, thumbnail_height, thumbnail_remote_url, thumbnail_local_path, has_card, has_thumbnail, is_content_downloaded from urls where tweet_id = ? order by rowid - `) - if err != nil { - return - } - defer stmt.Close() - rows, err := stmt.Query(t.ID) - if err != nil { - return - } - var url scraper.Url - for rows.Next() { - err = rows.Scan(&url.Domain, &url.Text, &url.ShortText, &url.Title, &url.Description, &url.CreatorID, &url.SiteID, - &url.ThumbnailWidth, &url.ThumbnailHeight, &url.ThumbnailRemoteUrl, &url.ThumbnailLocalPath, &url.HasCard, - &url.HasThumbnail, &url.IsContentDownloaded) - if err != nil { - return - } - url.TweetID = t.ID - urls = append(urls, url) - } - return + `, t.ID) + return } /** * Get the list of Polls for a Tweet */ func (p Profile) GetPollsForTweet(t scraper.Tweet) (polls []scraper.Poll, err error) { - stmt, err := p.DB.Prepare(` - select id, num_choices, choice1, choice1_votes, choice2, choice2_votes, choice3, choice3_votes, choice4, choice4_votes, + err = p.DB.Select(&polls, ` + select id, tweet_id, num_choices, choice1, choice1_votes, choice2, choice2_votes, choice3, choice3_votes, choice4, choice4_votes, voting_duration, voting_ends_at, last_scraped_at from polls where tweet_id = ? - `) - if err != nil { - return - } - defer stmt.Close() - rows, err := stmt.Query(t.ID) - if err != nil { - return - } - var poll scraper.Poll - for rows.Next() { - err = rows.Scan(&poll.ID, &poll.NumChoices, &poll.Choice1, &poll.Choice1_Votes, &poll.Choice2, &poll.Choice2_Votes, &poll.Choice3, - &poll.Choice3_Votes, &poll.Choice4, &poll.Choice4_Votes, &poll.VotingDuration, &poll.VotingEndsAt, &poll.LastUpdatedAt) - if err != nil { - return - } - poll.TweetID = t.ID - polls = append(polls, poll) - } - return + `, t.ID) + return } diff --git a/persistence/profile.go b/persistence/profile.go index 1650de2..9b84f71 100644 --- a/persistence/profile.go +++ b/persistence/profile.go @@ -5,7 +5,8 @@ import ( "fmt" "os" "path" - "database/sql" + sql "github.com/jmoiron/sqlx" + "github.com/jmoiron/sqlx/reflectx" _ "github.com/mattn/go-sqlite3" "gopkg.in/yaml.v2" ) @@ -64,15 +65,10 @@ func NewProfile(target_dir string) (Profile, error) { // Create `twitter.db` fmt.Printf("Creating............. %s\n", sqlite_file) - db, err := sql.Open("sqlite3", sqlite_file + "?_foreign_keys=on") - if err != nil { - return Profile{}, err - } - _, err = db.Exec(sql_init) - if err != nil { - return Profile{}, err - } + db := sql.MustOpen("sqlite3", sqlite_file+"?_foreign_keys=on") + db.MustExec(sql_init) InitializeDatabaseVersion(db) + db.Mapper = reflectx.NewMapperFunc("db", ToSnakeCase) // Create `settings.yaml` fmt.Printf("Creating............. %s\n", settings_file) @@ -156,10 +152,8 @@ func LoadProfile(profile_dir string) (Profile, error) { return Profile{}, err } - db, err := sql.Open("sqlite3", sqlite_file+"?_foreign_keys=on&_journal_mode=WAL") - if err != nil { - return Profile{}, err - } + db := sql.MustOpen("sqlite3", fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL", sqlite_file)) + db.Mapper = reflectx.NewMapperFunc("db", ToSnakeCase) ret := Profile{ ProfileDir: profile_dir, diff --git a/persistence/profile_test.go b/persistence/profile_test.go index d8aab7d..3952e68 100644 --- a/persistence/profile_test.go +++ b/persistence/profile_test.go @@ -111,3 +111,14 @@ func TestLoadProfile(t *testing.T) { assert.Equal(t, profile_path, profile.ProfileDir) } + +/** + * Test the ToSnakeCase implementation + */ +func TestSnakeCase(t *testing.T) { + assert := assert.New(t) + + assert.Equal("tweet_id", persistence.ToSnakeCase("TweetID")) + assert.Equal("i_am_a_computer", persistence.ToSnakeCase("IAmAComputer")) + assert.Equal("choice1_votes", persistence.ToSnakeCase("Choice1_Votes")) +} diff --git a/persistence/retweet_queries.go b/persistence/retweet_queries.go index 0877d3a..0ccffc5 100644 --- a/persistence/retweet_queries.go +++ b/persistence/retweet_queries.go @@ -22,22 +22,11 @@ func (p Profile) SaveRetweet(r scraper.Retweet) error { * Retrieve a Retweet by ID */ func (p Profile) GetRetweetById(id scraper.TweetID) (scraper.Retweet, error) { - stmt, err := p.DB.Prepare(` + var r scraper.Retweet + err := p.DB.Get(&r, ` select retweet_id, tweet_id, retweeted_by, retweeted_at from retweets where retweet_id = ? - `) - if err != nil { - return scraper.Retweet{}, err - } - defer stmt.Close() - - var r scraper.Retweet - - row := stmt.QueryRow(id) - err = row.Scan(&r.RetweetID, &r.TweetID, &r.RetweetedByID, &r.RetweetedAt) - if err != nil { - return scraper.Retweet{}, err - } - return r, nil + `, id) + return r, err } diff --git a/persistence/tweet_queries.go b/persistence/tweet_queries.go index c9b903a..89c5496 100644 --- a/persistence/tweet_queries.go +++ b/persistence/tweet_queries.go @@ -10,11 +10,9 @@ import ( func (p Profile) SaveTweet(t scraper.Tweet) error { db := p.DB - tx, err := db.Begin() - if err != nil { - return err - } - _, err = db.Exec(` + tx := db.MustBegin() + + _, err := db.Exec(` insert into tweets (id, user_id, text, posted_at, num_likes, num_retweets, num_replies, num_quote_tweets, in_reply_to_id, quoted_tweet_id, mentions, reply_mentions, hashtags, tombstone_type, is_stub, is_content_downloaded, is_conversation_scraped, last_scraped_at) diff --git a/persistence/user_queries.go b/persistence/user_queries.go index a6e2b2d..500a031 100644 --- a/persistence/user_queries.go +++ b/persistence/user_queries.go @@ -74,33 +74,20 @@ func (p Profile) SaveUser(u *scraper.User) error { * - true if there is such a User in the database, false otherwise */ func (p Profile) UserExists(handle scraper.UserHandle) bool { - db := p.DB + db := p.DB - var dummy string - err := db.QueryRow("select 1 from users where lower(handle) = lower(?)", handle).Scan(&dummy) - if err != nil { - if err != sql.ErrNoRows { - // A real error - panic(err) - } - return false - } - return true + var dummy string + err := db.QueryRow("select 1 from users where lower(handle) = lower(?)", handle).Scan(&dummy) + if err != nil { + if err != sql.ErrNoRows { + // A real error + panic(err) + } + return false + } + return true } -/** - * Helper function. Create a User from a Row. - */ -func parse_user_from_row(row *sql.Row) (scraper.User, error) { - var u scraper.User - - err := row.Scan(&u.ID, &u.DisplayName, &u.Handle, &u.Bio, &u.FollowingCount, &u.FollowersCount, &u.Location, &u.Website, &u.JoinDate, - &u.IsPrivate, &u.IsVerified, &u.IsBanned, &u.ProfileImageUrl, &u.ProfileImageLocalPath, &u.BannerImageUrl, - &u.BannerImageLocalPath, &u.PinnedTweetID, &u.IsContentDownloaded, &u.IsFollowed) - return u, err -} - - /** * Retrieve a User from the database, by handle. * @@ -111,29 +98,23 @@ func parse_user_from_row(row *sql.Row) (scraper.User, error) { * - the User, if it exists */ func (p Profile) GetUserByHandle(handle scraper.UserHandle) (scraper.User, error) { - db := p.DB + db := p.DB - stmt, err := db.Prepare(` + var ret scraper.User + err := db.Get(&ret, ` select id, display_name, handle, bio, following_count, followers_count, location, website, join_date, is_private, is_verified, is_banned, profile_image_url, profile_image_local_path, banner_image_url, banner_image_local_path, pinned_tweet_id, is_content_downloaded, is_followed from users where lower(handle) = lower(?) - `) - if err != nil { - return scraper.User{}, err - } - defer stmt.Close() + `, handle) - row := stmt.QueryRow(handle) - ret, err := parse_user_from_row(row) - if err == sql.ErrNoRows { - return ret, ErrNotInDatabase{"User", handle} - } - return ret, nil + if err == sql.ErrNoRows { + return ret, ErrNotInDatabase{"User", handle} + } + return ret, nil } - /** * Retrieve a User from the database, by user ID. * @@ -144,26 +125,21 @@ func (p Profile) GetUserByHandle(handle scraper.UserHandle) (scraper.User, error * - the User, if it exists */ func (p Profile) GetUserByID(id scraper.UserID) (scraper.User, error) { - db := p.DB + db := p.DB - stmt, err := db.Prepare(` + var ret scraper.User + + err := db.Get(&ret, ` select id, display_name, handle, bio, following_count, followers_count, location, website, join_date, is_private, is_verified, is_banned, profile_image_url, profile_image_local_path, banner_image_url, banner_image_local_path, pinned_tweet_id, is_content_downloaded, is_followed from users where id = ? - `) - if err != nil { - return scraper.User{}, err - } - defer stmt.Close() - - row := stmt.QueryRow(id) - ret, err := parse_user_from_row(row) - if err == sql.ErrNoRows { - return ret, ErrNotInDatabase{"User", id} - } - return ret, err + `, id) + if err == sql.ErrNoRows { + return ret, ErrNotInDatabase{"User", id} + } + return ret, err } /** diff --git a/persistence/utils.go b/persistence/utils.go index f164a56..44a8199 100644 --- a/persistence/utils.go +++ b/persistence/utils.go @@ -1,8 +1,10 @@ package persistence import ( - "fmt" "errors" + "fmt" + "regexp" + "strings" "os" ) @@ -28,3 +30,12 @@ func file_exists(path string) bool { panic(err) } } + +/** + * https://stackoverflow.com/questions/56616196/how-to-convert-camel-case-string-to-snake-case#56616250 + */ +func ToSnakeCase(str string) string { + snake := regexp.MustCompile("(.)_?([A-Z][a-z]+)").ReplaceAllString(str, "${1}_${2}") + snake = regexp.MustCompile("([a-z0-9])_?([A-Z])").ReplaceAllString(snake, "${1}_${2}") + return strings.ToLower(snake) +} diff --git a/persistence/versions.go b/persistence/versions.go index cb5d6cf..c231607 100644 --- a/persistence/versions.go +++ b/persistence/versions.go @@ -2,7 +2,7 @@ package persistence import ( "fmt" - "database/sql" + sql "github.com/jmoiron/sqlx" "offline_twitter/terminal_utils" ) @@ -77,10 +77,7 @@ var MIGRATIONS = []string{ * Subsequent updates should change the number, not insert a new row. */ func InitializeDatabaseVersion(db *sql.DB) { - _, err := db.Exec("insert into database_version (version_number) values (?)", ENGINE_DATABASE_VERSION) - if err != nil { - panic(err) - } + db.MustExec("insert into database_version (version_number) values (?)", ENGINE_DATABASE_VERSION) } func (p Profile) GetDatabaseVersion() (int, error) { @@ -126,16 +123,11 @@ func (p Profile) UpgradeFromXToY(x int, y int) error { fmt.Println(MIGRATIONS[i]) fmt.Printf(terminal_utils.COLOR_RESET) - _, err := p.DB.Exec(MIGRATIONS[i]) - if err != nil { - return err - } - _, err = p.DB.Exec("update database_version set version_number = ?", i+1) - if err != nil { - return err - } + p.DB.MustExec(MIGRATIONS[i]) + p.DB.MustExec("update database_version set version_number = ?", i+1) + fmt.Printf(terminal_utils.COLOR_YELLOW) - fmt.Printf("Now at database schema version %d.\n", i + 1) + fmt.Printf("Now at database schema version %d.\n", i+1) fmt.Printf(terminal_utils.COLOR_RESET) } fmt.Printf(terminal_utils.COLOR_GREEN) diff --git a/scraper/poll.go b/scraper/poll.go index 0edae23..6acd166 100644 --- a/scraper/poll.go +++ b/scraper/poll.go @@ -25,7 +25,7 @@ type Poll struct { VotingDuration int // In seconds VotingEndsAt Timestamp - LastUpdatedAt Timestamp + LastUpdatedAt Timestamp `db:"last_scraped_at"` } func ParseAPIPoll(apiCard APICard) Poll { diff --git a/scraper/retweet.go b/scraper/retweet.go index 77be9ae..39b3979 100644 --- a/scraper/retweet.go +++ b/scraper/retweet.go @@ -4,7 +4,7 @@ type Retweet struct { RetweetID TweetID TweetID TweetID Tweet *Tweet - RetweetedByID UserID + RetweetedByID UserID `db:"retweeted_by"` RetweetedBy *User RetweetedAt Timestamp } diff --git a/scraper/video.go b/scraper/video.go index a71d5ee..e2179f8 100644 --- a/scraper/video.go +++ b/scraper/video.go @@ -20,7 +20,7 @@ type Video struct { LocalFilename string ThumbnailRemoteUrl string - ThumbnailLocalPath string + ThumbnailLocalPath string `db:"thumbnail_local_filename"` Duration int // milliseconds ViewCount int