600 lines
15 KiB
Go
600 lines
15 KiB
Go
package testfixtures // import "github.com/go-testfixtures/testfixtures/v3"
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"text/template"
|
|
"time"
|
|
|
|
"gopkg.in/yaml.v2"
|
|
)
|
|
|
|
// Loader is the responsible to loading fixtures.
|
|
type Loader struct {
|
|
db *sql.DB
|
|
helper helper
|
|
fixturesFiles []*fixtureFile
|
|
|
|
skipTestDatabaseCheck bool
|
|
location *time.Location
|
|
|
|
template bool
|
|
templateFuncs template.FuncMap
|
|
templateLeftDelim string
|
|
templateRightDelim string
|
|
templateOptions []string
|
|
templateData interface{}
|
|
}
|
|
|
|
type fixtureFile struct {
|
|
path string
|
|
fileName string
|
|
content []byte
|
|
insertSQLs []insertSQL
|
|
}
|
|
|
|
type insertSQL struct {
|
|
sql string
|
|
params []interface{}
|
|
}
|
|
|
|
var (
|
|
testDatabaseRegexp = regexp.MustCompile("(?i)test")
|
|
|
|
errDatabaseIsRequired = fmt.Errorf("testfixtures: database is required")
|
|
errDialectIsRequired = fmt.Errorf("testfixtures: dialect is required")
|
|
)
|
|
|
|
// New instantiates a new Loader instance. The "Database" and "Driver"
|
|
// options are required.
|
|
func New(options ...func(*Loader) error) (*Loader, error) {
|
|
l := &Loader{
|
|
templateLeftDelim: "{{",
|
|
templateRightDelim: "}}",
|
|
templateOptions: []string{"missingkey=zero"},
|
|
}
|
|
|
|
for _, option := range options {
|
|
if err := option(l); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if l.db == nil {
|
|
return nil, errDatabaseIsRequired
|
|
}
|
|
if l.helper == nil {
|
|
return nil, errDialectIsRequired
|
|
}
|
|
|
|
if err := l.helper.init(l.db); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := l.buildInsertSQLs(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return l, nil
|
|
}
|
|
|
|
// Database sets an existing sql.DB instant to Loader.
|
|
func Database(db *sql.DB) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
l.db = db
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Dialect informs Loader about which database dialect you're using.
|
|
//
|
|
// Possible options are "postgresql", "timescaledb", "mysql", "mariadb",
|
|
// "sqlite" and "sqlserver".
|
|
func Dialect(dialect string) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
h, err := helperForDialect(dialect)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
l.helper = h
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func helperForDialect(dialect string) (helper, error) {
|
|
switch dialect {
|
|
case "postgres", "postgresql", "timescaledb":
|
|
return &postgreSQL{}, nil
|
|
case "mysql", "mariadb":
|
|
return &mySQL{}, nil
|
|
case "sqlite", "sqlite3":
|
|
return &sqlite{}, nil
|
|
case "mssql", "sqlserver":
|
|
return &sqlserver{}, nil
|
|
default:
|
|
return nil, fmt.Errorf(`testfixtures: unrecognized dialect "%s"`, dialect)
|
|
}
|
|
}
|
|
|
|
// UseAlterConstraint If true, the contraint disabling will do
|
|
// using ALTER CONTRAINT sintax, only allowed in PG >= 9.4.
|
|
// If false, the constraint disabling will use DISABLE TRIGGER ALL,
|
|
// which requires SUPERUSER privileges.
|
|
//
|
|
// Only valid for PostgreSQL. Returns an error otherwise.
|
|
func UseAlterConstraint() func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
pgHelper, ok := l.helper.(*postgreSQL)
|
|
if !ok {
|
|
return fmt.Errorf("testfixtures: UseAlterConstraint is only valid for PostgreSQL databases")
|
|
}
|
|
pgHelper.useAlterConstraint = true
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// SkipResetSequences prevents Loader from reseting sequences after loading
|
|
// fixtures.
|
|
//
|
|
// Only valid for PostgreSQL. Returns an error otherwise.
|
|
func SkipResetSequences() func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
pgHelper, ok := l.helper.(*postgreSQL)
|
|
if !ok {
|
|
return fmt.Errorf("testfixtures: SkipResetSequences is only valid for PostgreSQL databases")
|
|
}
|
|
pgHelper.skipResetSequences = true
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// ResetSequencesTo sets the value the sequences will be reset to.
|
|
//
|
|
// Defaults to 10000.
|
|
//
|
|
// Only valid for PostgreSQL. Returns an error otherwise.
|
|
func ResetSequencesTo(value int64) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
pgHelper, ok := l.helper.(*postgreSQL)
|
|
if !ok {
|
|
return fmt.Errorf("testfixtures: ResetSequencesTo is only valid for PostgreSQL databases")
|
|
}
|
|
pgHelper.resetSequencesTo = value
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// DangerousSkipTestDatabaseCheck will make Loader not check if the database
|
|
// name contains "test". Use with caution!
|
|
func DangerousSkipTestDatabaseCheck() func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
l.skipTestDatabaseCheck = true
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Directory informs Loader to load YAML files from a given directory.
|
|
func Directory(dir string) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
fixtures, err := l.fixturesFromDir(dir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
l.fixturesFiles = append(l.fixturesFiles, fixtures...)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Files informs Loader to load a given set of YAML files.
|
|
func Files(files ...string) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
fixtures, err := l.fixturesFromFiles(files...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
l.fixturesFiles = append(l.fixturesFiles, fixtures...)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Paths inform Loader to load a given set of YAML files and directories.
|
|
func Paths(paths ...string) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
fixtures, err := l.fixturesFromPaths(paths...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
l.fixturesFiles = append(l.fixturesFiles, fixtures...)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Location makes Loader use the given location by default when parsing
|
|
// dates. If not given, by default it uses the value of time.Local.
|
|
func Location(location *time.Location) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
l.location = location
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Template makes loader process each YAML file as an template using the
|
|
// text/template package.
|
|
//
|
|
// For more information on how templates work in Go please read:
|
|
// https://golang.org/pkg/text/template/
|
|
//
|
|
// If not given the YAML files are parsed as is.
|
|
func Template() func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
l.template = true
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// TemplateFuncs allow choosing which functions will be available
|
|
// when processing templates.
|
|
//
|
|
// For more information see: https://golang.org/pkg/text/template/#Template.Funcs
|
|
func TemplateFuncs(funcs template.FuncMap) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
if !l.template {
|
|
return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateFuns() option`)
|
|
}
|
|
|
|
l.templateFuncs = funcs
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// TemplateDelims allow choosing which delimiters will be used for templating.
|
|
// This defaults to "{{" and "}}".
|
|
//
|
|
// For more information see https://golang.org/pkg/text/template/#Template.Delims
|
|
func TemplateDelims(left, right string) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
if !l.template {
|
|
return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateDelims() option`)
|
|
}
|
|
|
|
l.templateLeftDelim = left
|
|
l.templateRightDelim = right
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// TemplateOptions allows you to specific which text/template options will
|
|
// be enabled when processing templates.
|
|
//
|
|
// This defaults to "missingkey=zero". Check the available options here:
|
|
// https://golang.org/pkg/text/template/#Template.Option
|
|
func TemplateOptions(options ...string) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
if !l.template {
|
|
return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateOptions() option`)
|
|
}
|
|
|
|
l.templateOptions = options
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// TemplateData allows you to specify which data will be available
|
|
// when processing templates. Data is accesible by prefixing it with a "."
|
|
// like {{.MyKey}}.
|
|
func TemplateData(data interface{}) func(*Loader) error {
|
|
return func(l *Loader) error {
|
|
if !l.template {
|
|
return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateData() option`)
|
|
}
|
|
|
|
l.templateData = data
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// EnsureTestDatabase returns an error if the database name does not contains
|
|
// "test".
|
|
func (l *Loader) EnsureTestDatabase() error {
|
|
dbName, err := l.helper.databaseName(l.db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !testDatabaseRegexp.MatchString(dbName) {
|
|
return fmt.Errorf(`testfixtures: database "%s" does not appear to be a test database`, dbName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Load wipes and after load all fixtures in the database.
|
|
// if err := fixtures.Load(); err != nil {
|
|
// ...
|
|
// }
|
|
func (l *Loader) Load() error {
|
|
if !l.skipTestDatabaseCheck {
|
|
if err := l.EnsureTestDatabase(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
err := l.helper.disableReferentialIntegrity(l.db, func(tx *sql.Tx) error {
|
|
for _, file := range l.fixturesFiles {
|
|
modified, err := l.helper.isTableModified(tx, file.fileNameWithoutExtension())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !modified {
|
|
continue
|
|
}
|
|
if err := file.delete(tx, l.helper); err != nil {
|
|
return err
|
|
}
|
|
|
|
err = l.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
|
|
for j, i := range file.insertSQLs {
|
|
if _, err := tx.Exec(i.sql, i.params...); err != nil {
|
|
return &InsertError{
|
|
Err: err,
|
|
File: file.fileName,
|
|
Index: j,
|
|
SQL: i.sql,
|
|
Params: i.params,
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return l.helper.afterLoad(l.db)
|
|
}
|
|
|
|
// InsertError will be returned if any error happens on database while
|
|
// inserting the record.
|
|
type InsertError struct {
|
|
Err error
|
|
File string
|
|
Index int
|
|
SQL string
|
|
Params []interface{}
|
|
}
|
|
|
|
func (e *InsertError) Error() string {
|
|
return fmt.Sprintf(
|
|
"testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v",
|
|
e.Err,
|
|
e.File,
|
|
e.Index,
|
|
e.SQL,
|
|
e.Params,
|
|
)
|
|
}
|
|
|
|
func (l *Loader) buildInsertSQLs() error {
|
|
for _, f := range l.fixturesFiles {
|
|
var records interface{}
|
|
if err := yaml.Unmarshal(f.content, &records); err != nil {
|
|
return fmt.Errorf("testfixtures: could not unmarshal YAML: %w", err)
|
|
}
|
|
|
|
switch records := records.(type) {
|
|
case []interface{}:
|
|
f.insertSQLs = make([]insertSQL, 0, len(records))
|
|
|
|
for _, record := range records {
|
|
recordMap, ok := record.(map[interface{}]interface{})
|
|
if !ok {
|
|
return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")
|
|
}
|
|
|
|
sql, values, err := l.buildInsertSQL(f, recordMap)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
|
|
}
|
|
case map[interface{}]interface{}:
|
|
f.insertSQLs = make([]insertSQL, 0, len(records))
|
|
|
|
for _, record := range records {
|
|
recordMap, ok := record.(map[interface{}]interface{})
|
|
if !ok {
|
|
return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")
|
|
}
|
|
|
|
sql, values, err := l.buildInsertSQL(f, recordMap)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
|
|
}
|
|
default:
|
|
return fmt.Errorf("testfixtures: fixture is not a slice or map")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (f *fixtureFile) fileNameWithoutExtension() string {
|
|
return strings.Replace(f.fileName, filepath.Ext(f.fileName), "", 1)
|
|
}
|
|
|
|
func (f *fixtureFile) delete(tx *sql.Tx, h helper) error {
|
|
if _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", h.quoteKeyword(f.fileNameWithoutExtension()))); err != nil {
|
|
return fmt.Errorf(`testfixtures: could not clean table "%s": %w`, f.fileNameWithoutExtension(), err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *Loader) buildInsertSQL(f *fixtureFile, record map[interface{}]interface{}) (sqlStr string, values []interface{}, err error) {
|
|
var (
|
|
sqlColumns = make([]string, 0, len(record))
|
|
sqlValues = make([]string, 0, len(record))
|
|
i = 1
|
|
)
|
|
for key, value := range record {
|
|
keyStr, ok := key.(string)
|
|
if !ok {
|
|
err = fmt.Errorf("testfixtures: record map key is not a string")
|
|
return
|
|
}
|
|
|
|
sqlColumns = append(sqlColumns, l.helper.quoteKeyword(keyStr))
|
|
|
|
// if string, try convert to SQL or time
|
|
// if map or array, convert to json
|
|
switch v := value.(type) {
|
|
case string:
|
|
if strings.HasPrefix(v, "RAW=") {
|
|
sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
|
|
continue
|
|
}
|
|
|
|
if t, err := l.tryStrToDate(v); err == nil {
|
|
value = t
|
|
}
|
|
case []interface{}, map[interface{}]interface{}:
|
|
value = recursiveToJSON(v)
|
|
}
|
|
|
|
switch l.helper.paramType() {
|
|
case paramTypeDollar:
|
|
sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
|
|
case paramTypeQuestion:
|
|
sqlValues = append(sqlValues, "?")
|
|
case paramTypeAtSign:
|
|
sqlValues = append(sqlValues, fmt.Sprintf("@p%d", i))
|
|
}
|
|
|
|
values = append(values, value)
|
|
i++
|
|
}
|
|
|
|
sqlStr = fmt.Sprintf(
|
|
"INSERT INTO %s (%s) VALUES (%s)",
|
|
l.helper.quoteKeyword(f.fileNameWithoutExtension()),
|
|
strings.Join(sqlColumns, ", "),
|
|
strings.Join(sqlValues, ", "),
|
|
)
|
|
return
|
|
}
|
|
|
|
func (l *Loader) fixturesFromDir(dir string) ([]*fixtureFile, error) {
|
|
fileinfos, err := ioutil.ReadDir(dir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`testfixtures: could not stat directory "%s": %w`, dir, err)
|
|
}
|
|
|
|
files := make([]*fixtureFile, 0, len(fileinfos))
|
|
|
|
for _, fileinfo := range fileinfos {
|
|
fileExt := filepath.Ext(fileinfo.Name())
|
|
if !fileinfo.IsDir() && (fileExt == ".yml" || fileExt == ".yaml") {
|
|
fixture := &fixtureFile{
|
|
path: path.Join(dir, fileinfo.Name()),
|
|
fileName: fileinfo.Name(),
|
|
}
|
|
fixture.content, err = ioutil.ReadFile(fixture.path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err)
|
|
}
|
|
if err := l.processFileTemplate(fixture); err != nil {
|
|
return nil, err
|
|
}
|
|
files = append(files, fixture)
|
|
}
|
|
}
|
|
return files, nil
|
|
}
|
|
|
|
func (l *Loader) fixturesFromFiles(fileNames ...string) ([]*fixtureFile, error) {
|
|
var (
|
|
fixtureFiles = make([]*fixtureFile, 0, len(fileNames))
|
|
err error
|
|
)
|
|
|
|
for _, f := range fileNames {
|
|
fixture := &fixtureFile{
|
|
path: f,
|
|
fileName: filepath.Base(f),
|
|
}
|
|
fixture.content, err = ioutil.ReadFile(fixture.path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err)
|
|
}
|
|
if err := l.processFileTemplate(fixture); err != nil {
|
|
return nil, err
|
|
}
|
|
fixtureFiles = append(fixtureFiles, fixture)
|
|
}
|
|
|
|
return fixtureFiles, nil
|
|
}
|
|
|
|
func (l *Loader) fixturesFromPaths(paths ...string) ([]*fixtureFile, error) {
|
|
fixtureExtractor := func(p string, isDir bool) ([]*fixtureFile, error) {
|
|
if isDir {
|
|
return l.fixturesFromDir(p)
|
|
}
|
|
|
|
return l.fixturesFromFiles(p)
|
|
}
|
|
|
|
var fixtureFiles []*fixtureFile
|
|
|
|
for _, p := range paths {
|
|
f, err := os.Stat(p)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`testfixtures: could not stat path "%s": %w`, p, err)
|
|
}
|
|
|
|
fixtures, err := fixtureExtractor(p, f.IsDir())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fixtureFiles = append(fixtureFiles, fixtures...)
|
|
}
|
|
|
|
return fixtureFiles, nil
|
|
}
|
|
|
|
func (l *Loader) processFileTemplate(f *fixtureFile) error {
|
|
if !l.template {
|
|
return nil
|
|
}
|
|
|
|
t := template.New("").
|
|
Funcs(l.templateFuncs).
|
|
Delims(l.templateLeftDelim, l.templateRightDelim).
|
|
Option(l.templateOptions...)
|
|
t, err := t.Parse(string(f.content))
|
|
if err != nil {
|
|
return fmt.Errorf(`textfixtures: error on parsing template in %s: %w`, f.fileName, err)
|
|
}
|
|
|
|
var buffer bytes.Buffer
|
|
if err := t.Execute(&buffer, l.templateData); err != nil {
|
|
return fmt.Errorf(`textfixtures: error on executing template in %s: %w`, f.fileName, err)
|
|
}
|
|
|
|
f.content = buffer.Bytes()
|
|
return nil
|
|
}
|