diff --git a/cmd/main.go b/cmd/main.go index d35f667..d62dbe8 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -21,10 +21,9 @@ func main() { } root_cmd.AddCommand(sqlite_lint) root_cmd.AddCommand(cmd_init) + root_cmd.AddCommand(generate_model) if err := root_cmd.Execute(); err != nil { fmt.Println(RED + err.Error() + RESET) os.Exit(1) } } - -// Subcommand "generate_models" diff --git a/cmd/subcmd_generate_models.go b/cmd/subcmd_generate_models.go new file mode 100644 index 0000000..858fdbf --- /dev/null +++ b/cmd/subcmd_generate_models.go @@ -0,0 +1,81 @@ +package main + +import ( + "errors" + "fmt" + "go/ast" + "go/printer" + "go/token" + "os" + + "github.com/spf13/cobra" + + "git.offline-twitter.com/offline-labs/gas-stack/pkg/codegen/modelgenerate" + . "git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils" + "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" +) + +var ErrNoSuchTable = errors.New("no such table") + +var generate_model = &cobra.Command{ + Use: "generate ", + Short: "Generate a model type", + + Args: cobra.ExactArgs(1), + + RunE: func(cmd *cobra.Command, args []string) error { + path := Must(cmd.Flags().GetString("schema")) + modname := Must(cmd.Flags().GetString("modname")) + schema_sql, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("reading path %s: %w", path, err) + } + db := schema.InitDB(string(schema_sql)) + tables := schema.SchemaFromDB(db).Tables + table, isOk := tables[args[0]] + if !isOk { + return ErrNoSuchTable + } + + fset := token.NewFileSet() + + if Must(cmd.Flags().GetBool("test")) { + file2 := modelgenerate.GenerateModelTestAST(table, modname) + PanicIf(printer.Fprint(os.Stdout, fset, file2)) + } else { + file := &ast.File{ + Name: ast.NewIdent("db"), // TODO: parameterize + + Decls: []ast.Decl{ + &ast.GenDecl{ + Tok: token.IMPORT, + Specs: []ast.Spec{ + &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"database/sql"`}}, + &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"errors"`}}, + &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"fmt"`}}, + &ast.ImportSpec{ + Name: ast.NewIdent("."), + Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/db"`}, + }, + }, + }, + modelgenerate.GenerateIDType(table), + modelgenerate.GenerateModelAST(table), + modelgenerate.GenerateSaveItemFunc(table), + modelgenerate.GenerateGetItemByIDFunc(table), + modelgenerate.GenerateDeleteItemFunc(table), + }, + } + + PanicIf(printer.Fprint(os.Stdout, fset, file)) + } + + return nil + }, +} + +func init() { + generate_model.Flags().String("schema", "pkg/db/schema.sql", "Path to SQL schema file") + generate_model.Flags().String("modname", "mymodule", "Name of project's Go module (TODO: detect automatically)") + generate_model.Flags().Bool("test", false, "Generate test file instead of regular file") +} diff --git a/doc/TODO.txt b/doc/TODO.txt index 5789c9d..428a500 100644 --- a/doc/TODO.txt +++ b/doc/TODO.txt @@ -14,3 +14,5 @@ TODO: foreign-key TODO: modified-timestamps - set updated_at and created_at in SaveXYZ - soft delete option + +TODO: the `db_meta` table doesn't pass sqlite_lint diff --git a/ops/gas_init_test.sh b/ops/gas_init_test.sh index 5ea4328..fb36d87 100755 --- a/ops/gas_init_test.sh +++ b/ops/gas_init_test.sh @@ -26,6 +26,23 @@ mydb.db prog EOF +cd $test_project + +# Create a new table in the schema +cat >> pkg/db/schema.sql < pkg/db/item.go +$gas generate items --test > pkg/db/item_test.go +go mod tidy + +# Run the tests +go test ./... # Notify success in green echo -e "\033[32mAll tests passed. Finished successfully.\033[0m" diff --git a/pkg/codegen/modelgenerate/generate_model.go b/pkg/codegen/modelgenerate/generate_model.go new file mode 100644 index 0000000..fdde9b9 --- /dev/null +++ b/pkg/codegen/modelgenerate/generate_model.go @@ -0,0 +1,243 @@ +//nolint:lll // This file has lots of long lines lol +package modelgenerate + +import ( + "fmt" + "go/ast" + "go/token" + "strings" + + "github.com/jinzhu/inflection" + + "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" + "git.offline-twitter.com/offline-labs/gas-stack/pkg/textutils" +) + +func GenerateIDType(table schema.Table) *ast.GenDecl { + // e.g., `type FoodID uint64` + return &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{&ast.TypeSpec{Name: ast.NewIdent(table.TypeIDName), Type: ast.NewIdent("uint64")}}, + } +} + +// GenerateModelAST produces an AST for a struct type corresponding to the model. +// TODO: generate the right field types here based on column types. +func GenerateModelAST(table schema.Table) *ast.GenDecl { + // Fields for the struct + fields := []*ast.Field{} + + // Other fields (just strings for now) + for _, col := range table.Columns { + switch col.Name { + case "rowid": + fields = append(fields, &ast.Field{ + Names: []*ast.Ident{ast.NewIdent("ID")}, + Type: ast.NewIdent(table.TypeIDName), + Tag: &ast.BasicLit{Kind: token.STRING, Value: "`db:\"rowid\" json:\"id\"`"}, + }) + default: + if col.IsForeignKey && strings.HasSuffix(col.Name, "_id") { + fields = append(fields, &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(strings.TrimSuffix(col.Name, "_id")) + "ID")}, + Type: ast.NewIdent(textutils.SnakeToCamel(inflection.Singular(col.ForeignKeyTargetTable) + "ID")), + Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)}, + }) + } else { + typeName := "string" + switch col.Type { + case "integer", "int": + if strings.HasPrefix(col.Name, "is_") || strings.HasPrefix(col.Name, "has_") { + typeName = "bool" + } else if strings.HasSuffix(col.Name, "_at") { + typeName = "Timestamp" + } else { + typeName = "int64" + } + case "text": + typeName = "string" + case "real": + typeName = "float32" + case "blob": + typeName = "[]byte" + default: + panic("Unrecognized sqlite column type: " + col.Type) + } + + fields = append(fields, &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))}, + Type: ast.NewIdent(typeName), + Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)}, + }) + } + } + } + + return &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{&ast.TypeSpec{ + Name: ast.NewIdent(table.TypeName), + Type: &ast.StructType{Fields: &ast.FieldList{List: fields}}, + }}, + } +} + +// GenerateSaveItemFunc produces an AST for the SaveXyz() function of the model. +// E.g., a table with `table.TypeName = "foods"` will produce a "SaveFood()" function. +func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl { + insertCols := make([]string, 0, len(tbl.Columns)) + insertVals := make([]string, 0, len(tbl.Columns)) + updatePairs := make([]string, 0, len(tbl.Columns)) + + for _, col := range tbl.Columns { + if col.Name == "rowid" { + continue + } + insertCols = append(insertCols, col.Name) + val := ":" + col.Name + if col.IsNullableForeignKey() { + val = fmt.Sprintf("nullif(%s, 0)", val) + } + insertVals = append(insertVals, val) + updatePairs = append(updatePairs, col.Name+"="+val) + } + + insertStmt := fmt.Sprintf("\n\t\t insert into %s (%s)\n\t\t values (%s)\n\t\t", tbl.TableName, strings.Join(insertCols, ", "), strings.Join(insertVals, ", ")) + updateStmt := fmt.Sprintf("\n\t\t update %s\n\t\t set %s\n\t\t where rowid = :rowid\n\t\t", tbl.TableName, strings.Join(updatePairs, ",\n\t\t ")) + + funcBody := &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.IfStmt{ + Cond: &ast.BinaryExpr{ + X: &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}, + Op: token.EQL, + Y: &ast.BasicLit{Kind: token.INT, Value: "0"}, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("result"), ast.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`" + insertStmt + "`"}, ast.NewIdent(tbl.VarName)}}}, + }, + &ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}}, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("id"), ast.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("LastInsertId")}, Args: []ast.Expr{}}}, + }, + &ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}}, + &ast.AssignStmt{Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}, Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent(tbl.TypeIDName), Args: []ast.Expr{ast.NewIdent("id")}}}}, + }, + }, + Else: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("result"), ast.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`" + updateStmt + "`"}, ast.NewIdent(tbl.VarName)}}}, + }, + &ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}}, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("count"), ast.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, Args: []ast.Expr{}}}, + }, + &ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}}, + &ast.IfStmt{ + Cond: &ast.BinaryExpr{X: ast.NewIdent("count"), Op: token.NEQ, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("fmt.Errorf"), Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"got %s with ID (%%d), so attempted update, but it doesn't exist\"", strings.ToLower(tbl.TypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}}}}}}, + }, + }, + }, + }, + }, + } + + funcDecl := &ast.FuncDecl{ + Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}}, + Name: ast.NewIdent("Save" + tbl.TypeName), + Type: &ast.FuncType{ + Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: &ast.StarExpr{X: ast.NewIdent(tbl.TypeName)}}}}, + Results: nil, + }, + Body: funcBody, + } + return funcDecl +} + +// GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function. +// E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function. +func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl { + funcName := "Get" + tbl.TypeName + "ByID" + + recv := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}} + arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("id")}, Type: ast.NewIdent(tbl.TypeIDName)}}} + result := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.TypeName)}, {Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")}}} + + selectCols := make([]string, 0, len(tbl.Columns)) + for _, col := range tbl.Columns { + selectCols = append(selectCols, col.Name) + } + selectStmt := fmt.Sprintf("\n\t select %s\n\t from %s\n\t where rowid = ?\n\t", strings.Join(selectCols, ", "), tbl.TableName) + + funcBody := &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("err")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, &ast.BasicLit{Kind: token.STRING, Value: "`" + selectStmt + "`"}, ast.NewIdent("id")}}}, + }, + &ast.IfStmt{ + Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.TypeName)}, ast.NewIdent("ErrNotInDB")}}}}, + }, + &ast.ReturnStmt{}, + }, + } + + funcDecl := &ast.FuncDecl{ + Recv: recv, + Name: ast.NewIdent(funcName), + Type: &ast.FuncType{Params: arg, Results: result}, + Body: funcBody, + } + return funcDecl +} + +// GenerateDeleteItemFunc produces an AST for the `DeleteXyz()` function. +// E.g., a table with `table.TypeName = "foods"` will produce a "DeleteFood()" function. +func GenerateDeleteItemFunc(tbl schema.Table) *ast.FuncDecl { + funcName := "Delete" + tbl.TypeName + recv := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}} + arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: ast.NewIdent(tbl.TypeName)}}} + + funcBody := &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("result"), ast.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Exec")}, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`delete from " + tbl.TableName + " where rowid = ?`"}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}}, + }, + &ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}}, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("count"), ast.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, Args: []ast.Expr{}}}, + }, + &ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}}, + &ast.IfStmt{ + Cond: &ast.BinaryExpr{X: ast.NewIdent("count"), Op: token.NEQ, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("fmt.Errorf"), Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"tried to delete %s with ID (%%d) but it doesn't exist\"", strings.ToLower(tbl.TypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}}}}}}, + }, + }, + } + + funcDecl := &ast.FuncDecl{ + Recv: recv, + Name: ast.NewIdent(funcName), + Type: &ast.FuncType{Params: arg, Results: nil}, + Body: funcBody, + } + return funcDecl +} diff --git a/pkg/codegen/modelgenerate/generate_testfile.go b/pkg/codegen/modelgenerate/generate_testfile.go new file mode 100644 index 0000000..1a86c59 --- /dev/null +++ b/pkg/codegen/modelgenerate/generate_testfile.go @@ -0,0 +1,251 @@ +package modelgenerate + +import ( + "fmt" + "go/ast" + "go/token" + + "git.offline-twitter.com/offline-labs/gas-stack/pkg/schema" +) + +// GenerateModelTestAST produces an AST for a starter test file for a given model. +func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File { + packageName := "db" + testpackageName := packageName + "_test" + + testDBDecl := &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{ + &ast.ValueSpec{ + Names: []*ast.Ident{ast.NewIdent("TestDB")}, + Type: &ast.StarExpr{X: ast.NewIdent("DB")}, + }, + }, + } + + initFuncDecl := &ast.FuncDecl{ + Name: ast.NewIdent("init"), + Type: &ast.FuncType{Params: &ast.FieldList{}}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("TestDB")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("MakeDB"), + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: `"tmp"`}}, + }}, + }, + }, + }, + } + + makeDBHelperDecl := &ast.FuncDecl{ + Name: ast.NewIdent("MakeDB"), + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{{ + Names: []*ast.Ident{ast.NewIdent("dbName")}, + Type: ast.NewIdent("string"), + }}, + }, + Results: &ast.FieldList{ + List: []*ast.Field{{Type: &ast.StarExpr{X: ast.NewIdent("DB")}}}, + }, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + // db := Must(Create(fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName))) + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("db")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("Must"), + Args: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("Create"), + Args: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("fmt.Sprintf"), + Args: []ast.Expr{ + &ast.BasicLit{Kind: token.STRING, Value: `"file:%s?mode=memory&cache=shared"`}, + ast.NewIdent("dbName"), + }, + }}, + }}, + }}, + }, + // return db + &ast.ReturnStmt{ + Results: []ast.Expr{ast.NewIdent("db")}, + }, + }, + }, + } + + testObj := ast.NewIdent("item") + testObj2 := ast.NewIdent("item2") + fieldName := ast.NewIdent("Description") + description1 := `"an item"` + description2 := `"a big item"` + + return &ast.File{ + Name: ast.NewIdent(testpackageName), + Decls: []ast.Decl{ + &ast.GenDecl{ + Tok: token.IMPORT, + Specs: []ast.Spec{ + &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"fmt"`}}, + &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"testing"`}}, + &ast.ImportSpec{ + Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/db"`}, + Name: ast.NewIdent("db"), + }, + &ast.ImportSpec{ + Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils"`}, + Name: ast.NewIdent("."), + }, + &ast.ImportSpec{ + Path: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s/pkg/%s"`, gomodName, packageName)}, + Name: ast.NewIdent("."), + }, + &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"github.com/stretchr/testify/assert"`}}, + &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"github.com/stretchr/testify/require"`}}, + }, + }, + // var TestDB *DB + testDBDecl, + + // func init() { TestDB = MakeDB("tmp") } + initFuncDecl, + + // func MakeDB(dbName string) *DB { db := Must(Create(fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName))); return db } + makeDBHelperDecl, + + &ast.FuncDecl{ + Name: ast.NewIdent("TestCreateUpdateDelete" + tbl.TypeName), + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{{ + Names: []*ast.Ident{ast.NewIdent("t")}, + Type: ast.NewIdent("*testing.T"), + }}, + }, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + // item := Item{Description: "an item"} + &ast.AssignStmt{ + Lhs: []ast.Expr{testObj}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CompositeLit{ + Type: ast.NewIdent(tbl.TypeName), + Elts: []ast.Expr{ + &ast.KeyValueExpr{ + Key: fieldName, + Value: &ast.BasicLit{Kind: token.STRING, Value: description1}, + }, + }, + }}, + }, + + // TestDB.SaveItem(&item) + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: ast.NewIdent("TestDB.Save" + tbl.TypeName), + Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}}, + }}, + + // require.NotZero(t, item.ID) + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: ast.NewIdent("require.NotZero"), + Args: []ast.Expr{ast.NewIdent("t"), &ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, + }}, + + // item2 := Must(TestDB.GetItemByID(item.ID)) + &ast.AssignStmt{ + Lhs: []ast.Expr{testObj2}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("Must"), + Args: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("TestDB.Get" + tbl.TypeName + "ByID"), + Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, + }}, + }}, + }, + + // assert.Equal(t, "an item", item2.Description) + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: ast.NewIdent("assert.Equal"), + Args: []ast.Expr{ + ast.NewIdent("t"), + &ast.BasicLit{Kind: token.STRING, Value: description1}, + &ast.SelectorExpr{X: testObj2, Sel: fieldName}, + }, + }}, + + // item.Description = "a big item" + &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: fieldName}}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: description2}}, + }, + + // TestDB.SaveItem(&item) + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: ast.NewIdent("TestDB.Save" + tbl.TypeName), + Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}}, + }}, + + // item2 = Must(TestDB.GetItemByID(item.ID)) + &ast.AssignStmt{ + Lhs: []ast.Expr{testObj2}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("Must"), + Args: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("TestDB.Get" + tbl.TypeName + "ByID"), + Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, + }}, + }}, + }, + + // assert.Equal(t, item.Description, item2.Description) + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: ast.NewIdent("assert.Equal"), + Args: []ast.Expr{ + ast.NewIdent("t"), + &ast.SelectorExpr{X: testObj, Sel: fieldName}, + &ast.SelectorExpr{X: testObj2, Sel: fieldName}, + }, + }}, + + // TestDB.DeleteItem(item) + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: ast.NewIdent("TestDB.Delete" + tbl.TypeName), + Args: []ast.Expr{testObj}, + }}, + + // _, err := TestDB.GetItemByID(item.ID) + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("TestDB.Get" + tbl.TypeName + "ByID"), + Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}}, + }}, + }, + + // assert.ErrorIs(t, err, db.ErrNotInDB) + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: ast.NewIdent("assert.ErrorIs"), + Args: []ast.Expr{ + ast.NewIdent("t"), + ast.NewIdent("err"), + &ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent("ErrNotInDB")}, + }, + }}, + }, + }, + }, + }, + } +} diff --git a/pkg/codegen/pkg.go b/pkg/codegen/pkg.go index c36347b..46bc007 100644 --- a/pkg/codegen/pkg.go +++ b/pkg/codegen/pkg.go @@ -37,10 +37,10 @@ func InitPkg(opts PkgOpts) { PanicIf(os.MkdirAll("sample_data", 0o755)) PanicIf(os.WriteFile("pkg/db/schema.sql", Must(tpl.ReadFile("tpl/schema.sql")), 0o664)) + PanicIf(os.WriteFile("pkg/db/db.go", Must(tpl.ReadFile("tpl/db.go")), 0o664)) PanicIf(os.WriteFile("sample_data/mount.sh", Must(tpl.ReadFile("tpl/mount.sh")), 0o775)) PanicIf(os.WriteFile("sample_data/reset.sh", Must(tpl.ReadFile("tpl/reset.sh")), 0o775)) - PanicIf(os.WriteFile("pkg/db/schema.sql", Must(tpl.ReadFile("tpl/schema.sql")), 0o664)) // TODO: // - create `pkg/db/errors.go` diff --git a/pkg/codegen/tpl/db.go.tpl b/pkg/codegen/tpl/db.go.tpl new file mode 100644 index 0000000..1619d12 --- /dev/null +++ b/pkg/codegen/tpl/db.go.tpl @@ -0,0 +1,38 @@ +package db + +import ( + _ "embed" + "fmt" + + "github.com/jmoiron/sqlx" + + "git.offline-twitter.com/offline-labs/gas-stack/pkg/db" +) + +//go:embed schema.sql +var sql_schema string + +// Database starts at version 0. First migration brings us to version 1 +var MIGRATIONS = []string{} + +type DB struct { + DB *sqlx.DB +} + +func Create(path string) (*DB, error) { + conf := db.Init(&sql_schema, &MIGRATIONS) + dbHandle, err := conf.Create(path) + if err != nil { + return nil, fmt.Errorf("creating db: %w", err) + } + return &DB{dbHandle}, nil +} + +func Connect(path string) (*DB, error) { + conf := db.Init(&sql_schema, &MIGRATIONS) + dbHandle, err := conf.Connect(path) + if err != nil { + return nil, fmt.Errorf("creating db: %w", err) + } + return &DB{dbHandle}, nil +}