diff --git a/cmd/main.go b/cmd/main.go index d62dbe8..d361510 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -22,6 +22,7 @@ func main() { root_cmd.AddCommand(sqlite_lint) root_cmd.AddCommand(cmd_init) root_cmd.AddCommand(generate_model) + root_cmd.AddCommand(generate_codetable_type) if err := root_cmd.Execute(); err != nil { fmt.Println(RED + err.Error() + RESET) os.Exit(1) diff --git a/cmd/subcmd_generate_codetable_type.go b/cmd/subcmd_generate_codetable_type.go new file mode 100644 index 0000000..67c0331 --- /dev/null +++ b/cmd/subcmd_generate_codetable_type.go @@ -0,0 +1,55 @@ +package main + +import ( + "fmt" + "go/ast" + "os" + + "git.offline-twitter.com/offline-labs/gas-stack/pkg/codegen/modelgenerate" + . "git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils" + "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" + "github.com/spf13/cobra" +) + +var generate_codetable_type = &cobra.Command{ + Use: "generate_codetable ", + Short: "Generate a code-table enum type", + + Args: cobra.ExactArgs(1), + + RunE: func(cmd *cobra.Command, args []string) error { + path := Must(cmd.Flags().GetString("schema")) + sql, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("reading path %s: %w", path, err) + } + db := schema.InitDB(string(sql)) + schema := schema.SchemaFromDB(db) + table, isOk := schema.Tables[args[0]] + if !isOk { + return ErrNoSuchTable + } + + vals := table.GetCodeTableValues(db) + decls := []ast.Decl{ + modelgenerate.GenerateCodetableType(table), + modelgenerate.GenerateCodetableEnum(table, vals), + modelgenerate.GenerateCodetableStringerFunc(table, vals), + } + + file := &ast.File{ + Name: ast.NewIdent("db"), // TODO: parameterize + Decls: decls, + } + + PanicIf(modelgenerate.FprintWithComments(os.Stdout, file)) + return nil + }, +} + +// DUPE: generate-flags +func init() { + generate_codetable_type.Flags().String("schema", "pkg/db/schema.sql", "Path to SQL schema file") + generate_codetable_type.Flags().String("modname", "mymodule", "Name of project's Go module (TODO: detect automatically)") + generate_codetable_type.Flags().Bool("test", false, "Generate test file instead of regular file") +} diff --git a/cmd/subcmd_generate_models.go b/cmd/subcmd_generate_models.go index f11dcba..7f1482c 100644 --- a/cmd/subcmd_generate_models.go +++ b/cmd/subcmd_generate_models.go @@ -101,6 +101,7 @@ var generate_model = &cobra.Command{ }, } +// DUPE: generate-flags func init() { generate_model.Flags().String("schema", "pkg/db/schema.sql", "Path to SQL schema file") generate_model.Flags().String("modname", "mymodule", "Name of project's Go module (TODO: detect automatically)") diff --git a/pkg/codegen/modelgenerate/generate_codetable_type.go b/pkg/codegen/modelgenerate/generate_codetable_type.go new file mode 100644 index 0000000..c3f43d3 --- /dev/null +++ b/pkg/codegen/modelgenerate/generate_codetable_type.go @@ -0,0 +1,111 @@ +package modelgenerate + +import ( + "fmt" + "go/ast" + "go/token" + + "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" + "git.offline-twitter.com/offline-labs/gas-stack/pkg/textutils" +) + +func GenerateCodetableType(table schema.Table) *ast.GenDecl { + return &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{&ast.TypeSpec{Name: &ast.Ident{Name: table.GoTypeName}, Type: &ast.Ident{Name: "int"}}}, + } +} + +func GenerateCodetableEnum(table schema.Table, vals []string) *ast.GenDecl { + getConstName := func(s string) string { + return table.GoTypeName + textutils.KebabToPascal(s) + } + + constSpecs := []ast.Spec{} + for i, val := range vals { + spec := &ast.ValueSpec{ + Names: []*ast.Ident{{Name: getConstName(val)}}, + } + // Only the first one needs `iota` + if i == 0 { + spec.Type = &ast.Ident{Name: table.GoTypeName} + spec.Values = []ast.Expr{&ast.BinaryExpr{ + X: &ast.Ident{Name: "iota"}, + Op: token.ADD, + Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, + }} + } + constSpecs = append(constSpecs, spec) + } + + return &ast.GenDecl{Tok: token.CONST, Specs: constSpecs} +} + +// GenerateCodetableStringerFunc implements the `Stringer` interface by defining a `String() string` function. +func GenerateCodetableStringerFunc(table schema.Table, vals []string) *ast.FuncDecl { + objIdent := ast.NewIdent(table.VarName) + return &ast.FuncDecl{ + Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{objIdent}, Type: &ast.Ident{Name: table.GoTypeName}}}}, + Name: &ast.Ident{Name: "String"}, + Type: &ast.FuncType{Params: &ast.FieldList{}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: "string"}}}}}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + // names := []string{ ... } + &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.Ident{Name: "names"}}, + Tok: token.DEFINE, + Rhs: []ast.Expr{ + &ast.CompositeLit{ + Type: &ast.ArrayType{Elt: &ast.Ident{Name: "string"}}, + Elts: func() []ast.Expr { + ret := []ast.Expr{ + &ast.BasicLit{Kind: token.STRING, Value: `""`}, + } + for _, val := range vals { + ret = append(ret, &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", val)}) + } + return ret + }(), + }, + }, + }, + + // if int(c) < 1 || int(c) >= len(names) { return "invalid" } + &ast.IfStmt{ + Cond: &ast.BinaryExpr{ + X: &ast.BinaryExpr{ + X: &ast.CallExpr{Fun: &ast.Ident{Name: "int"}, Args: []ast.Expr{objIdent}}, + Op: token.LSS, + Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, + }, + Op: token.LOR, + Y: &ast.BinaryExpr{ + X: &ast.CallExpr{Fun: &ast.Ident{Name: "int"}, Args: []ast.Expr{objIdent}}, + Op: token.GEQ, + Y: &ast.CallExpr{Fun: &ast.Ident{Name: "len"}, Args: []ast.Expr{&ast.Ident{Name: "names"}}}, + }, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ReturnStmt{Results: []ast.Expr{ + &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("fmt"), Sel: ast.NewIdent("Sprintf")}, Args: []ast.Expr{ + &ast.BasicLit{Kind: token.STRING, Value: `"<%d=invalid>"`}, + objIdent, + }}, + }}, + }, + }, + }, + // return names[c] + &ast.ReturnStmt{ + Results: []ast.Expr{ + &ast.IndexExpr{ + X: &ast.Ident{Name: "names"}, + Index: objIdent, + }, + }, + }, + }, + }, + } +} diff --git a/pkg/schema/column.go b/pkg/schema/column.go index e614465..b94685d 100644 --- a/pkg/schema/column.go +++ b/pkg/schema/column.go @@ -8,9 +8,15 @@ import ( // Column represents a single column in a table. type Column struct { - TableName string `db:"table_name"` - Name string `db:"column_name"` - Type string `db:"column_type"` + // TableName is the name of the SQLite table this column belongs to. + TableName string `db:"table_name"` + + // Name is the SQLite column name. + Name string `db:"column_name"` + + // Type is the SQLite type this column contains. + Type string `db:"column_type"` + IsNotNull bool `db:"notnull"` HasDefaultValue bool `db:"has_default_value"` DefaultValue string `db:"dflt_value"` diff --git a/pkg/schema/table.go b/pkg/schema/table.go index 88c9722..5b41361 100644 --- a/pkg/schema/table.go +++ b/pkg/schema/table.go @@ -1,7 +1,12 @@ package schema import ( + "fmt" + "slices" "sort" + + "git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils" + "github.com/jmoiron/sqlx" ) // Table is a single SQLite table. @@ -61,3 +66,11 @@ func (t Table) HasAutoTimestamps() (hasCreatedAt bool, hasUpdatedAt bool) { } return } + +func (t Table) GetCodeTableValues(db *sqlx.DB) (ret []string) { + if !slices.ContainsFunc(t.Columns, func(c Column) bool { return c.Name == "name" }) { + panic("not a code table") + } + flowutils.PanicIf(db.Select(&ret, fmt.Sprintf("select name from %s", t.TableName))) + return +} diff --git a/pkg/textutils/snake.go b/pkg/textutils/snake.go index 5b6187b..c113202 100644 --- a/pkg/textutils/snake.go +++ b/pkg/textutils/snake.go @@ -10,6 +10,16 @@ func SnakeToCamel(s string) string { return strings.Join(parts, "") } +func KebabToPascal(s string) string { + parts := strings.Split(s, "-") + for i, part := range parts { + if len(part) > 0 { + parts[i] = strings.ToUpper(part[:1]) + part[1:] + } + } + return strings.Join(parts, "") +} + func CamelToPascal(s string) string { return strings.ToLower(s)[0:1] + s[1:] } diff --git a/sample_data/test_schemas/codetables.sql b/sample_data/test_schemas/codetables.sql new file mode 100644 index 0000000..26c40c3 --- /dev/null +++ b/sample_data/test_schemas/codetables.sql @@ -0,0 +1,8 @@ +create table item_types ( + rowid integer primary key, + name text not null unique +) strict; +insert into item_types(rowid, name) values + (1, 'first-type'), + (2, 'second-type'), + (3, 'third-type');