diff --git a/internal/pgtools/quote.go b/internal/pgtools/quote.go index ba2e084..9f8cbec 100644 --- a/internal/pgtools/quote.go +++ b/internal/pgtools/quote.go @@ -102,3 +102,16 @@ func requiresQuoting(identifier string) bool { } return false } + +// ParseTableName splits a table name into schema and table parts. +// +// If tableName is unqualified (no "."), the schema defaults to "public". +// If tableName is qualified, everything before the first "." is returned as +// schema and everything after it is returned as tablename. +func ParseTableName(tableName string) (schema string, tablename string) { + schema, tablename, found := strings.Cut(tableName, ".") + if !found { + return "public", tableName + } + return schema, tablename +} diff --git a/internal/pgtools/quote_test.go b/internal/pgtools/quote_test.go index c73d0d3..d71ff16 100644 --- a/internal/pgtools/quote_test.go +++ b/internal/pgtools/quote_test.go @@ -51,3 +51,22 @@ func TestIdentifierGarbageInputs(t *testing.T) { check.Equal(t, `"""schema"""."""tablename"""`, pgtools.Identifier(`"schema"."tablename"`)) check.Equal(t, `"""schema"."tablename"""`, pgtools.Identifier(`"schema.tablename"`)) } + +func TestParseTableName(t *testing.T) { + t.Parallel() + schema, tablename := pgtools.ParseTableName("users") + check.Equal(t, "public", schema) + check.Equal(t, "users", tablename) + + schema, tablename = pgtools.ParseTableName("custom.users") + check.Equal(t, "custom", schema) + check.Equal(t, "users", tablename) + + schema, tablename = pgtools.ParseTableName(".users") + check.Equal(t, "", schema) + check.Equal(t, "users", tablename) + + schema, tablename = pgtools.ParseTableName("a.b.c") + check.Equal(t, "a", schema) + check.Equal(t, "b.c", tablename) +} diff --git a/migrator.go b/migrator.go index ea9d133..22f8dc9 100644 --- a/migrator.go +++ b/migrator.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "strings" "time" "github.com/peterldowns/pgmigrate/internal/multierr" @@ -142,6 +141,15 @@ func (m *Migrator) Migrate(ctx context.Context, db *sql.DB) ([]VerificationError // ensureMigrationsTable will create the migrations table if it does not exist. func (m *Migrator) ensureMigrationsTable(ctx context.Context, db Executor) error { m.info(ctx, "ensuring migrations table exists", LogField{Key: "table_name", Value: m.TableName}) + schema, _ := pgtools.ParseTableName(m.TableName) + if schema != "" { + query := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, pgtools.Identifier(schema)) + m.debug(ctx, query) + _, err := db.ExecContext(ctx, query) + if err != nil { + return fmt.Errorf("ensureMigrationsTable/create schema: %w", err) + } + } query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id TEXT PRIMARY KEY, @@ -161,15 +169,7 @@ func (m *Migrator) ensureMigrationsTable(ctx context.Context, db Executor) error // hasMigrationsTable returns true if the migrations table exists, false // otherwise. func (m *Migrator) hasMigrationsTable(ctx context.Context, db Executor) (bool, error) { - parts := strings.SplitN(m.TableName, ".", 2) - var schema, tablename string - if len(parts) == 1 { - schema = "public" - tablename = parts[0] - } else { - schema = parts[0] - tablename = parts[1] - } + schema, tablename := pgtools.ParseTableName(m.TableName) query := fmt.Sprintf(` SELECT EXISTS ( SELECT FROM pg_tables diff --git a/migrator_test.go b/migrator_test.go index c2141dc..582b674 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -31,6 +31,36 @@ func TestApplyNoMigrationsSucceeds(t *testing.T) { assert.Nil(t, err) } +func TestCreateMigrationsTableInMissingSchema(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := pgmigrate.NewTestLogger(t) + err := withdb.WithDB(ctx, "pgx", func(db *sql.DB) error { + migrations := []pgmigrate.Migration{ + { + ID: "0001_initial", + SQL: "CREATE TABLE users (id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY);", + }, + } + migrator := pgmigrate.NewMigrator(migrations) + migrator.Logger = logger + migrator.TableName = "new_schema.pgmigrate_migrations" + + verrs, err := migrator.Migrate(ctx, db) + assert.Nil(t, err) + assert.Equal(t, nil, verrs) + + tables, err := schema.LoadTables(schema.DumpConfig{ + SchemaNames: []string{"new_schema"}, + }, db) + assert.Nil(t, err) + assert.Equal(t, 1, len(tables)) + check.Equal(t, "pgmigrate_migrations", tables[0].Name) + return nil + }) + assert.Nil(t, err) +} + func TestApplyOneMigrationSucceeds(t *testing.T) { t.Parallel() ctx := context.Background()