diff --git a/pkg/codegen/modelgenerate/generate_model.go b/pkg/codegen/modelgenerate/generate_model.go index 6f21a7b..fdfbb47 100644 --- a/pkg/codegen/modelgenerate/generate_model.go +++ b/pkg/codegen/modelgenerate/generate_model.go @@ -110,33 +110,227 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { updatePairs = append(updatePairs, col.Name+"="+val) } + hasFks := false + checkForeignKeyFailuresAssignment := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("checkForeignKeyFailures")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{ + &ast.FuncLit{ + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{{ + Names: []*ast.Ident{ast.NewIdent("err")}, + Type: ast.NewIdent("error"), + }}, + }, + Results: &ast.FieldList{ + List: []*ast.Field{{Type: ast.NewIdent("error")}}, + }, + }, + Body: &ast.BlockStmt{ + List: func() []ast.Stmt { + ret := []ast.Stmt{} + // if !isSqliteFkError(err) { return nil } + ret = append(ret, &ast.IfStmt{ + Cond: &ast.UnaryExpr{Op: token.NOT, X: &ast.CallExpr{Fun: ast.NewIdent("IsSqliteFkError"), Args: []ast.Expr{ast.NewIdent("err")}}}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}}}}, + }) + + for _, col := range tbl.Columns { + if !col.IsForeignKey { // Check both "real" foreign keys and code table values + continue + } + hasFks = true + + structFieldName := fkFieldName(col) + structField := &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent(structFieldName)} + + if col.IsNonCodeTableForeignKey() { + // Real foreign key; look up referent by ID to see if it exists + ret = append(ret, &ast.IfStmt{ + Init: &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent(getByIDFuncName(col.ForeignKeyTargetTable))}, + Args: []ast.Expr{ + structField, + }, + }, + }, + }, + Cond: &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, + Args: []ast.Expr{ + ast.NewIdent("err"), + ast.NewIdent("ErrNotInDB"), + }, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ReturnStmt{ + Results: []ast.Expr{ + &ast.CallExpr{ + Fun: ast.NewIdent("NewForeignKeyError"), + Args: []ast.Expr{ + &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", structFieldName)}, + &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", col.ForeignKeyTargetTable)}, + structField, + }, + }, + }, + }, + }, + }, + }) + } else { + // Code table value. Query the table to see if it exists + ret = append(ret, &ast.IfStmt{ + Init: &ast.AssignStmt{ + Lhs: []ast.Expr{ + ast.NewIdent("err"), + }, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent("DB")}, Sel: ast.NewIdent("Get")}, + Args: []ast.Expr{ + &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("int")}}, + &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`select 1 from %s where rowid = ?`", col.ForeignKeyTargetTable)}, + structField, + }, + }, + }, + }, + Cond: &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, + Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ReturnStmt{ + Results: []ast.Expr{ + &ast.CallExpr{ + Fun: ast.NewIdent("NewForeignKeyError"), + Args: []ast.Expr{ + ast.NewIdent(`"Type"`), + ast.NewIdent(fmt.Sprintf("%q", col.ForeignKeyTargetTable)), + structField, + }, + }, + }, + }, + }, + }, + }) + } + } + // final return nil + ret = append(ret, &ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}}) + return ret + }(), + }, + }, + }, + } + insertStmt := fmt.Sprintf("\n\t\t insert into %s (%s)\n\t\t values (%s)\n\t\t", tbl.TableName, strings.Join(insertCols, ", "), strings.Join(insertVals, ", ")) updateStmt := fmt.Sprintf("\n\t\t update %s\n\t\t set %s\n\t\t where rowid = :rowid\n\t\t", tbl.TableName, strings.Join(updatePairs, ",\n\t\t ")) funcBody := &ast.BlockStmt{ - List: []ast.Stmt{ - &ast.IfStmt{ + List: func() []ast.Stmt { + ret := []ast.Stmt{} + if hasFks { + ret = append(ret, checkForeignKeyFailuresAssignment) + } + // if item.ID == 0 {...} else {...} + ret = append(ret, &ast.IfStmt{ Cond: &ast.BinaryExpr{ X: &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}, Op: token.EQL, Y: &ast.BasicLit{Kind: token.INT, Value: "0"}, }, Body: &ast.BlockStmt{ - List: []ast.Stmt{ - &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent("result")}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("Must"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, - Args: []ast.Expr{ - &ast.BasicLit{Kind: token.STRING, Value: "`" + insertStmt + "`"}, - ast.NewIdent(tbl.VarName), + // Do create + List: append( + func() []ast.Stmt { + namedExecStmt := &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, + Args: []ast.Expr{ + &ast.BasicLit{Kind: token.STRING, Value: "`" + insertStmt + "`"}, + ast.NewIdent(tbl.VarName), + }, + } + if !hasFks { + // No foreign key checking needed; just use `Must` for brevity + return []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("result")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("Must"), + Args: []ast.Expr{namedExecStmt}, + }}, }, - }}, - }}, - }, + } + } + + return []ast.Stmt{ + // result, err := db.DB.NamedExec(`...`, u) + &ast.AssignStmt{ + Lhs: []ast.Expr{ + ast.NewIdent("result"), + ast.NewIdent("err"), + }, + Tok: token.DEFINE, + Rhs: []ast.Expr{namedExecStmt}, + }, + + // if fkErr := checkForeignKeyFailures(err); fkErr != nil { return fkErr } else if err != nil { panic(err) } + &ast.IfStmt{ + Init: &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("fkErr")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{ + &ast.CallExpr{ + Fun: ast.NewIdent("checkForeignKeyFailures"), + Args: []ast.Expr{ast.NewIdent("err")}, + }, + }, + }, + Cond: &ast.BinaryExpr{ + X: ast.NewIdent("fkErr"), + Op: token.NEQ, + Y: ast.NewIdent("nil"), + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ReturnStmt{ + Results: []ast.Expr{ast.NewIdent("fkErr")}, + }, + }, + }, + Else: &ast.IfStmt{ + Cond: &ast.BinaryExpr{ + X: ast.NewIdent("err"), + Op: token.NEQ, + Y: ast.NewIdent("nil"), + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ExprStmt{ + X: &ast.CallExpr{ + Fun: ast.NewIdent("panic"), + Args: []ast.Expr{ast.NewIdent("err")}, + }, + }, + }, + }, + }, + }, + } + }(), &ast.AssignStmt{ Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}, Tok: token.ASSIGN, @@ -151,9 +345,10 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { }}, }}, }, - }, + ), }, Else: &ast.BlockStmt{ + // Do update List: []ast.Stmt{ &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("result")}, @@ -183,16 +378,26 @@ func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { }, }, }, - }, - }, + }) + if hasFks { + // If there's foreign key checking, it needs to return an error (or nil) + ret = append(ret, &ast.ReturnStmt{Results: []ast.Expr{ast.NewIdent("nil")}}) + } + return ret + }(), } funcDecl := &ast.FuncDecl{ Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}}, Name: ast.NewIdent("Save" + tbl.GoTypeName), Type: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: &ast.StarExpr{X: ast.NewIdent(tbl.GoTypeName)}}}}, - Results: nil, + Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: &ast.StarExpr{X: ast.NewIdent(tbl.GoTypeName)}}}}, + Results: func() *ast.FieldList { + if hasFks { + return &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("error")}}} + } + return nil + }(), }, Body: funcBody, }