package modelgenerate import ( "fmt" "go/ast" "go/token" "github.com/jinzhu/inflection" pkgschema "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" ) // GenerateModelTestAST produces an AST for a starter test file for a given model. func GenerateModelTestAST(tbl pkgschema.Table, schema pkgschema.Schema, gomodName string) *ast.File { packageName := "db" testpackageName := packageName + "_test" // func MakeItem() Item { return Item{} } makeItemFunc := &ast.FuncDecl{ Name: ast.NewIdent("Make" + tbl.GoTypeName), Type: &ast.FuncType{ Params: &ast.FieldList{}, Results: &ast.FieldList{ List: []*ast.Field{ {Type: ast.NewIdent(tbl.GoTypeName)}, }, }, }, Body: &ast.BlockStmt{ List: []ast.Stmt{ &ast.ReturnStmt{ Results: []ast.Expr{ &ast.CompositeLit{ Type: ast.NewIdent(tbl.GoTypeName), }, }, }, }, }, } testObj := ast.NewIdent("item") testObj2 := ast.NewIdent("item2") fieldName := ast.NewIdent("Description") description1 := `"an item"` description2 := `"a big item"` testDB := ast.NewIdent("TestDB") hasCreatedAt, hasUpdatedAt := tbl.HasAutoTimestamps() testFuncType := &ast.FuncType{ Params: &ast.FieldList{ List: []*ast.Field{{ Names: []*ast.Ident{ast.NewIdent("t")}, Type: &ast.StarExpr{X: &ast.SelectorExpr{X: ast.NewIdent("testing"), Sel: ast.NewIdent("T")}}, }}, }, } testCreateUpdateDelete := &ast.FuncDecl{ Name: ast.NewIdent("TestCreateUpdateDelete" + tbl.GoTypeName), Type: testFuncType, Body: &ast.BlockStmt{ List: func() []ast.Stmt { assertNotZero := func(obj *ast.Ident, field string) *ast.ExprStmt { return &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("NotZero")}, Args: []ast.Expr{ast.NewIdent("t"), &ast.SelectorExpr{X: obj, Sel: ast.NewIdent(field)}}, }} } stmts := []ast.Stmt{ Comment("Create"), // item := Item{Description: "an item"} &ast.AssignStmt{ Lhs: []ast.Expr{testObj}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CompositeLit{ Type: ast.NewIdent(tbl.GoTypeName), Elts: []ast.Expr{ &ast.KeyValueExpr{ Key: fieldName, Value: &ast.BasicLit{Kind: token.STRING, Value: description1}, }, }, }}, }, // TestDB.SaveItem(&item) &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}}, }}, // require.NotZero(t, item.ID) &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("require"), Sel: ast.NewIdent("NotZero")}, Args: []ast.Expr{ast.NewIdent("t"), &ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, }}, } // After create: assert timestamps are set if hasCreatedAt { stmts = append(stmts, assertNotZero(testObj, "CreatedAt")) } if hasUpdatedAt { stmts = append(stmts, assertNotZero(testObj, "UpdatedAt")) } stmts = append(stmts, BlankLine(), Comment("Load"), // item2 := Must(TestDB.GetItemByID(item.ID)) &ast.AssignStmt{ Lhs: []ast.Expr{testObj2}, Tok: token.DEFINE, Rhs: []ast.Expr{mustCall(&ast.CallExpr{ Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, })}, }, // assert.Equal(t, "an item", item2.Description) &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")}, Args: []ast.Expr{ ast.NewIdent("t"), &ast.BasicLit{Kind: token.STRING, Value: description1}, &ast.SelectorExpr{X: testObj2, Sel: fieldName}, }, }}, ) stmts = append(stmts, BlankLine(), Comment("Update"), // item.Description = "a big item" &ast.AssignStmt{ Lhs: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: fieldName}}, Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: description2}}, }, // TestDB.SaveItem(&item) &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}}, }}, // item2 = Must(TestDB.GetItemByID(item.ID)) &ast.AssignStmt{ Lhs: []ast.Expr{testObj2}, Tok: token.ASSIGN, Rhs: []ast.Expr{mustCall(&ast.CallExpr{ Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, })}, }, // assert.Equal(t, item.Description, item2.Description) &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")}, Args: []ast.Expr{ ast.NewIdent("t"), &ast.SelectorExpr{X: testObj, Sel: fieldName}, &ast.SelectorExpr{X: testObj2, Sel: fieldName}, }, }}, ) indexGets, hasIndexedGets := []ast.Stmt{ BlankLine(), Comment("Indexed lookups"), }, false for _, index := range schema.Indexes { if index.TableName != tbl.TableName { // Skip indexes on other tables continue } if index.IsUnique && len(index.Columns) == 1 { col := tbl.GetColumnByName(index.Columns[0]) indexGets = append(indexGets, []ast.Stmt{ // assert.Equal(t, item2, TestDB.GetItemByXYZ(...)) &ast.ExprStmt{X: &ast.CallExpr{ // TODO: what if just delete the "ExprStmt" wrapper? Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")}, Args: []ast.Expr{ ast.NewIdent("t"), testObj2, mustCall(&ast.CallExpr{ Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent( "Get" + pkgschema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName(), )}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj2, Sel: ast.NewIdent(col.GoFieldName())}}, }), }, }}, }...) // decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0]))) hasIndexedGets = true } } if hasIndexedGets { stmts = append(stmts, indexGets...) } stmts = append(stmts, BlankLine(), Comment("Delete"), // TestDB.DeleteItem(item) &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Delete" + tbl.GoTypeName)}, Args: []ast.Expr{testObj}, }}, // _, err := TestDB.GetItemByID(item.ID) &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")}, Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, }}, }, // assert.ErrorIs(t, err, db.ErrNotInDB) &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("ErrorIs")}, Args: []ast.Expr{ ast.NewIdent("t"), ast.NewIdent("err"), ast.NewIdent("ErrNotInDB"), }, }}, ) return stmts }(), }, } testGetAll := &ast.FuncDecl{ Name: ast.NewIdent("TestGetAll" + inflection.Plural(tbl.GoTypeName)), Type: testFuncType, Body: &ast.BlockStmt{ List: []ast.Stmt{ &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("_")}, Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{ X: testDB, Sel: ast.NewIdent("GetAll" + inflection.Plural(tbl.GoTypeName)), }, }}, }, }, }, } shouldIncludeTestFkCheck := false testFkChecking := &ast.FuncDecl{ Name: ast.NewIdent("Test" + tbl.GoTypeName + "FkChecking"), Type: testFuncType, Body: &ast.BlockStmt{ List: func() []ast.Stmt { // post := MakePost() stmts := []ast.Stmt{ &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent(tbl.VarName)}, Tok: token.DEFINE, Rhs: []ast.Expr{ &ast.CallExpr{ Fun: ast.NewIdent("Make" + tbl.GoTypeName), }, }, }, } for _, col := range tbl.Columns { if col.IsForeignKey { shouldIncludeTestFkCheck = true stmts = append(stmts, []ast.Stmt{ // post.QuotedPostID = 94354538969386985 &ast.AssignStmt{ Lhs: []ast.Expr{ &ast.SelectorExpr{ X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent(col.GoFieldName()), }, }, Tok: token.ASSIGN, Rhs: []ast.Expr{ &ast.BasicLit{ Kind: token.INT, Value: "94354538969386985", }, }, }, // err := db.SavePost(&post) &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent("err")}, Tok: token.DEFINE, Rhs: []ast.Expr{ &ast.CallExpr{ Fun: &ast.SelectorExpr{ X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName), }, Args: []ast.Expr{ &ast.UnaryExpr{ Op: token.AND, X: ast.NewIdent(tbl.VarName), }, }, }, }, }, // assertForeignKeyError(t, err, "QuotedPostID", post.QuotedPostID) &ast.ExprStmt{ X: &ast.CallExpr{ Fun: ast.NewIdent("AssertForeignKeyError"), Args: []ast.Expr{ ast.NewIdent("t"), ast.NewIdent("err"), &ast.BasicLit{ Kind: token.STRING, Value: fmt.Sprintf("%q", col.GoFieldName()), }, &ast.SelectorExpr{ X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent(col.GoFieldName()), }, }, }, }, }...) } } return stmts }(), }, } testList := []ast.Decl{ makeItemFunc, testCreateUpdateDelete, testGetAll, } if shouldIncludeTestFkCheck { testList = append(testList, testFkChecking) } return &ast.File{ Name: ast.NewIdent(testpackageName), Decls: append([]ast.Decl{ &ast.GenDecl{ Tok: token.IMPORT, Specs: []ast.Spec{ &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"testing"`}}, &ast.ImportSpec{ Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/db"`}, Name: ast.NewIdent("."), }, &ast.ImportSpec{ Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils"`}, Name: ast.NewIdent("."), }, &ast.ImportSpec{ Path: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s/pkg/%s"`, gomodName, packageName)}, Name: ast.NewIdent("."), }, &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"github.com/stretchr/testify/assert"`}}, &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"github.com/stretchr/testify/require"`}}, }, }, }, testList...), } }