gas-stack/pkg/codegen/modelgenerate/generate_model.go
~wispem-wantex 1bc7f9111f
Some checks failed
CI / build-docker (push) Successful in 6s
CI / build-docker-bootstrap (push) Has been skipped
CI / release-test (push) Failing after 16s
codegen: implement "without rowid" tables
2026-03-18 10:39:38 -07:00

725 lines
24 KiB
Go

//nolint:lll // This file has lots of long lines lol
package modelgenerate
import (
"fmt"
"go/ast"
"go/token"
"strings"
"github.com/jinzhu/inflection"
"git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
"git.offline-twitter.com/offline-labs/gas-stack/pkg/textutils"
)
// ---------------
// Helpers
// ---------------
var (
dbRecv = &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}}
dbDB = &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent("DB")}
fmtErrorf = &ast.SelectorExpr{X: ast.NewIdent("fmt"), Sel: ast.NewIdent("Errorf")}
)
func SQLFieldsConstIdent(tbl schema.Table) *ast.Ident {
return ast.NewIdent(strings.ToLower(tbl.GoTypeName) + "SQLFields")
}
// GoTypeForColumn returns a type expression for this column.
//
// For most columns this isjust its mapped name as a `ast.NewIdent`, but for "blob" it needs
// a slice expression (`[]byte`).
func GoTypeForColumn(c schema.Column) ast.Expr {
if c.IsNonCodeTableForeignKey() {
return ast.NewIdent(schema.TypenameFromTablename(c.ForeignKeyTargetTable) + "ID")
}
switch c.Type {
case "integer", "int":
if strings.HasPrefix(c.Name, "is_") || strings.HasPrefix(c.Name, "has_") {
return ast.NewIdent("bool")
} else if strings.HasSuffix(c.Name, "_at") {
return ast.NewIdent("Timestamp")
}
return ast.NewIdent("int")
case "text":
return ast.NewIdent("string")
case "real":
return ast.NewIdent("float32")
case "blob":
return &ast.ArrayType{Elt: ast.NewIdent("byte")}
default:
panic("Unrecognized sqlite column type: " + c.Type)
}
}
func PanicIfRowsAffected(tbl schema.Table) *ast.IfStmt {
return &ast.IfStmt{
Cond: &ast.BinaryExpr{
X: mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")},
Args: []ast.Expr{},
}),
Op: token.NEQ,
Y: &ast.BasicLit{Kind: token.INT, Value: "1"},
},
Body: &ast.BlockStmt{List: []ast.Stmt{
&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent(tbl.VarName)}}},
}},
}
}
// ---------------
// Generators
// ---------------
// GenerateIDType produces an AST for the model's ID field.
func GenerateIDType(table schema.Table) *ast.GenDecl {
// e.g., `type FoodID int`
return &ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{&ast.TypeSpec{Name: ast.NewIdent(table.TypeIDName), Type: ast.NewIdent("int")}},
}
}
// GenerateModelAST produces an AST for a struct type corresponding to the model.
// TODO: generate the right field types here based on column types.
func GenerateModelAST(table schema.Table) *ast.GenDecl {
// Fields for the struct
fields := []*ast.Field{}
// Other fields (just strings for now)
for _, col := range table.Columns {
switch col.Name {
case "rowid":
fields = append(fields, &ast.Field{
Names: []*ast.Ident{ast.NewIdent("ID")},
Type: ast.NewIdent(table.TypeIDName),
Tag: &ast.BasicLit{Kind: token.STRING, Value: "`db:\"rowid\" json:\"id\"`"},
})
default:
if col.IsNonCodeTableForeignKey() {
fields = append(fields, &ast.Field{
Names: []*ast.Ident{ast.NewIdent(col.GoFieldName())},
Type: ast.NewIdent(schema.TypenameFromTablename(col.ForeignKeyTargetTable) + "ID"),
Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)},
})
} else {
fields = append(fields, &ast.Field{
Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))},
Type: GoTypeForColumn(col),
Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)},
})
}
}
}
return &ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{&ast.TypeSpec{
Name: ast.NewIdent(table.GoTypeName),
Type: &ast.StructType{Fields: &ast.FieldList{List: fields}},
}},
}
}
// buildFKCheckLambda builds the `checkForeignKeyFailures := func(err error) error { ... }` AST.
// Returns the assignment statement and whether any FK columns were found.
func buildFKCheckLambda(tbl schema.Table) (*ast.AssignStmt, bool) {
hasFks := false
stmt := &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("checkForeignKeyFailures")},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.FuncLit{
Type: &ast.FuncType{
Params: &ast.FieldList{
List: []*ast.Field{{
Names: []*ast.Ident{ast.NewIdent("err")},
Type: ast.NewIdent("error"),
}},
},
Results: &ast.FieldList{
List: []*ast.Field{{Type: ast.NewIdent("error")}},
},
},
Body: &ast.BlockStmt{
List: func() []ast.Stmt {
ret := []ast.Stmt{}
// if !isSqliteFkError(err) { return nil }
ret = append(ret, &ast.IfStmt{
Cond: &ast.UnaryExpr{Op: token.NOT, X: &ast.CallExpr{Fun: ast.NewIdent("IsSqliteFkError"), Args: []ast.Expr{ast.NewIdent("err")}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}}}},
})
for _, col := range tbl.Columns {
if !col.IsForeignKey { // Check both "real" foreign keys and code table values
continue
}
hasFks = true
structFieldName := col.GoFieldName()
structField := &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent(structFieldName)}
if col.IsNonCodeTableForeignKey() {
// Real foreign key; look up referent by ID to see if it exists
ret = append(ret, &ast.IfStmt{
Init: &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent(getByIDFuncName(col.ForeignKeyTargetTable))},
Args: []ast.Expr{structField},
},
},
},
Cond: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")},
Args: []ast.Expr{ast.NewIdent("err"), ast.NewIdent("ErrNotInDB")},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{
&ast.CallExpr{
Fun: ast.NewIdent("NewForeignKeyError"),
Args: []ast.Expr{
&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", structFieldName)},
&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", col.ForeignKeyTargetTable)},
structField,
},
},
},
},
},
},
})
} else {
// Code table value. Query the table to see if it exists
ret = append(ret, &ast.IfStmt{
Init: &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("err")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")},
Args: []ast.Expr{
&ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("int")}},
&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`select 1 from %s where rowid = ?`", col.ForeignKeyTargetTable)},
structField,
},
},
},
},
Cond: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")},
Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{
&ast.CallExpr{
Fun: ast.NewIdent("NewForeignKeyError"),
Args: []ast.Expr{
&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", structFieldName)},
&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", col.ForeignKeyTargetTable)},
structField,
},
},
},
},
},
},
})
}
}
// final return nil
ret = append(ret, &ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}})
return ret
}(),
},
},
},
}
return stmt, hasFks
}
// GenerateSaveItemFunc produces an AST for the SaveXyz() function of the model.
// E.g., a table with `table.TypeName = "foods"` will produce a "SaveFood()" function.
func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl {
insertCols := make([]string, 0, len(tbl.Columns))
insertVals := make([]string, 0, len(tbl.Columns))
updatePairs := make([]string, 0, len(tbl.Columns))
hasCreatedAt, hasUpdatedAt := tbl.HasAutoTimestamps()
// Assemble data for building SQL "insert" and "update" strings
for _, col := range tbl.Columns {
if col.Name == "rowid" {
continue
}
insertCols = append(insertCols, col.Name)
val := ":" + col.Name
if col.IsNullableForeignKey() {
val = fmt.Sprintf("nullif(%s, 0)", val)
}
insertVals = append(insertVals, val)
// created_at should not be updated after creation
if col.Name == "created_at" && hasCreatedAt {
continue
}
if !col.IsPrimaryKey { // Don't try to update primary key columns (mainly for w/o rowid tables)
updatePairs = append(updatePairs, col.Name+"="+val)
}
}
insertStmt := fmt.Sprintf("\n\t\t insert into %s (%s)\n\t\t values (%s)\n\t\t",
tbl.TableName,
strings.Join(insertCols, ", "),
strings.Join(insertVals, ", "),
)
updateStmt := fmt.Sprintf("\n\t\t update %s\n\t\t set %s\n\t\t where rowid = :rowid\n\t\t",
tbl.TableName,
strings.Join(updatePairs, ",\n\t\t "),
)
upsertStmt := fmt.Sprintf("\n\t insert into %s (%s)\n\t values (%s)\n\t",
tbl.TableName,
strings.Join(insertCols, ", "),
strings.Join(insertVals, ", "),
)
if len(updatePairs) == 0 {
upsertStmt = upsertStmt + " on conflict do nothing\n\t"
} else {
upsertStmt = upsertStmt + fmt.Sprintf(" on conflict do update\n\t set %s\n\t", strings.Join(updatePairs, ",\n\t "))
}
checkForeignKeyFailuresAssignment, hasFks := buildFKCheckLambda(tbl)
funcBody := &ast.BlockStmt{
List: func() []ast.Stmt {
ret := []ast.Stmt{}
if hasFks {
ret = append(ret, checkForeignKeyFailuresAssignment, BlankLine())
}
if hasUpdatedAt {
// Auto-timestamps: updated_at
ret = append(ret, &ast.AssignStmt{
Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("UpdatedAt")}},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("TimestampNow"), Args: []ast.Expr{}}},
})
}
namedExecStmt := func(stmt string) []ast.Stmt {
queryStmt := &ast.CallExpr{
Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("NamedExec")},
Args: []ast.Expr{
&ast.BasicLit{Kind: token.STRING, Value: "`" + stmt + "`"},
ast.NewIdent(tbl.VarName),
},
}
if !hasFks {
// No foreign key checking needed; just use `Must` for brevity
return []ast.Stmt{&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("result")},
Tok: token.DEFINE,
Rhs: []ast.Expr{mustCall(queryStmt)},
}}
}
// There's foreign keys
return []ast.Stmt{
// result, err := db.DB.NamedExec(`...`, u)
&ast.AssignStmt{
Lhs: []ast.Expr{
ast.NewIdent("result"),
ast.NewIdent("err"),
},
Tok: token.DEFINE,
Rhs: []ast.Expr{queryStmt},
},
// if fkErr := checkForeignKeyFailures(err); fkErr != nil { return fkErr } else if err != nil { panic(err) }
&ast.IfStmt{
Init: &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("fkErr")},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: ast.NewIdent("checkForeignKeyFailures"),
Args: []ast.Expr{ast.NewIdent("err")},
},
},
},
Cond: &ast.BinaryExpr{
X: ast.NewIdent("fkErr"),
Op: token.NEQ,
Y: ast.NewIdent("nil"),
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{ast.NewIdent("fkErr")},
},
},
},
Else: func() *ast.IfStmt {
panicStmt := &ast.ExprStmt{
X: &ast.CallExpr{
Fun: ast.NewIdent("panic"),
Args: []ast.Expr{ast.NewIdent("err")},
},
}
TrailingComments[panicStmt] = "not a foreign key error"
return &ast.IfStmt{
Cond: &ast.BinaryExpr{
X: ast.NewIdent("err"),
Op: token.NEQ,
Y: ast.NewIdent("nil"),
},
Body: &ast.BlockStmt{
List: []ast.Stmt{panicStmt},
},
}
}(),
},
}
}
if tbl.IsWithoutRowid {
ret = append(ret, namedExecStmt(upsertStmt)...)
ret = append(ret, PanicIfRowsAffected(tbl))
} else {
// if item.ID == 0 {...} else {...}
ret = append(ret, &ast.IfStmt{
Cond: &ast.BinaryExpr{
X: &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")},
Op: token.EQL,
Y: &ast.BasicLit{Kind: token.INT, Value: "0"},
},
Body: &ast.BlockStmt{
// Do create
List: append(
func() []ast.Stmt {
ret1 := []ast.Stmt{Comment("Do create")}
if hasCreatedAt {
// Auto-timestamps: created_at
ret1 = append(ret1, &ast.AssignStmt{
Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("CreatedAt")}},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("TimestampNow"), Args: []ast.Expr{}}},
})
}
return append(ret1, namedExecStmt(insertStmt)...)
}(),
&ast.AssignStmt{
Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: ast.NewIdent(tbl.TypeIDName),
Args: []ast.Expr{mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("LastInsertId")},
Args: []ast.Expr{},
})},
}},
},
),
},
Else: &ast.BlockStmt{
// Do update
List: append(
[]ast.Stmt{Comment("Do update")},
append(
namedExecStmt(updateStmt),
PanicIfRowsAffected(tbl),
)...,
),
},
})
}
if hasFks {
// If there's foreign key checking, it needs to return an error (or nil)
ret = append(ret, &ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}})
}
return ret
}(),
}
funcDecl := &ast.FuncDecl{
Doc: &ast.CommentGroup{List: []*ast.Comment{
{Text: fmt.Sprintf("// Save%s creates or updates a %s in the database.", tbl.GoTypeName, tbl.GoTypeName)},
{Text: "// If the item doesn't exist (has no ID set), it will create it; otherwise it will do an update."},
}},
Recv: dbRecv,
Name: ast.NewIdent("Save" + tbl.GoTypeName),
Type: &ast.FuncType{
Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: &ast.StarExpr{X: ast.NewIdent(tbl.GoTypeName)}}}},
Results: func() *ast.FieldList {
if hasFks {
return &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("error")}}}
}
return nil
}(),
},
Body: funcBody,
}
return funcDecl
}
func getByIDFuncName(tblname string) string {
return "Get" + schema.TypenameFromTablename(tblname) + "ByID"
}
func GenerateGetItemBy(tbl schema.Table, cols []schema.Column) *ast.FuncDecl {
colNames := []string{}
funcNameSuffix := []string{}
funcParams := &ast.FieldList{List: []*ast.Field{}}
sqlParams := []ast.Expr{}
for _, col := range cols {
funcParam := ast.NewIdent(col.LongGoVarName())
funcParams.List = append(funcParams.List, &ast.Field{Names: []*ast.Ident{funcParam}, Type: GoTypeForColumn(col)})
colNames = append(colNames, fmt.Sprintf("%s = :%s", col.Name, col.Name))
funcNameSuffix = append(funcNameSuffix, col.GoFieldName())
sqlParams = append(sqlParams, funcParam)
}
selectExpr := &ast.BinaryExpr{
X: &ast.BinaryExpr{
X: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t select `"},
Op: token.ADD,
Y: SQLFieldsConstIdent(tbl),
},
Op: token.ADD,
Y: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`\n\t from %s\n\t where %s = ?\n\t`", tbl.TableName, strings.Join(colNames, " and "))},
}
return &ast.FuncDecl{
Recv: dbRecv,
Name: ast.NewIdent(fmt.Sprintf("Get%sBy%s", schema.TypenameFromTablename(tbl.TableName), strings.Join(funcNameSuffix, "And"))),
Type: &ast.FuncType{
Params: funcParams,
Results: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)},
{Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")},
}},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("err")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")},
Args: append([]ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr}, sqlParams...),
}},
},
&ast.IfStmt{
Cond: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")},
Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.GoTypeName)}, ast.NewIdent("ErrNotInDB")}}}},
},
&ast.ReturnStmt{},
},
},
}
}
// GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function.
// E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function.
func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl {
arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("id")}, Type: ast.NewIdent(tbl.TypeIDName)}}}
result := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)}, {Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")}}}
// Use the xyzSQLFields constant in the select query
selectExpr := &ast.BinaryExpr{
X: &ast.BinaryExpr{
X: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t select `"},
Op: token.ADD,
Y: SQLFieldsConstIdent(tbl),
},
Op: token.ADD,
Y: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t from " + tbl.TableName + "\n\t where rowid = ?\n\t`"},
}
funcBody := &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("err")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, ast.NewIdent("id")}}},
},
&ast.IfStmt{
Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.GoTypeName)}, ast.NewIdent("ErrNotInDB")}}}},
},
&ast.ReturnStmt{},
},
}
return &ast.FuncDecl{
Recv: dbRecv,
Name: ast.NewIdent(getByIDFuncName(tbl.TableName)),
Type: &ast.FuncType{Params: arg, Results: result},
Body: funcBody,
}
}
// GenerateGetItemByUniqColFunc produces an AST for the `GetXyzByID()` function.
// E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function.
func GenerateGetItemByUniqColFunc(tbl schema.Table, col schema.Column) *ast.FuncDecl {
// Use the xyzSQLFields constant in the select query
selectExpr := &ast.BinaryExpr{
X: &ast.BinaryExpr{
X: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t select `"},
Op: token.ADD,
Y: SQLFieldsConstIdent(tbl),
},
Op: token.ADD,
Y: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`\n\t from %s\n\t where %s = ?\n\t`", tbl.TableName, col.Name)},
}
param := ast.NewIdent(col.GoVarName())
return &ast.FuncDecl{
Recv: dbRecv,
Name: ast.NewIdent("Get" + schema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName()),
Type: &ast.FuncType{
Params: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{param}, Type: GoTypeForColumn(col)},
}},
Results: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)},
{Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")},
}},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("err")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, param}}},
},
&ast.IfStmt{
Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.GoTypeName)}, ast.NewIdent("ErrNotInDB")}}}},
},
&ast.ReturnStmt{},
},
},
}
}
// GenerateGetAllItemsFunc produces an AST for the `GetAllXyzs()` function.
// E.g., a table with `table.TypeName = "foods"` will produce a "GetAllFoods()" function.
func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl {
funcName := "GetAll" + inflection.Plural(tbl.GoTypeName)
result := &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: &ast.ArrayType{Elt: ast.NewIdent(tbl.GoTypeName)}},
}}
selectCall := &ast.CallExpr{
Fun: ast.NewIdent("PanicIf"),
Args: []ast.Expr{
&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: dbDB,
Sel: ast.NewIdent("Select"),
},
Args: []ast.Expr{
&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")},
&ast.BinaryExpr{
X: &ast.BinaryExpr{
X: &ast.BasicLit{Kind: token.STRING, Value: "`select `"},
Op: token.ADD,
Y: SQLFieldsConstIdent(tbl),
},
Op: token.ADD,
Y: &ast.BasicLit{Kind: token.STRING, Value: "` from " + tbl.TableName + "`"},
},
},
},
},
}
funcBody := &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{X: selectCall},
&ast.ReturnStmt{},
},
}
return &ast.FuncDecl{
Recv: dbRecv,
Name: ast.NewIdent(funcName),
Type: &ast.FuncType{
Params: &ast.FieldList{},
Results: result,
},
Body: funcBody,
}
}
// GenerateDeleteItemFunc produces an AST for the `DeleteXyz()` function.
// E.g., a table with `table.TypeName = "foods"` will produce a "DeleteFood()" function.
func GenerateDeleteItemFunc(tbl schema.Table) *ast.FuncDecl {
colNames := []string{}
for _, c := range tbl.PrimaryKeyColumns() {
colNames = append(colNames, fmt.Sprintf("%s = :%s", c.Name, c.Name))
}
sqlStr := "`delete from " + tbl.TableName + fmt.Sprintf(" where %s`", strings.Join(colNames, " and "))
funcBody := &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("result")},
Tok: token.DEFINE,
Rhs: []ast.Expr{mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("NamedExec")},
Args: []ast.Expr{
&ast.BasicLit{Kind: token.STRING, Value: sqlStr},
ast.NewIdent(tbl.VarName),
},
})},
},
PanicIfRowsAffected(tbl),
},
}
funcDecl := &ast.FuncDecl{
Recv: dbRecv,
Name: ast.NewIdent("Delete" + tbl.GoTypeName),
Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{
Names: []*ast.Ident{ast.NewIdent(tbl.VarName)},
Type: ast.NewIdent(tbl.GoTypeName),
}}}, Results: nil},
Body: funcBody,
}
return funcDecl
}
// GenerateSQLFieldsConst produces an AST for the `const xyzSQLFields = ...` string.
func GenerateSQLFieldsConst(tbl schema.Table) *ast.GenDecl {
columns := make([]string, 0, len(tbl.Columns))
for _, col := range tbl.Columns {
if col.IsNullableForeignKey() {
columns = append(columns, fmt.Sprintf("ifnull(%s, 0) %s", col.Name, col.Name))
} else {
columns = append(columns, col.Name)
}
}
// Join with comma and space
value := "`" + strings.Join(columns, ", ") + "`"
return &ast.GenDecl{
Tok: token.CONST,
Specs: []ast.Spec{
&ast.ValueSpec{
Names: []*ast.Ident{SQLFieldsConstIdent(tbl)},
Values: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: value}},
},
},
}
}