db: remove implicit singleton pattern, move functions under a struct object
All checks were successful
CI / build-docker (push) Successful in 7s
CI / build-docker-bootstrap (push) Has been skipped
CI / release-test (push) Successful in 25s

This commit is contained in:
wispem-wantex 2025-11-09 00:01:15 -08:00
parent bfb7073cf6
commit a81dfb69a5
2 changed files with 29 additions and 28 deletions

View File

@ -10,7 +10,7 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
var ( type DBConfig struct {
// // Tracks whether the DB connector has been initialized // // Tracks whether the DB connector has been initialized
// is_initialized bool // is_initialized bool
@ -20,7 +20,7 @@ var (
// Database starts at version 0. First migration brings us to version 1 // Database starts at version 0. First migration brings us to version 1
migrations *[]string migrations *[]string
version_number uint version_number uint
) }
// Colors for terminal output // Colors for terminal output
const ( const (
@ -36,14 +36,15 @@ const (
ColorWhite = "\033[97m" ColorWhite = "\033[97m"
) )
func Init(schema *string, migrationsList *[]string) { func Init(schema *string, migrationsList *[]string) DBConfig {
sql_schema = schema return DBConfig{
migrations = migrationsList sql_schema: schema,
version_number = uint(len(*migrations)) migrations: migrationsList,
// is_initialized = true version_number: uint(len(*migrationsList)),
}
} }
func Create(path string) (*sqlx.DB, error) { func (c DBConfig) Create(path string) (*sqlx.DB, error) {
// First check if the path already exists // First check if the path already exists
_, err := os.Stat(path) _, err := os.Stat(path)
if err == nil { if err == nil {
@ -63,14 +64,14 @@ func Create(path string) (*sqlx.DB, error) {
if _, err = db.Exec("pragma foreign_keys=on; pragma journal_mode=WAL;"); err != nil { if _, err = db.Exec("pragma foreign_keys=on; pragma journal_mode=WAL;"); err != nil {
return nil, fmt.Errorf("running pragma statements: %w", err) return nil, fmt.Errorf("running pragma statements: %w", err)
} }
if _, err = db.Exec(*sql_schema); err != nil { if _, err = db.Exec(*c.sql_schema); err != nil {
return nil, fmt.Errorf("creating schema: %w", err) return nil, fmt.Errorf("creating schema: %w", err)
} }
return db, nil return db, nil
} }
func Connect(path string) (*sqlx.DB, error) { func (c DBConfig) Connect(path string) (*sqlx.DB, error) {
db, err := sqlx.Open("sqlite3", path) db, err := sqlx.Open("sqlite3", path)
if err != nil { if err != nil {
return nil, fmt.Errorf("opening db: %w", err) return nil, fmt.Errorf("opening db: %w", err)
@ -79,38 +80,38 @@ func Connect(path string) (*sqlx.DB, error) {
return nil, fmt.Errorf("running pragma statements: %w", err) return nil, fmt.Errorf("running pragma statements: %w", err)
} }
err = CheckAndUpdateVersion(db) err = c.CheckAndUpdateVersion(db)
return db, err return db, err
} }
func CheckAndUpdateVersion(db *sqlx.DB) error { func (c DBConfig) CheckAndUpdateVersion(db *sqlx.DB) error {
var version uint var version uint
err := db.Get(&version, "select version from db_version") err := db.Get(&version, "select version from db_version")
if err != nil { if err != nil {
return fmt.Errorf("couldn't check database version: %w", err) return fmt.Errorf("couldn't check database version: %w", err)
} }
if version > version_number { if version > c.version_number {
return VersionMismatchError{version_number, version} return VersionMismatchError{c.version_number, version}
} }
if version_number > version { if c.version_number > version {
fmt.Print(ColorYellow) fmt.Print(ColorYellow)
fmt.Printf("================================================\n") fmt.Printf("================================================\n")
fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version, fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version,
version_number) c.version_number)
fmt.Print(ColorReset) fmt.Print(ColorReset)
UpgradeFromXToY(db, version, version_number) c.UpgradeFromXToY(db, version, c.version_number)
} }
return nil return nil
} }
// UpgradeFromXToY runs all the migrations from version X to version Y, and update the `database_version` table's `version_number` // UpgradeFromXToY runs all the migrations from version X to version Y, and update the `database_version` table's `version_number`
func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) { func (c DBConfig) UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
for i := x; i < y; i++ { for i := x; i < y; i++ {
fmt.Print(ColorCyan) fmt.Print(ColorCyan)
fmt.Println((*migrations)[i]) fmt.Println((*c.migrations)[i])
fmt.Print(ColorReset) fmt.Print(ColorReset)
// Execute the migration in a transaction // Execute the migration in a transaction
@ -118,7 +119,7 @@ func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
tx.MustExec((*migrations)[i]) tx.MustExec((*c.migrations)[i])
tx.MustExec("update db_version set version = ?", i+1) tx.MustExec("update db_version set version = ?", i+1)
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
panic(err) panic(err)

View File

@ -20,7 +20,7 @@ func TestCreateAndConnectToDB(t *testing.T) {
schema_sql := string(_schema_sql) schema_sql := string(_schema_sql)
migrations := []string{} migrations := []string{}
db.Init(&schema_sql, &migrations) config := db.Init(&schema_sql, &migrations)
data_dir := "../../sample_data/data" data_dir := "../../sample_data/data"
_ = os.MkdirAll(data_dir, os.FileMode(0o644)) _ = os.MkdirAll(data_dir, os.FileMode(0o644))
@ -28,9 +28,9 @@ func TestCreateAndConnectToDB(t *testing.T) {
dbPath := filepath.Join(data_dir, "test.db") dbPath := filepath.Join(data_dir, "test.db")
_ = os.Remove(dbPath) // Delete it if it exists _ = os.Remove(dbPath) // Delete it if it exists
_, err = db.Create(dbPath) _, err = config.Create(dbPath)
assert.NoError(err) assert.NoError(err)
_, err = db.Connect(dbPath) _, err = config.Connect(dbPath)
assert.NoError(err) assert.NoError(err)
} }
@ -44,9 +44,9 @@ func TestVersionUpgrade(t *testing.T) {
insert into db_version values(0); insert into db_version values(0);
` `
migrations := []string{} migrations := []string{}
db.Init(&initial_schema, &migrations) config := db.Init(&initial_schema, &migrations)
connection, err := db.Create(":memory:") connection, err := config.Create(":memory:")
require.NoError(err) require.NoError(err)
get_version := func(c *sqlx.DB) (ret int) { get_version := func(c *sqlx.DB) (ret int) {
@ -63,7 +63,7 @@ func TestVersionUpgrade(t *testing.T) {
// Create a migration to add a new Item // Create a migration to add a new Item
migrations = append(migrations, "insert into items (rowid) values (1)") migrations = append(migrations, "insert into items (rowid) values (1)")
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
var items2 []int var items2 []int
require.NoError(connection.Select(&items2, "select * from items")) require.NoError(connection.Select(&items2, "select * from items"))
@ -72,8 +72,8 @@ func TestVersionUpgrade(t *testing.T) {
// Create a migration to add a new Item // Create a migration to add a new Item
migrations = append(migrations, `alter table items add column name string default 'asdf'`) migrations = append(migrations, `alter table items add column name string default 'asdf'`)
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration config = db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
var items3 []struct { var items3 []struct {
ID uint64 `db:"rowid"` ID uint64 `db:"rowid"`