diff --git a/pkg/codegen/modelgenerate/ast_helpers.go b/pkg/codegen/modelgenerate/ast_helpers.go new file mode 100644 index 0000000..a7ad809 --- /dev/null +++ b/pkg/codegen/modelgenerate/ast_helpers.go @@ -0,0 +1,11 @@ +package modelgenerate + +import "go/ast" + +// mustCall wraps a call expression in Must(...), producing AST for Must(inner). +func mustCall(inner ast.Expr) *ast.CallExpr { + return &ast.CallExpr{ + Fun: ast.NewIdent("Must"), + Args: []ast.Expr{inner}, + } +} diff --git a/pkg/codegen/modelgenerate/generate_model.go b/pkg/codegen/modelgenerate/generate_model.go index 72b12a6..e837a87 100644 --- a/pkg/codegen/modelgenerate/generate_model.go +++ b/pkg/codegen/modelgenerate/generate_model.go @@ -13,12 +13,18 @@ import ( "git.offline-twitter.com/offline-labs/gas-stack/pkg/textutils" ) -func GenerateIDType(table schema.Table) *ast.GenDecl { - // e.g., `type FoodID int` - return &ast.GenDecl{ - Tok: token.TYPE, - Specs: []ast.Spec{&ast.TypeSpec{Name: ast.NewIdent(table.TypeIDName), Type: ast.NewIdent("int")}}, - } +// --------------- +// Helpers +// --------------- + +var ( + dbRecv = &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}} + dbDB = &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent("DB")} + fmtErrorf = &ast.SelectorExpr{X: ast.NewIdent("fmt"), Sel: ast.NewIdent("Errorf")} +) + +func SQLFieldsConstIdent(tbl schema.Table) *ast.Ident { + return ast.NewIdent(strings.ToLower(tbl.GoTypeName) + "SQLFields") } func fkFieldName(col schema.Column) string { @@ -29,6 +35,19 @@ func fkFieldName(col schema.Column) string { } } +// --------------- +// Generators +// --------------- + +// GenerateIDType produces an AST for the model's ID field. +func GenerateIDType(table schema.Table) *ast.GenDecl { + // e.g., `type FoodID int` + return &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{&ast.TypeSpec{Name: ast.NewIdent(table.TypeIDName), Type: ast.NewIdent("int")}}, + } +} + // GenerateModelAST produces an AST for a struct type corresponding to the model. // TODO: generate the right field types here based on column types. func GenerateModelAST(table schema.Table) *ast.GenDecl { @@ -203,7 +222,7 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { Tok: token.ASSIGN, Rhs: []ast.Expr{ &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent("DB")}, Sel: ast.NewIdent("Get")}, + Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{ &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("int")}}, &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`select 1 from %s where rowid = ?`", col.ForeignKeyTargetTable)}, @@ -269,17 +288,17 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { // Do create List: append( func() []ast.Stmt { - ret := []ast.Stmt{} + ret1 := []ast.Stmt{} if hasCreatedAt { // Auto-timestamps: created_at - ret = append(ret, &ast.AssignStmt{ + ret1 = append(ret1, &ast.AssignStmt{ Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("CreatedAt")}}, Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("TimestampNow"), Args: []ast.Expr{}}}, }) } namedExecStmt := &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, + Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("NamedExec")}, Args: []ast.Expr{ &ast.BasicLit{Kind: token.STRING, Value: "`" + insertStmt + "`"}, ast.NewIdent(tbl.VarName), @@ -287,17 +306,14 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { } if !hasFks { // No foreign key checking needed; just use `Must` for brevity - return append(ret, &ast.AssignStmt{ + return append(ret1, &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("result")}, Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{namedExecStmt}, - }}, + Rhs: []ast.Expr{mustCall(namedExecStmt)}, }) } - return append(ret, + return append(ret1, // result, err := db.DB.NamedExec(`...`, u) &ast.AssignStmt{ Lhs: []ast.Expr{ @@ -357,13 +373,10 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.CallExpr{ Fun: ast.NewIdent(tbl.TypeIDName), - Args: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("LastInsertId")}, - Args: []ast.Expr{}, - }}, - }}, + Args: []ast.Expr{mustCall(&ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("LastInsertId")}, + Args: []ast.Expr{}, + })}, }}, }, ), @@ -374,28 +387,22 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("result")}, Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`" + updateStmt + "`"}, ast.NewIdent(tbl.VarName)}, - }}, - }}, + Rhs: []ast.Expr{mustCall(&ast.CallExpr{ + Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("NamedExec")}, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`" + updateStmt + "`"}, ast.NewIdent(tbl.VarName)}, + })}, }, &ast.IfStmt{ Cond: &ast.BinaryExpr{ - X: &ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, - Args: []ast.Expr{}, - }}, - }, + X: mustCall(&ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, + Args: []ast.Expr{}, + }), Op: token.NEQ, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, }, - Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("fmt.Errorf"), Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"got %s with ID (%%d), so attempted update, but it doesn't exist\"", strings.ToLower(tbl.GoTypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}}}}}}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{Fun: fmtErrorf, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"got %s with ID (%%d), so attempted update, but it doesn't exist\"", strings.ToLower(tbl.GoTypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}}}}}}, }, }, }, @@ -409,7 +416,7 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { } funcDecl := &ast.FuncDecl{ - Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}}, + Recv: dbRecv, Name: ast.NewIdent("Save" + tbl.GoTypeName), Type: &ast.FuncType{ Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: &ast.StarExpr{X: ast.NewIdent(tbl.GoTypeName)}}}}, @@ -432,7 +439,6 @@ func getByIDFuncName(tblname string) string { // GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function. // E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function. func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl { - recv := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}} arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("id")}, Type: ast.NewIdent(tbl.TypeIDName)}}} result := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)}, {Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")}}} @@ -452,7 +458,7 @@ func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl { &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("err")}, Tok: token.ASSIGN, - Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, ast.NewIdent("id")}}}, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, ast.NewIdent("id")}}}, }, &ast.IfStmt{ Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}}, @@ -463,7 +469,7 @@ func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl { } funcDecl := &ast.FuncDecl{ - Recv: recv, + Recv: dbRecv, Name: ast.NewIdent(getByIDFuncName(tbl.TableName)), Type: &ast.FuncType{Params: arg, Results: result}, Body: funcBody, @@ -475,9 +481,6 @@ func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl { // E.g., a table with `table.TypeName = "foods"` will produce a "GetAllFoods()" function. func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl { funcName := "GetAll" + inflection.Plural(tbl.GoTypeName) - recv := &ast.FieldList{List: []*ast.Field{ - {Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}, - }} result := &ast.FieldList{List: []*ast.Field{ {Names: []*ast.Ident{ast.NewIdent("ret")}, Type: &ast.ArrayType{Elt: ast.NewIdent(tbl.GoTypeName)}}, }} @@ -487,7 +490,7 @@ func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl { Args: []ast.Expr{ &ast.CallExpr{ Fun: &ast.SelectorExpr{ - X: ast.NewIdent("db.DB"), + X: dbDB, Sel: ast.NewIdent("Select"), }, Args: []ast.Expr{ @@ -514,7 +517,7 @@ func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl { } return &ast.FuncDecl{ - Recv: recv, + Recv: dbRecv, Name: ast.NewIdent(funcName), Type: &ast.FuncType{ Params: &ast.FieldList{}, @@ -528,7 +531,6 @@ func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl { // E.g., a table with `table.TypeName = "foods"` will produce a "DeleteFood()" function. func GenerateDeleteItemFunc(tbl schema.Table) *ast.FuncDecl { funcName := "Delete" + tbl.GoTypeName - recv := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}} arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: ast.NewIdent(tbl.GoTypeName)}}} funcBody := &ast.BlockStmt{ @@ -536,25 +538,19 @@ func GenerateDeleteItemFunc(tbl schema.Table) *ast.FuncDecl { &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("result")}, Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Exec")}, - Args: []ast.Expr{ - &ast.BasicLit{Kind: token.STRING, Value: "`delete from " + tbl.TableName + " where rowid = ?`"}, - &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}, - }, - }}, - }}, + Rhs: []ast.Expr{mustCall(&ast.CallExpr{ + Fun: &ast.SelectorExpr{X: dbDB, Sel: ast.NewIdent("Exec")}, + Args: []ast.Expr{ + &ast.BasicLit{Kind: token.STRING, Value: "`delete from " + tbl.TableName + " where rowid = ?`"}, + &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}, + }, + })}, }, &ast.IfStmt{ Cond: &ast.BinaryExpr{ - X: &ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{ - &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, Args: []ast.Expr{}}, - }, - }, + X: mustCall( + &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, Args: []ast.Expr{}}, + ), Op: token.NEQ, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, }, @@ -562,7 +558,7 @@ func GenerateDeleteItemFunc(tbl schema.Table) *ast.FuncDecl { &ast.ExprStmt{X: &ast.CallExpr{ Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("fmt.Errorf"), + Fun: fmtErrorf, Args: []ast.Expr{ &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"tried to delete %s with ID (%%d) but it doesn't exist\"", strings.ToLower(tbl.GoTypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}, @@ -575,7 +571,7 @@ func GenerateDeleteItemFunc(tbl schema.Table) *ast.FuncDecl { } funcDecl := &ast.FuncDecl{ - Recv: recv, + Recv: dbRecv, Name: ast.NewIdent(funcName), Type: &ast.FuncType{Params: arg, Results: nil}, Body: funcBody, @@ -606,11 +602,3 @@ func GenerateSQLFieldsConst(tbl schema.Table) *ast.GenDecl { }, } } - -// --------------- -// Helpers -// --------------- - -func SQLFieldsConstIdent(tbl schema.Table) *ast.Ident { - return ast.NewIdent(strings.ToLower(tbl.GoTypeName) + "SQLFields") -} diff --git a/pkg/codegen/modelgenerate/generate_testfile.go b/pkg/codegen/modelgenerate/generate_testfile.go index a003aba..a42e458 100644 --- a/pkg/codegen/modelgenerate/generate_testfile.go +++ b/pkg/codegen/modelgenerate/generate_testfile.go @@ -61,19 +61,16 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("db")}, Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Must"), + Rhs: []ast.Expr{mustCall(&ast.CallExpr{ + Fun: ast.NewIdent("Create"), Args: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Create"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("fmt.Sprintf"), - Args: []ast.Expr{ - &ast.BasicLit{Kind: token.STRING, Value: `"file:%s?mode=memory&cache=shared"`}, - ast.NewIdent("dbName"), - }, - }}, + Fun: ast.NewIdent("fmt.Sprintf"), + Args: []ast.Expr{ + &ast.BasicLit{Kind: token.STRING, Value: `"file:%s?mode=memory&cache=shared"`}, + ast.NewIdent("dbName"), + }, }}, - }}, + })}, }, // return db &ast.ReturnStmt{ @@ -176,13 +173,10 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { &ast.AssignStmt{ Lhs: []ast.Expr{testObj2}, Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("TestDB.Get" + tbl.GoTypeName + "ByID"), - Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, - }}, - }}, + Rhs: []ast.Expr{mustCall(&ast.CallExpr{ + Fun: ast.NewIdent("TestDB.Get" + tbl.GoTypeName + "ByID"), + Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, + })}, }, // assert.Equal(t, "an item", item2.Description) @@ -222,13 +216,10 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { &ast.AssignStmt{ Lhs: []ast.Expr{testObj2}, Tok: token.ASSIGN, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("TestDB.Get" + tbl.GoTypeName + "ByID"), - Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, - }}, - }}, + Rhs: []ast.Expr{mustCall(&ast.CallExpr{ + Fun: ast.NewIdent("TestDB.Get" + tbl.GoTypeName + "ByID"), + Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, + })}, }, // assert.Equal(t, item.Description, item2.Description)