Create DB management / versioning package
Some checks failed
CI / release-test (push) Failing after 7s
Some checks failed
CI / release-test (push) Failing after 7s
This commit is contained in:
parent
7d044062f7
commit
aa89a4ff1f
127
pkg/db/connect.go
Normal file
127
pkg/db/connect.go
Normal file
@ -0,0 +1,127 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
var (
|
||||
// Tracks whether the DB connector has been initialized
|
||||
is_initialized bool
|
||||
|
||||
// The SQL schema of the database under management
|
||||
sql_schema *string
|
||||
|
||||
// Database starts at version 0. First migration brings us to version 1
|
||||
migrations *[]string
|
||||
version_number uint
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTargetExists = errors.New("target already exists")
|
||||
)
|
||||
|
||||
// Colors for terminal output
|
||||
const (
|
||||
ColorReset = "\033[0m"
|
||||
ColorBlack = "\033[30m"
|
||||
ColorRed = "\033[31m"
|
||||
ColorGreen = "\033[32m"
|
||||
ColorYellow = "\033[33m"
|
||||
ColorBlue = "\033[34m"
|
||||
ColorPurple = "\033[35m"
|
||||
ColorCyan = "\033[36m"
|
||||
ColorGray = "\033[37m"
|
||||
ColorWhite = "\033[97m"
|
||||
)
|
||||
|
||||
func Init(schema *string, migrationsList *[]string) {
|
||||
sql_schema = schema
|
||||
migrations = migrationsList
|
||||
version_number = uint(len(*migrations))
|
||||
is_initialized = true
|
||||
}
|
||||
|
||||
func Create(path string) (*sqlx.DB, error) {
|
||||
// First check if the path already exists
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return nil, ErrTargetExists
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("path error: %w", err)
|
||||
}
|
||||
|
||||
// Create DB file
|
||||
fmt.Printf("Creating............. %s\n", path)
|
||||
db := sqlx.MustOpen("sqlite3", path+"?_foreign_keys=on&_journal_mode=WAL")
|
||||
db.MustExec(*sql_schema)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func Connect(path string) (*sqlx.DB, error) {
|
||||
db := sqlx.MustOpen("sqlite3", fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL", path))
|
||||
err := CheckAndUpdateVersion(db)
|
||||
return db, err
|
||||
}
|
||||
|
||||
func CheckAndUpdateVersion(db *sqlx.DB) error {
|
||||
var version uint
|
||||
err := db.Get(&version, "select version from db_version")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't check database version: %w", err)
|
||||
}
|
||||
|
||||
if version > version_number {
|
||||
return VersionMismatchError{version_number, version}
|
||||
}
|
||||
|
||||
if version_number > version {
|
||||
fmt.Print(ColorYellow)
|
||||
fmt.Printf("================================================\n")
|
||||
fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version,
|
||||
version_number)
|
||||
fmt.Print(ColorReset)
|
||||
UpgradeFromXToY(db, version, version_number)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run all the migrations from version X to version Y, and update the `database_version` table's `version_number`
|
||||
func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
||||
for i := x; i < y; i++ {
|
||||
fmt.Print(ColorCyan)
|
||||
fmt.Println((*migrations)[i])
|
||||
fmt.Print(ColorReset)
|
||||
|
||||
db.MustExec((*migrations)[i])
|
||||
db.MustExec("update db_version set version = ?", i+1)
|
||||
|
||||
fmt.Print(ColorYellow)
|
||||
fmt.Printf("Now at database schema version %d.\n", i+1)
|
||||
fmt.Print(ColorReset)
|
||||
}
|
||||
fmt.Print(ColorGreen)
|
||||
fmt.Printf("================================================\n")
|
||||
fmt.Printf("Database version has been upgraded to version %d.\n", y)
|
||||
fmt.Print(ColorReset)
|
||||
}
|
||||
|
||||
type VersionMismatchError struct {
|
||||
EngineVersion uint
|
||||
DatabaseVersion uint
|
||||
}
|
||||
|
||||
func (e VersionMismatchError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
`This profile was created with database schema version %d, which is newer than this application's database schema version, %d.
|
||||
Please upgrade this application to a newer version to use this profile. Or downgrade the profile's schema version, somehow.`,
|
||||
e.DatabaseVersion, e.EngineVersion,
|
||||
)
|
||||
}
|
85
pkg/db/connect_test.go
Normal file
85
pkg/db/connect_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"gas_stack/pkg/db"
|
||||
)
|
||||
|
||||
func TestCreateAndConnectToDB(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
_schema_sql, err := os.ReadFile("../../sample_data/test_schemas/food.sql")
|
||||
require.NoError(t, err)
|
||||
schema_sql := string(_schema_sql)
|
||||
migrations := []string{}
|
||||
|
||||
db.Init(&schema_sql, &migrations)
|
||||
|
||||
data_dir := "../../sample_data/data"
|
||||
// Create then connect to a new empty DB
|
||||
dbPath := filepath.Join(data_dir, "test.db")
|
||||
_ = os.Remove(dbPath) // Delete it if it exists
|
||||
|
||||
_, err = db.Create(dbPath)
|
||||
assert.NoError(err)
|
||||
_, err = db.Connect(dbPath)
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
func TestVersionUpgrade(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
// Create a little schema
|
||||
initial_schema := `
|
||||
create table items(rowid integer primary key);
|
||||
create table db_version (version integer primary key) strict, without rowid;
|
||||
insert into db_version values(0);
|
||||
`
|
||||
migrations := []string{}
|
||||
db.Init(&initial_schema, &migrations)
|
||||
|
||||
connection, err := db.Create(":memory:")
|
||||
require.NoError(err)
|
||||
|
||||
get_version := func(c *sqlx.DB) (ret int) {
|
||||
// TODO: this should be a function exposed by the `db` package itself
|
||||
c.Get(&ret, "select version from db_version")
|
||||
return
|
||||
}
|
||||
|
||||
var items []int
|
||||
connection.Select(&items, "select * from items")
|
||||
require.Len(items, 0)
|
||||
require.Equal(0, get_version(connection))
|
||||
|
||||
// Create a migration to add a new Item
|
||||
migrations = append(migrations, "insert into items (rowid) values (1)")
|
||||
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
||||
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
||||
|
||||
var items2 []int
|
||||
connection.Select(&items2, "select * from items")
|
||||
require.Len(items2, 1)
|
||||
require.Equal(1, get_version(connection))
|
||||
|
||||
// Create a migration to add a new Item
|
||||
migrations = append(migrations, `alter table items add column name string default 'asdf'`)
|
||||
db.Init(&initial_schema, &migrations) // Reinitialize with the new migration
|
||||
db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations)))
|
||||
|
||||
var items3 []struct {
|
||||
ID uint64 `db:"rowid"`
|
||||
Name string `db:"name"`
|
||||
}
|
||||
connection.Select(&items3, "select * from items")
|
||||
require.Len(items2, 1)
|
||||
assert.Equal(t, "asdf", items3[0].Name)
|
||||
require.Equal(2, get_version(connection))
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user