gas-stack/pkg/codegen/modelgenerate/generate_testfile.go
~wispem-wantex 8c29d455ff
All checks were successful
CI / build-docker (push) Successful in 12s
CI / build-docker-bootstrap (push) Has been skipped
CI / release-test (push) Successful in 41s
codegen: fix defining the 'err' variable multiple times in foreign key checking test
2026-03-19 15:43:31 -07:00

416 lines
12 KiB
Go

package modelgenerate
import (
"fmt"
"go/ast"
"go/token"
"github.com/jinzhu/inflection"
pkgschema "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
"git.offline-twitter.com/offline-labs/gas-stack/pkg/textutils"
)
// GenerateModelTestAST produces an AST for a starter test file for a given model.
func GenerateModelTestAST(tbl pkgschema.Table, schema pkgschema.Schema, gomodName string) *ast.File {
packageName := "db"
testpackageName := packageName + "_test"
makeHelperName := ast.NewIdent("Make" + tbl.GoTypeName)
// func MakeItem() Item { return Item{} }
makeItemFunc := &ast.FuncDecl{
Name: makeHelperName,
Type: &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{
List: []*ast.Field{
{Type: ast.NewIdent(tbl.GoTypeName)},
},
},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{
&ast.CompositeLit{
Type: ast.NewIdent(tbl.GoTypeName),
Elts: []ast.Expr{
&ast.KeyValueExpr{
Key: ast.NewIdent("Data"),
Value: &ast.CompositeLit{
Type: &ast.ArrayType{
Elt: ast.NewIdent("byte"),
},
Elts: []ast.Expr{},
},
},
&ast.KeyValueExpr{
Key: ast.NewIdent("Description"),
Value: &ast.BasicLit{
Kind: token.STRING,
Value: `""`,
},
},
},
},
},
},
},
},
}
testObj := ast.NewIdent(textutils.CamelToPascal(tbl.GoTypeName))
testObj2 := ast.NewIdent(textutils.CamelToPascal(tbl.GoTypeName) + "2")
fieldName := ast.NewIdent("Description") // TODO
description1 := `"an item"`
description2 := `"a big item"`
testDB := ast.NewIdent("TestDB")
hasCreatedAt, hasUpdatedAt := tbl.HasAutoTimestamps()
testFuncType := &ast.FuncType{
Params: &ast.FieldList{
List: []*ast.Field{{
Names: []*ast.Ident{ast.NewIdent("t")},
Type: &ast.StarExpr{X: &ast.SelectorExpr{X: ast.NewIdent("testing"), Sel: ast.NewIdent("T")}},
}},
},
}
testCreateUpdateDelete := &ast.FuncDecl{
Name: ast.NewIdent("TestCreateUpdateDelete" + tbl.GoTypeName),
Type: testFuncType,
Body: &ast.BlockStmt{
List: func() []ast.Stmt {
assertNotZero := func(obj *ast.Ident, field string) *ast.ExprStmt {
return &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("NotZero")},
Args: []ast.Expr{ast.NewIdent("t"), &ast.SelectorExpr{X: obj, Sel: ast.NewIdent(field)}},
}}
}
stmts := []ast.Stmt{
Comment("Create"),
// item := MakeItem()
&ast.AssignStmt{
Lhs: []ast.Expr{testObj},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{Fun: makeHelperName, Args: nil}},
},
// item.Description = "an item"
&ast.AssignStmt{
Lhs: []ast.Expr{
&ast.SelectorExpr{
X: testObj,
Sel: ast.NewIdent("Description"),
},
},
Tok: token.ASSIGN,
Rhs: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf("%q", description1),
},
},
},
// TestDB.SaveItem(&item)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
}},
// require.NotZero(t, item.ID)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("require"), Sel: ast.NewIdent("NotZero")},
Args: []ast.Expr{ast.NewIdent("t"), &ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
}},
}
// After create: assert timestamps are set
if hasCreatedAt {
stmts = append(stmts, assertNotZero(testObj, "CreatedAt"))
}
if hasUpdatedAt {
stmts = append(stmts, assertNotZero(testObj, "UpdatedAt"))
}
stmts = append(stmts,
BlankLine(),
Comment("Load"),
// item2 := Must(TestDB.GetItemByID(item.ID))
&ast.AssignStmt{
Lhs: []ast.Expr{testObj2},
Tok: token.DEFINE,
Rhs: []ast.Expr{mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
})},
},
// assert.Equal(t, item.Description, item2.Description)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")},
Args: []ast.Expr{
ast.NewIdent("t"),
&ast.SelectorExpr{X: testObj, Sel: fieldName},
&ast.SelectorExpr{X: testObj2, Sel: fieldName},
},
}},
)
stmts = append(stmts,
BlankLine(),
Comment("Update"),
// item.Description = "a big item"
&ast.AssignStmt{
Lhs: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: fieldName}},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: description2}},
},
// TestDB.SaveItem(&item)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Save" + tbl.GoTypeName)},
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
}},
// item2 = Must(TestDB.GetItemByID(item.ID))
&ast.AssignStmt{
Lhs: []ast.Expr{testObj2},
Tok: token.ASSIGN,
Rhs: []ast.Expr{mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
})},
},
// assert.Equal(t, item.Description, item2.Description)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")},
Args: []ast.Expr{
ast.NewIdent("t"),
&ast.SelectorExpr{X: testObj, Sel: fieldName},
&ast.SelectorExpr{X: testObj2, Sel: fieldName},
},
}},
)
indexGets, hasIndexedGets := []ast.Stmt{
BlankLine(),
Comment("Indexed lookups"),
}, false
for _, index := range schema.Indexes {
if index.TableName != tbl.TableName {
// Skip indexes on other tables
continue
}
if index.IsUnique && len(index.Columns) == 1 {
col := tbl.GetColumnByName(index.Columns[0])
indexGets = append(indexGets, []ast.Stmt{
// assert.Equal(t, item2, TestDB.GetItemByXYZ(...))
&ast.ExprStmt{X: &ast.CallExpr{ // TODO: what if just delete the "ExprStmt" wrapper?
Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("Equal")},
Args: []ast.Expr{
ast.NewIdent("t"),
testObj2,
mustCall(&ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent(
"Get" + pkgschema.TypenameFromTablename(tbl.TableName) + "By" + col.GoFieldName(),
)},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj2, Sel: ast.NewIdent(col.GoFieldName())}},
}),
},
}},
}...)
// decls = append(decls, modelgenerate.GenerateGetItemByUniqColFunc(table, table.GetColumnByName(index.Columns[0])))
hasIndexedGets = true
}
}
if hasIndexedGets {
stmts = append(stmts, indexGets...)
}
stmts = append(stmts,
BlankLine(),
Comment("Delete"),
// TestDB.DeleteItem(item)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Delete" + tbl.GoTypeName)},
Args: []ast.Expr{testObj},
}},
// _, err := TestDB.GetItemByID(item.ID)
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{X: testDB, Sel: ast.NewIdent("Get" + tbl.GoTypeName + "ByID")},
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
}},
},
// assert.ErrorIs(t, err, db.ErrNotInDB)
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("assert"), Sel: ast.NewIdent("ErrorIs")},
Args: []ast.Expr{
ast.NewIdent("t"),
ast.NewIdent("err"),
ast.NewIdent("ErrNotInDB"),
},
}},
)
return stmts
}(),
},
}
testGetAll := &ast.FuncDecl{
Name: ast.NewIdent("TestGetAll" + inflection.Plural(tbl.GoTypeName)),
Type: testFuncType,
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("_")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: testDB,
Sel: ast.NewIdent("GetAll" + inflection.Plural(tbl.GoTypeName)),
},
}},
},
},
},
}
shouldIncludeTestFkCheck := false
testFkChecking := &ast.FuncDecl{
Name: ast.NewIdent("Test" + tbl.GoTypeName + "FkChecking"),
Type: testFuncType,
Body: &ast.BlockStmt{
List: func() []ast.Stmt {
// post := MakePost()
stmts := []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(tbl.VarName)},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: ast.NewIdent("Make" + tbl.GoTypeName),
},
},
},
}
shouldDefineErr := true
for _, col := range tbl.Columns {
if col.IsForeignKey {
shouldIncludeTestFkCheck = true
stmts = append(stmts, []ast.Stmt{
// post.QuotedPostID = 94354538969386985
&ast.AssignStmt{
Lhs: []ast.Expr{
&ast.SelectorExpr{
X: ast.NewIdent(tbl.VarName),
Sel: ast.NewIdent(col.GoFieldName()),
},
},
Tok: token.ASSIGN,
Rhs: []ast.Expr{
&ast.BasicLit{
Kind: token.INT,
Value: "94354538969386985",
},
},
},
// err := db.SavePost(&post)
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("err")},
Tok: map[bool]token.Token{true: token.DEFINE, false: token.ASSIGN}[shouldDefineErr],
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: testDB,
Sel: ast.NewIdent("Save" + tbl.GoTypeName),
},
Args: []ast.Expr{
&ast.UnaryExpr{
Op: token.AND,
X: ast.NewIdent(tbl.VarName),
},
},
},
},
},
// assertForeignKeyError(t, err, "QuotedPostID", post.QuotedPostID)
&ast.ExprStmt{
X: &ast.CallExpr{
Fun: ast.NewIdent("AssertForeignKeyError"),
Args: []ast.Expr{
ast.NewIdent("t"),
ast.NewIdent("err"),
&ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf("%q", col.GoFieldName()),
},
&ast.SelectorExpr{
X: ast.NewIdent(tbl.VarName),
Sel: ast.NewIdent(col.GoFieldName()),
},
},
},
},
}...)
shouldDefineErr = false
}
}
return stmts
}(),
},
}
testList := []ast.Decl{
makeItemFunc,
testCreateUpdateDelete,
testGetAll,
}
if shouldIncludeTestFkCheck {
testList = append(testList, testFkChecking)
}
return &ast.File{
Name: ast.NewIdent(testpackageName),
Decls: append([]ast.Decl{
&ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"testing"`}},
&ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/db"`},
Name: ast.NewIdent("."),
},
&ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils"`},
Name: ast.NewIdent("."),
},
&ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s/pkg/%s"`, gomodName, packageName)},
Name: ast.NewIdent("."),
},
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"github.com/stretchr/testify/assert"`}},
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"github.com/stretchr/testify/require"`}},
},
},
}, testList...),
}
}