agola/internal/testutil/db.go

806 lines
20 KiB
Go

package testutil
import (
"bufio"
"bytes"
stdcmp "cmp"
"context"
stdsql "database/sql"
"encoding/json"
"fmt"
"io"
"math/rand"
"os"
"path/filepath"
"slices"
"strconv"
"testing"
"time"
atlasschema "ariga.io/atlas/sql/schema"
atlassqlclient "ariga.io/atlas/sql/sqlclient"
"github.com/google/go-cmp/cmp/cmpopts"
sq "github.com/huandu/go-sqlbuilder"
"github.com/rs/zerolog"
"github.com/sorintlab/errors"
"gotest.tools/assert"
"gotest.tools/assert/cmp"
"muzzammil.xyz/jsonc"
"agola.io/agola/internal/sqlg"
"agola.io/agola/internal/sqlg/lock"
"agola.io/agola/internal/sqlg/manager"
"agola.io/agola/internal/sqlg/sql"
_ "ariga.io/atlas/sql/postgres"
_ "ariga.io/atlas/sql/postgres/postgrescheck"
_ "ariga.io/atlas/sql/sqlite"
_ "ariga.io/atlas/sql/sqlite/sqlitecheck"
)
func DBType(t *testing.T) sql.Type {
var dbType sql.Type
switch os.Getenv("DB_TYPE") {
case "":
fallthrough
case "sqlite3":
dbType = sql.Sqlite3
case "postgres":
dbType = sql.Postgres
default:
t.Fatalf("unknown db type")
}
return dbType
}
func CreateDB(t *testing.T, log zerolog.Logger, ctx context.Context, dir string) (*sql.DB, lock.LockFactory, string) {
dbType := DBType(t)
return CreateDBWithType(t, log, ctx, dir, dbType)
}
func CreateDBWithType(t *testing.T, log zerolog.Logger, ctx context.Context, dir string, dbType sql.Type) (*sql.DB, lock.LockFactory, string) {
pgConnString := os.Getenv("PG_CONNSTRING")
var err error
var sdb *sql.DB
var connString string
switch dbType {
case sql.Sqlite3:
dbName := "testdb" + strconv.FormatUint(uint64(rand.Uint32()), 10)
connString = filepath.Join(dir, dbName)
sdb, err = sql.NewDB("sqlite3", connString)
NilError(t, err)
case sql.Postgres:
dbName := "testdb" + strconv.FormatUint(uint64(rand.Uint32()), 10)
connString = fmt.Sprintf(pgConnString, dbName)
pgdb, err := stdsql.Open("postgres", fmt.Sprintf(pgConnString, "postgres"))
NilError(t, err)
_, err = pgdb.Exec(fmt.Sprintf("drop database if exists %s", dbName))
NilError(t, err)
_, err = pgdb.Exec(fmt.Sprintf("create database %s", dbName))
NilError(t, err)
sdb, err = sql.NewDB("postgres", connString)
NilError(t, err)
default:
t.Fatalf("unknown db type")
}
var lf lock.LockFactory
switch dbType {
case sql.Sqlite3:
ll := lock.NewLocalLocks()
lf = lock.NewLocalLockFactory(ll)
case sql.Postgres:
lf = lock.NewPGLockFactory(sdb)
default:
t.Fatalf("unknown type %q", dbType)
}
return sdb, lf, connString
}
type DBContext struct {
D manager.DB
DBM *manager.DBManager
LF lock.LockFactory
DBConnString string
Schema []TableInfo
}
func (c *DBContext) AtlasConnString() string {
switch c.D.DBType() {
case sql.Postgres:
return c.DBConnString
case sql.Sqlite3:
return fmt.Sprintf("sqlite://%s", c.DBConnString)
}
return ""
}
func (c *DBContext) Tables() []string {
tables := []string{}
for _, table := range c.Schema {
tables = append(tables, table.Name)
}
return tables
}
func (c *DBContext) Table(tableName string) (TableInfo, bool) {
for _, table := range c.Schema {
if table.Name == tableName {
return table, true
}
}
return TableInfo{}, false
}
func (c *DBContext) Column(tableName, colName string) (ColInfo, bool) {
ti, ok := c.Table(tableName)
if !ok {
return ColInfo{}, false
}
for _, ci := range ti.Columns {
if ci.Name == colName {
return ci, true
}
}
return ColInfo{}, false
}
type ColType int
const (
ColTypeString ColType = iota
ColTypeBool
ColTypeInt
ColTypeFloat
ColTypeTime
ColTypeDuration
ColTypeJSON
ColTypeByteArray
)
func (c *DBContext) ColumnType(tableName, colName string) (ColType, error) {
col, ok := c.Column(tableName, colName)
if !ok {
return 0, errors.Errorf("unknown column %q.%q", tableName, colName)
}
switch col.Type {
case "string":
return ColTypeString, nil
case "bool":
return ColTypeBool, nil
case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "byte", "rune":
return ColTypeInt, nil
case "float32", "float64":
return ColTypeFloat, nil
case "time.Time":
return ColTypeTime, nil
case "time.Duration":
return ColTypeDuration, nil
case "json":
return ColTypeJSON, nil
case "[]byte":
return ColTypeByteArray, nil
default:
panic(fmt.Errorf("unknown col type: %q", col.Type))
}
}
type importData struct {
Table string
Values map[string]json.RawMessage
}
type exportData struct {
Table string
Values map[string]any
}
func (c *DBContext) sqFlavor() sq.Flavor {
switch c.D.DBType() {
case sql.Postgres:
return sq.PostgreSQL
case sql.Sqlite3:
return sq.SQLite
}
return sq.PostgreSQL
}
func (c *DBContext) exec(tx *sql.Tx, rq sq.Builder) (stdsql.Result, error) {
q, args := rq.BuildWithFlavor(c.sqFlavor())
r, err := tx.Exec(q, args...)
return r, errors.WithStack(err)
}
func (c *DBContext) query(tx *sql.Tx, rq sq.Builder) (*stdsql.Rows, error) {
q, args := rq.BuildWithFlavor(c.sqFlavor())
r, err := tx.Query(q, args...)
return r, errors.WithStack(err)
}
func (c *DBContext) Import(ctx context.Context, r io.Reader, createData *CreateData) error {
br := bufio.NewReader(r)
dec := json.NewDecoder(br)
err := c.D.Do(ctx, func(tx *sql.Tx) error {
for {
var data importData
err := dec.Decode(&data)
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return errors.WithStack(err)
}
tableName := data.Table
for colName := range data.Values {
// check if column exists in schema
if _, err := c.ColumnType(tableName, colName); err != nil {
return errors.WithStack(err)
}
}
table, ok := c.Table(tableName)
if !ok {
return errors.Errorf("unknown table %q", tableName)
}
cols := []string{}
values := []any{}
for _, col := range table.Columns {
colName := col.Name
cols = append(cols, colName)
if colName == "revision" {
values = append(values, 1)
continue
}
v, hasValue := data.Values[colName]
colType, err := c.ColumnType(tableName, colName)
if err != nil {
return errors.WithStack(err)
}
switch colType {
case ColTypeString:
if !hasValue {
values = append(values, "")
} else {
var s string
if err := json.Unmarshal(v, &s); err != nil {
return errors.WithStack(err)
}
values = append(values, s)
}
case ColTypeInt:
if !hasValue {
values = append(values, "")
} else {
var n int64
if err := json.Unmarshal(v, &n); err != nil {
return errors.WithStack(err)
}
values = append(values, n)
}
case ColTypeFloat:
if !hasValue {
values = append(values, "")
} else {
var n float64
if err := json.Unmarshal(v, &n); err != nil {
return errors.WithStack(err)
}
values = append(values, n)
}
case ColTypeBool:
if !hasValue {
values = append(values, false)
} else {
values = append(values, v)
}
case ColTypeTime:
if !hasValue {
values = append(values, time.Time{})
} else {
t := time.Time{}
if err := t.UnmarshalJSON(v); err != nil {
return errors.WithStack(err)
}
values = append(values, t)
}
case ColTypeDuration:
if !hasValue {
values = append(values, 0)
} else {
var d int64
if err := json.Unmarshal(v, &d); err != nil {
return errors.WithStack(err)
}
values = append(values, d)
}
case ColTypeJSON:
if !hasValue {
v = json.RawMessage("null")
}
vj, err := json.Marshal(v)
if err != nil {
return errors.WithStack(err)
}
values = append(values, vj)
case ColTypeByteArray:
if !hasValue {
values = append(values, 0)
} else {
var b []byte
if err := json.Unmarshal(v, &b); err != nil {
return errors.WithStack(err)
}
values = append(values, b)
}
default:
values = append(values, v)
}
}
q := sq.NewInsertBuilder()
q.InsertInto(tableName).Cols(cols...).Values(values...)
if _, err := c.exec(tx, q); err != nil {
return errors.WithStack(err)
}
}
// Populate sequences
for _, seq := range createData.Sequences {
switch c.D.DBType() {
case sql.Postgres:
q := fmt.Sprintf("SELECT setval('%s', (SELECT COALESCE(MAX(%s), 1) FROM %s));", seq.Name, seq.Column, seq.Table)
if _, err := tx.Exec(q); err != nil {
return errors.Wrapf(err, "failed to update sequence %s", seq.Name)
}
case sql.Sqlite3:
q := fmt.Sprintf("INSERT INTO sequence_t (name, value) VALUES ('%s', (SELECT COALESCE(MAX(%s), 1) FROM %s));", seq.Name, seq.Column, seq.Table)
if _, err := tx.Exec(q); err != nil {
return errors.Wrap(err, "failed to update sequence for run_sequence_seq")
}
}
}
return nil
})
if err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *DBContext) Export(ctx context.Context, tables []string, w io.Writer) error {
bw := bufio.NewWriter(w)
e := json.NewEncoder(bw)
err := c.D.Do(ctx, func(tx *sql.Tx) error {
for _, table := range tables {
q := sq.NewSelectBuilder()
q.Select("*")
q.From(table)
q.OrderBy("id")
rows, err := c.query(tx, q)
if err != nil {
return errors.WithStack(err)
}
columns, err := rows.Columns()
if err != nil {
return errors.WithStack(err)
}
cols := make([]any, len(columns))
colsPtr := make([]any, len(columns))
for i := range cols {
colsPtr[i] = &cols[i]
}
for rows.Next() {
err := rows.Scan(colsPtr...)
if err != nil {
rows.Close()
return errors.WithStack(err)
}
var data exportData
data.Table = table
data.Values = make(map[string]any)
for i, col := range columns {
v := cols[i]
colType, err := c.ColumnType(data.Table, col)
if err != nil {
return errors.WithStack(err)
}
switch colType {
case ColTypeBool:
switch vt := v.(type) {
case bool:
data.Values[col] = vt
case int64:
if vt != 0 && vt != 1 {
return errors.Errorf("unknown type int64 value %d for bool column type", vt)
}
data.Values[col] = vt != 0
case []uint8:
bv, err := strconv.ParseBool(string(vt))
if err != nil {
return errors.WithStack(err)
}
data.Values[col] = bv
default:
return errors.Errorf("unknown type %T for bool column type", v)
}
case ColTypeJSON:
var vj any
if err := json.Unmarshal(v.([]byte), &vj); err != nil {
return errors.WithStack(err)
}
data.Values[col] = vj
default:
data.Values[col] = v
}
}
if err := e.Encode(data); err != nil {
return errors.WithStack(err)
}
}
if err := rows.Err(); err != nil {
return errors.WithStack(err)
}
}
return nil
})
if err != nil {
return errors.WithStack(err)
}
return errors.WithStack(bw.Flush())
}
type SetupDBFn func(ctx context.Context, t *testing.T, dir string) *DBContext
type CreateData struct {
DDL DDL `json:"ddl"`
Sequences []Sequence `json:"sequences"`
Tables []TableInfo `json:"tables"`
}
type DDL struct {
Postgres []string `json:"postgres"`
Sqlite3 []string `json:"sqlite3"`
}
type TableInfo struct {
Name string `json:"name"`
Columns []ColInfo `json:"columns"`
}
type ColInfo struct {
Name string `json:"name"`
Type string `json:"type"`
Nullable bool `json:"nullable"`
}
type Sequence struct {
Name string `json:"name"`
Table string `json:"table"`
Column string `json:"column"`
}
type DataFixtures map[uint]string
func TestCreate(t *testing.T, lastVersion uint, dataFixtures DataFixtures, setupDBFn SetupDBFn) {
startVersion := uint(1)
for createVersion := startVersion; createVersion <= lastVersion; createVersion++ {
t.Run(fmt.Sprintf("create db at version %d", createVersion), func(t *testing.T) {
dir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dc := setupDBFn(ctx, t, dir)
dataFixtureFile, ok := dataFixtures[createVersion]
if !ok {
t.Fatalf("missing fixture for db version %d", createVersion)
}
dataFixture, err := os.ReadFile(filepath.Join("fixtures", "migrate", dataFixtureFile))
NilError(t, err)
dataFixture = jsonc.ToJSON(dataFixture)
createFixtureFile := fmt.Sprintf("v%d.json", createVersion)
createFixture, err := os.ReadFile(filepath.Join("fixtures", "create", createFixtureFile))
NilError(t, err)
var createData *CreateData
err = json.Unmarshal(createFixture, &createData)
NilError(t, err)
dc.Schema = createData.Tables
ddl := createData.DDL
var stmts []string
switch dc.D.DBType() {
case sql.Postgres:
stmts = ddl.Postgres
case sql.Sqlite3:
stmts = ddl.Sqlite3
}
err = dc.DBM.Setup(ctx)
NilError(t, err, "setup db error")
err = dc.DBM.Create(ctx, stmts, createVersion)
NilError(t, err)
err = dc.Import(ctx, bytes.NewBuffer(dataFixture), createData)
NilError(t, err)
})
}
}
func TestMigrate(t *testing.T, lastVersion uint, dataFixtures DataFixtures, setupDBFn SetupDBFn) {
startVersion := uint(1)
// check all versions are available
for createVersion := startVersion; createVersion < lastVersion; createVersion++ {
if _, ok := dataFixtures[createVersion]; !ok {
t.Fatalf("missing test import fixtures for version %d", createVersion)
}
}
for createVersion := startVersion; createVersion < lastVersion; createVersion++ {
for migrateVersion := createVersion + 1; migrateVersion <= lastVersion; migrateVersion++ {
t.Run(fmt.Sprintf("migrate db from version %d to version %d", createVersion, migrateVersion), func(t *testing.T) {
dir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// create db at migrate version. For diff from migrated version.
createDC := setupDBFn(ctx, t, dir)
dataFixtureFileCreate, ok := dataFixtures[migrateVersion]
if !ok {
t.Fatalf("missing data fixture for db version %d", migrateVersion)
}
dataFixtureCreate, err := os.ReadFile(filepath.Join("fixtures", "migrate", dataFixtureFileCreate))
NilError(t, err)
dataFixtureCreate = jsonc.ToJSON(dataFixtureCreate)
createFixtureFileCreate := fmt.Sprintf("v%d.json", migrateVersion)
createFixtureCreate, err := os.ReadFile(filepath.Join("fixtures", "create", createFixtureFileCreate))
NilError(t, err)
var createDataCreate *CreateData
err = json.Unmarshal(createFixtureCreate, &createDataCreate)
NilError(t, err)
createDC.Schema = createDataCreate.Tables
createDDL := createDataCreate.DDL
var createStmts []string
switch createDC.D.DBType() {
case sql.Postgres:
createStmts = createDDL.Postgres
case sql.Sqlite3:
createStmts = createDDL.Sqlite3
}
err = createDC.DBM.Setup(ctx)
NilError(t, err, "setup db error")
err = createDC.DBM.Create(ctx, createStmts, migrateVersion)
NilError(t, err)
err = createDC.Import(ctx, bytes.NewBuffer(dataFixtureCreate), createDataCreate)
NilError(t, err)
// create db at create version to be migrated.
dc := setupDBFn(ctx, t, dir)
dataFixtureFile, ok := dataFixtures[createVersion]
if !ok {
t.Fatalf("missing fixture for db version %d", createVersion)
}
dataFixture, err := os.ReadFile(filepath.Join("fixtures", "migrate", dataFixtureFile))
NilError(t, err)
dataFixture = jsonc.ToJSON(dataFixture)
createFixtureFile := fmt.Sprintf("v%d.json", createVersion)
createFixture, err := os.ReadFile(filepath.Join("fixtures", "create", createFixtureFile))
NilError(t, err)
var createData *CreateData
err = json.Unmarshal(createFixture, &createData)
NilError(t, err)
dc.Schema = createData.Tables
ddl := createData.DDL
var stmts []string
switch dc.D.DBType() {
case sql.Postgres:
stmts = ddl.Postgres
case sql.Sqlite3:
stmts = ddl.Sqlite3
}
err = dc.DBM.Setup(ctx)
NilError(t, err, "setup db error")
err = dc.DBM.Create(ctx, stmts, createVersion)
NilError(t, err)
err = dc.Import(ctx, bytes.NewBuffer(dataFixture), createData)
NilError(t, err)
err = dc.DBM.MigrateToVersion(ctx, migrateVersion)
NilError(t, err)
// Diff created and migrated schema
createAtlasClient, err := atlassqlclient.Open(ctx, createDC.AtlasConnString())
NilError(t, err)
atlasClient, err := atlassqlclient.Open(ctx, dc.AtlasConnString())
NilError(t, err)
createSchema, err := createAtlasClient.InspectSchema(ctx, "", nil)
NilError(t, err)
schema, err := atlasClient.InspectSchema(ctx, "", nil)
NilError(t, err)
diff, err := atlasClient.SchemaDiff(createSchema, schema)
NilError(t, err)
assert.Assert(t, cmp.Len(diff, 0), "schema of db created at version %d and db migrated from version %d to version %d is different:\n %s", migrateVersion, createVersion, migrateVersion, diff)
// set the db schema at the migrated version.
dc.Schema = createDataCreate.Tables
createExport := &bytes.Buffer{}
export := &bytes.Buffer{}
tableNames := func(schema *atlasschema.Schema) []string {
tableNames := []string{}
for _, table := range schema.Tables {
if table.Name == "dbversion" || table.Name == "sequence_t" {
continue
}
tableNames = append(tableNames, table.Name)
}
slices.Sort(tableNames)
return tableNames
}
err = createDC.Export(ctx, tableNames(createSchema), createExport)
NilError(t, err)
err = dc.Export(ctx, tableNames(schema), export)
NilError(t, err)
// Diff database data
// Since postgres has microsecond time precision while go has nanosecond time precision we should check times with a microsecond margin
assert.DeepEqual(t, createExport.Bytes(), export.Bytes(), cmpopts.EquateApproxTime(1*time.Microsecond))
})
}
}
}
func decodeExport(t *testing.T, d manager.DB, export []byte) []any {
dec := json.NewDecoder(bytes.NewReader(export))
objs := []any{}
for {
var jobj json.RawMessage
err := dec.Decode(&jobj)
if errors.Is(err, io.EOF) {
break
}
NilError(t, err)
obj, err := d.UnmarshalExportObject(jobj)
NilError(t, err)
objs = append(objs, obj)
}
// sort objects by id
slices.SortFunc(objs, func(a, b any) int {
ao := a.(sqlg.Object)
bo := b.(sqlg.Object)
return stdcmp.Compare(ao.GetID(), bo.GetID())
})
return objs
}
func TestImportExport(t *testing.T, importFixtureFile string, setupDBFn SetupDBFn, seqs map[string]uint64) {
dir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dc := setupDBFn(ctx, t, dir)
fixture, err := os.ReadFile(filepath.Join("fixtures", "import", importFixtureFile))
NilError(t, err)
stmts := dc.D.DDL()
err = dc.DBM.Create(ctx, stmts, dc.D.Version())
NilError(t, err)
err = dc.DBM.Import(ctx, bytes.NewBuffer(fixture))
NilError(t, err)
// check sequences
curSeqs := map[string]uint64{}
err = dc.D.Do(ctx, func(tx *sql.Tx) error {
for _, seq := range dc.D.Sequences() {
var err error
seqValue, err := dc.D.GetSequence(tx, seq.Name)
if err != nil {
return errors.WithStack(err)
}
curSeqs[seq.Name] = seqValue
}
return nil
})
NilError(t, err)
assert.DeepEqual(t, curSeqs, seqs)
export := &bytes.Buffer{}
err = dc.DBM.Export(ctx, sqlg.ObjectNames(dc.D.ObjectsInfo()), export)
NilError(t, err)
exportMap := decodeExport(t, dc.D, export.Bytes())
fixturesMap := decodeExport(t, dc.D, fixture)
// Since postgres has microsecond time precision while go has nanosecond time precision we should check times with a microsecond margin
assert.DeepEqual(t, fixturesMap, exportMap, cmpopts.EquateApproxTime(1*time.Microsecond))
}