Create DB management / versioning package
	
		
			
	
		
	
	
		
	
		
			Some checks failed
		
		
	
	
		
			
				
	
				CI / release-test (push) Failing after 7s
				
			
		
		
	
	
				
					
				
			
		
			Some checks failed
		
		
	
	CI / release-test (push) Failing after 7s
				
			This commit is contained in:
		
							parent
							
								
									7d044062f7
								
							
						
					
					
						commit
						aa89a4ff1f
					
				
							
								
								
									
										127
									
								
								pkg/db/connect.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								pkg/db/connect.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,127 @@ | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	_ "embed" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 
 | ||||
| 	"github.com/jmoiron/sqlx" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	// Tracks whether the DB connector has been initialized | ||||
| 	is_initialized bool | ||||
| 
 | ||||
| 	// The SQL schema of the database under management | ||||
| 	sql_schema *string | ||||
| 
 | ||||
| 	// Database starts at version 0.  First migration brings us to version 1 | ||||
| 	migrations     *[]string | ||||
| 	version_number uint | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	ErrTargetExists = errors.New("target already exists") | ||||
| ) | ||||
| 
 | ||||
| // Colors for terminal output | ||||
| const ( | ||||
| 	ColorReset  = "\033[0m" | ||||
| 	ColorBlack  = "\033[30m" | ||||
| 	ColorRed    = "\033[31m" | ||||
| 	ColorGreen  = "\033[32m" | ||||
| 	ColorYellow = "\033[33m" | ||||
| 	ColorBlue   = "\033[34m" | ||||
| 	ColorPurple = "\033[35m" | ||||
| 	ColorCyan   = "\033[36m" | ||||
| 	ColorGray   = "\033[37m" | ||||
| 	ColorWhite  = "\033[97m" | ||||
| ) | ||||
| 
 | ||||
| func Init(schema *string, migrationsList *[]string) { | ||||
| 	sql_schema = schema | ||||
| 	migrations = migrationsList | ||||
| 	version_number = uint(len(*migrations)) | ||||
| 	is_initialized = true | ||||
| } | ||||
| 
 | ||||
