Платформа ЦРНП "Мирокод" для разработки проектов
https://git.mirocod.ru
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
211 lines
4.5 KiB
211 lines
4.5 KiB
package testfixtures |
|
|
|
import ( |
|
"database/sql" |
|
"fmt" |
|
) |
|
|
|
// PostgreSQL is the PG helper for this package |
|
type PostgreSQL struct { |
|
baseHelper |
|
|
|
// UseAlterConstraint If true, the contraint disabling will do |
|
// using ALTER CONTRAINT sintax, only allowed in PG >= 9.4. |
|
// If false, the constraint disabling will use DISABLE TRIGGER ALL, |
|
// which requires SUPERUSER privileges. |
|
UseAlterConstraint bool |
|
|
|
tables []string |
|
sequences []string |
|
nonDeferrableConstraints []pgConstraint |
|
} |
|
|
|
type pgConstraint struct { |
|
tableName string |
|
constraintName string |
|
} |
|
|
|
func (h *PostgreSQL) init(db *sql.DB) error { |
|
var err error |
|
|
|
h.tables, err = h.getTables(db) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
h.sequences, err = h.getSequences(db) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
h.nonDeferrableConstraints, err = h.getNonDeferrableConstraints(db) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (*PostgreSQL) paramType() int { |
|
return paramTypeDollar |
|
} |
|
|
|
func (*PostgreSQL) databaseName(db *sql.DB) (dbName string) { |
|
db.QueryRow("SELECT current_database()").Scan(&dbName) |
|
return |
|
} |
|
|
|
func (h *PostgreSQL) getTables(db *sql.DB) ([]string, error) { |
|
var tables []string |
|
|
|
sql := ` |
|
SELECT table_name |
|
FROM information_schema.tables |
|
WHERE table_schema = 'public' |
|
AND table_type = 'BASE TABLE'; |
|
` |
|
rows, err := db.Query(sql) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
defer rows.Close() |
|
for rows.Next() { |
|
var table string |
|
rows.Scan(&table) |
|
tables = append(tables, table) |
|
} |
|
return tables, nil |
|
} |
|
|
|
func (h *PostgreSQL) getSequences(db *sql.DB) ([]string, error) { |
|
var sequences []string |
|
|
|
sql := "SELECT relname FROM pg_class WHERE relkind = 'S'" |
|
rows, err := db.Query(sql) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
defer rows.Close() |
|
for rows.Next() { |
|
var sequence string |
|
if err = rows.Scan(&sequence); err != nil { |
|
return nil, err |
|
} |
|
sequences = append(sequences, sequence) |
|
} |
|
return sequences, nil |
|
} |
|
|
|
func (*PostgreSQL) getNonDeferrableConstraints(db *sql.DB) ([]pgConstraint, error) { |
|
var constraints []pgConstraint |
|
|
|
sql := ` |
|
SELECT table_name, constraint_name |
|
FROM information_schema.table_constraints |
|
WHERE constraint_type = 'FOREIGN KEY' |
|
AND is_deferrable = 'NO'` |
|
rows, err := db.Query(sql) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
defer rows.Close() |
|
for rows.Next() { |
|
var constraint pgConstraint |
|
err = rows.Scan(&constraint.tableName, &constraint.constraintName) |
|
if err != nil { |
|
return nil, err |
|
} |
|
constraints = append(constraints, constraint) |
|
} |
|
return constraints, nil |
|
} |
|
|
|
func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) error { |
|
defer func() { |
|
// re-enable triggers after load |
|
var sql string |
|
for _, table := range h.tables { |
|
sql += fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL;", h.quoteKeyword(table)) |
|
} |
|
db.Exec(sql) |
|
}() |
|
|
|
tx, err := db.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
var sql string |
|
for _, table := range h.tables { |
|
sql += fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER ALL;", h.quoteKeyword(table)) |
|
} |
|
if _, err = tx.Exec(sql); err != nil { |
|
return err |
|
} |
|
|
|
if err = loadFn(tx); err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
|
|
return tx.Commit() |
|
} |
|
|
|
func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) error { |
|
defer func() { |
|
// ensure constraint being not deferrable again after load |
|
var sql string |
|
for _, constraint := range h.nonDeferrableConstraints { |
|
sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s NOT DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName)) |
|
} |
|
db.Exec(sql) |
|
}() |
|
|
|
var sql string |
|
for _, constraint := range h.nonDeferrableConstraints { |
|
sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName)) |
|
} |
|
if _, err := db.Exec(sql); err != nil { |
|
return err |
|
} |
|
|
|
tx, err := db.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if _, err = tx.Exec("SET CONSTRAINTS ALL DEFERRED"); err != nil { |
|
return nil |
|
} |
|
|
|
if err = loadFn(tx); err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
|
|
return tx.Commit() |
|
} |
|
|
|
func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error { |
|
// ensure sequences being reset after load |
|
defer h.resetSequences(db) |
|
|
|
if h.UseAlterConstraint { |
|
return h.makeConstraintsDeferrable(db, loadFn) |
|
} else { |
|
return h.disableTriggers(db, loadFn) |
|
} |
|
} |
|
|
|
func (h *PostgreSQL) resetSequences(db *sql.DB) error { |
|
for _, sequence := range h.sequences { |
|
_, err := db.Exec(fmt.Sprintf("SELECT SETVAL('%s', %d)", sequence, resetSequencesTo)) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
return nil |
|
}
|
|
|