Add model generator based on AST
This commit is contained in:
parent
2146e1de77
commit
208c1eb26c
@ -21,10 +21,9 @@ func main() {
|
||||
}
|
||||
root_cmd.AddCommand(sqlite_lint)
|
||||
root_cmd.AddCommand(cmd_init)
|
||||
root_cmd.AddCommand(generate_model)
|
||||
if err := root_cmd.Execute(); err != nil {
|
||||
fmt.Println(RED + err.Error() + RESET)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Subcommand "generate_models"
|
||||
|
||||
81
cmd/subcmd_generate_models.go
Normal file
81
cmd/subcmd_generate_models.go
Normal file
@ -0,0 +1,81 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"git.offline-twitter.com/offline-labs/gas-stack/pkg/codegen/modelgenerate"
|
||||
. "git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils"
|
||||
"git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
|
||||
)
|
||||
|
||||
var ErrNoSuchTable = errors.New("no such table")
|
||||
|
||||
var generate_model = &cobra.Command{
|
||||
Use: "generate <model_name>",
|
||||
Short: "Generate a model type",
|
||||
|
||||
Args: cobra.ExactArgs(1),
|
||||
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
path := Must(cmd.Flags().GetString("schema"))
|
||||
modname := Must(cmd.Flags().GetString("modname"))
|
||||
schema_sql, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading path %s: %w", path, err)
|
||||
}
|
||||
db := schema.InitDB(string(schema_sql))
|
||||
tables := schema.SchemaFromDB(db).Tables
|
||||
table, isOk := tables[args[0]]
|
||||
if !isOk {
|
||||
return ErrNoSuchTable
|
||||
}
|
||||
|
||||
fset := token.NewFileSet()
|
||||
|
||||
if Must(cmd.Flags().GetBool("test")) {
|
||||
file2 := modelgenerate.GenerateModelTestAST(table, modname)
|
||||
PanicIf(printer.Fprint(os.Stdout, fset, file2))
|
||||
} else {
|
||||
file := &ast.File{
|
||||
Name: ast.NewIdent("db"), // TODO: parameterize
|
||||
|
||||
Decls: []ast.Decl{
|
||||
&ast.GenDecl{
|
||||
Tok: token.IMPORT,
|
||||
Specs: []ast.Spec{
|
||||
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"database/sql"`}},
|
||||
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"errors"`}},
|
||||
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"fmt"`}},
|
||||
&ast.ImportSpec{
|
||||
Name: ast.NewIdent("."),
|
||||
Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/db"`},
|
||||
},
|
||||
},
|
||||
},
|
||||
modelgenerate.GenerateIDType(table),
|
||||
modelgenerate.GenerateModelAST(table),
|
||||
modelgenerate.GenerateSaveItemFunc(table),
|
||||
modelgenerate.GenerateGetItemByIDFunc(table),
|
||||
modelgenerate.GenerateDeleteItemFunc(table),
|
||||
},
|
||||
}
|
||||
|
||||
PanicIf(printer.Fprint(os.Stdout, fset, file))
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
generate_model.Flags().String("schema", "pkg/db/schema.sql", "Path to SQL schema file")
|
||||
generate_model.Flags().String("modname", "mymodule", "Name of project's Go module (TODO: detect automatically)")
|
||||
generate_model.Flags().Bool("test", false, "Generate test file instead of regular file")
|
||||
}
|
||||
@ -14,3 +14,5 @@ TODO: foreign-key
|
||||
TODO: modified-timestamps
|
||||
- set updated_at and created_at in SaveXYZ
|
||||
- soft delete option
|
||||
|
||||
TODO: the `db_meta` table doesn't pass sqlite_lint
|
||||
|
||||
@ -26,6 +26,23 @@ mydb.db
|
||||
prog
|
||||
EOF
|
||||
|
||||
cd $test_project
|
||||
|
||||
# Create a new table in the schema
|
||||
cat >> pkg/db/schema.sql <<EOF
|
||||
create table items (
|
||||
rowid integer primary key,
|
||||
description text not null default ''
|
||||
);
|
||||
EOF
|
||||
|
||||
# Generate an item model and test file
|
||||
$gas generate items > pkg/db/item.go
|
||||
$gas generate items --test > pkg/db/item_test.go
|
||||
go mod tidy
|
||||
|
||||
# Run the tests
|
||||
go test ./...
|
||||
|
||||
# Notify success in green
|
||||
echo -e "\033[32mAll tests passed. Finished successfully.\033[0m"
|
||||
|
||||
243
pkg/codegen/modelgenerate/generate_model.go
Normal file
243
pkg/codegen/modelgenerate/generate_model.go
Normal file
@ -0,0 +1,243 @@
|
||||
//nolint:lll // This file has lots of long lines lol
|
||||
package modelgenerate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"strings"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
|
||||
"git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
|
||||
"git.offline-twitter.com/offline-labs/gas-stack/pkg/textutils"
|
||||
)
|
||||
|
||||
func GenerateIDType(table schema.Table) *ast.GenDecl {
|
||||
// e.g., `type FoodID uint64`
|
||||
return &ast.GenDecl{
|
||||
Tok: token.TYPE,
|
||||
Specs: []ast.Spec{&ast.TypeSpec{Name: ast.NewIdent(table.TypeIDName), Type: ast.NewIdent("uint64")}},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateModelAST produces an AST for a struct type corresponding to the model.
|
||||
// TODO: generate the right field types here based on column types.
|
||||
func GenerateModelAST(table schema.Table) *ast.GenDecl {
|
||||
// Fields for the struct
|
||||
fields := []*ast.Field{}
|
||||
|
||||
// Other fields (just strings for now)
|
||||
for _, col := range table.Columns {
|
||||
switch col.Name {
|
||||
case "rowid":
|
||||
fields = append(fields, &ast.Field{
|
||||
Names: []*ast.Ident{ast.NewIdent("ID")},
|
||||
Type: ast.NewIdent(table.TypeIDName),
|
||||
Tag: &ast.BasicLit{Kind: token.STRING, Value: "`db:\"rowid\" json:\"id\"`"},
|
||||
})
|
||||
default:
|
||||
if col.IsForeignKey && strings.HasSuffix(col.Name, "_id") {
|
||||
fields = append(fields, &ast.Field{
|
||||
Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(strings.TrimSuffix(col.Name, "_id")) + "ID")},
|
||||
Type: ast.NewIdent(textutils.SnakeToCamel(inflection.Singular(col.ForeignKeyTargetTable) + "ID")),
|
||||
Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)},
|
||||
})
|
||||
} else {
|
||||
typeName := "string"
|
||||
switch col.Type {
|
||||
case "integer", "int":
|
||||
if strings.HasPrefix(col.Name, "is_") || strings.HasPrefix(col.Name, "has_") {
|
||||
typeName = "bool"
|
||||
} else if strings.HasSuffix(col.Name, "_at") {
|
||||
typeName = "Timestamp"
|
||||
} else {
|
||||
typeName = "int64"
|
||||
}
|
||||
case "text":
|
||||
typeName = "string"
|
||||
case "real":
|
||||
typeName = "float32"
|
||||
case "blob":
|
||||
typeName = "[]byte"
|
||||
default:
|
||||
panic("Unrecognized sqlite column type: " + col.Type)
|
||||
}
|
||||
|
||||
fields = append(fields, &ast.Field{
|
||||
Names: []*ast.Ident{ast.NewIdent(textutils.SnakeToCamel(col.Name))},
|
||||
Type: ast.NewIdent(typeName),
|
||||
Tag: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("`db:\"%s\" json:\"%s\"`", col.Name, col.Name)},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ast.GenDecl{
|
||||
Tok: token.TYPE,
|
||||
Specs: []ast.Spec{&ast.TypeSpec{
|
||||
Name: ast.NewIdent(table.TypeName),
|
||||
Type: &ast.StructType{Fields: &ast.FieldList{List: fields}},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSaveItemFunc produces an AST for the SaveXyz() function of the model.
|
||||
// E.g., a table with `table.TypeName = "foods"` will produce a "SaveFood()" function.
|
||||
func GenerateSaveItemFunc(tbl schema.Table) *ast.FuncDecl {
|
||||
insertCols := make([]string, 0, len(tbl.Columns))
|
||||
insertVals := make([]string, 0, len(tbl.Columns))
|
||||
updatePairs := make([]string, 0, len(tbl.Columns))
|
||||
|
||||
for _, col := range tbl.Columns {
|
||||
if col.Name == "rowid" {
|
||||
continue
|
||||
}
|
||||
insertCols = append(insertCols, col.Name)
|
||||
val := ":" + col.Name
|
||||
if col.IsNullableForeignKey() {
|
||||
val = fmt.Sprintf("nullif(%s, 0)", val)
|
||||
}
|
||||
insertVals = append(insertVals, val)
|
||||
updatePairs = append(updatePairs, col.Name+"="+val)
|
||||
}
|
||||
|
||||
insertStmt := fmt.Sprintf("\n\t\t insert into %s (%s)\n\t\t values (%s)\n\t\t", tbl.TableName, strings.Join(insertCols, ", "), strings.Join(insertVals, ", "))
|
||||
updateStmt := fmt.Sprintf("\n\t\t update %s\n\t\t set %s\n\t\t where rowid = :rowid\n\t\t", tbl.TableName, strings.Join(updatePairs, ",\n\t\t "))
|
||||
|
||||
funcBody := &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.IfStmt{
|
||||
Cond: &ast.BinaryExpr{
|
||||
X: &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")},
|
||||
Op: token.EQL,
|
||||
Y: &ast.BasicLit{Kind: token.INT, Value: "0"},
|
||||
},
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("result"), ast.NewIdent("err")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`" + insertStmt + "`"}, ast.NewIdent(tbl.VarName)}}},
|
||||
},
|
||||
&ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("id"), ast.NewIdent("err")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("LastInsertId")}, Args: []ast.Expr{}}},
|
||||
},
|
||||
&ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}},
|
||||
&ast.AssignStmt{Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}, Tok: token.ASSIGN, Rhs: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent(tbl.TypeIDName), Args: []ast.Expr{ast.NewIdent("id")}}}},
|
||||
},
|
||||
},
|
||||
Else: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("result"), ast.NewIdent("err")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("NamedExec")}, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`" + updateStmt + "`"}, ast.NewIdent(tbl.VarName)}}},
|
||||
},
|
||||
&ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("count"), ast.NewIdent("err")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, Args: []ast.Expr{}}},
|
||||
},
|
||||
&ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}},
|
||||
&ast.IfStmt{
|
||||
Cond: &ast.BinaryExpr{X: ast.NewIdent("count"), Op: token.NEQ, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}},
|
||||
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("fmt.Errorf"), Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"got %s with ID (%%d), so attempted update, but it doesn't exist\"", strings.ToLower(tbl.TypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}}}}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
funcDecl := &ast.FuncDecl{
|
||||
Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}},
|
||||
Name: ast.NewIdent("Save" + tbl.TypeName),
|
||||
Type: &ast.FuncType{
|
||||
Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: &ast.StarExpr{X: ast.NewIdent(tbl.TypeName)}}}},
|
||||
Results: nil,
|
||||
},
|
||||
Body: funcBody,
|
||||
}
|
||||
return funcDecl
|
||||
}
|
||||
|
||||
// GenerateGetItemByIDFunc produces an AST for the `GetXyzByID()` function.
|
||||
// E.g., a table with `table.TypeName = "foods"` will produce a "GetFoodByID()" function.
|
||||
func GenerateGetItemByIDFunc(tbl schema.Table) *ast.FuncDecl {
|
||||
funcName := "Get" + tbl.TypeName + "ByID"
|
||||
|
||||
recv := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}}
|
||||
arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("id")}, Type: ast.NewIdent(tbl.TypeIDName)}}}
|
||||
result := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("ret")}, Type: ast.NewIdent(tbl.TypeName)}, {Names: []*ast.Ident{ast.NewIdent("err")}, Type: ast.NewIdent("error")}}}
|
||||
|
||||
selectCols := make([]string, 0, len(tbl.Columns))
|
||||
for _, col := range tbl.Columns {
|
||||
selectCols = append(selectCols, col.Name)
|
||||
}
|
||||
selectStmt := fmt.Sprintf("\n\t select %s\n\t from %s\n\t where rowid = ?\n\t", strings.Join(selectCols, ", "), tbl.TableName)
|
||||
|
||||
funcBody := &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("err")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Get")}, Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: ast.NewIdent("ret")}, &ast.BasicLit{Kind: token.STRING, Value: "`" + selectStmt + "`"}, ast.NewIdent("id")}}},
|
||||
},
|
||||
&ast.IfStmt{
|
||||
Cond: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("errors"), Sel: ast.NewIdent("Is")}, Args: []ast.Expr{ast.NewIdent("err"), &ast.SelectorExpr{X: ast.NewIdent("sql"), Sel: ast.NewIdent("ErrNoRows")}}},
|
||||
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.CompositeLit{Type: ast.NewIdent(tbl.TypeName)}, ast.NewIdent("ErrNotInDB")}}}},
|
||||
},
|
||||
&ast.ReturnStmt{},
|
||||
},
|
||||
}
|
||||
|
||||
funcDecl := &ast.FuncDecl{
|
||||
Recv: recv,
|
||||
Name: ast.NewIdent(funcName),
|
||||
Type: &ast.FuncType{Params: arg, Results: result},
|
||||
Body: funcBody,
|
||||
}
|
||||
return funcDecl
|
||||
}
|
||||
|
||||
// GenerateDeleteItemFunc produces an AST for the `DeleteXyz()` function.
|
||||
// E.g., a table with `table.TypeName = "foods"` will produce a "DeleteFood()" function.
|
||||
func GenerateDeleteItemFunc(tbl schema.Table) *ast.FuncDecl {
|
||||
funcName := "Delete" + tbl.TypeName
|
||||
recv := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent("db")}, Type: ast.NewIdent("DB")}}}
|
||||
arg := &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{ast.NewIdent(tbl.VarName)}, Type: ast.NewIdent(tbl.TypeName)}}}
|
||||
|
||||
funcBody := &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("result"), ast.NewIdent("err")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("db.DB"), Sel: ast.NewIdent("Exec")}, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: "`delete from " + tbl.TableName + " where rowid = ?`"}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}},
|
||||
},
|
||||
&ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("count"), ast.NewIdent("err")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent("result"), Sel: ast.NewIdent("RowsAffected")}, Args: []ast.Expr{}}},
|
||||
},
|
||||
&ast.IfStmt{Cond: &ast.BinaryExpr{X: ast.NewIdent("err"), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{ast.NewIdent("err")}}}}}},
|
||||
&ast.IfStmt{
|
||||
Cond: &ast.BinaryExpr{X: ast.NewIdent("count"), Op: token.NEQ, Y: &ast.BasicLit{Kind: token.INT, Value: "1"}},
|
||||
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: ast.NewIdent("panic"), Args: []ast.Expr{&ast.CallExpr{Fun: ast.NewIdent("fmt.Errorf"), Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"tried to delete %s with ID (%%d) but it doesn't exist\"", strings.ToLower(tbl.TypeName))}, &ast.SelectorExpr{X: ast.NewIdent(tbl.VarName), Sel: ast.NewIdent("ID")}}}}}}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
funcDecl := &ast.FuncDecl{
|
||||
Recv: recv,
|
||||
Name: ast.NewIdent(funcName),
|
||||
Type: &ast.FuncType{Params: arg, Results: nil},
|
||||
Body: funcBody,
|
||||
}
|
||||
return funcDecl
|
||||
}
|
||||
251
pkg/codegen/modelgenerate/generate_testfile.go
Normal file
251
pkg/codegen/modelgenerate/generate_testfile.go
Normal file
@ -0,0 +1,251 @@
|
||||
package modelgenerate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
|
||||
"git.offline-twitter.com/offline-labs/gas-stack/pkg/schema"
|
||||
)
|
||||
|
||||
// GenerateModelTestAST produces an AST for a starter test file for a given model.
|
||||
func GenerateModelTestAST(tbl schema.Table, gomodName string) *ast.File {
|
||||
packageName := "db"
|
||||
testpackageName := packageName + "_test"
|
||||
|
||||
testDBDecl := &ast.GenDecl{
|
||||
Tok: token.VAR,
|
||||
Specs: []ast.Spec{
|
||||
&ast.ValueSpec{
|
||||
Names: []*ast.Ident{ast.NewIdent("TestDB")},
|
||||
Type: &ast.StarExpr{X: ast.NewIdent("DB")},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
initFuncDecl := &ast.FuncDecl{
|
||||
Name: ast.NewIdent("init"),
|
||||
Type: &ast.FuncType{Params: &ast.FieldList{}},
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("TestDB")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("MakeDB"),
|
||||
Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: `"tmp"`}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
makeDBHelperDecl := &ast.FuncDecl{
|
||||
Name: ast.NewIdent("MakeDB"),
|
||||
Type: &ast.FuncType{
|
||||
Params: &ast.FieldList{
|
||||
List: []*ast.Field{{
|
||||
Names: []*ast.Ident{ast.NewIdent("dbName")},
|
||||
Type: ast.NewIdent("string"),
|
||||
}},
|
||||
},
|
||||
Results: &ast.FieldList{
|
||||
List: []*ast.Field{{Type: &ast.StarExpr{X: ast.NewIdent("DB")}}},
|
||||
},
|
||||
},
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
// db := Must(Create(fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName)))
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("db")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("Must"),
|
||||
Args: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("Create"),
|
||||
Args: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("fmt.Sprintf"),
|
||||
Args: []ast.Expr{
|
||||
&ast.BasicLit{Kind: token.STRING, Value: `"file:%s?mode=memory&cache=shared"`},
|
||||
ast.NewIdent("dbName"),
|
||||
},
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
// return db
|
||||
&ast.ReturnStmt{
|
||||
Results: []ast.Expr{ast.NewIdent("db")},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testObj := ast.NewIdent("item")
|
||||
testObj2 := ast.NewIdent("item2")
|
||||
fieldName := ast.NewIdent("Description")
|
||||
description1 := `"an item"`
|
||||
description2 := `"a big item"`
|
||||
|
||||
return &ast.File{
|
||||
Name: ast.NewIdent(testpackageName),
|
||||
Decls: []ast.Decl{
|
||||
&ast.GenDecl{
|
||||
Tok: token.IMPORT,
|
||||
Specs: []ast.Spec{
|
||||
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"fmt"`}},
|
||||
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"testing"`}},
|
||||
&ast.ImportSpec{
|
||||
Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/db"`},
|
||||
Name: ast.NewIdent("db"),
|
||||
},
|
||||
&ast.ImportSpec{
|
||||
Path: &ast.BasicLit{Kind: token.STRING, Value: `"git.offline-twitter.com/offline-labs/gas-stack/pkg/flowutils"`},
|
||||
Name: ast.NewIdent("."),
|
||||
},
|
||||
&ast.ImportSpec{
|
||||
Path: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s/pkg/%s"`, gomodName, packageName)},
|
||||
Name: ast.NewIdent("."),
|
||||
},
|
||||
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"github.com/stretchr/testify/assert"`}},
|
||||
&ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"github.com/stretchr/testify/require"`}},
|
||||
},
|
||||
},
|
||||
// var TestDB *DB
|
||||
testDBDecl,
|
||||
|
||||
// func init() { TestDB = MakeDB("tmp") }
|
||||
initFuncDecl,
|
||||
|
||||
// func MakeDB(dbName string) *DB { db := Must(Create(fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName))); return db }
|
||||
makeDBHelperDecl,
|
||||
|
||||
&ast.FuncDecl{
|
||||
Name: ast.NewIdent("TestCreateUpdateDelete" + tbl.TypeName),
|
||||
Type: &ast.FuncType{
|
||||
Params: &ast.FieldList{
|
||||
List: []*ast.Field{{
|
||||
Names: []*ast.Ident{ast.NewIdent("t")},
|
||||
Type: ast.NewIdent("*testing.T"),
|
||||
}},
|
||||
},
|
||||
},
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
// item := Item{Description: "an item"}
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{testObj},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CompositeLit{
|
||||
Type: ast.NewIdent(tbl.TypeName),
|
||||
Elts: []ast.Expr{
|
||||
&ast.KeyValueExpr{
|
||||
Key: fieldName,
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: description1},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
|
||||
// TestDB.SaveItem(&item)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: ast.NewIdent("TestDB.Save" + tbl.TypeName),
|
||||
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
|
||||
}},
|
||||
|
||||
// require.NotZero(t, item.ID)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: ast.NewIdent("require.NotZero"),
|
||||
Args: []ast.Expr{ast.NewIdent("t"), &ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
|
||||
}},
|
||||
|
||||
// item2 := Must(TestDB.GetItemByID(item.ID))
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{testObj2},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("Must"),
|
||||
Args: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("TestDB.Get" + tbl.TypeName + "ByID"),
|
||||
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
|
||||
// assert.Equal(t, "an item", item2.Description)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: ast.NewIdent("assert.Equal"),
|
||||
Args: []ast.Expr{
|
||||
ast.NewIdent("t"),
|
||||
&ast.BasicLit{Kind: token.STRING, Value: description1},
|
||||
&ast.SelectorExpr{X: testObj2, Sel: fieldName},
|
||||
},
|
||||
}},
|
||||
|
||||
// item.Description = "a big item"
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: fieldName}},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: description2}},
|
||||
},
|
||||
|
||||
// TestDB.SaveItem(&item)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: ast.NewIdent("TestDB.Save" + tbl.TypeName),
|
||||
Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: testObj}},
|
||||
}},
|
||||
|
||||
// item2 = Must(TestDB.GetItemByID(item.ID))
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{testObj2},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("Must"),
|
||||
Args: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("TestDB.Get" + tbl.TypeName + "ByID"),
|
||||
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
|
||||
// assert.Equal(t, item.Description, item2.Description)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: ast.NewIdent("assert.Equal"),
|
||||
Args: []ast.Expr{
|
||||
ast.NewIdent("t"),
|
||||
&ast.SelectorExpr{X: testObj, Sel: fieldName},
|
||||
&ast.SelectorExpr{X: testObj2, Sel: fieldName},
|
||||
},
|
||||
}},
|
||||
|
||||
// TestDB.DeleteItem(item)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: ast.NewIdent("TestDB.Delete" + tbl.TypeName),
|
||||
Args: []ast.Expr{testObj},
|
||||
}},
|
||||
|
||||
// _, err := TestDB.GetItemByID(item.ID)
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent("_"), ast.NewIdent("err")},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.CallExpr{
|
||||
Fun: ast.NewIdent("TestDB.Get" + tbl.TypeName + "ByID"),
|
||||
Args: []ast.Expr{&ast.SelectorExpr{X: testObj, Sel: ast.NewIdent("ID")}},
|
||||
}},
|
||||
},
|
||||
|
||||
// assert.ErrorIs(t, err, db.ErrNotInDB)
|
||||
&ast.ExprStmt{X: &ast.CallExpr{
|
||||
Fun: ast.NewIdent("assert.ErrorIs"),
|
||||
Args: []ast.Expr{
|
||||
ast.NewIdent("t"),
|
||||
ast.NewIdent("err"),
|
||||
&ast.SelectorExpr{X: ast.NewIdent("db"), Sel: ast.NewIdent("ErrNotInDB")},
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -37,10 +37,10 @@ func InitPkg(opts PkgOpts) {
|
||||
PanicIf(os.MkdirAll("sample_data", 0o755))
|
||||
|
||||
PanicIf(os.WriteFile("pkg/db/schema.sql", Must(tpl.ReadFile("tpl/schema.sql")), 0o664))
|
||||
PanicIf(os.WriteFile("pkg/db/db.go", Must(tpl.ReadFile("tpl/db.go")), 0o664))
|
||||
|
||||
PanicIf(os.WriteFile("sample_data/mount.sh", Must(tpl.ReadFile("tpl/mount.sh")), 0o775))
|
||||
PanicIf(os.WriteFile("sample_data/reset.sh", Must(tpl.ReadFile("tpl/reset.sh")), 0o775))
|
||||
PanicIf(os.WriteFile("pkg/db/schema.sql", Must(tpl.ReadFile("tpl/schema.sql")), 0o664))
|
||||
|
||||
// TODO:
|
||||
// - create `pkg/db/errors.go`
|
||||
|
||||
38
pkg/codegen/tpl/db.go.tpl
Normal file
38
pkg/codegen/tpl/db.go.tpl
Normal file
@ -0,0 +1,38 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"git.offline-twitter.com/offline-labs/gas-stack/pkg/db"
|
||||
)
|
||||
|
||||
//go:embed schema.sql
|
||||
var sql_schema string
|
||||
|
||||
// Database starts at version 0. First migration brings us to version 1
|
||||
var MIGRATIONS = []string{}
|
||||
|
||||
type DB struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func Create(path string) (*DB, error) {
|
||||
conf := db.Init(&sql_schema, &MIGRATIONS)
|
||||
dbHandle, err := conf.Create(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating db: %w", err)
|
||||
}
|
||||
return &DB{dbHandle}, nil
|
||||
}
|
||||
|
||||
func Connect(path string) (*DB, error) {
|
||||
conf := db.Init(&sql_schema, &MIGRATIONS)
|
||||
dbHandle, err := conf.Connect(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating db: %w", err)
|
||||
}
|
||||
return &DB{dbHandle}, nil
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user