//nolint:lll // This file has lots of long lines lol package modelgenerate import ( "fmt" "go/ast" "go/token" "strings" "github.com/jinzhu/inflection" "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" "git.offline-twitter.com/offline-labs/gas-stack/pkg/textutils" ) func GenerateIDType(table schema.Table) *ast.GenDecl { // e.g., `type FoodID int` return &ast.GenDecl{ Tok: token.TYPE, Specs: []ast.Spec{&ast.TypeSpec{Name: ast.NewIdent(table.TypeIDName), Type: ast.NewIdent("int")}}, } } func fkFieldName(col schema.Column) string { if col.IsNonCodeTableForeignKey() { return textutils.SnakeToCamel(strings.TrimSuffix(col.Name, "_id")) + "ID" } else { return textutils.SnakeToCamel(col.Name) } } // GenerateModelAST produces an AST for a struct type corresponding to the model. // TODO: generate the right field types here based on column types. func GenerateModelAST(table schema.Table) *ast.GenDecl { // Fields for the struct fields := []*ast.Field{} // Other fields (just strings for now) for _, col := range table.Columns { switch col.Name { case "rowid": fields = append(fields, &ast.Field{ Names: []*ast.Ident{ast.NewIdent("ID")}, Type: ast.NewIdent(table.TypeIDName), Tag: &ast.BasicLit{Kind: token.STRING, Value: "`db:\"rowid\" json:\"id\"`"}, }) default: if col.IsNonCodeTableForeignKey() { fields = append(fields, &ast.Field{ Names: []*ast.Ident{ast.NewIdent(fkFieldName(col))}, Type: ast.NewIdent(schema.TypenameFromTablename(col.ForeignKeyTargetTable) + "ID"), Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)}, }) } else { typeName := "string" switch col.Type { case "integer", "int": if strings.HasPrefix(col.Name, "is_") || strings.HasPrefix(col.Name, "has_") { typeName = "bool" } else if strings.HasSuffix(col.Name, "_at") { typeName = "Timestamp" } else { typeName = "int" } case "text": typeName = "string" case "real": typeName = "float32" case "blob": typeName = "[]byte" default: panic("Unrecognized sqlite column type: " + col.Type) } fields = append(fields, &ast.Field{ Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))}, Type: ast.NewIdent(typeName), Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)}, }) } } } return &ast.GenDecl{ Tok: token.TYPE, Specs: []ast.Spec{&ast.TypeSpec{ Name: ast.NewIdent(table.GoTypeName), Type: &ast.StructType{Fields: &ast.FieldList{List: fields}}, }}, } } // GenerateSaveItemFunc produces an AST for the SaveXyz() function of the model. // E.g., a table with `table.TypeName = "foods"` will produce a "SaveFood()" function. func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { insertCols := make([]string, 0, len(tbl.Columns)) insertVals := make([]string, 0, len(tbl.Columns)) updatePairs := make([]string, 0, len(tbl.Columns)) for _, col := range tbl.Columns { if col.Name == "rowid" { continue } insertCols = append(insertCols, col.Name) val := ":" + col.Name if col.IsNullableForeignKey() { val = fmt.Sprintf("nullif(%s, 0)", val) } insertVals = append(insertVals, val) updatePairs = append(updatePairs, col.Name+"="+val) } hasFks := false checkForeignKeyFailuresAssignment := &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("checkForeignKeyFailures")}, Tok: token.DEFINE, Rhs: []ast.Expr{ &ast.FuncLit{ Type: &ast.FuncType{ Params: &ast.FieldList{ List: []*ast.Field{{ Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error"), }}, }, Results: &ast.FieldList{ List: []*ast.Field{{Type: ast.NewIdent("error")}}, }, }, Body: &ast.BlockStmt{ List: func() []ast.Stmt { ret := []ast.Stmt{} // if !isSqliteFkError(err) { return nil } ret = append(ret, &ast.IfStmt{ Cond: &ast.UnaryExpr{Op: token.NOT, X: &ast.CallExpr{Fun: ast.NewIdent("IsSqliteFkError"), Args: []ast.Expr{ast.NewIdent("err")}}}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}}}}, }) for _, col := range tbl.Columns { if !col.IsForeignKey { // Check both "real" foreign keys and code table values continue } hasFks = true structFieldName := fkFieldName(col) structField := &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent(structFieldName)} if col.IsNonCodeTableForeignKey() { // Real foreign key; look up referent by ID to see if it exists ret = append(ret, &ast.IfStmt{ Init: &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")}, Tok: token.DEFINE, Rhs: []ast.Expr{ &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent(getByIDFuncName(col.ForeignKeyTargetTable))}, Args: []ast.Expr{ structField, }, }, }, }, Cond: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ ast.NewIdent("err"), ast.NewIdent("ErrNotInDB"), }, }, Body: &ast.BlockStmt{ List: []ast.Stmt{ &ast.ReturnStmt{ Results: []ast.Expr{ &ast.CallExpr{ Fun: ast.NewIdent("NewForeignKeyError"), Args: []ast.Expr{ &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", structFieldName)}, &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", col.ForeignKeyTargetTable)}, structField, }, }, }, }, }, }, }) } else { // Code table value. Query the table to see if it exists ret = append(ret, &ast.IfStmt{ Init: &ast.AssignStmt{ Lhs: []ast.Expr{ ast.NewIdent("err"), }, Tok: token.ASSIGN, Rhs: []ast.Expr{ &ast.CallExpr{ Fun: &ast.SelectorExpr{X: &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent("DB")}, Sel: ast.NewIdent("Get")}, Args: []ast.Expr{ &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("int")}}, &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`select 1 from %s where rowid = ?`", col.ForeignKeyTargetTable)}, structField, }, }, }, }, Cond: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}, }, Body: &ast.BlockStmt{ List: []ast.Stmt{ &ast.ReturnStmt{ Results: []ast.Expr{ &ast.CallExpr{ Fun: ast.NewIdent("NewForeignKeyError"), Args: []ast.Expr{ &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", structFieldName)}, ast.NewIdent(fmt.Sprintf("%q", col.ForeignKeyTargetTable)), structField, }, }, }, }, }, }, }) } } // final return nil ret = append(ret, &ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}}) return ret }(), }, }, }, } insertStmt := fmt.Sprintf("\n\t\t insert into %s (%s)\n\t\t values (%s)\n\t\t", tbl.TableName, strings.Join(insertCols, ", "), strings.Join(insertVals, ", ")) updateStmt := fmt.Sprintf("\n\t\t update %s\n\t\t set %s\n\t\t where rowid = :rowid\n\t\t", tbl.TableName, strings.Join(updatePairs, ",\n\t\t ")) funcBody := &ast.BlockStmt{ List: func() []ast.Stmt { ret := []ast.Stmt{} if hasFks { ret = append(ret, checkForeignKeyFailuresAssignment) } // if item.ID == 0 {...} else {...} ret = append(ret, &ast.IfStmt{ Cond: &ast.BinaryExpr{ X: &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}, Op: token.EQL, Y: &ast.BasicLit{Kind: token.INT, Value: "0"}, }, Body: &ast.BlockStmt{ // Do create List: append( func() []ast.Stmt { namedExecStmt := &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, Args: []ast.Expr{ &ast.BasicLit{Kind: token.STRING, Value: "`" + insertStmt + "`"}, ast.NewIdent(tbl.VarName), }, } if !hasFks { // No foreign key checking needed; just use `Must` for brevity return []ast.Stmt{ &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("result")}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: ast.NewIdent("Must"), Args: []ast.Expr{namedExecStmt}, }}, }, } } return []ast.Stmt{ // result, err := db.DB.NamedExec(`...`, u) &ast.AssignStmt{ Lhs: []ast.Expr{ ast.NewIdent("result"), ast.NewIdent("err"), }, Tok: token.DEFINE, Rhs: []ast.Expr{namedExecStmt}, }, // if fkErr := checkForeignKeyFailures(err); fkErr != nil { return fkErr } else if err != nil { panic(err) } &ast.IfStmt{ Init: &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("fkErr")}, Tok: token.DEFINE, Rhs: []ast.Expr{ &ast.CallExpr{ Fun: ast.NewIdent("checkForeignKeyFailures"), Args: []ast.Expr{ast.NewIdent("err")}, }, }, }, Cond: &ast.BinaryExpr{ X: ast.NewIdent("fkErr"), Op: token.NEQ, Y: ast.NewIdent("nil"), }, Body: &ast.BlockStmt{ List: []ast.Stmt{ &ast.ReturnStmt{ Results: []ast.Expr{ast.NewIdent("fkErr")}, }, }, }, Else: &ast.IfStmt{ Cond: &ast.BinaryExpr{ X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil"), }, Body: &ast.BlockStmt{ List: []ast.Stmt{ &ast.ExprStmt{ X: &ast.CallExpr{ Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}, }, }, }, }, }, }, } }(), &ast.AssignStmt{ Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}, Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.CallExpr{ Fun: ast.NewIdent(tbl.TypeIDName), Args: []ast.Expr{&ast.CallExpr{ Fun: ast.NewIdent("Must"), Args: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("LastInsertId")}, Args: []ast.Expr{}, }}, }}, }}, }, ), }, Else: &ast.BlockStmt{ // Do update List: []ast.Stmt{ &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("result")}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: ast.NewIdent("Must"), Args: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`" + updateStmt + "`"}, ast.NewIdent(tbl.VarName)}, }}, }}, }, &ast.IfStmt{ Cond: &ast.BinaryExpr{ X: &ast.CallExpr{ Fun: ast.NewIdent("Must"), Args: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, Args: []ast.Expr{}, }}, }, Op: token.NEQ, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, }, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("fmt.Errorf"), Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"got %s with ID (%%d), so attempted update, but it doesn't exist\"", strings.ToLower(tbl.GoTypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}}}}}}, }, }, }, }) if hasFks { // If there's foreign key checking, it needs to return an error (or nil) ret = append(ret, &ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}}) } return ret }(), } funcDecl := &ast.FuncDecl{ Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}}, Name: ast.NewIdent("Save" + tbl.GoTypeName), Type: &ast.FuncType{ Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: &ast.StarExpr{X: ast.NewIdent(tbl.GoTypeName)}}}}, Results: func() *ast.FieldList { if hasFks { return &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("error")}}} } return nil }(), }, Body: funcBody, } return funcDecl } func getByIDFuncName(tblname string) string { return "Get" + schema.TypenameFromTablename(tblname) + "ByID" } // GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function. // E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function. func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl { recv := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}} arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("id")}, Type: ast.NewIdent(tbl.TypeIDName)}}} result := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.GoTypeName)}, {Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")}}} // Use the xyzSQLFields constant in the select query selectExpr := &ast.BinaryExpr{ X: &ast.BinaryExpr{ X: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t select `"}, Op: token.ADD, Y: SQLFieldsConstIdent(tbl), }, Op: token.ADD, Y: &ast.BasicLit{Kind: token.STRING, Value: "`\n\t from " + tbl.TableName + "\n\t where rowid = ?\n\t`"}, } funcBody := &ast.BlockStmt{ List: []ast.Stmt{ &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("err")}, Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, selectExpr, ast.NewIdent("id")}}}, }, &ast.IfStmt{ Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.GoTypeName)}, ast.NewIdent("ErrNotInDB")}}}}, }, &ast.ReturnStmt{}, }, } funcDecl := &ast.FuncDecl{ Recv: recv, Name: ast.NewIdent(getByIDFuncName(tbl.TableName)), Type: &ast.FuncType{Params: arg, Results: result}, Body: funcBody, } return funcDecl } // GenerateGetAllItemsFunc produces an AST for the `GetAllXyzs()` function. // E.g., a table with `table.TypeName = "foods"` will produce a "GetAllFoods()" function. func GenerateGetAllItemsFunc(tbl schema.Table) *ast.FuncDecl { funcName := "GetAll" + inflection.Plural(tbl.GoTypeName) recv := &ast.FieldList{List: []*ast.Field{ {Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}, }} result := &ast.FieldList{List: []*ast.Field{ {Names: []*ast.Ident{ast.NewIdent("ret")}, Type: &ast.ArrayType{Elt: ast.NewIdent(tbl.GoTypeName)}}, }} selectCall := &ast.CallExpr{ Fun: ast.NewIdent("PanicIf"), Args: []ast.Expr{ &ast.CallExpr{ Fun: &ast.SelectorExpr{ X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Select"), }, Args: []ast.Expr{ &ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, &ast.BinaryExpr{ X: &ast.BinaryExpr{ X: &ast.BasicLit{Kind: token.STRING, Value: "`select `"}, Op: token.ADD, Y: SQLFieldsConstIdent(tbl), }, Op: token.ADD, Y: &ast.BasicLit{Kind: token.STRING, Value: "` from " + tbl.TableName + "`"}, }, }, }, }, } funcBody := &ast.BlockStmt{ List: []ast.Stmt{ &ast.ExprStmt{X: selectCall}, &ast.ReturnStmt{}, }, } return &ast.FuncDecl{ Recv: recv, Name: ast.NewIdent(funcName), Type: &ast.FuncType{ Params: &ast.FieldList{}, Results: result, }, Body: funcBody, } } // GenerateDeleteItemFunc produces an AST for the `DeleteXyz()` function. // E.g., a table with `table.TypeName = "foods"` will produce a "DeleteFood()" function. func GenerateDeleteItemFunc(tbl schema.Table) *ast.FuncDecl { funcName := "Delete" + tbl.GoTypeName recv := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}} arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: ast.NewIdent(tbl.GoTypeName)}}} funcBody := &ast.BlockStmt{ List: []ast.Stmt{ &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("result")}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: ast.NewIdent("Must"), Args: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Exec")}, Args: []ast.Expr{ &ast.BasicLit{Kind: token.STRING, Value: "`delete from " + tbl.TableName + " where rowid = ?`"}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}, }, }}, }}, }, &ast.IfStmt{ Cond: &ast.BinaryExpr{ X: &ast.CallExpr{ Fun: ast.NewIdent("Must"), Args: []ast.Expr{ &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, Args: []ast.Expr{}}, }, }, Op: token.NEQ, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, }, Body: &ast.BlockStmt{List: []ast.Stmt{ &ast.ExprStmt{X: &ast.CallExpr{ Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{ Fun: ast.NewIdent("fmt.Errorf"), Args: []ast.Expr{ &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"tried to delete %s with ID (%%d) but it doesn't exist\"", strings.ToLower(tbl.GoTypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}, }, }}, }}, }}, }, }, } funcDecl := &ast.FuncDecl{ Recv: recv, Name: ast.NewIdent(funcName), Type: &ast.FuncType{Params: arg, Results: nil}, Body: funcBody, } return funcDecl } // GenerateSQLFieldsConst produces an AST for the `const xyzSQLFields = ...` string. func GenerateSQLFieldsConst(tbl schema.Table) *ast.GenDecl { columns := make([]string, 0, len(tbl.Columns)) for _, col := range tbl.Columns { if col.IsNullableForeignKey() { columns = append(columns, fmt.Sprintf("ifnull(%s, 0) %s", col.Name, col.Name)) } else { columns = append(columns, col.Name) } } // Join with comma and space value := "`" + strings.Join(columns, ", ") + "`" return &ast.GenDecl{ Tok: token.CONST, Specs: []ast.Spec{ &ast.ValueSpec{ Names: []*ast.Ident{SQLFieldsConstIdent(tbl)}, Values: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: value}}, }, }, } } // --------------- // Helpers // --------------- func SQLFieldsConstIdent(tbl schema.Table) *ast.Ident { return ast.NewIdent(strings.ToLower(tbl.GoTypeName) + "SQLFields") }