codegen: add unique index lookup funcs

This commit is contained in:
wispem-wantex 2026-02-19 21:40:02 -08:00
parent d6426bba14
commit 3e0a85fcfa
5 changed files with 149 additions and 20 deletions

View File

@ -25,19 +25,19 @@ var generate_model = &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
path := Must(cmd.Flags().GetString("schema")) path := Must(cmd.Flags().GetString("schema"))
modname := Must(cmd.Flags().GetString("modname")) modname := Must(cmd.Flags().GetString("modname"))
schema_sql, err := os.ReadFile(path) sql, err := os.ReadFile(path)
if err != nil { if err != nil {
return fmt.Errorf("reading path %s: %w", path, err) return fmt.Errorf("reading path %s: %w", path, err)
} }
db := schema.InitDB(string(schema_sql)) db := schema.InitDB(string(sql))
tables := schema.SchemaFromDB(db).Tables schema := schema.SchemaFromDB(db)
table, isOk := tables[args[0]] table, isOk := schema.Tables[args[0]]
if !isOk { if !isOk {
return ErrNoSuchTable return ErrNoSuchTable
} }
if Must(cmd.Flags().GetBool("test")) { if Must(cmd.Flags().GetBool("test")) {
file2 := modelgenerate.GenerateModelTestAST(table, modname) file2 := modelgenerate.GenerateModelTestAST(table, schema, modname)
PanicIf(modelgenerate.FprintWithComments(os.Stdout, file2)) PanicIf(modelgenerate.FprintWithComments(os.Stdout, file2))
} else { } else {
decls := []ast.Decl{ decls := []ast.Decl{
@ -70,8 +70,17 @@ var generate_model = &cobra.Command{
modelgenerate.GenerateSaveItemFunc(table), modelgenerate.GenerateSaveItemFunc(table),
modelgenerate.GenerateDeleteItemFunc(table), modelgenerate.GenerateDeleteItemFunc(table),
modelgenerate.GenerateGetItemByIDFunc(table), modelgenerate.GenerateGetItemByIDFunc(table),
modelgenerate.GenerateGetAllItemsFunc(table),
) )
for _, index := range schema.Indexes {
if index.TableName != table.TableName {
// Skip indexes on other tables
continue
}
if index.IsUnique && len(index.Columns) == 1 {
decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0])))
}
}
decls = append(decls, modelgenerate.GenerateGetAllItemsFunc(table))
file := &ast.File{ file := &ast.File{
Name: ast.NewIdent("db"), // TODO: parameterize Name: ast.NewIdent("db"), // TODO: parameterize

View File

@ -41,6 +41,7 @@ create table items (
rowid integer primary key, rowid integer primary key,
description text not null default '', description text not null default '',
flavor integer references item_flavor(rowid), flavor integer references item_flavor(rowid),
thing text not null unique,
created_at integer not null, created_at integer not null,
updated_at integer not null updated_at integer not null
) strict; ) strict;

View File

@ -63,7 +63,7 @@ func GenerateModelAST(table schema.Table) *ast.GenDecl {
Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)}, Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)},
}) })
} else { } else {
typeName := col.GoType() typeName := col.GoTypeName()
fields = append(fields, &ast.Field{ fields = append(fields, &ast.Field{
Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))}, Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))},
Type: ast.NewIdent(typeName), Type: ast.NewIdent(typeName),
@ -457,6 +457,51 @@ func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl {
return funcDecl return funcDecl
} }
// GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function.
// E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function.
func GenerateGetItemByUniqColFunc(tbl schema.Table, col schema.Column) *ast.FuncDecl {
// Use the xyzSQLFields constant in the select query
selectExpr := &ast.BinaryExpr{
X: &ast.BinaryExpr{
X: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t select `"},
Op: token.ADD,
Y: SQLFieldsConstIdent(tbl),
},
Op: token.ADD,
Y: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`\n\t from %s\n\t where %s = ?\n\t`", tbl.TableName, col.Name)},
}
param := ast.NewIdent(col.GoVarName())
return &ast.FuncDecl{
Recv: dbRecv,
Name: ast.NewIdent("Get" + schema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName()),
Type: &ast.FuncType{
Params: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{param}, Type: ast.NewIdent(col.GoTypeName())},
}},
Results: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)},
{Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")},
}},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("err")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, param}}},
},
&ast.IfStmt{
Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.GoTypeName)}, ast.NewIdent("ErrNotInDB")}}}},
},
&ast.ReturnStmt{},
},
},
}
}
// GenerateGetAllItemsFunc produces an AST for the `GetAllXyzs()` function. // GenerateGetAllItemsFunc produces an AST for the `GetAllXyzs()` function.
// E.g., a table with `table.TypeName = "foods"` will produce a "GetAllFoods()" function. // E.g., a table with `table.TypeName = "foods"` will produce a "GetAllFoods()" function.
func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl { func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl {

View File

@ -7,11 +7,11 @@ import (
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" pkgschema "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
) )
// GenerateModelTestAST produces an AST for a starter test file for a given model. // GenerateModelTestAST produces an AST for a starter test file for a given model.
func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { func GenerateModelTestAST(tbl pkgschema.Table, schema pkgschema.Schema, gomodName string) *ast.File {
packageName := "db" packageName := "db"
testpackageName := packageName + "_test" testpackageName := packageName + "_test"
@ -44,6 +44,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
fieldName := ast.NewIdent("Description") fieldName := ast.NewIdent("Description")
description1 := `"an item"` description1 := `"an item"`
description2 := `"a big item"` description2 := `"a big item"`
testDB := ast.NewIdent("TestDB")
hasCreatedAt, hasUpdatedAt := tbl.HasAutoTimestamps() hasCreatedAt, hasUpdatedAt := tbl.HasAutoTimestamps()
@ -69,6 +70,8 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
} }
stmts := []ast.Stmt{ stmts := []ast.Stmt{
Comment("Create"),
// item := Item{Description: "an item"} // item := Item{Description: "an item"}
&ast.AssignStmt{ &ast.AssignStmt{
Lhs: []ast.Expr{testObj}, Lhs: []ast.Expr{testObj},
@ -86,7 +89,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
// TestDB.SaveItem(&item) // TestDB.SaveItem(&item)
&ast.ExprStmt{X: &ast.CallExpr{ &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Save" + tbl.GoTypeName)}, Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
}}, }},
@ -106,12 +109,15 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
} }
stmts = append(stmts, stmts = append(stmts,
BlankLine(),
Comment("Load"),
// item2 := Must(TestDB.GetItemByID(item.ID)) // item2 := Must(TestDB.GetItemByID(item.ID))
&ast.AssignStmt{ &ast.AssignStmt{
Lhs: []ast.Expr{testObj2}, Lhs: []ast.Expr{testObj2},
Tok: token.DEFINE, Tok: token.DEFINE,
Rhs: []ast.Expr{mustCall(&ast.CallExpr{ Rhs: []ast.Expr{mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
})}, })},
}, },
@ -125,6 +131,11 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
&ast.SelectorExpr{X: testObj2, Sel: fieldName}, &ast.SelectorExpr{X: testObj2, Sel: fieldName},
}, },
}}, }},
)
stmts = append(stmts,
BlankLine(),
Comment("Update"),
// item.Description = "a big item" // item.Description = "a big item"
&ast.AssignStmt{ &ast.AssignStmt{
@ -135,18 +146,16 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
// TestDB.SaveItem(&item) // TestDB.SaveItem(&item)
&ast.ExprStmt{X: &ast.CallExpr{ &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Save" + tbl.GoTypeName)}, Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
}}, }},
)
stmts = append(stmts,
// item2 = Must(TestDB.GetItemByID(item.ID)) // item2 = Must(TestDB.GetItemByID(item.ID))
&ast.AssignStmt{ &ast.AssignStmt{
Lhs: []ast.Expr{testObj2}, Lhs: []ast.Expr{testObj2},
Tok: token.ASSIGN, Tok: token.ASSIGN,
Rhs: []ast.Expr{mustCall(&ast.CallExpr{ Rhs: []ast.Expr{mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
})}, })},
}, },
@ -160,10 +169,50 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
&ast.SelectorExpr{X: testObj2, Sel: fieldName}, &ast.SelectorExpr{X: testObj2, Sel: fieldName},
}, },
}}, }},
)
indexGets, hasIndexedGets := []ast.Stmt{
BlankLine(),
Comment("Indexed lookups"),
}, false
for _, index := range schema.Indexes {
if index.TableName != tbl.TableName {
// Skip indexes on other tables
continue
}
if index.IsUnique && len(index.Columns) == 1 {
col := tbl.GetColumnByName(index.Columns[0])
indexGets = append(indexGets, []ast.Stmt{
// assert.Equal(t, item2, TestDB.GetItemByXYZ(...))
&ast.ExprStmt{X: &ast.CallExpr{ // TODO: what if just delete the "ExprStmt" wrapper?
Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")},
Args: []ast.Expr{
ast.NewIdent("t"),
testObj2,
mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + pkgschema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName())},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj2, Sel: ast.NewIdent(col.GoFieldName())}},
}),
},
}},
}...)
// decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0])))
hasIndexedGets = true
}
}
if hasIndexedGets {
stmts = append(stmts, indexGets...)
}
stmts = append(stmts,
BlankLine(),
Comment("Delete"),
// TestDB.DeleteItem(item) // TestDB.DeleteItem(item)
&ast.ExprStmt{X: &ast.CallExpr{ &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Delete" + tbl.GoTypeName)}, Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Delete" + tbl.GoTypeName)},
Args: []ast.Expr{testObj}, Args: []ast.Expr{testObj},
}}, }},
@ -172,7 +221,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")}, Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")},
Tok: token.DEFINE, Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{ Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
}}, }},
}, },
@ -203,7 +252,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
Tok: token.ASSIGN, Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{ Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{ Fun: &ast.SelectorExpr{
X: ast.NewIdent("TestDB"), X: testDB,
Sel: ast.NewIdent("GetAll" + inflection.Plural(tbl.GoTypeName)), Sel: ast.NewIdent("GetAll" + inflection.Plural(tbl.GoTypeName)),
}, },
}}, }},
@ -259,7 +308,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
Rhs: []ast.Expr{ Rhs: []ast.Expr{
&ast.CallExpr{ &ast.CallExpr{
Fun: &ast.SelectorExpr{ Fun: &ast.SelectorExpr{
X: ast.NewIdent("TestDB"), X: testDB,
Sel: ast.NewIdent("Save" + tbl.GoTypeName), Sel: ast.NewIdent("Save" + tbl.GoTypeName),
}, },
Args: []ast.Expr{ Args: []ast.Expr{

View File

@ -41,7 +41,23 @@ func (c Column) GoFieldName() string {
return textutils.SnakeToCamel(c.Name) return textutils.SnakeToCamel(c.Name)
} }
func (c Column) GoType() string { // GoVarName returns the name of a local variable for this column, e.g., when used as a function parameter.
func (c Column) GoVarName() string {
if c.Name == "rowid" {
return strings.ToLower(c.TableName)[0:1] + "ID"
// TODO: Or should it just be "id"??
}
// For foreign keys, use first letter of the target type and "ID". "UserID" => "uID"
if c.IsNonCodeTableForeignKey() {
return strings.ToLower(c.ForeignKeyTargetTable)[0:1] + "ID"
}
// Otherwise, just lowercase the field name
fieldname := c.GoFieldName()
return strings.ToLower(fieldname)[0:1] + fieldname[1:]
}
func (c Column) GoTypeName() string {
if c.IsNonCodeTableForeignKey() { if c.IsNonCodeTableForeignKey() {
return TypenameFromTablename(c.ForeignKeyTargetTable) + "ID" return TypenameFromTablename(c.ForeignKeyTargetTable) + "ID"
} }
@ -101,6 +117,15 @@ func (t Table) PrimaryKeyColumns() []Column {
return pks return pks
} }
func (t Table) GetColumnByName(name string) Column {
for _, c := range t.Columns {
if c.Name == name {
return c
}
}
panic("no such column: " + name)
}
func (t Table) HasAutoTimestamps() (hasCreatedAt bool, hasUpdatedAt bool) { func (t Table) HasAutoTimestamps() (hasCreatedAt bool, hasUpdatedAt bool) {
for _, c := range t.Columns { for _, c := range t.Columns {
if c.Name == "created_at" && c.Type == "integer" { if c.Name == "created_at" && c.Type == "integer" {