Create DB management / versioning package
	
		
			
	
		
	
	
		
	
		
			Some checks failed
		
		
	
	
		
			
				
	
				CI / release-test (push) Failing after 7s
				
			
		
		
	
	
				
					
				
			
		
			Some checks failed
		
		
	
	CI / release-test (push) Failing after 7s
				
			This commit is contained in:
		
							parent
							
								
									7d044062f7
								
							
						
					
					
						commit
						aa89a4ff1f
					
				
							
								
								
									
										127
									
								
								pkg/db/connect.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								pkg/db/connect.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,127 @@ | |||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	_ "embed" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jmoiron/sqlx" | ||||||
|  | 	_ "github.com/mattn/go-sqlite3" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	// Tracks whether the DB connector has been initialized | ||||||
|  | 	is_initialized bool | ||||||
|  | 
 | ||||||
|  | 	// The SQL schema of the database under management | ||||||
|  | 	sql_schema *string | ||||||
|  | 
 | ||||||
|  | 	// Database starts at version 0.  First migration brings us to version 1 | ||||||
|  | 	migrations     *[]string | ||||||
|  | 	version_number uint | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	ErrTargetExists = errors.New("target already exists") | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Colors for terminal output | ||||||
|  | const ( | ||||||
|  | 	ColorReset  = "\033[0m" | ||||||
|  | 	ColorBlack  = "\033[30m" | ||||||
|  | 	ColorRed    = "\033[31m" | ||||||
|  | 	ColorGreen  = "\033[32m" | ||||||
|  | 	ColorYellow = "\033[33m" | ||||||
|  | 	ColorBlue   = "\033[34m" | ||||||
|  | 	ColorPurple = "\033[35m" | ||||||
|  | 	ColorCyan   = "\033[36m" | ||||||
|  | 	ColorGray   = "\033[37m" | ||||||
|  | 	ColorWhite  = "\033[97m" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func Init(schema *string, migrationsList *[]string) { | ||||||
|  | 	sql_schema = schema | ||||||
|  | 	migrations = migrationsList | ||||||
|  | 	version_number = uint(len(*migrations)) | ||||||
|  | 	is_initialized = true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Create(path string) (*sqlx.DB, error) { | ||||||
|  | 	// First check if the path already exists | ||||||
|  | 	_, err := os.Stat(path) | ||||||
|  | 	if err == nil { | ||||||
|  | 		return nil, ErrTargetExists | ||||||
|  | 	} else if !errors.Is(err, os.ErrNotExist) { | ||||||
|  | 		return nil, fmt.Errorf("path error: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create DB file | ||||||
|  | 	fmt.Printf("Creating.............   %s\n", path) | ||||||
|  | 	db := sqlx.MustOpen("sqlite3", path+"?_foreign_keys=on&_journal_mode=WAL") | ||||||
|  | 	db.MustExec(*sql_schema) | ||||||
|  | 
 | ||||||
|  | 	return db, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Connect(path string) (*sqlx.DB, error) { | ||||||
|  | 	db := sqlx.MustOpen("sqlite3", fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL", path)) | ||||||
|  | 	err := CheckAndUpdateVersion(db) | ||||||
|  | 	return db, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func CheckAndUpdateVersion(db *sqlx.DB) error { | ||||||
|  | 	var version uint | ||||||
|  | 	err := db.Get(&version, "select version from db_version") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("couldn't check database version: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if version > version_number { | ||||||
|  | 		return VersionMismatchError{version_number, version} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if version_number > version { | ||||||
|  | 		fmt.Print(ColorYellow) | ||||||
|  | 		fmt.Printf("================================================\n") | ||||||
|  | 		fmt.Printf("Database version is out of date.  Upgrading database from version %d to version %d!\n", version, | ||||||
|  | 			version_number) | ||||||
|  | 		fmt.Print(ColorReset) | ||||||
|  | 		UpgradeFromXToY(db, version, version_number) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Run all the migrations from version X to version Y, and update the `database_version` table's `version_number` | ||||||
|  | func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) { | ||||||
|  | 	for i := x; i < y; i++ { | ||||||
|  | 		fmt.Print(ColorCyan) | ||||||
|  | 		fmt.Println((*migrations)[i]) | ||||||
|  | 		fmt.Print(ColorReset) | ||||||
|  | 
 | ||||||
|  | 		db.MustExec((*migrations)[i]) | ||||||
|  | 		db.MustExec("update db_version set version = ?", i+1) | ||||||
|  | 
 | ||||||
|  | 		fmt.Print(ColorYellow) | ||||||
|  | 		fmt.Printf("Now at database schema version %d.\n", i+1) | ||||||
|  | 		fmt.Print(ColorReset) | ||||||
|  | 	} | ||||||
|  | 	fmt.Print(ColorGreen) | ||||||
|  | 	fmt.Printf("================================================\n") | ||||||
|  | 	fmt.Printf("Database version has been upgraded to version %d.\n", y) | ||||||
|  | 	fmt.Print(ColorReset) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type VersionMismatchError struct { | ||||||
|  | 	EngineVersion   uint | ||||||
|  | 	DatabaseVersion uint | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (e VersionMismatchError) Error() string { | ||||||
|  | 	return fmt.Sprintf( | ||||||
|  | 		`This profile was created with database schema version %d, which is newer than this application's database schema version, %d. | ||||||
|  | Please upgrade this application to a newer version to use this profile.  Or downgrade the profile's schema version, somehow.`, | ||||||
|  | 		e.DatabaseVersion, e.EngineVersion, | ||||||
|  | 	) | ||||||
|  | } | ||||||
							
								
								
									
										85
									
								
								pkg/db/connect_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								pkg/db/connect_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,85 @@ | |||||||
|  | package db_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jmoiron/sqlx" | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
|  | 
 | ||||||
|  | 	"gas_stack/pkg/db" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestCreateAndConnectToDB(t *testing.T) { | ||||||
|  | 	assert := assert.New(t) | ||||||
|  | 
 | ||||||
|  | 	_schema_sql, err := os.ReadFile("../../sample_data/test_schemas/food.sql") | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	schema_sql := string(_schema_sql) | ||||||
|  | 	migrations := []string{} | ||||||
|  | 
 | ||||||
|  | 	db.Init(&schema_sql, &migrations) | ||||||
|  | 
 | ||||||
|  | 	data_dir := "../../sample_data/data" | ||||||
|  | 	// Create then connect to a new empty DB | ||||||
|  | 	dbPath := filepath.Join(data_dir, "test.db") | ||||||
|  | 	_ = os.Remove(dbPath) // Delete it if it exists | ||||||
|  | 
 | ||||||
|  | 	_, err = db.Create(dbPath) | ||||||
|  | 	assert.NoError(err) | ||||||
|  | 	_, err = db.Connect(dbPath) | ||||||
|  | 	assert.NoError(err) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestVersionUpgrade(t *testing.T) { | ||||||
|  | 	require := require.New(t) | ||||||
|  | 
 | ||||||
|  | 	// Create a little schema | ||||||
|  | 	initial_schema := ` | ||||||
|  | 		create table items(rowid integer primary key); | ||||||
|  | 		create table db_version (version integer primary key) strict, without rowid; | ||||||
|  | 		insert into db_version values(0); | ||||||
|  | 	` | ||||||
|  | 	migrations := []string{} | ||||||
|  | 	db.Init(&initial_schema, &migrations) | ||||||
|  | 
 | ||||||
|  | 	connection, err := db.Create(":memory:") | ||||||
|  | 	require.NoError(err) | ||||||
|  | 
 | ||||||
|  | 	get_version := func(c *sqlx.DB) (ret int) { | ||||||
|  | 		// TODO: this should be a function exposed by the `db` package itself | ||||||
|  | 		c.Get(&ret, "select version from db_version") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var items []int | ||||||
|  | 	connection.Select(&items, "select * from items") | ||||||
|  | 	require.Len(items, 0) | ||||||
|  | 	require.Equal(0, get_version(connection)) | ||||||
|  | 
 | ||||||
|  | 	// Create a migration to add a new Item | ||||||
|  | 	migrations = append(migrations, "insert into items (rowid) values (1)") | ||||||
|  | 	db.Init(&initial_schema, &migrations) // Reinitialize with the new migration | ||||||
|  | 	db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) | ||||||
|  | 
 | ||||||
|  | 	var items2 []int | ||||||
|  | 	connection.Select(&items2, "select * from items") | ||||||
|  | 	require.Len(items2, 1) | ||||||
|  | 	require.Equal(1, get_version(connection)) | ||||||
|  | 
 | ||||||
|  | 	// Create a migration to add a new Item | ||||||
|  | 	migrations = append(migrations, `alter table items add column name string default 'asdf'`) | ||||||
|  | 	db.Init(&initial_schema, &migrations) // Reinitialize with the new migration | ||||||
|  | 	db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) | ||||||
|  | 
 | ||||||
|  | 	var items3 []struct { | ||||||
|  | 		ID   uint64 `db:"rowid"` | ||||||
|  | 		Name string `db:"name"` | ||||||
|  | 	} | ||||||
|  | 	connection.Select(&items3, "select * from items") | ||||||
|  | 	require.Len(items2, 1) | ||||||
|  | 	assert.Equal(t, "asdf", items3[0].Name) | ||||||
|  | 	require.Equal(2, get_version(connection)) | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user