diff --git a/pkg/codegen/modelgenerate/generate_testfile.go b/pkg/codegen/modelgenerate/generate_testfile.go index 268cbd8..dbc4e9d 100644 --- a/pkg/codegen/modelgenerate/generate_testfile.go +++ b/pkg/codegen/modelgenerate/generate_testfile.go @@ -15,71 +15,6 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { packageName := "db" testpackageName := packageName + "_test" - testDBDecl := &ast.GenDecl{ - Tok: token.VAR, - Specs: []ast.Spec{ - &ast.ValueSpec{ - Names: []*ast.Ident{ast.NewIdent("TestDB")}, - Type: &ast.StarExpr{X: ast.NewIdent("DB")}, - }, - }, - } - - initFuncDecl := &ast.FuncDecl{ - Name: ast.NewIdent("init"), - Type: &ast.FuncType{Params: &ast.FieldList{}}, - Body: &ast.BlockStmt{ - List: []ast.Stmt{ - &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent("TestDB")}, - Tok: token.ASSIGN, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("MakeDB"), - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: `"tmp"`}}, - }}, - }, - }, - }, - } - - makeDBHelperDecl := &ast.FuncDecl{ - Name: ast.NewIdent("MakeDB"), - Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{{ - Names: []*ast.Ident{ast.NewIdent("dbName")}, - Type: ast.NewIdent("string"), - }}, - }, - Results: &ast.FieldList{ - List: []*ast.Field{{Type: &ast.StarExpr{X: ast.NewIdent("DB")}}}, - }, - }, - Body: &ast.BlockStmt{ - List: []ast.Stmt{ - // db := Must(Create(fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName))) - &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent("db")}, - Tok: token.DEFINE, - Rhs: []ast.Expr{mustCall(&ast.CallExpr{ - Fun: ast.NewIdent("Create"), - Args: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("fmt"), Sel: ast.NewIdent("Sprintf")}, - Args: []ast.Expr{ - &ast.BasicLit{Kind: token.STRING, Value: `"file:%s?mode=memory&cache=shared"`}, - ast.NewIdent("dbName"), - }, - }}, - })}, - }, - // return db - &ast.ReturnStmt{ - Results: []ast.Expr{ast.NewIdent("db")}, - }, - }, - }, - } - // func MakeItem() Item { return Item{} } makeItemFunc := &ast.FuncDecl{ Name: ast.NewIdent("Make" + tbl.GoTypeName), @@ -364,16 +299,7 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { } testList := []ast.Decl{ - // var TestDB *DB - testDBDecl, makeItemFunc, - - // func init() { TestDB = MakeDB("tmp") } - initFuncDecl, - - // func MakeDB(dbName string) *DB { db := Must(Create(fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName))); return db } - makeDBHelperDecl, - testCreateUpdateDelete, testGetAll, } @@ -386,7 +312,6 @@ func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { &ast.GenDecl{ Tok: token.IMPORT, Specs: []ast.Spec{ - &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"fmt"`}}, &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"testing"`}}, &ast.ImportSpec{ Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/db"`}, diff --git a/pkg/codegen/pkg.go b/pkg/codegen/pkg.go index 6ea7eba..b2c5c09 100644 --- a/pkg/codegen/pkg.go +++ b/pkg/codegen/pkg.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "os/exec" + "text/template" . "git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils" ) @@ -39,6 +40,11 @@ func InitPkg(opts PkgOpts) { PanicIf(os.WriteFile("pkg/db/schema.sql", Must(tpl.ReadFile("tpl/schema.sql")), 0o664)) PanicIf(os.WriteFile("pkg/db/db.go", Must(tpl.ReadFile("tpl/db.go.tpl")), 0o664)) + dbTest := Must(os.Create("pkg/db/db_test.go")) + defer MustClose(dbTest) + t := Must(template.ParseFS(tpl, "tpl/db_test.go.tpl")) + PanicIf(t.Execute(dbTest, opts)) + PanicIf(os.WriteFile("sample_data/mount.sh", Must(tpl.ReadFile("tpl/mount.sh")), 0o775)) PanicIf(os.WriteFile("sample_data/reset.sh", Must(tpl.ReadFile("tpl/reset.sh")), 0o775)) diff --git a/pkg/codegen/tpl/db_test.go.tpl b/pkg/codegen/tpl/db_test.go.tpl new file mode 100644 index 0000000..8f61eab --- /dev/null +++ b/pkg/codegen/tpl/db_test.go.tpl @@ -0,0 +1,19 @@ +package db_test + +import ( + "fmt" + + . "git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils" + + . "{{ .ModuleName }}/pkg/db" +) + +var TestDB *DB + +func init() { + TestDB = MakeDB("tmp") +} +func MakeDB(dbName string) *DB { + db := Must(Create(fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName))) + return db +} diff --git a/pkg/flowutils/flowutils.go b/pkg/flowutils/flowutils.go index dea3396..600266b 100644 --- a/pkg/flowutils/flowutils.go +++ b/pkg/flowutils/flowutils.go @@ -1,5 +1,7 @@ package flowutils +import "io" + func PanicIf(err error) { if err != nil { panic(err) @@ -10,3 +12,7 @@ func Must[T any](val T, err error) T { PanicIf(err) return val } + +func MustClose(closer io.Closer) { + PanicIf(closer.Close()) +}