db: remove implicit singleton pattern, move functions under a struct object
All checks were successful
CI / build-docker (push) Successful in 7s
CI / build-docker-bootstrap (push) Has been skipped
CI / release-test (push) Successful in 25s

This commit is contained in:
wispem-wantex 2025-11-09 00:01:15 -08:00
parent bfb7073cf6
commit a81dfb69a5
2 changed files with 29 additions and 28 deletions

View File

@ -10,7 +10,7 @@ import (
_ "github.com/mattn/go-sqlite3"
)
var (
type DBConfig struct {
// // Tracks whether the DB connector has been initialized
// is_initialized bool
@ -20,7 +20,7 @@ var (
// Database starts at version 0. First migration brings us to version 1
migrations *[]string
version_number uint
)
}
// Colors for terminal output
const (
@ -36,14 +36,15 @@ const (
ColorWhite = "\033[97m"
)
func Init(schema *string, migrationsList *[]string) {
sql_schema = schema
migrations = migrationsList
version_number = uint(len(*migrations))
// is_initialized = true
func Init(schema *string, migrationsList *[]string) DBConfig {
return DBConfig{
sql_schema: schema,
migrations: migrationsList,
version_number: uint(len(*migrationsList)),
}
}
func Create(path string) (*sqlx.DB, error) {
func (c DBConfig) Create(path string) (*sqlx.DB, error) {
// First check if the path already exists
_, err := os.Stat(path)
if err == nil {
@ -63,14 +64,14 @@ func Create(path string) (*sqlx.DB, error) {
if _, err = db.Exec("pragma foreign_keys=on; pragma journal_mode=WAL;"); err != nil {
return nil, fmt.Errorf("running pragma statements: %w", err)
}
if _, err = db.Exec(*sql_schema); err != nil {
if _, err = db.Exec(*c.sql_schema); err != nil {
return nil, fmt.Errorf("creating schema: %w", err)
}
return db, nil
}
func Connect(path string) (*sqlx.DB, error) {
func (c DBConfig) Connect(path string) (*sqlx.DB, error) {
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return nil, fmt.Errorf("opening db: %w", err)
@ -79,38 +80,38 @@ func Connect(path string) (*sqlx.DB, error) {
return nil, fmt.Errorf("running pragma statements: %w", err)
}
err = CheckAndUpdateVersion(db)
err = c.CheckAndUpdateVersion(db)
return db, err
}
func CheckAndUpdateVersion(db *sqlx.DB) error {
func (c DBConfig) CheckAndUpdateVersion(db *sqlx.DB) error {
var version uint
err := db.Get(&version, "select version from db_version")
if err != nil {
return fmt.Errorf("couldn't check database version: %w", err)
}
if version > version_number {
return VersionMismatchError{version_number, version}
if version > c.version_number {
return VersionMismatchError{c.version_number, version}
}
if version_number > version {
if c.version_number > version {
fmt.Print(ColorYellow)
fmt.Printf("================================================\n")
fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version,
version_number)
c.version_number)
fmt.Print(ColorReset)
UpgradeFromXToY(db, version, version_number)
c.UpgradeFromXToY(db, version, c.version_number)
}
return nil
}
// UpgradeFromXToY runs all the migrations from version X to version Y, and update the `database_version` table's `version_number`
func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
func (c DBConfig) UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
for i := x; i < y; i++ {
fmt.Print(ColorCyan)
fmt.Println((*migrations)[i])
fmt.Println((*c.migrations)[i])
fmt.Print(ColorReset)
// Execute the migration in a transaction
@ -118,7 +119,7 @@ func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
if err != nil {
panic(err)
}
tx.MustExec((*migrations)[i])
tx.MustExec((*c.migrations)[i])
tx.MustExec("update db_version set version = ?", i+1)
if err := tx.Commit(); err != nil {
panic(err)

View File

@ -20,7 +20,7 @@ func TestCreateAndConnectToDB(t *testing.T) {
schema_sql := string(_schema_sql)
migrations := []string{}
db.Init(&schema_sql, &migrations)
config := db.Init(&schema_sql, &migrations)
data_dir := "../../sample_data/data"
_ = os.MkdirAll(data_dir, os.FileMode(0o644))
@ -28,9 +28,9 @@ func TestCreateAndConnectToDB(t *testing.T) {
dbPath := filepath.Join(data_dir, "test.db")
_ = os.Remove(dbPath) // Delete it if it exists
_, err = db.Create(dbPath)
_, err = config.Create(dbPath)
assert.NoError(err)
_, err = db.Connect(dbPath)
_, err = config.Connect(dbPath)
assert.NoError(err)
}
@ -44,9 +44,9 @@ func TestVersionUpgrade(t *testing.T) {
insert into db_version values(0);
`
migrations := []string{}
db.Init(&initial_schema, &migrations)
config := db.Init(&initial_schema, &migrations)
connection, err := db.Create(":memory:")
connection, err := config.Create(":memory:")
require.NoError(err)
get_version := func(c *sqlx.DB) (ret int) {
@ -63,7 +63,7 @@ func TestVersionUpgrade(t *testing.T) {
// Create a migration to add a new Item
migrations = append(migrations, "insert into items (rowid) values (1)")
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
var items2 []int
require.NoError(connection.Select(&items2, "select * from items"))
@ -72,8 +72,8 @@ func TestVersionUpgrade(t *testing.T) {
// Create a migration to add a new Item
migrations = append(migrations, `alter table items add column name string default 'asdf'`)
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
config = db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
var items3 []struct {
ID uint64 `db:"rowid"`