| func Create(path string) (*sqlx.DB, error) { | ||||
| 	// First check if the path already exists | ||||
| 	_, err := os.Stat(path) | ||||
| 	if err == nil { | ||||
| 		return nil, ErrTargetExists | ||||
| 	} else if !errors.Is(err, os.ErrNotExist) { | ||||
| 		return nil, fmt.Errorf("path error: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// Create DB file | ||||
| 	fmt.Printf("Creating.............   %s\n", path) | ||||
| 	db := sqlx.MustOpen("sqlite3", path+"?_foreign_keys=on&_journal_mode=WAL") | ||||
| 	db.MustExec(*sql_schema) | ||||
| 
 | ||||
| 	return db, nil | ||||
| } | ||||
| 
 | ||||
| func Connect(path string) (*sqlx.DB, error) { | ||||
| 	db := sqlx.MustOpen("sqlite3", fmt.Sprintf("%s?_foreign_keys=on&_journal_mode=WAL", path)) | ||||
| 	err := CheckAndUpdateVersion(db) | ||||
| 	return db, err | ||||
| } | ||||
| 
 | ||||
| func CheckAndUpdateVersion(db *sqlx.DB) error { | ||||
| 	var version uint | ||||
| 	err := db.Get(&version, "select version from db_version") | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("couldn't check database version: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if version > version_number { | ||||
| 		return VersionMismatchError{version_number, version} | ||||
| 	} | ||||
| 
 | ||||
| 	if version_number > version { | ||||
| 		fmt.Print(ColorYellow) | ||||
| 		fmt.Printf("================================================\n") | ||||
| 		fmt.Printf("Database version is out of date.  Upgrading database from version %d to version %d!\n", version, | ||||
| 			version_number) | ||||
| 		fmt.Print(ColorReset) | ||||
| 		UpgradeFromXToY(db, version, version_number) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Run all the migrations from version X to version Y, and update the `database_version` table's `version_number` | ||||
| func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) { | ||||
| 	for i := x; i < y; i++ { | ||||
| 		fmt.Print(ColorCyan) | ||||
| 		fmt.Println((*migrations)[i]) | ||||
| 		fmt.Print(ColorReset) | ||||
| 
 | ||||
| 		db.MustExec((*migrations)[i]) | ||||
| 		db.MustExec("update db_version set version = ?", i+1) | ||||
| 
 | ||||
| 		fmt.Print(ColorYellow) | ||||
| 		fmt.Printf("Now at database schema version %d.\n", i+1) | ||||
| 		fmt.Print(ColorReset) | ||||
| 	} | ||||
| 	fmt.Print(ColorGreen) | ||||
| 	fmt.Printf("================================================\n") | ||||
| 	fmt.Printf("Database version has been upgraded to version %d.\n", y) | ||||
| 	fmt.Print(ColorReset) | ||||
| } | ||||
| 
 | ||||
| type VersionMismatchError struct { | ||||
| 	EngineVersion   uint | ||||
| 	DatabaseVersion uint | ||||
| } | ||||
| 
 | ||||
| func (e VersionMismatchError) Error() string { | ||||
| 	return fmt.Sprintf( | ||||
| 		`This profile was created with database schema version %d, which is newer than this application's database schema version, %d. | ||||
| Please upgrade this application to a newer version to use this profile.  Or downgrade the profile's schema version, somehow.`, | ||||
| 		e.DatabaseVersion, e.EngineVersion, | ||||
| 	) | ||||
| } | ||||
							
								
								
									
										85
									
								
								pkg/db/connect_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								pkg/db/connect_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,85 @@ | ||||
| package db_test | ||||
| 
 | ||||
| import ( | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jmoiron/sqlx" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 
 | ||||
| 	"gas_stack/pkg/db" | ||||
| ) | ||||
| 
 | ||||
| func TestCreateAndConnectToDB(t *testing.T) { | ||||
| 	assert := assert.New(t) | ||||
| 
 | ||||
| 	_schema_sql, err := os.ReadFile("../../sample_data/test_schemas/food.sql") | ||||
| 	require.NoError(t, err) | ||||
| 	schema_sql := string(_schema_sql) | ||||
| 	migrations := []string{} | ||||
| 
 | ||||
| 	db.Init(&schema_sql, &migrations) | ||||
| 
 | ||||
| 	data_dir := "../../sample_data/data" | ||||
| 	// Create then connect to a new empty DB | ||||
| 	dbPath := filepath.Join(data_dir, "test.db") | ||||
| 	_ = os.Remove(dbPath) // Delete it if it exists | ||||
| 
 | ||||
| 	_, err = db.Create(dbPath) | ||||
| 	assert.NoError(err) | ||||
| 	_, err = db.Connect(dbPath) | ||||
| 	assert.NoError(err) | ||||
| } | ||||
| 
 | ||||
| func TestVersionUpgrade(t *testing.T) { | ||||
| 	require := require.New(t) | ||||
| 
 | ||||
| 	// Create a little schema | ||||
| 	initial_schema := ` | ||||
| 		create table items(rowid integer primary key); | ||||
| 		create table db_version (version integer primary key) strict, without rowid; | ||||
| 		insert into db_version values(0); | ||||
| 	` | ||||
| 	migrations := []string{} | ||||
| 	db.Init(&initial_schema, &migrations) | ||||
| 
 | ||||
| 	connection, err := db.Create(":memory:") | ||||
| 	require.NoError(err) | ||||
| 
 | ||||
| 	get_version := func(c *sqlx.DB) (ret int) { | ||||
| 		// TODO: this should be a function exposed by the `db` package itself | ||||
| 		c.Get(&ret, "select version from db_version") | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	var items []int | ||||
| 	connection.Select(&items, "select * from items") | ||||
| 	require.Len(items, 0) | ||||
| 	require.Equal(0, get_version(connection)) | ||||
| 
 | ||||
| 	// Create a migration to add a new Item | ||||
| 	migrations = append(migrations, "insert into items (rowid) values (1)") | ||||
| 	db.Init(&initial_schema, &migrations) // Reinitialize with the new migration | ||||
| 	db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) | ||||
| 
 | ||||
| 	var items2 []int | ||||
| 	connection.Select(&items2, "select * from items") | ||||
| 	require.Len(items2, 1) | ||||
| 	require.Equal(1, get_version(connection)) | ||||
| 
 | ||||
| 	// Create a migration to add a new Item | ||||
| 	migrations = append(migrations, `alter table items add column name string default 'asdf'`) | ||||
| 	db.Init(&initial_schema, &migrations) // Reinitialize with the new migration | ||||
| 	db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) | ||||
| 
 | ||||
| 	var items3 []struct { | ||||
| 		ID   uint64 `db:"rowid"` | ||||
| 		Name string `db:"name"` | ||||
| 	} | ||||
| 	connection.Select(&items3, "select * from items") | ||||
| 	require.Len(items2, 1) | ||||
| 	assert.Equal(t, "asdf", items3[0].Name) | ||||
| 	require.Equal(2, get_version(connection)) | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user