Платформа ЦРНП "Мирокод" для разработки проектов
https://git.mirocod.ru
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
599 lines
15 KiB
599 lines
15 KiB
package testfixtures // import "github.com/go-testfixtures/testfixtures/v3" |
|
|
|
import ( |
|
"bytes" |
|
"database/sql" |
|
"fmt" |
|
"io/ioutil" |
|
"os" |
|
"path" |
|
"path/filepath" |
|
"regexp" |
|
"strings" |
|
"text/template" |
|
"time" |
|
|
|
"gopkg.in/yaml.v2" |
|
) |
|
|
|
// Loader is the responsible to loading fixtures. |
|
type Loader struct { |
|
db *sql.DB |
|
helper helper |
|
fixturesFiles []*fixtureFile |
|
|
|
skipTestDatabaseCheck bool |
|
location *time.Location |
|
|
|
template bool |
|
templateFuncs template.FuncMap |
|
templateLeftDelim string |
|
templateRightDelim string |
|
templateOptions []string |
|
templateData interface{} |
|
} |
|
|
|
type fixtureFile struct { |
|
path string |
|
fileName string |
|
content []byte |
|
insertSQLs []insertSQL |
|
} |
|
|
|
type insertSQL struct { |
|
sql string |
|
params []interface{} |
|
} |
|
|
|
var ( |
|
testDatabaseRegexp = regexp.MustCompile("(?i)test") |
|
|
|
errDatabaseIsRequired = fmt.Errorf("testfixtures: database is required") |
|
errDialectIsRequired = fmt.Errorf("testfixtures: dialect is required") |
|
) |
|
|
|
// New instantiates a new Loader instance. The "Database" and "Driver" |
|
// options are required. |
|
func New(options ...func(*Loader) error) (*Loader, error) { |
|
l := &Loader{ |
|
templateLeftDelim: "{{", |
|
templateRightDelim: "}}", |
|
templateOptions: []string{"missingkey=zero"}, |
|
} |
|
|
|
for _, option := range options { |
|
if err := option(l); err != nil { |
|
return nil, err |
|
} |
|
} |
|
|
|
if l.db == nil { |
|
return nil, errDatabaseIsRequired |
|
} |
|
if l.helper == nil { |
|
return nil, errDialectIsRequired |
|
} |
|
|
|
if err := l.helper.init(l.db); err != nil { |
|
return nil, err |
|
} |
|
if err := l.buildInsertSQLs(); err != nil { |
|
return nil, err |
|
} |
|
|
|
return l, nil |
|
} |
|
|
|
// Database sets an existing sql.DB instant to Loader. |
|
func Database(db *sql.DB) func(*Loader) error { |
|
return func(l *Loader) error { |
|
l.db = db |
|
return nil |
|
} |
|
} |
|
|
|
// Dialect informs Loader about which database dialect you're using. |
|
// |
|
// Possible options are "postgresql", "timescaledb", "mysql", "mariadb", |
|
// "sqlite" and "sqlserver". |
|
func Dialect(dialect string) func(*Loader) error { |
|
return func(l *Loader) error { |
|
h, err := helperForDialect(dialect) |
|
if err != nil { |
|
return err |
|
} |
|
l.helper = h |
|
return nil |
|
} |
|
} |
|
|
|
func helperForDialect(dialect string) (helper, error) { |
|
switch dialect { |
|
case "postgres", "postgresql", "timescaledb": |
|
return &postgreSQL{}, nil |
|
case "mysql", "mariadb": |
|
return &mySQL{}, nil |
|
case "sqlite", "sqlite3": |
|
return &sqlite{}, nil |
|
case "mssql", "sqlserver": |
|
return &sqlserver{}, nil |
|
default: |
|
return nil, fmt.Errorf(`testfixtures: unrecognized dialect "%s"`, dialect) |
|
} |
|
} |
|
|
|
// UseAlterConstraint If true, the contraint disabling will do |
|
// using ALTER CONTRAINT sintax, only allowed in PG >= 9.4. |
|
// If false, the constraint disabling will use DISABLE TRIGGER ALL, |
|
// which requires SUPERUSER privileges. |
|
// |
|
// Only valid for PostgreSQL. Returns an error otherwise. |
|
func UseAlterConstraint() func(*Loader) error { |
|
return func(l *Loader) error { |
|
pgHelper, ok := l.helper.(*postgreSQL) |
|
if !ok { |
|
return fmt.Errorf("testfixtures: UseAlterConstraint is only valid for PostgreSQL databases") |
|
} |
|
pgHelper.useAlterConstraint = true |
|
return nil |
|
} |
|
} |
|
|
|
// SkipResetSequences prevents Loader from reseting sequences after loading |
|
// fixtures. |
|
// |
|
// Only valid for PostgreSQL. Returns an error otherwise. |
|
func SkipResetSequences() func(*Loader) error { |
|
return func(l *Loader) error { |
|
pgHelper, ok := l.helper.(*postgreSQL) |
|
if !ok { |
|
return fmt.Errorf("testfixtures: SkipResetSequences is only valid for PostgreSQL databases") |
|
} |
|
pgHelper.skipResetSequences = true |
|
return nil |
|
} |
|
} |
|
|
|
// ResetSequencesTo sets the value the sequences will be reset to. |
|
// |
|
// Defaults to 10000. |
|
// |
|
// Only valid for PostgreSQL. Returns an error otherwise. |
|
func ResetSequencesTo(value int64) func(*Loader) error { |
|
return func(l *Loader) error { |
|
pgHelper, ok := l.helper.(*postgreSQL) |
|
if !ok { |
|
return fmt.Errorf("testfixtures: ResetSequencesTo is only valid for PostgreSQL databases") |
|
} |
|
pgHelper.resetSequencesTo = value |
|
return nil |
|
} |
|
} |
|
|
|
// DangerousSkipTestDatabaseCheck will make Loader not check if the database |
|
// name contains "test". Use with caution! |
|
func DangerousSkipTestDatabaseCheck() func(*Loader) error { |
|
return func(l *Loader) error { |
|
l.skipTestDatabaseCheck = true |
|
return nil |
|
} |
|
} |
|
|
|
// Directory informs Loader to load YAML files from a given directory. |
|
func Directory(dir string) func(*Loader) error { |
|
return func(l *Loader) error { |
|
fixtures, err := l.fixturesFromDir(dir) |
|
if err != nil { |
|
return err |
|
} |
|
l.fixturesFiles = append(l.fixturesFiles, fixtures...) |
|
return nil |
|
} |
|
} |
|
|
|
// Files informs Loader to load a given set of YAML files. |
|
func Files(files ...string) func(*Loader) error { |
|
return func(l *Loader) error { |
|
fixtures, err := l.fixturesFromFiles(files...) |
|
if err != nil { |
|
return err |
|
} |
|
l.fixturesFiles = append(l.fixturesFiles, fixtures...) |
|
return nil |
|
} |
|
} |
|
|
|
// Paths inform Loader to load a given set of YAML files and directories. |
|
func Paths(paths ...string) func(*Loader) error { |
|
return func(l *Loader) error { |
|
fixtures, err := l.fixturesFromPaths(paths...) |
|
if err != nil { |
|
return err |
|
} |
|
l.fixturesFiles = append(l.fixturesFiles, fixtures...) |
|
return nil |
|
} |
|
} |
|
|
|
// Location makes Loader use the given location by default when parsing |
|
// dates. If not given, by default it uses the value of time.Local. |
|
func Location(location *time.Location) func(*Loader) error { |
|
return func(l *Loader) error { |
|
l.location = location |
|
return nil |
|
} |
|
} |
|
|
|
// Template makes loader process each YAML file as an template using the |
|
// text/template package. |
|
// |
|
// For more information on how templates work in Go please read: |
|
// https://golang.org/pkg/text/template/ |
|
// |
|
// If not given the YAML files are parsed as is. |
|
func Template() func(*Loader) error { |
|
return func(l *Loader) error { |
|
l.template = true |
|
return nil |
|
} |
|
} |
|
|
|
// TemplateFuncs allow choosing which functions will be available |
|
// when processing templates. |
|
// |
|
// For more information see: https://golang.org/pkg/text/template/#Template.Funcs |
|
func TemplateFuncs(funcs template.FuncMap) func(*Loader) error { |
|
return func(l *Loader) error { |
|
if !l.template { |
|
return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateFuns() option`) |
|
} |
|
|
|
l.templateFuncs = funcs |
|
return nil |
|
} |
|
} |
|
|
|
// TemplateDelims allow choosing which delimiters will be used for templating. |
|
// This defaults to "{{" and "}}". |
|
// |
|
// For more information see https://golang.org/pkg/text/template/#Template.Delims |
|
func TemplateDelims(left, right string) func(*Loader) error { |
|
return func(l *Loader) error { |
|
if !l.template { |
|
return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateDelims() option`) |
|
} |
|
|
|
l.templateLeftDelim = left |
|
l.templateRightDelim = right |
|
return nil |
|
} |
|
} |
|
|
|
// TemplateOptions allows you to specific which text/template options will |
|
// be enabled when processing templates. |
|
// |
|
// This defaults to "missingkey=zero". Check the available options here: |
|
// https://golang.org/pkg/text/template/#Template.Option |
|
func TemplateOptions(options ...string) func(*Loader) error { |
|
return func(l *Loader) error { |
|
if !l.template { |
|
return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateOptions() option`) |
|
} |
|
|
|
l.templateOptions = options |
|
return nil |
|
} |
|
} |
|
|
|
// TemplateData allows you to specify which data will be available |
|
// when processing templates. Data is accesible by prefixing it with a "." |
|
// like {{.MyKey}}. |
|
func TemplateData(data interface{}) func(*Loader) error { |
|
return func(l *Loader) error { |
|
if !l.template { |
|
return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateData() option`) |
|
} |
|
|
|
l.templateData = data |
|
return nil |
|
} |
|
} |
|
|
|
// EnsureTestDatabase returns an error if the database name does not contains |
|
// "test". |
|
func (l *Loader) EnsureTestDatabase() error { |
|
dbName, err := l.helper.databaseName(l.db) |
|
if err != nil { |
|
return err |
|
} |
|
if !testDatabaseRegexp.MatchString(dbName) { |
|
return fmt.Errorf(`testfixtures: database "%s" does not appear to be a test database`, dbName) |
|
} |
|
return nil |
|
} |
|
|
|
// Load wipes and after load all fixtures in the database. |
|
// if err := fixtures.Load(); err != nil { |
|
// ... |
|
// } |
|
func (l *Loader) Load() error { |
|
if !l.skipTestDatabaseCheck { |
|
if err := l.EnsureTestDatabase(); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
err := l.helper.disableReferentialIntegrity(l.db, func(tx *sql.Tx) error { |
|
for _, file := range l.fixturesFiles { |
|
modified, err := l.helper.isTableModified(tx, file.fileNameWithoutExtension()) |
|
if err != nil { |
|
return err |
|
} |
|
if !modified { |
|
continue |
|
} |
|
if err := file.delete(tx, l.helper); err != nil { |
|
return err |
|
} |
|
|
|
err = l.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error { |
|
for j, i := range file.insertSQLs { |
|
if _, err := tx.Exec(i.sql, i.params...); err != nil { |
|
return &InsertError{ |
|
Err: err, |
|
File: file.fileName, |
|
Index: j, |
|
SQL: i.sql, |
|
Params: i.params, |
|
} |
|
} |
|
} |
|
return nil |
|
}) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
return nil |
|
}) |
|
if err != nil { |
|
return err |
|
} |
|
return l.helper.afterLoad(l.db) |
|
} |
|
|
|
// InsertError will be returned if any error happens on database while |
|
// inserting the record. |
|
type InsertError struct { |
|
Err error |
|
File string |
|
Index int |
|
SQL string |
|
Params []interface{} |
|
} |
|
|
|
func (e *InsertError) Error() string { |
|
return fmt.Sprintf( |
|
"testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v", |
|
e.Err, |
|
e.File, |
|
e.Index, |
|
e.SQL, |
|
e.Params, |
|
) |
|
} |
|
|
|
func (l *Loader) buildInsertSQLs() error { |
|
for _, f := range l.fixturesFiles { |
|
var records interface{} |
|
if err := yaml.Unmarshal(f.content, &records); err != nil { |
|
return fmt.Errorf("testfixtures: could not unmarshal YAML: %w", err) |
|
} |
|
|
|
switch records := records.(type) { |
|
case []interface{}: |
|
f.insertSQLs = make([]insertSQL, 0, len(records)) |
|
|
|
for _, record := range records { |
|
recordMap, ok := record.(map[interface{}]interface{}) |
|
if !ok { |
|
return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}") |
|
} |
|
|
|
sql, values, err := l.buildInsertSQL(f, recordMap) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values}) |
|
} |
|
case map[interface{}]interface{}: |
|
f.insertSQLs = make([]insertSQL, 0, len(records)) |
|
|
|
for _, record := range records { |
|
recordMap, ok := record.(map[interface{}]interface{}) |
|
if !ok { |
|
return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}") |
|
} |
|
|
|
sql, values, err := l.buildInsertSQL(f, recordMap) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values}) |
|
} |
|
default: |
|
return fmt.Errorf("testfixtures: fixture is not a slice or map") |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (f *fixtureFile) fileNameWithoutExtension() string { |
|
return strings.Replace(f.fileName, filepath.Ext(f.fileName), "", 1) |
|
} |
|
|
|
func (f *fixtureFile) delete(tx *sql.Tx, h helper) error { |
|
if _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", h.quoteKeyword(f.fileNameWithoutExtension()))); err != nil { |
|
return fmt.Errorf(`testfixtures: could not clean table "%s": %w`, f.fileNameWithoutExtension(), err) |
|
} |
|
return nil |
|
} |
|
|
|
func (l *Loader) buildInsertSQL(f *fixtureFile, record map[interface{}]interface{}) (sqlStr string, values []interface{}, err error) { |
|
var ( |
|
sqlColumns = make([]string, 0, len(record)) |
|
sqlValues = make([]string, 0, len(record)) |
|
i = 1 |
|
) |
|
for key, value := range record { |
|
keyStr, ok := key.(string) |
|
if !ok { |
|
err = fmt.Errorf("testfixtures: record map key is not a string") |
|
return |
|
} |
|
|
|
sqlColumns = append(sqlColumns, l.helper.quoteKeyword(keyStr)) |
|
|
|
// if string, try convert to SQL or time |
|
// if map or array, convert to json |
|
switch v := value.(type) { |
|
case string: |
|
if strings.HasPrefix(v, "RAW=") { |
|
sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW=")) |
|
continue |
|
} |
|
|
|
if t, err := l.tryStrToDate(v); err == nil { |
|
value = t |
|
} |
|
case []interface{}, map[interface{}]interface{}: |
|
value = recursiveToJSON(v) |
|
} |
|
|
|
switch l.helper.paramType() { |
|
case paramTypeDollar: |
|
sqlValues = append(sqlValues, fmt.Sprintf("$%d", i)) |
|
case paramTypeQuestion: |
|
sqlValues = append(sqlValues, "?") |
|
case paramTypeAtSign: |
|
sqlValues = append(sqlValues, fmt.Sprintf("@p%d", i)) |
|
} |
|
|
|
values = append(values, value) |
|
i++ |
|
} |
|
|
|
sqlStr = fmt.Sprintf( |
|
"INSERT INTO %s (%s) VALUES (%s)", |
|
l.helper.quoteKeyword(f.fileNameWithoutExtension()), |
|
strings.Join(sqlColumns, ", "), |
|
strings.Join(sqlValues, ", "), |
|
) |
|
return |
|
} |
|
|
|
func (l *Loader) fixturesFromDir(dir string) ([]*fixtureFile, error) { |
|
fileinfos, err := ioutil.ReadDir(dir) |
|
if err != nil { |
|
return nil, fmt.Errorf(`testfixtures: could not stat directory "%s": %w`, dir, err) |
|
} |
|
|
|
files := make([]*fixtureFile, 0, len(fileinfos)) |
|
|
|
for _, fileinfo := range fileinfos { |
|
fileExt := filepath.Ext(fileinfo.Name()) |
|
if !fileinfo.IsDir() && (fileExt == ".yml" || fileExt == ".yaml") { |
|
fixture := &fixtureFile{ |
|
path: path.Join(dir, fileinfo.Name()), |
|
fileName: fileinfo.Name(), |
|
} |
|
fixture.content, err = ioutil.ReadFile(fixture.path) |
|
if err != nil { |
|
return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err) |
|
} |
|
if err := l.processFileTemplate(fixture); err != nil { |
|
return nil, err |
|
} |
|
files = append(files, fixture) |
|
} |
|
} |
|
return files, nil |
|
} |
|
|
|
func (l *Loader) fixturesFromFiles(fileNames ...string) ([]*fixtureFile, error) { |
|
var ( |
|
fixtureFiles = make([]*fixtureFile, 0, len(fileNames)) |
|
err error |
|
) |
|
|
|
for _, f := range fileNames { |
|
fixture := &fixtureFile{ |
|
path: f, |
|
fileName: filepath.Base(f), |
|
} |
|
fixture.content, err = ioutil.ReadFile(fixture.path) |
|
if err != nil { |
|
return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err) |
|
} |
|
if err := l.processFileTemplate(fixture); err != nil { |
|
return nil, err |
|
} |
|
fixtureFiles = append(fixtureFiles, fixture) |
|
} |
|
|
|
return fixtureFiles, nil |
|
} |
|
|
|
func (l *Loader) fixturesFromPaths(paths ...string) ([]*fixtureFile, error) { |
|
fixtureExtractor := func(p string, isDir bool) ([]*fixtureFile, error) { |
|
if isDir { |
|
return l.fixturesFromDir(p) |
|
} |
|
|
|
return l.fixturesFromFiles(p) |
|
} |
|
|
|
var fixtureFiles []*fixtureFile |
|
|
|
for _, p := range paths { |
|
f, err := os.Stat(p) |
|
if err != nil { |
|
return nil, fmt.Errorf(`testfixtures: could not stat path "%s": %w`, p, err) |
|
} |
|
|
|
fixtures, err := fixtureExtractor(p, f.IsDir()) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
fixtureFiles = append(fixtureFiles, fixtures...) |
|
} |
|
|
|
return fixtureFiles, nil |
|
} |
|
|
|
func (l *Loader) processFileTemplate(f *fixtureFile) error { |
|
if !l.template { |
|
return nil |
|
} |
|
|
|
t := template.New(""). |
|
Funcs(l.templateFuncs). |
|
Delims(l.templateLeftDelim, l.templateRightDelim). |
|
Option(l.templateOptions...) |
|
t, err := t.Parse(string(f.content)) |
|
if err != nil { |
|
return fmt.Errorf(`textfixtures: error on parsing template in %s: %w`, f.fileName, err) |
|
} |
|
|
|
var buffer bytes.Buffer |
|
if err := t.Execute(&buffer, l.templateData); err != nil { |
|
return fmt.Errorf(`textfixtures: error on executing template in %s: %w`, f.fileName, err) |
|
} |
|
|
|
f.content = buffer.Bytes() |
|
return nil |
|
}
|
|
|