diff --git a/cmd/subcmd_generate_models.go b/cmd/subcmd_generate_models.go index 2e634b6..b7a1bf6 100644 --- a/cmd/subcmd_generate_models.go +++ b/cmd/subcmd_generate_models.go @@ -25,19 +25,19 @@ var generate_model = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { path := Must(cmd.Flags().GetString("schema")) modname := Must(cmd.Flags().GetString("modname")) - schema_sql, err := os.ReadFile(path) + sql, err := os.ReadFile(path) if err != nil { return fmt.Errorf("reading path %s: %w", path, err) } - db := schema.InitDB(string(schema_sql)) - tables := schema.SchemaFromDB(db).Tables - table, isOk := tables[args[0]] + db := schema.InitDB(string(sql)) + schema := schema.SchemaFromDB(db) + table, isOk := schema.Tables[args[0]] if !isOk { return ErrNoSuchTable } if Must(cmd.Flags().GetBool("test")) { - file2 := modelgenerate.GenerateModelTestAST(table, modname) + file2 := modelgenerate.GenerateModelTestAST(table, schema, modname) PanicIf(modelgenerate.FprintWithComments(os.Stdout, file2)) } else { decls := []ast.Decl{ @@ -70,8 +70,17 @@ var generate_model = &cobra.Command{ modelgenerate.GenerateSaveItemFunc(table), modelgenerate.GenerateDeleteItemFunc(table), modelgenerate.GenerateGetItemByIDFunc(table), - modelgenerate.GenerateGetAllItemsFunc(table), ) + for _, index := range schema.Indexes { + if index.TableName != table.TableName { + // Skip indexes on other tables + continue + } + if index.IsUnique && len(index.Columns) == 1 { + decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0]))) + } + } + decls = append(decls, modelgenerate.GenerateGetAllItemsFunc(table)) file := &ast.File{ Name: ast.NewIdent("db"), // TODO: parameterize diff --git a/ops/gas_init_test.sh b/ops/gas_init_test.sh index 483a1c1..50d8c90 100755 --- a/ops/gas_init_test.sh +++ b/ops/gas_init_test.sh @@ -41,6 +41,7 @@ create table items ( rowid integer primary key, description text not null default '', flavor integer references item_flavor(rowid), + thing text not null unique, created_at integer not null, updated_at integer not null ) strict; diff --git a/pkg/codegen/modelgenerate/generate_model.go b/pkg/codegen/modelgenerate/generate_model.go index bb96ee3..166b72c 100644 --- a/pkg/codegen/modelgenerate/generate_model.go +++ b/pkg/codegen/modelgenerate/generate_model.go @@ -63,7 +63,7 @@ func GenerateModelAST(table schema.Table) *ast.GenDecl { Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)}, }) } else { - typeName := col.GoType() + typeName := col.GoTypeName() fields = append(fields, &ast.Field{ Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))}, Type: ast.NewIdent(typeName), @@ -457,6 +457,51 @@ func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl { return funcDecl } +// GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function. +// E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function. +func GenerateGetItemByUniqColFunc(tbl schema.Table, col schema.Column) *ast.FuncDecl { + // Use the xyzSQLFields constant in the select query + selectExpr := &ast.BinaryExpr{ + X: &ast.BinaryExpr{ + X: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t select `"}, + Op: token.ADD, + Y: SQLFieldsConstIdent(tbl), + }, + Op: token.ADD, + Y: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`\n\t from %s\n\t where %s = ?\n\t`", tbl.TableName, col.Name)}, + } + + param := ast.NewIdent(col.GoVarName()) + + return &ast.FuncDecl{ + Recv: dbRecv, + Name: ast.NewIdent("Get" + schema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName()), + Type: &ast.FuncType{ + Params: &ast.FieldList{List: []*ast.Field{ + {Names: []*ast.Ident{param}, Type: ast.NewIdent(col.GoTypeName())}, + }}, + Results: &ast.FieldList{List: []*ast.Field{ + {Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)}, + {Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")}, + }}, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("err")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, param}}}, + }, + &ast.IfStmt{ + Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.GoTypeName)}, ast.NewIdent("ErrNotInDB")}}}}, + }, + &ast.ReturnStmt{}, + }, + }, + } +} + // GenerateGetAllItemsFunc produces an AST for the `GetAllXyzs()` function. // E.g., a table with `table.TypeName = "foods"` will produce a "GetAllFoods()" function. func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl { diff --git a/pkg/codegen/modelgenerate/generate_testfile.go b/pkg/codegen/modelgenerate/generate_testfile.go index dbc4e9d..1ca760b 100644 --- a/pkg/codegen/modelgenerate/generate_testfile.go +++ b/pkg/codegen/modelgenerate/generate_testfile.go @@ -7,11 +7,11 @@ import ( "github.com/jinzhu/inflection" - "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" + pkgschema "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" ) // GenerateModelTestAST produces an AST for a starter test file for a given model. -func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { +func GenerateModelTestAST(tbl pkgschema.Table, schema pkgschema.Schema, gomodName string) *ast.File { packageName := "db" testpackageName := packageName + "_test" @@ -44,6 +44,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { fieldName := ast.NewIdent("Description") description1 := `"an item"` description2 := `"a big item"` + testDB := ast.NewIdent("TestDB") hasCreatedAt, hasUpdatedAt := tbl.HasAutoTimestamps() @@ -69,6 +70,8 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { } stmts := []ast.Stmt{ + Comment("Create"), + // item := Item{Description: "an item"} &ast.AssignStmt{ Lhs: []ast.Expr{testObj}, @@ -86,7 +89,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { // TestDB.SaveItem(&item) &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Save" + tbl.GoTypeName)}, + Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}}, }}, @@ -106,12 +109,15 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { } stmts = append(stmts, + BlankLine(), + Comment("Load"), + // item2 := Must(TestDB.GetItemByID(item.ID)) &ast.AssignStmt{ Lhs: []ast.Expr{testObj2}, Tok: token.DEFINE, Rhs: []ast.Expr{mustCall(&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, + Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, })}, }, @@ -125,6 +131,11 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { &ast.SelectorExpr{X: testObj2, Sel: fieldName}, }, }}, + ) + + stmts = append(stmts, + BlankLine(), + Comment("Update"), // item.Description = "a big item" &ast.AssignStmt{ @@ -135,18 +146,16 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { // TestDB.SaveItem(&item) &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Save" + tbl.GoTypeName)}, + Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}}, }}, - ) - stmts = append(stmts, // item2 = Must(TestDB.GetItemByID(item.ID)) &ast.AssignStmt{ Lhs: []ast.Expr{testObj2}, Tok: token.ASSIGN, Rhs: []ast.Expr{mustCall(&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, + Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, })}, }, @@ -160,10 +169,50 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { &ast.SelectorExpr{X: testObj2, Sel: fieldName}, }, }}, + ) + + indexGets, hasIndexedGets := []ast.Stmt{ + BlankLine(), + Comment("Indexed lookups"), + }, false + + for _, index := range schema.Indexes { + if index.TableName != tbl.TableName { + // Skip indexes on other tables + continue + } + if index.IsUnique && len(index.Columns) == 1 { + col := tbl.GetColumnByName(index.Columns[0]) + indexGets = append(indexGets, []ast.Stmt{ + // assert.Equal(t, item2, TestDB.GetItemByXYZ(...)) + &ast.ExprStmt{X: &ast.CallExpr{ // TODO: what if just delete the "ExprStmt" wrapper? + Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")}, + Args: []ast.Expr{ + ast.NewIdent("t"), + testObj2, + mustCall(&ast.CallExpr{ + Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + pkgschema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName())}, + Args: []ast.Expr{&ast.SelectorExpr{X: testObj2, Sel: ast.NewIdent(col.GoFieldName())}}, + }), + }, + }}, + }...) + // decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0]))) + hasIndexedGets = true + } + } + + if hasIndexedGets { + stmts = append(stmts, indexGets...) + } + + stmts = append(stmts, + BlankLine(), + Comment("Delete"), // TestDB.DeleteItem(item) &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Delete" + tbl.GoTypeName)}, + Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Delete" + tbl.GoTypeName)}, Args: []ast.Expr{testObj}, }}, @@ -172,7 +221,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, + Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, }}, }, @@ -203,7 +252,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{ - X: ast.NewIdent("TestDB"), + X: testDB, Sel: ast.NewIdent("GetAll" + inflection.Plural(tbl.GoTypeName)), }, }}, @@ -259,7 +308,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { Rhs: []ast.Expr{ &ast.CallExpr{ Fun: &ast.SelectorExpr{ - X: ast.NewIdent("TestDB"), + X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName), }, Args: []ast.Expr{ diff --git a/pkg/schema/table.go b/pkg/schema/table.go index 248d1ac..5e3a988 100644 --- a/pkg/schema/table.go +++ b/pkg/schema/table.go @@ -41,7 +41,23 @@ func (c Column) GoFieldName() string { return textutils.SnakeToCamel(c.Name) } -func (c Column) GoType() string { +// GoVarName returns the name of a local variable for this column, e.g., when used as a function parameter. +func (c Column) GoVarName() string { + if c.Name == "rowid" { + return strings.ToLower(c.TableName)[0:1] + "ID" + // TODO: Or should it just be "id"?? + } + // For foreign keys, use first letter of the target type and "ID". "UserID" => "uID" + if c.IsNonCodeTableForeignKey() { + return strings.ToLower(c.ForeignKeyTargetTable)[0:1] + "ID" + } + + // Otherwise, just lowercase the field name + fieldname := c.GoFieldName() + return strings.ToLower(fieldname)[0:1] + fieldname[1:] +} + +func (c Column) GoTypeName() string { if c.IsNonCodeTableForeignKey() { return TypenameFromTablename(c.ForeignKeyTargetTable) + "ID" } @@ -101,6 +117,15 @@ func (t Table) PrimaryKeyColumns() []Column { return pks } +func (t Table) GetColumnByName(name string) Column { + for _, c := range t.Columns { + if c.Name == name { + return c + } + } + panic("no such column: " + name) +} + func (t Table) HasAutoTimestamps() (hasCreatedAt bool, hasUpdatedAt bool) { for _, c := range t.Columns { if c.Name == "created_at" && c.Type == "integer" {