Last active
July 28, 2024 13:03
-
-
Save mortenson/c3c1e7f2a1b10c5c3674f8c91123c3e0 to your computer and use it in GitHub Desktop.
Create sqlc migrations automatically using pg-schema-diff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"database/sql" | |
"flag" | |
"fmt" | |
"os" | |
"path/filepath" | |
"strings" | |
"time" | |
"github.com/jackc/pgx/v4" | |
"github.com/jackc/pgx/v4/stdlib" | |
_ "github.com/lib/pq" | |
"github.com/stripe/pg-schema-diff/pkg/diff" | |
"github.com/stripe/pg-schema-diff/pkg/tempdb" | |
) | |
func main() { | |
namePtr := flag.String("name", "", "The migration name") | |
dsnPtr := flag.String("dsn", "", "The connection string") | |
schemaDirPtr := flag.String("schemaDir", "", "The schema directory") | |
migrationDirPtr := flag.String("migrationDir", "", "The migration directory") | |
flag.Parse() | |
name := *namePtr | |
dsn := *dsnPtr | |
schemaDir := *schemaDirPtr | |
migrationDir := *migrationDirPtr | |
if name == "" || dsn == "" || schemaDir == "" || migrationDir == "" { | |
fmt.Println("Example usage: go run create_migration.go -dsn \"postgres://...\" -schemaDir schema -migrationDir migrations -name create_users") | |
return | |
} | |
ctx := context.Background() | |
connConfig, err := pgx.ParseConfig(dsn) | |
if err != nil { | |
panic(err) | |
} | |
tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) { | |
copiedConfig := connConfig.Copy() | |
copiedConfig.Database = dbName | |
return openDbWithPgxConfig(copiedConfig) | |
}) | |
if err != nil { | |
panic(err) | |
} | |
defer tempDbFactory.Close() | |
ddl, err := getDDLFromPath(schemaDir) | |
if err != nil { | |
panic(err) | |
} | |
connPool, err := openDbWithPgxConfig(connConfig) | |
if err != nil { | |
panic(err) | |
} | |
defer connPool.Close() | |
conn, err := connPool.Conn(ctx) | |
if err != nil { | |
panic(err) | |
} | |
defer conn.Close() | |
plan, err := diff.GeneratePlan(ctx, conn, tempDbFactory, ddl, diff.WithDoNotValidatePlan()) | |
if err != nil { | |
panic(err) | |
} | |
statements := [][]string{} | |
curr_index := 0 | |
for _, statement := range plan.Statements { | |
statementStr := statement.ToSQL() | |
if strings.Contains(statementStr, "goose") { | |
continue | |
} | |
if len(statements) == curr_index { | |
statements = append(statements, []string{}) | |
} | |
if strings.Contains(statementStr, "CONCURRENTLY") { | |
if len(statements[curr_index]) == 0 { | |
statements[curr_index] = []string{statementStr} | |
curr_index += 1 | |
} else { | |
statements = append(statements, []string{statementStr}) | |
curr_index += 2 | |
} | |
} else { | |
statements[curr_index] = append(statements[curr_index], statementStr) | |
} | |
fmt.Printf("[STATEMENT] %s\n", statementStr) | |
for _, hazard := range statement.Hazards { | |
fmt.Printf("\033[31m[WARNING] %s\033[0m\n", hazard.String()) | |
} | |
} | |
now := time.Now() | |
for i, statementBlock := range statements { | |
contents := strings.Join(statementBlock, "\n\n") | |
migrationStr := fmt.Sprintf("-- +goose Up\n-- +goose StatementBegin\n%s\n-- +goose StatementEnd\n", contents) | |
if strings.Contains(contents, "CONCURRENTLY") { | |
migrationStr = "-- +goose NO TRANSACTION\n" + migrationStr | |
} | |
var filename string | |
if len(statements) > 1 { | |
now = now.Add(time.Second) | |
filename = fmt.Sprintf("%s_%s_%02d", now.Format("20060102150405"), name, i+1) | |
} else { | |
filename = fmt.Sprintf("%s_%s", now.Format("20060102150405"), name) | |
} | |
filename += ".sql" | |
filePath := filepath.Join(migrationDir, filename) | |
err = os.WriteFile(filePath, []byte(migrationStr), 0644) | |
if err != nil { | |
panic(err) | |
} | |
fmt.Printf("\033[32mCreated %s\033[0m\n", filePath) | |
} | |
} | |
func getDDLFromPath(path string) ([]string, error) { | |
fileEntries, err := os.ReadDir(path) | |
if err != nil { | |
return nil, err | |
} | |
var ddl []string | |
for _, entry := range fileEntries { | |
if filepath.Ext(entry.Name()) == ".sql" { | |
if stmts, err := os.ReadFile(filepath.Join(path, entry.Name())); err != nil { | |
return nil, err | |
} else { | |
ddl = append(ddl, string(stmts)) | |
} | |
} | |
} | |
return ddl, nil | |
} | |
func openDbWithPgxConfig(config *pgx.ConnConfig) (*sql.DB, error) { | |
connPool := stdlib.OpenDB(*config) | |
if err := connPool.Ping(); err != nil { | |
connPool.Close() | |
return nil, err | |
} | |
return connPool, nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment