codegen: add unique index lookup funcs

This commit is contained in:
wispem-wantex 2026-02-19 21:40:02 -08:00
parent d6426bba14
commit 3e0a85fcfa
5 changed files with 149 additions and 20 deletions

View File

@ -25,19 +25,19 @@ var generate_model = &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error {
path := Must(cmd.Flags().GetString("schema"))
modname := Must(cmd.Flags().GetString("modname"))
schema_sql, err := os.ReadFile(path)
sql, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("reading path %s: %w", path, err)
}
db := schema.InitDB(string(schema_sql))
tables := schema.SchemaFromDB(db).Tables
table, isOk := tables[args[0]]
db := schema.InitDB(string(sql))
schema := schema.SchemaFromDB(db)
table, isOk := schema.Tables[args[0]]
if !isOk {
return ErrNoSuchTable
}
if Must(cmd.Flags().GetBool("test")) {
file2 := modelgenerate.GenerateModelTestAST(table, modname)
file2 := modelgenerate.GenerateModelTestAST(table, schema, modname)
PanicIf(modelgenerate.FprintWithComments(os.Stdout, file2))
} else {
decls := []ast.Decl{
@ -70,8 +70,17 @@ var generate_model = &cobra.Command{
modelgenerate.GenerateSaveItemFunc(table),
modelgenerate.GenerateDeleteItemFunc(table),
modelgenerate.GenerateGetItemByIDFunc(table),
modelgenerate.GenerateGetAllItemsFunc(table),
)
for _, index := range schema.Indexes {
if index.TableName != table.TableName {
// Skip indexes on other tables
continue
}
if index.IsUnique && len(index.Columns) == 1 {
decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0])))
}
}
decls = append(decls, modelgenerate.GenerateGetAllItemsFunc(table))
file := &ast.File{
Name: ast.NewIdent("db"), // TODO: parameterize

View File

@ -41,6 +41,7 @@ create table items (
rowid integer primary key,
description text not null default '',
flavor integer references item_flavor(rowid),
thing text not null unique,
created_at integer not null,
updated_at integer not null
) strict;

View File

@ -63,7 +63,7 @@ func GenerateModelAST(table schema.Table) *ast.GenDecl {
Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)},
})
} else {
typeName := col.GoType()
typeName := col.GoTypeName()
fields = append(fields, &ast.Field{
Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))},
Type: ast.NewIdent(typeName),
@ -457,6 +457,51 @@ func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl {
return funcDecl
}
// GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function.
// E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function.
func GenerateGetItemByUniqColFunc(tbl schema.Table, col schema.Column) *ast.FuncDecl {
// Use the xyzSQLFields constant in the select query
selectExpr := &ast.BinaryExpr{
X: &ast.BinaryExpr{
X: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t select `"},
Op: token.ADD,
Y: SQLFieldsConstIdent(tbl),
},
Op: token.ADD,
Y: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`\n\t from %s\n\t where %s = ?\n\t`", tbl.TableName, col.Name)},
}
param := ast.NewIdent(col.GoVarName())
return &ast.FuncDecl{
Recv: dbRecv,
Name: ast.NewIdent("Get" + schema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName()),
Type: &ast.FuncType{
Params: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{param}, Type: ast.NewIdent(col.GoTypeName())},
}},
Results: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)},
{Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")},
}},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("err")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, param}}},
},
&ast.IfStmt{
Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.GoTypeName)}, ast.NewIdent("ErrNotInDB")}}}},
},
&ast.ReturnStmt{},
},
},
}
}
// GenerateGetAllItemsFunc produces an AST for the `GetAllXyzs()` function.
// E.g., a table with `table.TypeName = "foods"` will produce a "GetAllFoods()" function.
func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl {

View File

@ -7,11 +7,11 @@ import (
"github.com/jinzhu/inflection"
"git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
pkgschema "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
)
// GenerateModelTestAST produces an AST for a starter test file for a given model.
func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
func GenerateModelTestAST(tbl pkgschema.Table, schema pkgschema.Schema, gomodName string) *ast.File {
packageName := "db"
testpackageName := packageName + "_test"
@ -44,6 +44,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
fieldName := ast.NewIdent("Description")
description1 := `"an item"`
description2 := `"a big item"`
testDB := ast.NewIdent("TestDB")
hasCreatedAt, hasUpdatedAt := tbl.HasAutoTimestamps()
@ -69,6 +70,8 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
}
stmts := []ast.Stmt{
Comment("Create"),
// item := Item{Description: "an item"}
&ast.AssignStmt{
Lhs: []ast.Expr{testObj},
@ -86,7 +89,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
// TestDB.SaveItem(&item)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
}},
@ -106,12 +109,15 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
}
stmts = append(stmts,
BlankLine(),
Comment("Load"),
// item2 := Must(TestDB.GetItemByID(item.ID))
&ast.AssignStmt{
Lhs: []ast.Expr{testObj2},
Tok: token.DEFINE,
Rhs: []ast.Expr{mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
})},
},
@ -125,6 +131,11 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
&ast.SelectorExpr{X: testObj2, Sel: fieldName},
},
}},
)
stmts = append(stmts,
BlankLine(),
Comment("Update"),
// item.Description = "a big item"
&ast.AssignStmt{
@ -135,18 +146,16 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
// TestDB.SaveItem(&item)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
}},
)
stmts = append(stmts,
// item2 = Must(TestDB.GetItemByID(item.ID))
&ast.AssignStmt{
Lhs: []ast.Expr{testObj2},
Tok: token.ASSIGN,
Rhs: []ast.Expr{mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
})},
},
@ -160,10 +169,50 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
&ast.SelectorExpr{X: testObj2, Sel: fieldName},
},
}},
)
indexGets, hasIndexedGets := []ast.Stmt{
BlankLine(),
Comment("Indexed lookups"),
}, false
for _, index := range schema.Indexes {
if index.TableName != tbl.TableName {
// Skip indexes on other tables
continue
}
if index.IsUnique && len(index.Columns) == 1 {
col := tbl.GetColumnByName(index.Columns[0])
indexGets = append(indexGets, []ast.Stmt{
// assert.Equal(t, item2, TestDB.GetItemByXYZ(...))
&ast.ExprStmt{X: &ast.CallExpr{ // TODO: what if just delete the "ExprStmt" wrapper?
Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")},
Args: []ast.Expr{
ast.NewIdent("t"),
testObj2,
mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + pkgschema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName())},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj2, Sel: ast.NewIdent(col.GoFieldName())}},
}),
},
}},
}...)
// decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0])))
hasIndexedGets = true
}
}
if hasIndexedGets {
stmts = append(stmts, indexGets...)
}
stmts = append(stmts,
BlankLine(),
Comment("Delete"),
// TestDB.DeleteItem(item)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Delete" + tbl.GoTypeName)},
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Delete" + tbl.GoTypeName)},
Args: []ast.Expr{testObj},
}},
@ -172,7 +221,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
}},
},
@ -203,7 +252,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("TestDB"),
X: testDB,
Sel: ast.NewIdent("GetAll" + inflection.Plural(tbl.GoTypeName)),
},
}},
@ -259,7 +308,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("TestDB"),
X: testDB,
Sel: ast.NewIdent("Save" + tbl.GoTypeName),
},
Args: []ast.Expr{

View File

@ -41,7 +41,23 @@ func (c Column) GoFieldName() string {
return textutils.SnakeToCamel(c.Name)
}
func (c Column) GoType() string {
// GoVarName returns the name of a local variable for this column, e.g., when used as a function parameter.
func (c Column) GoVarName() string {
if c.Name == "rowid" {
return strings.ToLower(c.TableName)[0:1] + "ID"
// TODO: Or should it just be "id"??
}
// For foreign keys, use first letter of the target type and "ID". "UserID" => "uID"
if c.IsNonCodeTableForeignKey() {
return strings.ToLower(c.ForeignKeyTargetTable)[0:1] + "ID"
}
// Otherwise, just lowercase the field name
fieldname := c.GoFieldName()
return strings.ToLower(fieldname)[0:1] + fieldname[1:]
}
func (c Column) GoTypeName() string {
if c.IsNonCodeTableForeignKey() {
return TypenameFromTablename(c.ForeignKeyTargetTable) + "ID"
}
@ -101,6 +117,15 @@ func (t Table) PrimaryKeyColumns() []Column {
return pks
}
func (t Table) GetColumnByName(name string) Column {
for _, c := range t.Columns {
if c.Name == name {
return c
}
}
panic("no such column: " + name)
}
func (t Table) HasAutoTimestamps() (hasCreatedAt bool, hasUpdatedAt bool) {
for _, c := range t.Columns {
if c.Name == "created_at" && c.Type == "integer" {