diff --git a/pkg/db/connect.go b/pkg/db/connect.go new file mode 100644 index 0000000..60f9515 --- /dev/null +++ b/pkg/db/connect.go @@ -0,0 +1,127 @@ +package db + +import ( + _ "embed" + "errors" + "fmt" + "os" + + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" +) + +var ( + // Tracks whether the DB connector has been initialized + is_initialized bool + + // The SQL schema of the database under management + sql_schema *string + + // Database starts at version 0. First migration brings us to version 1 + migrations *[]string + version_number uint +) + +var ( + ErrTargetExists = errors.New("target already exists") +) + +// Colors for terminal output +const ( + ColorReset = "\033[0m" + ColorBlack = "\033[30m" + ColorRed = "\033[31m" + ColorGreen = "\033[32m" + ColorYellow = "\033[33m" + ColorBlue = "\033[34m" + ColorPurple = "\033[35m" + ColorCyan = "\033[36m" + ColorGray = "\033[37m" + ColorWhite = "\033[97m" +) + +func Init(schema *string, migrationsList *[]string) { + sql_schema = schema + migrations = migrationsList + version_number = uint(len(*migrations)) + is_initialized = true +} + +func Create(path string) (*sqlx.DB, error) { + // First check if the path already exists + _, err := os.Stat(path) + if err == nil { + return nil, ErrTargetExists + } else if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("path error: %w", err) + } + + // Create DB file + fmt.Printf("Creating............. %s\n", path) + db := sqlx.MustOpen("sqlite3", path+"?_foreign_keys=on&_journal_mode=WAL") + db.MustExec(*sql_schema) + + return db, nil +} + +func Connect(path string) (*sqlx.DB, error) { + db := sqlx.MustOpen("sqlite3", fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL", path)) + err := CheckAndUpdateVersion(db) + return db, err +} + +func CheckAndUpdateVersion(db *sqlx.DB) error { + var version uint + err := db.Get(&version, "select version from db_version") + if err != nil { + return fmt.Errorf("couldn't check database version: %w", err) + } + + if version > version_number { + return VersionMismatchError{version_number, version} + } + + if version_number > version { + fmt.Print(ColorYellow) + fmt.Printf("================================================\n") + fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version, + version_number) + fmt.Print(ColorReset) + UpgradeFromXToY(db, version, version_number) + } + + return nil +} + +// Run all the migrations from version X to version Y, and update the `database_version` table's `version_number` +func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) { + for i := x; i < y; i++ { + fmt.Print(ColorCyan) + fmt.Println((*migrations)[i]) + fmt.Print(ColorReset) + + db.MustExec((*migrations)[i]) + db.MustExec("update db_version set version = ?", i+1) + + fmt.Print(ColorYellow) + fmt.Printf("Now at database schema version %d.\n", i+1) + fmt.Print(ColorReset) + } + fmt.Print(ColorGreen) + fmt.Printf("================================================\n") + fmt.Printf("Database version has been upgraded to version %d.\n", y) + fmt.Print(ColorReset) +} + +type VersionMismatchError struct { + EngineVersion uint + DatabaseVersion uint +} + +func (e VersionMismatchError) Error() string { + return fmt.Sprintf( + `This profile was created with database schema version %d, which is newer than this application's database schema version, %d. +Please upgrade this application to a newer version to use this profile. Or downgrade the profile's schema version, somehow.`, + e.DatabaseVersion, e.EngineVersion, + ) +} diff --git a/pkg/db/connect_test.go b/pkg/db/connect_test.go new file mode 100644 index 0000000..11c5df2 --- /dev/null +++ b/pkg/db/connect_test.go @@ -0,0 +1,85 @@ +package db_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gas_stack/pkg/db" +) + +func TestCreateAndConnectToDB(t *testing.T) { + assert := assert.New(t) + + _schema_sql, err := os.ReadFile("../../sample_data/test_schemas/food.sql") + require.NoError(t, err) + schema_sql := string(_schema_sql) + migrations := []string{} + + db.Init(&schema_sql, &migrations) + + data_dir := "../../sample_data/data" + // Create then connect to a new empty DB + dbPath := filepath.Join(data_dir, "test.db") + _ = os.Remove(dbPath) // Delete it if it exists + + _, err = db.Create(dbPath) + assert.NoError(err) + _, err = db.Connect(dbPath) + assert.NoError(err) +} + +func TestVersionUpgrade(t *testing.T) { + require := require.New(t) + + // Create a little schema + initial_schema := ` + create table items(rowid integer primary key); + create table db_version (version integer primary key) strict, without rowid; + insert into db_version values(0); + ` + migrations := []string{} + db.Init(&initial_schema, &migrations) + + connection, err := db.Create(":memory:") + require.NoError(err) + + get_version := func(c *sqlx.DB) (ret int) { + // TODO: this should be a function exposed by the `db` package itself + c.Get(&ret, "select version from db_version") + return + } + + var items []int + connection.Select(&items, "select * from items") + require.Len(items, 0) + require.Equal(0, get_version(connection)) + + // Create a migration to add a new Item + migrations = append(migrations, "insert into items (rowid) values (1)") + db.Init(&initial_schema, &migrations) // Reinitialize with the new migration + db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) + + var items2 []int + connection.Select(&items2, "select * from items") + require.Len(items2, 1) + require.Equal(1, get_version(connection)) + + // Create a migration to add a new Item + migrations = append(migrations, `alter table items add column name string default 'asdf'`) + db.Init(&initial_schema, &migrations) // Reinitialize with the new migration + db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) + + var items3 []struct { + ID uint64 `db:"rowid"` + Name string `db:"name"` + } + connection.Select(&items3, "select * from items") + require.Len(items2, 1) + assert.Equal(t, "asdf", items3[0].Name) + require.Equal(2, get_version(connection)) +}