package db_test import ( "os" "path/filepath" "testing" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gas_stack/pkg/db" ) func TestCreateAndConnectToDB(t *testing.T) { assert := assert.New(t) _schema_sql, err := os.ReadFile("../../sample_data/test_schemas/food.sql") require.NoError(t, err) schema_sql := string(_schema_sql) migrations := []string{} db.Init(&schema_sql, &migrations) data_dir := "../../sample_data/data" // Create then connect to a new empty DB dbPath := filepath.Join(data_dir, "test.db") _ = os.Remove(dbPath) // Delete it if it exists _, err = db.Create(dbPath) assert.NoError(err) _, err = db.Connect(dbPath) assert.NoError(err) } func TestVersionUpgrade(t *testing.T) { require := require.New(t) // Create a little schema initial_schema := ` create table items(rowid integer primary key); create table db_version (version integer primary key) strict, without rowid; insert into db_version values(0); ` migrations := []string{} db.Init(&initial_schema, &migrations) connection, err := db.Create(":memory:") require.NoError(err) get_version := func(c *sqlx.DB) (ret int) { // TODO: this should be a function exposed by the `db` package itself c.Get(&ret, "select version from db_version") return } var items []int connection.Select(&items, "select * from items") require.Len(items, 0) require.Equal(0, get_version(connection)) // Create a migration to add a new Item migrations = append(migrations, "insert into items (rowid) values (1)") db.Init(&initial_schema, &migrations) // Reinitialize with the new migration db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) var items2 []int connection.Select(&items2, "select * from items") require.Len(items2, 1) require.Equal(1, get_version(connection)) // Create a migration to add a new Item migrations = append(migrations, `alter table items add column name string default 'asdf'`) db.Init(&initial_schema, &migrations) // Reinitialize with the new migration db.UpgradeFromXToY(connection, uint(len(migrations)-1), uint(len(migrations))) var items3 []struct { ID uint64 `db:"rowid"` Name string `db:"name"` } connection.Select(&items3, "select * from items") require.Len(items2, 1) assert.Equal(t, "asdf", items3[0].Name) require.Equal(2, get_version(connection)) }