Compare commits
2 Commits
3e164bcdfa
...
378b86b7f1
Author | SHA1 | Date | |
---|---|---|---|
378b86b7f1 | |||
e7ee10deb1 |
@ -100,8 +100,16 @@ func UpgradeFromXToY(db *sqlx.DB, x uint, y uint) {
|
|||||||
fmt.Println((*migrations)[i])
|
fmt.Println((*migrations)[i])
|
||||||
fmt.Print(ColorReset)
|
fmt.Print(ColorReset)
|
||||||
|
|
||||||
db.MustExec((*migrations)[i])
|
// Execute the migration in a transaction
|
||||||
db.MustExec("update db_version set version = ?", i+1)
|
tx, err := db.Beginx()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
tx.MustExec((*migrations)[i])
|
||||||
|
tx.MustExec("update db_version set version = ?", i+1)
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Print(ColorYellow)
|
fmt.Print(ColorYellow)
|
||||||
fmt.Printf("Now at database schema version %d.\n", i+1)
|
fmt.Printf("Now at database schema version %d.\n", i+1)
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
package sqlgenerate
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
_ "embed"
|
"gas_stack/pkg/textutils"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jinzhu/inflection"
|
"github.com/jinzhu/inflection"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// InitDB creates an in-memory DB from a given schema string.
|
||||||
func InitDB(sql_schema string) *sqlx.DB {
|
func InitDB(sql_schema string) *sqlx.DB {
|
||||||
db := sqlx.MustOpen("sqlite3", ":memory:")
|
db := sqlx.MustOpen("sqlite3", ":memory:")
|
||||||
db.MustExec(sql_schema)
|
db.MustExec(sql_schema)
|
||||||
@ -39,24 +41,20 @@ func InitDB(sql_schema string) *sqlx.DB {
|
|||||||
|
|
||||||
// SchemaFromDB takes a DB connection, checks its schema metadata tables, and returns a Schema.
|
// SchemaFromDB takes a DB connection, checks its schema metadata tables, and returns a Schema.
|
||||||
func SchemaFromDB(db *sqlx.DB) Schema {
|
func SchemaFromDB(db *sqlx.DB) Schema {
|
||||||
return ParseSchema(db)
|
|
||||||
}
|
|
||||||
func ParseSchema(db *sqlx.DB) Schema {
|
|
||||||
ret := Schema{}
|
ret := Schema{}
|
||||||
|
|
||||||
var table_list []string
|
var tables []Table
|
||||||
err := db.Select(&table_list, `select name from tables`)
|
err := db.Select(&tables, `select name, strict from tables`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table_name := range table_list {
|
for _, tbl := range tables {
|
||||||
tbl := Table{TableName: table_name}
|
tbl.TypeName = textutils.SnakeToCamel(inflection.Singular(tbl.TableName))
|
||||||
tbl.TypeName = snakeToCamel(inflection.Singular(table_name))
|
|
||||||
tbl.TypeIDName = tbl.TypeName + "ID"
|
tbl.TypeIDName = tbl.TypeName + "ID"
|
||||||
tbl.VarName = strings.ToLower(string(table_name[0]))
|
tbl.VarName = strings.ToLower(string(tbl.TableName[0]))
|
||||||
|
|
||||||
err := db.Select(&tbl.Columns, `select * from columns where table_name = ?`, table_name)
|
err := db.Select(&tbl.Columns, `select * from columns where table_name = ?`, tbl.TableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -64,11 +62,3 @@ func ParseSchema(db *sqlx.DB) Schema {
|
|||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func snakeToCamel(s string) string {
|
|
||||||
parts := strings.Split(s, "_")
|
|
||||||
for i := 0; i < len(parts); i++ {
|
|
||||||
parts[i] = strings.ToUpper(string(parts[i][0])) + parts[i][1:]
|
|
||||||
}
|
|
||||||
return strings.Join(parts, "")
|
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
package sqlgenerate_test
|
package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"gas_stack/pkg/sqlgenerate"
|
"gas_stack/pkg/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseSchema(t *testing.T) {
|
func TestParseSchema(t *testing.T) {
|
||||||
@ -15,8 +15,8 @@ func TestParseSchema(t *testing.T) {
|
|||||||
schema_sql, err := os.ReadFile("../../sample_data/test_schemas/food.sql")
|
schema_sql, err := os.ReadFile("../../sample_data/test_schemas/food.sql")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
db := sqlgenerate.InitDB(string(schema_sql))
|
db := schema.InitDB(string(schema_sql))
|
||||||
schema := sqlgenerate.ParseSchema(db)
|
schema := schema.SchemaFromDB(db)
|
||||||
expected_tbls := []string{"food_types", "foods", "units", "ingredients", "recipes", "iterations", "db_version"}
|
expected_tbls := []string{"food_types", "foods", "units", "ingredients", "recipes", "iterations", "db_version"}
|
||||||
for _, tbl_name := range expected_tbls {
|
for _, tbl_name := range expected_tbls {
|
||||||
_, is_ok := schema[tbl_name]
|
_, is_ok := schema[tbl_name]
|
||||||
@ -24,6 +24,10 @@ func TestParseSchema(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
foods := schema["foods"]
|
foods := schema["foods"]
|
||||||
|
assert.Equal(foods.TableName, "foods")
|
||||||
|
assert.Equal(foods.TypeName, "Food")
|
||||||
|
assert.Equal(foods.TypeIDName, "FoodID")
|
||||||
|
assert.Equal(foods.IsStrict, true)
|
||||||
assert.Len(foods.Columns, 20)
|
assert.Len(foods.Columns, 20)
|
||||||
assert.Equal(foods.Columns[0].Name, "rowid")
|
assert.Equal(foods.Columns[0].Name, "rowid")
|
||||||
assert.Equal(foods.Columns[0].Type, "integer")
|
assert.Equal(foods.Columns[0].Type, "integer")
|
@ -1,10 +1,4 @@
|
|||||||
package sqlgenerate
|
package schema
|
||||||
|
|
||||||
import (
|
|
||||||
_ "embed"
|
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Column represents a single column in a table.
|
// Column represents a single column in a table.
|
||||||
type Column struct {
|
type Column struct {
|
11
pkg/textutils/snake.go
Normal file
11
pkg/textutils/snake.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package textutils
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func SnakeToCamel(s string) string {
|
||||||
|
parts := strings.Split(s, "_")
|
||||||
|
for i := range len(parts) {
|
||||||
|
parts[i] = strings.ToUpper(string(parts[i][0])) + parts[i][1:]
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "")
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user