db: remove implicit singleton pattern, move functions under a struct object
This commit is contained in:
parent
bfb7073cf6
commit
a81dfb69a5
@ -10,7 +10,7 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type DBConfig struct {
|
||||||
// // Tracks whether the DB connector has been initialized
|
// // Tracks whether the DB connector has been initialized
|
||||||
// is_initialized bool
|
// is_initialized bool
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ var (
|
|||||||
// Database starts at version 0. First migration brings us to version 1
|
// Database starts at version 0. First migration brings us to version 1
|
||||||
migrations *[]string
|
migrations *[]string
|
||||||
version_number uint
|
version_number uint
|
||||||
)
|
}
|
||||||
|
|
||||||
// Colors for terminal output
|
// Colors for terminal output
|
||||||
const (
|
const (
|
||||||
@ -36,14 +36,15 @@ const (
|
|||||||
ColorWhite = "\033[97m"
|
ColorWhite = "\033[97m"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init(schema *string, migrationsList *[]string) {
|
func Init(schema *string, migrationsList *[]string) DBConfig {
|
||||||
sql_schema = schema
|
return DBConfig{
|
||||||
migrations = migrationsList
|
sql_schema: schema,
|
||||||
version_number = uint(len(*migrations))
|
migrations: migrationsList,
|
||||||
// is_initialized = true
|
version_number: uint(len(*migrationsList)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(path string) (*sqlx.DB, error) {
|
func (c DBConfig) Create(path string) (*sqlx.DB, error) {
|
||||||
// First check if the path already exists
|
// First check if the path already exists
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -63,14 +64,14 @@ func Create(path string) (*sqlx.DB, error) {
|
|||||||
if _, err = db.Exec("pragma foreign_keys=on; pragma journal_mode=WAL;"); err != nil {
|
if _, err = db.Exec("pragma foreign_keys=on; pragma journal_mode=WAL;"); err != nil {
|
||||||
return nil, fmt.Errorf("running pragma statements: %w", err)
|
return nil, fmt.Errorf("running pragma statements: %w", err)
|
||||||
}
|
}
|
||||||
if _, err = db.Exec(*sql_schema); err != nil {
|
if _, err = db.Exec(*c.sql_schema); err != nil {
|
||||||
return nil, fmt.Errorf("creating schema: %w", err)
|
return nil, fmt.Errorf("creating schema: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Connect(path string) (*sqlx.DB, error) {
|
func (c DBConfig) Connect(path string) (*sqlx.DB, error) {
|
||||||
db, err := sqlx.Open("sqlite3", path)
|
db, err := sqlx.Open("sqlite3", path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("opening db: %w", err)
|
return nil, fmt.Errorf("opening db: %w", err)
|
||||||
@ -79,38 +80,38 @@ func Connect(path string) (*sqlx.DB, error) {
|
|||||||
return nil, fmt.Errorf("running pragma statements: %w", err)
|
return nil, fmt.Errorf("running pragma statements: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = CheckAndUpdateVersion(db)
|
err = c.CheckAndUpdateVersion(db)
|
||||||
return db, err
|
return db, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckAndUpdateVersion(db *sqlx.DB) error {
|
func (c DBConfig) CheckAndUpdateVersion(db *sqlx.DB) error {
|
||||||
var version uint
|
var version uint
|
||||||
err := db.Get(&version, "select version from db_version")
|
err := db.Get(&version, "select version from db_version")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("couldn't check database version: %w", err)
|
return fmt.Errorf("couldn't check database version: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if version > version_number {
|
if version > c.version_number {
|
||||||
return VersionMismatchError{version_number, version}
|
return VersionMismatchError{c.version_number, version}
|
||||||
}
|
}
|
||||||
|
|
||||||
if version_number > version {
|
if c.version_number > version {
|
||||||
fmt.Print(ColorYellow)
|
fmt.Print(ColorYellow)
|
||||||
fmt.Printf("================================================\n")
|
fmt.Printf("================================================\n")
|
||||||
fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version,
|
fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version,
|
||||||
version_number)
|
c.version_number)
|
||||||
fmt.Print(ColorReset)
|
fmt.Print(ColorReset)
|
||||||
UpgradeFromXToY(db, version, version_number)
|
c.UpgradeFromXToY(db, version, c.version_number)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpgradeFromXToY runs all the migrations from version X to version Y, and update the `database_version` table's `version_number`
|
// UpgradeFromXToY runs all the migrations from version X to version Y, and update the `database_version` table's `version_number`
|
||||||
func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
func (c DBConfig) UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
||||||
for i := x; i < y; i++ {
|
for i := x; i < y; i++ {
|
||||||
fmt.Print(ColorCyan)
|
fmt.Print(ColorCyan)
|
||||||
fmt.Println((*migrations)[i])
|
fmt.Println((*c.migrations)[i])
|
||||||
fmt.Print(ColorReset)
|
fmt.Print(ColorReset)
|
||||||
|
|
||||||
// Execute the migration in a transaction
|
// Execute the migration in a transaction
|
||||||
@ -118,7 +119,7 @@ func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
tx.MustExec((*migrations)[i])
|
tx.MustExec((*c.migrations)[i])
|
||||||
tx.MustExec("update db_version set version = ?", i+1)
|
tx.MustExec("update db_version set version = ?", i+1)
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|||||||
@ -20,7 +20,7 @@ func TestCreateAndConnectToDB(t *testing.T) {
|
|||||||
schema_sql := string(_schema_sql)
|
schema_sql := string(_schema_sql)
|
||||||
migrations := []string{}
|
migrations := []string{}
|
||||||
|
|
||||||
db.Init(&schema_sql, &migrations)
|
config := db.Init(&schema_sql, &migrations)
|
||||||
|
|
||||||
data_dir := "../../sample_data/data"
|
data_dir := "../../sample_data/data"
|
||||||
_ = os.MkdirAll(data_dir, os.FileMode(0o644))
|
_ = os.MkdirAll(data_dir, os.FileMode(0o644))
|
||||||
@ -28,9 +28,9 @@ func TestCreateAndConnectToDB(t *testing.T) {
|
|||||||
dbPath := filepath.Join(data_dir, "test.db")
|
dbPath := filepath.Join(data_dir, "test.db")
|
||||||
_ = os.Remove(dbPath) // Delete it if it exists
|
_ = os.Remove(dbPath) // Delete it if it exists
|
||||||
|
|
||||||
_, err = db.Create(dbPath)
|
_, err = config.Create(dbPath)
|
||||||
assert.NoError(err)
|
assert.NoError(err)
|
||||||
_, err = db.Connect(dbPath)
|
_, err = config.Connect(dbPath)
|
||||||
assert.NoError(err)
|
assert.NoError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,9 +44,9 @@ func TestVersionUpgrade(t *testing.T) {
|
|||||||
insert into db_version values(0);
|
insert into db_version values(0);
|
||||||
`
|
`
|
||||||
migrations := []string{}
|
migrations := []string{}
|
||||||
db.Init(&initial_schema, &migrations)
|
config := db.Init(&initial_schema, &migrations)
|
||||||
|
|
||||||
connection, err := db.Create(":memory:")
|
connection, err := config.Create(":memory:")
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
|
|
||||||
get_version := func(c *sqlx.DB) (ret int) {
|
get_version := func(c *sqlx.DB) (ret int) {
|
||||||
@ -63,7 +63,7 @@ func TestVersionUpgrade(t *testing.T) {
|
|||||||
// Create a migration to add a new Item
|
// Create a migration to add a new Item
|
||||||
migrations = append(migrations, "insert into items (rowid) values (1)")
|
migrations = append(migrations, "insert into items (rowid) values (1)")
|
||||||
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
||||||
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
||||||
|
|
||||||
var items2 []int
|
var items2 []int
|
||||||
require.NoError(connection.Select(&items2, "select * from items"))
|
require.NoError(connection.Select(&items2, "select * from items"))
|
||||||
@ -72,8 +72,8 @@ func TestVersionUpgrade(t *testing.T) {
|
|||||||
|
|
||||||
// Create a migration to add a new Item
|
// Create a migration to add a new Item
|
||||||
migrations = append(migrations, `alter table items add column name string default 'asdf'`)
|
migrations = append(migrations, `alter table items add column name string default 'asdf'`)
|
||||||
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
config = db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
||||||
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
config.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
||||||
|
|
||||||
var items3 []struct {
|
var items3 []struct {
|
||||||
ID uint64 `db:"rowid"`
|
ID uint64 `db:"rowid"`
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user