diff --git a/pkg/sqlgenerate/schema_parse.go b/pkg/schema/parse.go similarity index 71% rename from pkg/sqlgenerate/schema_parse.go rename to pkg/schema/parse.go index e9fac37..f03c075 100644 --- a/pkg/sqlgenerate/schema_parse.go +++ b/pkg/schema/parse.go @@ -1,14 +1,16 @@ -package sqlgenerate +package schema import ( - _ "embed" + "gas_stack/pkg/textutils" "strings" "github.com/jinzhu/inflection" "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" ) +// InitDB creates an in-memory DB from a given schema string. func InitDB(sql_schema string) *sqlx.DB { db := sqlx.MustOpen("sqlite3", ":memory:") db.MustExec(sql_schema) @@ -39,24 +41,20 @@ func InitDB(sql_schema string) *sqlx.DB { // SchemaFromDB takes a DB connection, checks its schema metadata tables, and returns a Schema. func SchemaFromDB(db *sqlx.DB) Schema { - return ParseSchema(db) -} -func ParseSchema(db *sqlx.DB) Schema { ret := Schema{} - var table_list []string - err := db.Select(&table_list, `select name from tables`) + var tables []Table + err := db.Select(&tables, `select name, strict from tables`) if err != nil { panic(err) } - for _, table_name := range table_list { - tbl := Table{TableName: table_name} - tbl.TypeName = snakeToCamel(inflection.Singular(table_name)) + for _, tbl := range tables { + tbl.TypeName = textutils.SnakeToCamel(inflection.Singular(tbl.TableName)) tbl.TypeIDName = tbl.TypeName + "ID" - tbl.VarName = strings.ToLower(string(table_name[0])) + tbl.VarName = strings.ToLower(string(tbl.TableName[0])) - err := db.Select(&tbl.Columns, `select * from columns where table_name = ?`, table_name) + err := db.Select(&tbl.Columns, `select * from columns where table_name = ?`, tbl.TableName) if err != nil { panic(err) } @@ -64,11 +62,3 @@ func ParseSchema(db *sqlx.DB) Schema { } return ret } - -func snakeToCamel(s string) string { - parts := strings.Split(s, "_") - for i := 0; i < len(parts); i++ { - parts[i] = strings.ToUpper(string(parts[i][0])) + parts[i][1:] - } - return strings.Join(parts, "") -} diff --git a/pkg/sqlgenerate/schema_parse_test.go b/pkg/schema/parse_test.go similarity index 85% rename from pkg/sqlgenerate/schema_parse_test.go rename to pkg/schema/parse_test.go index 1fda5a8..e4c444f 100644 --- a/pkg/sqlgenerate/schema_parse_test.go +++ b/pkg/schema/parse_test.go @@ -1,4 +1,4 @@ -package sqlgenerate_test +package schema_test import ( "os" @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gas_stack/pkg/sqlgenerate" + "gas_stack/pkg/schema" ) func TestParseSchema(t *testing.T) { @@ -15,8 +15,8 @@ func TestParseSchema(t *testing.T) { schema_sql, err := os.ReadFile("../../sample_data/test_schemas/food.sql") require.NoError(t, err) - db := sqlgenerate.InitDB(string(schema_sql)) - schema := sqlgenerate.ParseSchema(db) + db := schema.InitDB(string(schema_sql)) + schema := schema.SchemaFromDB(db) expected_tbls := []string{"food_types", "foods", "units", "ingredients", "recipes", "iterations", "db_version"} for _, tbl_name := range expected_tbls { _, is_ok := schema[tbl_name] @@ -24,6 +24,10 @@ func TestParseSchema(t *testing.T) { } foods := schema["foods"] + assert.Equal(foods.TableName, "foods") + assert.Equal(foods.TypeName, "Food") + assert.Equal(foods.TypeIDName, "FoodID") + assert.Equal(foods.IsStrict, true) assert.Len(foods.Columns, 20) assert.Equal(foods.Columns[0].Name, "rowid") assert.Equal(foods.Columns[0].Type, "integer") diff --git a/pkg/sqlgenerate/table.go b/pkg/schema/table.go similarity index 93% rename from pkg/sqlgenerate/table.go rename to pkg/schema/table.go index 0baabcc..da1818e 100644 --- a/pkg/sqlgenerate/table.go +++ b/pkg/schema/table.go @@ -1,10 +1,4 @@ -package sqlgenerate - -import ( - _ "embed" - - _ "github.com/mattn/go-sqlite3" -) +package schema // Column represents a single column in a table. type Column struct { diff --git a/pkg/textutils/snake.go b/pkg/textutils/snake.go new file mode 100644 index 0000000..7aee3ab --- /dev/null +++ b/pkg/textutils/snake.go @@ -0,0 +1,11 @@ +package textutils + +import "strings" + +func SnakeToCamel(s string) string { + parts := strings.Split(s, "_") + for i := range len(parts) { + parts[i] = strings.ToUpper(string(parts[i][0])) + parts[i][1:] + } + return strings.Join(parts, "") +}