diff --git a/go.mod b/go.mod index 390a3ed..f707df1 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module git.offline-twitter.com/offline-labs/gas-stack go 1.22.5 require ( + github.com/go-test/deep v1.1.1 github.com/jinzhu/inflection v1.0.0 github.com/jmoiron/sqlx v1.4.0 github.com/mattn/go-sqlite3 v1.14.24 diff --git a/go.sum b/go.sum index b34c4f2..b50a71e 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= diff --git a/pkg/schema/migration_verification_test.go b/pkg/schema/migration_verification_test.go new file mode 100644 index 0000000..2c302c7 --- /dev/null +++ b/pkg/schema/migration_verification_test.go @@ -0,0 +1,160 @@ +package schema_test + +import ( + "slices" + "testing" + + "git.offline-twitter.com/offline-labs/gas-stack/pkg/db" + "git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils" + "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" + "github.com/go-test/deep" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + baseSchema = ` + create table db_version ( + version integer primary key + ) strict, without rowid; + insert into db_version values(0); + + create table t1 ( + rowid integer primary key, + data1 integer not null, + data2 text not null + ); + ` + + fullSchema = ` + create table db_version ( + version integer primary key + ) strict, without rowid; + insert into db_version values(2); + + create table t1 ( + rowid integer primary key, + data1 integer not null, + data2 text not null, + data3 integer + ); + create table t2 ( + rowid integer primary key + ); + ` +) + +func TestVerifyCorrectMigration(t *testing.T) { + db1 := schema.InitDB(fullSchema) + db1Schema := schema.SchemaFromDB(db1) + + t.Run("migrate in 1 step", func(t *testing.T) { + migration := ` + create table t2 ( + rowid integer primary key + ); + alter table t1 add column data3 integer; + ` + db2Config := db.Init(&baseSchema, &[]string{migration}) + db2 := flowutils.Must(db2Config.Create(":memory:")) + require.NoError(t, db2Config.CheckAndUpdateVersion(db2)) + db2Schema := schema.SchemaFromDB(db2) + + if diff := deep.Equal(db1Schema, db2Schema); diff != nil { + t.Error(diff) + } + }) + + t.Run("migrate in 2 steps", func(t *testing.T) { + migration1 := ` + create table t2 ( + rowid integer primary key + ); + ` + migration2 := ` + alter table t1 add column data3 integer; + ` + + db2Config := db.Init(&baseSchema, &[]string{migration1, migration2}) + db2 := flowutils.Must(db2Config.Create(":memory:")) + require.NoError(t, db2Config.CheckAndUpdateVersion(db2)) + db2Schema := schema.SchemaFromDB(db2) + + if diff := deep.Equal(db1Schema, db2Schema); diff != nil { + t.Error(diff) + } + }) +} + +func TestIncorrectMigrations(t *testing.T) { + db1 := schema.InitDB(fullSchema) + db1Schema := schema.SchemaFromDB(db1) + + t.Run("missing migration", func(t *testing.T) { + db2Config := db.Init(&baseSchema, &[]string{}) + db2 := flowutils.Must(db2Config.Create(":memory:")) + require.NoError(t, db2Config.CheckAndUpdateVersion(db2)) + db2Schema := schema.SchemaFromDB(db2) + + // Missing a table + assert.Len(t, db1Schema.Tables, len(db2Schema.Tables)+1) + assert.Contains(t, db1Schema.Tables, "t2") + assert.NotContains(t, db2Schema.Tables, "t2") + + // Missing the new column + assert.Len(t, db1Schema.Tables["t1"].Columns, len(db2Schema.Tables["t1"].Columns)+1) + assert.True(t, slices.ContainsFunc(db1Schema.Tables["t1"].Columns, func(c schema.Column) bool { return c.Name == "data3" }), "t2") + assert.False(t, slices.ContainsFunc(db2Schema.Tables["t1"].Columns, func(c schema.Column) bool { return c.Name == "data3" }), "t2") + }) + + t.Run("incomplete migration", func(t *testing.T) { + db2Config := db.Init(&baseSchema, &[]string{` + create table t2 ( + rowid integer primary key + ); + `}) + db2 := flowutils.Must(db2Config.Create(":memory:")) + require.NoError(t, db2Config.CheckAndUpdateVersion(db2)) + db2Schema := schema.SchemaFromDB(db2) + + // Has the new table + assert.Len(t, db1Schema.Tables, len(db2Schema.Tables)) + assert.Contains(t, db1Schema.Tables, "t2") + assert.Contains(t, db2Schema.Tables, "t2") + + // Still missing the new column + assert.Len(t, db1Schema.Tables["t1"].Columns, len(db2Schema.Tables["t1"].Columns)+1) + assert.True(t, slices.ContainsFunc(db1Schema.Tables["t1"].Columns, func(c schema.Column) bool { return c.Name == "data3" }), "t2") + assert.False(t, slices.ContainsFunc(db2Schema.Tables["t1"].Columns, func(c schema.Column) bool { return c.Name == "data3" }), "t2") + }) + + t.Run("incorrect migration (wrong data type)", func(t *testing.T) { + db2Config := db.Init(&baseSchema, &[]string{` + create table t2 ( + rowid integer primary key + ); + alter table t1 add column data3 text; + `}) + db2 := flowutils.Must(db2Config.Create(":memory:")) + require.NoError(t, db2Config.CheckAndUpdateVersion(db2)) + db2Schema := schema.SchemaFromDB(db2) + + // Has the new table + assert.Len(t, db1Schema.Tables, len(db2Schema.Tables)) + assert.Contains(t, db1Schema.Tables, "t2") + assert.Contains(t, db2Schema.Tables, "t2") + + // Has the right column, but it's the wrong type + assert.Len(t, db1Schema.Tables["t1"].Columns, len(db2Schema.Tables["t1"].Columns)) + col1 := db1Schema.Tables["t1"].Columns[slices.IndexFunc(db1Schema.Tables["t1"].Columns, func(c schema.Column) bool { return c.Name == "data3" })] + col2 := db2Schema.Tables["t1"].Columns[slices.IndexFunc(db2Schema.Tables["t1"].Columns, func(c schema.Column) bool { return c.Name == "data3" })] + + assert.NotEqual(t, col1, col2) + assert.Equal(t, col1.Type, "integer") // Full schema has an integer column + assert.Equal(t, col2.Type, "text") // Migration incorrectly uses a text column + + // Other than that they are equal + col2.Type = "integer" + assert.Equal(t, col1, col2) + }) +}