db: remove implicit singleton pattern, move functions under a struct object
This commit is contained in:
parent
bfb7073cf6
commit
a81dfb69a5
@ -10,7 +10,7 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
var (
|
||||
type DBConfig struct {
|
||||
// // Tracks whether the DB connector has been initialized
|
||||
// is_initialized bool
|
||||
|
||||
@ -20,7 +20,7 @@ var (
|
||||
// Database starts at version 0. First migration brings us to version 1
|
||||
migrations *[]string
|
||||
version_number uint
|
||||
)
|
||||
}
|
||||
|
||||
// Colors for terminal output
|
||||
const (
|
||||
@ -36,14 +36,15 @@ const (
|
||||
ColorWhite = "\033[97m"
|
||||
)
|
||||
|
||||
func Init(schema *string, migrationsList *[]string) {
|
||||
sql_schema = schema
|
||||
migrations = migrationsList
|
||||
version_number = uint(len(*migrations))
|
||||
// is_initialized = true
|
||||
func Init(schema *string, migrationsList *[]string) DBConfig {
|
||||
return DBConfig{
|
||||
sql_schema: schema,
|
||||
migrations: migrationsList,
|
||||
version_number: uint(len(*migrationsList)),
|
||||
}
|
||||
}
|
||||
|
||||
func Create(path string) (*sqlx.DB, error) {
|
||||
func (c DBConfig) Create(path string) (*sqlx.DB, error) {
|
||||
// First check if the path already exists
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
@ -63,14 +64,14 @@ func Create(path string) (*sqlx.DB, error) {
|
||||
if _, err = db.Exec("pragma foreign_keys=on; pragma journal_mode=WAL;"); err != nil {
|
||||
return nil, fmt.Errorf("running pragma statements: %w", err)
|
||||
}
|
||||
if _, err = db.Exec(*sql_schema); err != nil {
|
||||
if _, err = db.Exec(*c.sql_schema); err != nil {
|
||||
return nil, fmt.Errorf("creating schema: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func Connect(path string) (*sqlx.DB, error) {
|
||||
func (c DBConfig) Connect(path string) (*sqlx.DB, error) {
|
||||
db, err := sqlx.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening db: %w", err)
|
||||
@ -79,38 +80,38 @@ func Connect(path string) (*sqlx.DB, error) {
|
||||
return nil, fmt.Errorf("running pragma statements: %w", err)
|
||||
}
|
||||
|
||||
err = CheckAndUpdateVersion(db)
|
||||
err = c.CheckAndUpdateVersion(db)
|
||||
return db, err
|
||||
}
|
||||
|
||||
func CheckAndUpdateVersion(db *sqlx.DB) error {
|
||||
func (c DBConfig) CheckAndUpdateVersion(db *sqlx.DB) error {
|
||||
var version uint
|
||||
err := db.Get(&version, "select version from db_version")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't check database version: %w", err)
|
||||
}
|
||||
|
||||
if version > version_number {
|
||||
return VersionMismatchError{version_number, version}
|
||||
if version > c.version_number {
|
||||
return VersionMismatchError{c.version_number, version}
|
||||
}
|
||||
|
||||
if version_number > version {
|
||||
if c.version_number > version {
|
||||
fmt.Print(ColorYellow)
|
||||
fmt.Printf("================================================\n")
|
||||
fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version,
|
||||
version_number)
|
||||
c.version_number)
|
||||
fmt.Print(ColorReset)
|
||||
UpgradeFromXToY(db, version, version_number)
|
||||
c.UpgradeFromXToY(db, version, c.version_number)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpgradeFromXToY runs all the migrations from version X to version Y, and update the `database_version` table's `version_number`
|
||||
func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
||||
func (c DBConfig) UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
||||
for i := x; i < y; i++ {
|
||||
fmt.Print(ColorCyan)
|
||||
fmt.Println((*migrations)[i])
|
||||
fmt.Println((*c.migrations)[i])
|
||||
fmt.Print(ColorReset)
|
||||
|
||||
// Execute the migration in a transaction
|
||||
@ -118,7 +119,7 @@ func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tx.MustExec((*migrations)[i])
|
||||
tx.MustExec((*c.migrations)[i])
|
||||
tx.MustExec("update db_version set version = ?", i+1)
|
||||
if err := tx.Commit(); err != nil {
|
||||
panic(err)
|
||||
|
||||
@ -20,7 +20,7 @@ func TestCreateAndConnectToDB(t *testing.T) {
|
||||
schema_sql := string(_schema_sql)
|
||||
migrations := []string{}
|
||||
|
||||
db.Init(&schema_sql, &migrations)
|
||||
config := db.Init(&schema_sql, &migrations)
|
||||
|
||||
data_dir := "../../sample_data/data"
|
||||
_ = os.MkdirAll(data_dir, os.FileMode(0o644))
|
||||
@ -28,9 +28,9 @@ func TestCreateAndConnectToDB(t *testing.T) {
|
||||
dbPath := filepath.Join(data_dir, "test.db")
|
||||
_ = os.Remove(dbPath) // Delete it if it exists
|
||||
|
||||
_, err = db.Create(dbPath)
|
||||
_, err = config.Create(dbPath)
|
||||
assert.NoError(err)
|
||||
_, err = db.Connect(dbPath)
|
||||
_, err = config.Connect(dbPath)
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
@ -44,9 +44,9 @@ func TestVersionUpgrade(t *testing.T) {
|
||||
insert into db_version values(0);
|
||||
`
|
||||
migrations := []string{}
|
||||
db.Init(&initial_schema, &migrations)
|
||||
config := db.Init(&initial_schema, &migrations)
|
||||
|
||||
connection, err := db.Create(":memory:")
|
||||
connection, err := config.Create(":memory:")
|
||||
require.NoError(err)
|
||||
|
||||
get_version := func(c *sqlx.DB) (ret int) {
|
||||
@ -63,7 +63,7 @@ func TestVersionUpgrade(t *testing.T) {
|
||||
// Create a migration to add a new Item
|
||||
migrations = append(migrations, "insert into items (rowid) values (1)")
|
||||
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
||||
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
||||
config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
||||
|
||||
var items2 []int
|
||||
require.NoError(connection.Select(&items2, "select * from items"))
|
||||
@ -72,8 +72,8 @@ func TestVersionUpgrade(t *testing.T) {
|
||||
|
||||
// Create a migration to add a new Item
|
||||
migrations = append(migrations, `alter table items add column name string default 'asdf'`)
|
||||
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
||||
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
||||
config = db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
||||
config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
||||
|
||||
var items3 []struct {
|
||||
ID uint64 `db:"rowid"`
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user