From a81dfb69a5b51abb7503ef84faeeaf8eb524ac90 Mon Sep 17 00:00:00 2001 From: wispem-wantex Date: Sun, 9 Nov 2025 00:01:15 -0800 Subject: [PATCH] db: remove implicit singleton pattern, move functions under a struct object --- pkg/db/connect.go | 41 +++++++++++++++++++++-------------------- pkg/db/connect_test.go | 16 ++++++++-------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/pkg/db/connect.go b/pkg/db/connect.go index 6990cf6..3d7d5de 100644 --- a/pkg/db/connect.go +++ b/pkg/db/connect.go @@ -10,7 +10,7 @@ import ( _ "github.com/mattn/go-sqlite3" ) -var ( +type DBConfig struct { // // Tracks whether the DB connector has been initialized // is_initialized bool @@ -20,7 +20,7 @@ var ( // Database starts at version 0. First migration brings us to version 1 migrations *[]string version_number uint -) +} // Colors for terminal output const ( @@ -36,14 +36,15 @@ const ( ColorWhite = "\033[97m" ) -func Init(schema *string, migrationsList *[]string) { - sql_schema = schema - migrations = migrationsList - version_number = uint(len(*migrations)) - // is_initialized = true +func Init(schema *string, migrationsList *[]string) DBConfig { + return DBConfig{ + sql_schema: schema, + migrations: migrationsList, + version_number: uint(len(*migrationsList)), + } } -func Create(path string) (*sqlx.DB, error) { +func (c DBConfig) Create(path string) (*sqlx.DB, error) { // First check if the path already exists _, err := os.Stat(path) if err == nil { @@ -63,14 +64,14 @@ func Create(path string) (*sqlx.DB, error) { if _, err = db.Exec("pragma foreign_keys=on; pragma journal_mode=WAL;"); err != nil { return nil, fmt.Errorf("running pragma statements: %w", err) } - if _, err = db.Exec(*sql_schema); err != nil { + if _, err = db.Exec(*c.sql_schema); err != nil { return nil, fmt.Errorf("creating schema: %w", err) } return db, nil } -func Connect(path string) (*sqlx.DB, error) { +func (c DBConfig) Connect(path string) (*sqlx.DB, error) { db, err := sqlx.Open("sqlite3", path) if err != nil { return nil, fmt.Errorf("opening db: %w", err) @@ -79,38 +80,38 @@ func Connect(path string) (*sqlx.DB, error) { return nil, fmt.Errorf("running pragma statements: %w", err) } - err = CheckAndUpdateVersion(db) + err = c.CheckAndUpdateVersion(db) return db, err } -func CheckAndUpdateVersion(db *sqlx.DB) error { +func (c DBConfig) CheckAndUpdateVersion(db *sqlx.DB) error { var version uint err := db.Get(&version, "select version from db_version") if err != nil { return fmt.Errorf("couldn't check database version: %w", err) } - if version > version_number { - return VersionMismatchError{version_number, version} + if version > c.version_number { + return VersionMismatchError{c.version_number, version} } - if version_number > version { + if c.version_number > version { fmt.Print(ColorYellow) fmt.Printf("================================================\n") fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version, - version_number) + c.version_number) fmt.Print(ColorReset) - UpgradeFromXToY(db, version, version_number) + c.UpgradeFromXToY(db, version, c.version_number) } return nil } // UpgradeFromXToY runs all the migrations from version X to version Y, and update the `database_version` table's `version_number` -func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) { +func (c DBConfig) UpgradeFromXToY(db *sqlx.DB, x uint, y uint) { for i := x; i < y; i++ { fmt.Print(ColorCyan) - fmt.Println((*migrations)[i]) + fmt.Println((*c.migrations)[i]) fmt.Print(ColorReset) // Execute the migration in a transaction @@ -118,7 +119,7 @@ func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) { if err != nil { panic(err) } - tx.MustExec((*migrations)[i]) + tx.MustExec((*c.migrations)[i]) tx.MustExec("update db_version set version = ?", i+1) if err := tx.Commit(); err != nil { panic(err) diff --git a/pkg/db/connect_test.go b/pkg/db/connect_test.go index ea9f0d7..40ebf59 100644 --- a/pkg/db/connect_test.go +++ b/pkg/db/connect_test.go @@ -20,7 +20,7 @@ func TestCreateAndConnectToDB(t *testing.T) { schema_sql := string(_schema_sql) migrations := []string{} - db.Init(&schema_sql, &migrations) + config := db.Init(&schema_sql, &migrations) data_dir := "../../sample_data/data" _ = os.MkdirAll(data_dir, os.FileMode(0o644)) @@ -28,9 +28,9 @@ func TestCreateAndConnectToDB(t *testing.T) { dbPath := filepath.Join(data_dir, "test.db") _ = os.Remove(dbPath) // Delete it if it exists - _, err = db.Create(dbPath) + _, err = config.Create(dbPath) assert.NoError(err) - _, err = db.Connect(dbPath) + _, err = config.Connect(dbPath) assert.NoError(err) } @@ -44,9 +44,9 @@ func TestVersionUpgrade(t *testing.T) { insert into db_version values(0); ` migrations := []string{} - db.Init(&initial_schema, &migrations) + config := db.Init(&initial_schema, &migrations) - connection, err := db.Create(":memory:") + connection, err := config.Create(":memory:") require.NoError(err) get_version := func(c *sqlx.DB) (ret int) { @@ -63,7 +63,7 @@ func TestVersionUpgrade(t *testing.T) { // Create a migration to add a new Item migrations = append(migrations, "insert into items (rowid) values (1)") db.Init(&initial_schema, &migrations) // Reinitialize with the new migration - db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) + config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) var items2 []int require.NoError(connection.Select(&items2, "select * from items")) @@ -72,8 +72,8 @@ func TestVersionUpgrade(t *testing.T) { // Create a migration to add a new Item migrations = append(migrations, `alter table items add column name string default 'asdf'`) - db.Init(&initial_schema, &migrations) // Reinitialize with the new migration - db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) + config = db.Init(&initial_schema, &migrations) // Reinitialize with the new migration + config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) var items3 []struct { ID uint64 `db:"rowid"`