codegen: add unique index lookup funcs
This commit is contained in:
parent
d6426bba14
commit
3e0a85fcfa
@ -25,19 +25,19 @@ var generate_model = &cobra.Command{
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
path := Must(cmd.Flags().GetString("schema"))
|
||||
modname := Must(cmd.Flags().GetString("modname"))
|
||||
schema_sql, err := os.ReadFile(path)
|
||||
sql, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading path %s: %w", path, err)
|
||||
}
|
||||
db := schema.InitDB(string(schema_sql))
|
||||
tables := schema.SchemaFromDB(db).Tables
|
||||
table, isOk := tables[args[0]]
|
||||
db := schema.InitDB(string(sql))
|
||||
schema := schema.SchemaFromDB(db)
|
||||
table, isOk := schema.Tables[args[0]]
|
||||
if !isOk {
|
||||
return ErrNoSuchTable
|
||||
}
|
||||
|
||||
if Must(cmd.Flags().GetBool("test")) {
|
||||
file2 := modelgenerate.GenerateModelTestAST(table, modname)
|
||||
file2 := modelgenerate.GenerateModelTestAST(table, schema, modname)
|
||||
PanicIf(modelgenerate.FprintWithComments(os.Stdout, file2))
|
||||
} else {
|
||||
decls := []ast.Decl{
|
||||
@ -70,8 +70,17 @@ var generate_model = &cobra.Command{
|
||||
modelgenerate.GenerateSaveItemFunc(table),
|
||||
modelgenerate.GenerateDeleteItemFunc(table),
|
||||
modelgenerate.GenerateGetItemByIDFunc(table),
|
||||
modelgenerate.GenerateGetAllItemsFunc(table),
|
||||
)
|
||||
for _, index := range schema.Indexes {
|
||||
if index.TableName != table.TableName {
|
||||
// Skip indexes on other tables
|
||||
continue
|
||||
}
|
||||
if index.IsUnique && len(index.Columns) == 1 {
|
||||
decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0])))
|
||||
}
|
||||
}
|
||||
decls = append(decls, modelgenerate.GenerateGetAllItemsFunc(table))
|
||||
|
||||
file := &ast.File{
|
||||
Name: ast.NewIdent("db"), // TODO: parameterize
|
||||
|
||||
@ -41,6 +41,7 @@ create table items (
|
||||
rowid integer primary key,
|
||||
description text not null default '',
|
||||
flavor integer references item_flavor(rowid),
|
||||
thing text not null unique,
|
||||
created_at integer not null,
|
||||
updated_at integer not null
|
||||
) strict;
|
||||
|
||||
@ -63,7 +63,7 @@ func GenerateModelAST(table schema.Table) *ast.GenDecl {
|
||||
Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)},
|
||||
})
|
||||
} else {
|
||||
typeName := col.GoType()
|
||||
typeName := col.GoTypeName()
|
||||
fields = append(fields, &ast.Field{
|
||||
Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))},
|
||||
Type: ast.NewIdent(typeName),
|
||||
@ -457,6 +457,51 @@ func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl {
|
||||
return funcDecl
|
||||
}
|
||||
|
||||
// GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function.
|
||||
// E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function.
|
||||
func GenerateGetItemByUniqColFunc(tbl schema.Table, col schema.Column) *ast.FuncDecl {
|
||||
// Use the xyzSQLFields constant in the select query
|
||||
selectExpr := &ast.BinaryExpr{
|
||||
X: &ast.BinaryExpr{
|
||||
X: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t select `"},
|
||||
Op: token.ADD,
|
||||
Y: SQLFieldsConstIdent(tbl),
|
||||
},
|
||||
Op: token.ADD,
|
||||
Y: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`\n\t from %s\n\t where %s = ?\n\t`", tbl.TableName, col.Name)},
|
||||
}
|
||||
|
||||
param := ast.NewIdent(col.GoVarName())
|
||||
|
||||
return &ast.FuncDecl{
|
||||
Recv: dbRecv,
|
||||
Name: ast.NewIdent("Get" + schema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName()),
|
||||
Type: &ast.FuncType{
|
||||
Params: &ast.FieldList{List: []*ast.Field{
|
||||
{Names: []*ast.Ident{param}, Type: ast.NewIdent(col.GoTypeName())},
|
||||
}},
|
||||
Results: &ast.FieldList{List: []*ast.Field{
|
||||
{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)},
|
||||
{Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")},
|
||||
}},
|
||||
},
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("err")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, param}}},
|
||||
},
|
||||
&ast.IfStmt{
|
||||
Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}},
|
||||
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.GoTypeName)}, ast.NewIdent("ErrNotInDB")}}}},
|
||||
},
|
||||
&ast.ReturnStmt{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateGetAllItemsFunc produces an AST for the `GetAllXyzs()` function.
|
||||
// E.g., a table with `table.TypeName = "foods"` will produce a "GetAllFoods()" function.
|
||||
func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl {
|
||||
|
||||
@ -7,11 +7,11 @@ import (
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
|
||||
"git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
|
||||
pkgschema "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
|
||||
)
|
||||
|
||||
// GenerateModelTestAST produces an AST for a starter test file for a given model.
|
||||
func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
func GenerateModelTestAST(tbl pkgschema.Table, schema pkgschema.Schema, gomodName string) *ast.File {
|
||||
packageName := "db"
|
||||
testpackageName := packageName + "_test"
|
||||
|
||||
@ -44,6 +44,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
fieldName := ast.NewIdent("Description")
|
||||
description1 := `"an item"`
|
||||
description2 := `"a big item"`
|
||||
testDB := ast.NewIdent("TestDB")
|
||||
|
||||
hasCreatedAt, hasUpdatedAt := tbl.HasAutoTimestamps()
|
||||
|
||||
@ -69,6 +70,8 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
}
|
||||
|
||||
stmts := []ast.Stmt{
|
||||
Comment("Create"),
|
||||
|
||||
// item := Item{Description: "an item"}
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{testObj},
|
||||
@ -86,7 +89,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
|
||||
// TestDB.SaveItem(&item)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
|
||||
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
|
||||
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
|
||||
}},
|
||||
|
||||
@ -106,12 +109,15 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
}
|
||||
|
||||
stmts = append(stmts,
|
||||
BlankLine(),
|
||||
Comment("Load"),
|
||||
|
||||
// item2 := Must(TestDB.GetItemByID(item.ID))
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{testObj2},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{mustCall(&ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
|
||||
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
|
||||
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
|
||||
})},
|
||||
},
|
||||
@ -125,6 +131,11 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
&ast.SelectorExpr{X: testObj2, Sel: fieldName},
|
||||
},
|
||||
}},
|
||||
)
|
||||
|
||||
stmts = append(stmts,
|
||||
BlankLine(),
|
||||
Comment("Update"),
|
||||
|
||||
// item.Description = "a big item"
|
||||
&ast.AssignStmt{
|
||||
@ -135,18 +146,16 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
|
||||
// TestDB.SaveItem(&item)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
|
||||
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
|
||||
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
|
||||
}},
|
||||
)
|
||||
|
||||
stmts = append(stmts,
|
||||
// item2 = Must(TestDB.GetItemByID(item.ID))
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{testObj2},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{mustCall(&ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
|
||||
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
|
||||
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
|
||||
})},
|
||||
},
|
||||
@ -160,10 +169,50 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
&ast.SelectorExpr{X: testObj2, Sel: fieldName},
|
||||
},
|
||||
}},
|
||||
)
|
||||
|
||||
indexGets, hasIndexedGets := []ast.Stmt{
|
||||
BlankLine(),
|
||||
Comment("Indexed lookups"),
|
||||
}, false
|
||||
|
||||
for _, index := range schema.Indexes {
|
||||
if index.TableName != tbl.TableName {
|
||||
// Skip indexes on other tables
|
||||
continue
|
||||
}
|
||||
if index.IsUnique && len(index.Columns) == 1 {
|
||||
col := tbl.GetColumnByName(index.Columns[0])
|
||||
indexGets = append(indexGets, []ast.Stmt{
|
||||
// assert.Equal(t, item2, TestDB.GetItemByXYZ(...))
|
||||
&ast.ExprStmt{X: &ast.CallExpr{ // TODO: what if just delete the "ExprStmt" wrapper?
|
||||
Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")},
|
||||
Args: []ast.Expr{
|
||||
ast.NewIdent("t"),
|
||||
testObj2,
|
||||
mustCall(&ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + pkgschema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName())},
|
||||
Args: []ast.Expr{&ast.SelectorExpr{X: testObj2, Sel: ast.NewIdent(col.GoFieldName())}},
|
||||
}),
|
||||
},
|
||||
}},
|
||||
}...)
|
||||
// decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0])))
|
||||
hasIndexedGets = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasIndexedGets {
|
||||
stmts = append(stmts, indexGets...)
|
||||
}
|
||||
|
||||
stmts = append(stmts,
|
||||
BlankLine(),
|
||||
Comment("Delete"),
|
||||
|
||||
// TestDB.DeleteItem(item)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Delete" + tbl.GoTypeName)},
|
||||
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Delete" + tbl.GoTypeName)},
|
||||
Args: []ast.Expr{testObj},
|
||||
}},
|
||||
|
||||
@ -172,7 +221,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{X: ast.NewIdent("TestDB"), Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
|
||||
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
|
||||
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
|
||||
}},
|
||||
},
|
||||
@ -203,7 +252,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{
|
||||
X: ast.NewIdent("TestDB"),
|
||||
X: testDB,
|
||||
Sel: ast.NewIdent("GetAll" + inflection.Plural(tbl.GoTypeName)),
|
||||
},
|
||||
}},
|
||||
@ -259,7 +308,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
Rhs: []ast.Expr{
|
||||
&ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{
|
||||
X: ast.NewIdent("TestDB"),
|
||||
X: testDB,
|
||||
Sel: ast.NewIdent("Save" + tbl.GoTypeName),
|
||||
},
|
||||
Args: []ast.Expr{
|
||||
|
||||
@ -41,7 +41,23 @@ func (c Column) GoFieldName() string {
|
||||
return textutils.SnakeToCamel(c.Name)
|
||||
}
|
||||
|
||||
func (c Column) GoType() string {
|
||||
// GoVarName returns the name of a local variable for this column, e.g., when used as a function parameter.
|
||||
func (c Column) GoVarName() string {
|
||||
if c.Name == "rowid" {
|
||||
return strings.ToLower(c.TableName)[0:1] + "ID"
|
||||
// TODO: Or should it just be "id"??
|
||||
}
|
||||
// For foreign keys, use first letter of the target type and "ID". "UserID" => "uID"
|
||||
if c.IsNonCodeTableForeignKey() {
|
||||
return strings.ToLower(c.ForeignKeyTargetTable)[0:1] + "ID"
|
||||
}
|
||||
|
||||
// Otherwise, just lowercase the field name
|
||||
fieldname := c.GoFieldName()
|
||||
return strings.ToLower(fieldname)[0:1] + fieldname[1:]
|
||||
}
|
||||
|
||||
func (c Column) GoTypeName() string {
|
||||
if c.IsNonCodeTableForeignKey() {
|
||||
return TypenameFromTablename(c.ForeignKeyTargetTable) + "ID"
|
||||
}
|
||||
@ -101,6 +117,15 @@ func (t Table) PrimaryKeyColumns() []Column {
|
||||
return pks
|
||||
}
|
||||
|
||||
func (t Table) GetColumnByName(name string) Column {
|
||||
for _, c := range t.Columns {
|
||||
if c.Name == name {
|
||||
return c
|
||||
}
|
||||
}
|
||||
panic("no such column: " + name)
|
||||
}
|
||||
|
||||
func (t Table) HasAutoTimestamps() (hasCreatedAt bool, hasUpdatedAt bool) {
|
||||
for _, c := range t.Columns {
|
||||
if c.Name == "created_at" && c.Type == "integer" {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user