128 lines
3.2 KiB
Go
128 lines
3.2 KiB
Go
package db
|
|
|
|
import (
|
|
_ "embed"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
var (
|
|
// Tracks whether the DB connector has been initialized
|
|
is_initialized bool
|
|
|
|
// The SQL schema of the database under management
|
|
sql_schema *string
|
|
|
|
// Database starts at version 0. First migration brings us to version 1
|
|
migrations *[]string
|
|
version_number uint
|
|
)
|
|
|
|
var (
|
|
ErrTargetExists = errors.New("target already exists")
|
|
)
|
|
|
|
// Colors for terminal output
|
|
const (
|
|
ColorReset = "\033[0m"
|
|
ColorBlack = "\033[30m"
|
|
ColorRed = "\033[31m"
|
|
ColorGreen = "\033[32m"
|
|
ColorYellow = "\033[33m"
|
|
ColorBlue = "\033[34m"
|
|
ColorPurple = "\033[35m"
|
|
ColorCyan = "\033[36m"
|
|
ColorGray = "\033[37m"
|
|
ColorWhite = "\033[97m"
|
|
)
|
|
|
|
func Init(schema *string, migrationsList *[]string) {
|
|
sql_schema = schema
|
|
migrations = migrationsList
|
|
version_number = uint(len(*migrations))
|
|
is_initialized = true
|
|
}
|
|
|
|
func Create(path string) (*sqlx.DB, error) {
|
|
// First check if the path already exists
|
|
_, err := os.Stat(path)
|
|
if err == nil {
|
|
return nil, ErrTargetExists
|
|
} else if !errors.Is(err, os.ErrNotExist) {
|
|
return nil, fmt.Errorf("path error: %w", err)
|
|
}
|
|
|
|
// Create DB file
|
|
fmt.Printf("Creating............. %s\n", path)
|
|
db := sqlx.MustOpen("sqlite3", path+"?_foreign_keys=on&_journal_mode=WAL")
|
|
db.MustExec(*sql_schema)
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func Connect(path string) (*sqlx.DB, error) {
|
|
db := sqlx.MustOpen("sqlite3", fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL", path))
|
|
err := CheckAndUpdateVersion(db)
|
|
return db, err
|
|
}
|
|
|
|
func CheckAndUpdateVersion(db *sqlx.DB) error {
|
|
var version uint
|
|
err := db.Get(&version, "select version from db_version")
|
|
if err != nil {
|
|
return fmt.Errorf("couldn't check database version: %w", err)
|
|
}
|
|
|
|
if version > version_number {
|
|
return VersionMismatchError{version_number, version}
|
|
}
|
|
|
|
if version_number > version {
|
|
fmt.Print(ColorYellow)
|
|
fmt.Printf("================================================\n")
|
|
fmt.Printf("Database version is out of date. Upgrading database from version %d to version %d!\n", version,
|
|
version_number)
|
|
fmt.Print(ColorReset)
|
|
UpgradeFromXToY(db, version, version_number)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Run all the migrations from version X to version Y, and update the `database_version` table's `version_number`
|
|
func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
|
for i := x; i < y; i++ {
|
|
fmt.Print(ColorCyan)
|
|
fmt.Println((*migrations)[i])
|
|
fmt.Print(ColorReset)
|
|
|
|
db.MustExec((*migrations)[i])
|
|
db.MustExec("update db_version set version = ?", i+1)
|
|
|
|
fmt.Print(ColorYellow)
|
|
fmt.Printf("Now at database schema version %d.\n", i+1)
|
|
fmt.Print(ColorReset)
|
|
}
|
|
fmt.Print(ColorGreen)
|
|
fmt.Printf("================================================\n")
|
|
fmt.Printf("Database version has been upgraded to version %d.\n", y)
|
|
fmt.Print(ColorReset)
|
|
}
|
|
|
|
type VersionMismatchError struct {
|
|
EngineVersion uint
|
|
DatabaseVersion uint
|
|
}
|
|
|
|
func (e VersionMismatchError) Error() string {
|
|
return fmt.Sprintf(
|
|
`This profile was created with database schema version %d, which is newer than this application's database schema version, %d.
|
|
Please upgrade this application to a newer version to use this profile. Or downgrade the profile's schema version, somehow.`,
|
|
e.DatabaseVersion, e.EngineVersion,
|
|
)
|
|
}
|