Placeholder
Last active
June 22, 2024 06:05
-
-
Save dhermes/cc150b24cd156b6352259744a2a90645 to your computer and use it in GitHub Desktop.
[2024-06-21] Guts of `sql.Tx` <-> `pgx.Tx`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module gist.github.com/dhermes/cc150b24cd156b6352259744a2a90645 | |
go 1.22.4 | |
require github.com/jackc/pgx/v5 v5.6.0 | |
require ( | |
github.com/jackc/pgpassfile v1.0.0 // indirect | |
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | |
github.com/jackc/puddle/v2 v2.2.1 // indirect | |
golang.org/x/crypto v0.17.0 // indirect | |
golang.org/x/sync v0.1.0 // indirect | |
golang.org/x/text v0.14.0 // indirect | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | |
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | |
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | |
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= | |
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= | |
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= | |
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= | |
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= | |
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= | |
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= | |
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= | |
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | |
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | |
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | |
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= | |
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | |
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= | |
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= | |
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= | |
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= | |
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= | |
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= | |
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | |
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | |
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | |
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | |
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"database/sql" | |
"errors" | |
"fmt" | |
"os" | |
"reflect" | |
"unsafe" | |
pgx "github.com/jackc/pgx/v5" | |
"github.com/jackc/pgx/v5/pgxpool" | |
_ "github.com/jackc/pgx/v5/stdlib" | |
) | |
// wrapTx is vendored in from the pgx source: | |
// https://github.com/jackc/pgx/blob/v5.6.0/stdlib/sql.go#L874-L877 | |
type wrapTx struct { | |
ctx context.Context | |
tx pgx.Tx | |
} | |
func initStdlibPool(ctx context.Context, connectionURL string) (*sql.DB, error) { | |
pool, err := sql.Open("pgx", connectionURL) | |
if err != nil { | |
return nil, err | |
} | |
err = pool.PingContext(ctx) | |
if err != nil { | |
return nil, err | |
} | |
return pool, nil | |
} | |
func finalizeStdlibTx(tx *sql.Tx, err error) error { | |
if tx == nil { | |
return err | |
} | |
rollbackErr := tx.Rollback() | |
if rollbackErr == nil || rollbackErr == sql.ErrTxDone { | |
return err | |
} | |
return errors.Join(err, rollbackErr) | |
} | |
func finalizeStdlibPool(pool *sql.DB, err error) error { | |
if pool == nil { | |
return err | |
} | |
closeErr := pool.Close() | |
return errors.Join(err, closeErr) | |
} | |
func finalizePgxPool(pool *pgxpool.Pool, err error) error { | |
if pool != nil { | |
pool.Close() | |
} | |
return err | |
} | |
func copyReflectPointer(v reflect.Value) reflect.Value { | |
// H/T: https://stackoverflow.com/a/43918797/1068170 | |
return reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem() | |
} | |
func copyReflectStruct(v reflect.Value) (reflect.Value, error) { | |
// H/T: https://stackoverflow.com/a/43918797/1068170 | |
vt := v.Type() | |
v2 := reflect.New(vt).Elem() | |
if !v2.CanSet() { | |
return v2, fmt.Errorf("cannot set copy of struct value; (%s).%s", vt.PkgPath(), vt.Name()) | |
} | |
v2.Set(v) | |
return v2, nil | |
} | |
func unsafeConvertWrapTx(wrapTxValue reflect.Value) (*wrapTx, error) { | |
wrapTxValue, err := copyReflectStruct(wrapTxValue) | |
if err != nil { | |
return nil, err | |
} | |
wrapTxType := wrapTxValue.Type() | |
if wrapTxType.PkgPath() != "github.com/jackc/pgx/v5/stdlib" || wrapTxType.Name() != "wrapTx" { | |
return nil, fmt.Errorf("unexpected type; (%s).%s", wrapTxType.PkgPath(), wrapTxType.Name()) | |
} | |
if !wrapTxValue.CanAddr() { | |
return nil, errors.New("cannot address wrapTx") | |
} | |
p := unsafe.Pointer(wrapTxValue.UnsafeAddr()) | |
wt := (*wrapTx)(p) | |
return wt, nil | |
} | |
func dissectSQLTx(tx *sql.Tx) (pgxTX pgx.Tx, err error) { | |
// First get a reflect Value for the underlying `sql.Tx` valie | |
txValue := reflect.ValueOf(tx).Elem() | |
// Then grab the unexported `txi`, ensure it's addressable and copy | |
// it onto a `Value` that we can interface with. | |
txiValue := txValue.FieldByName("txi") | |
if !txiValue.CanAddr() { | |
return nil, fmt.Errorf("cannot address txi; (%s).%s", txiValue.Type().PkgPath(), txiValue.Type().Name()) | |
} | |
txiValue = copyReflectPointer(txiValue) | |
// Resolve the `driver.Tx` interface (`txi` field) to an actual underlying | |
// value. | |
if !txiValue.CanInterface() { | |
return nil, fmt.Errorf("cannot interface txi; (%s).%s", txiValue.Type().PkgPath(), txiValue.Type().Name()) | |
} | |
wrapTxValue := reflect.ValueOf(txiValue.Interface()) | |
wt, err := unsafeConvertWrapTx(wrapTxValue) | |
if err != nil { | |
return nil, err | |
} | |
return wt.tx, nil | |
} | |
func initPgxPool(ctx context.Context, connectionURL string) (*pgxpool.Pool, error) { | |
config, err := pgxpool.ParseConfig(connectionURL) | |
if err != nil { | |
return nil, err | |
} | |
config.ConnConfig.RuntimeParams["search_path"] = "tmp" | |
return pgxpool.NewWithConfig(ctx, config) | |
} | |
func showSearchPath(ctx context.Context, tx pgx.Tx, extra string) error { | |
row := tx.QueryRow(ctx, "SHOW search_path") | |
searchPath := "" | |
err := row.Scan(&searchPath) | |
if err != nil { | |
return err | |
} | |
fmt.Printf("search_path (%s): %s\n", extra, searchPath) | |
return nil | |
} | |
func setSearchPath(ctx context.Context, tx pgx.Tx) error { | |
_, err := tx.Exec(ctx, "SET search_path = 'tmp'") | |
return err | |
} | |
func run() (err error) { | |
var stdlibPool *sql.DB | |
var tx *sql.Tx | |
var pgxPool *pgxpool.Pool | |
defer func() { | |
err = finalizeStdlibTx(tx, err) | |
err = finalizeStdlibPool(stdlibPool, err) | |
err = finalizePgxPool(pgxPool, err) | |
}() | |
ctx := context.Background() | |
connectionURL, ok := os.LookupEnv("CONNECTION_URL") | |
if !ok { | |
return errors.New("missing CONNECTION_URL environment variable") | |
} | |
stdlibPool, err = initStdlibPool(ctx, connectionURL) | |
if err != nil { | |
return err | |
} | |
tx, err = stdlibPool.BeginTx(ctx, nil) | |
if err != nil { | |
return err | |
} | |
pgxTx, err := dissectSQLTx(tx) | |
if err != nil { | |
return err | |
} | |
pgxPool, err = initPgxPool(ctx, connectionURL) | |
if err != nil { | |
return err | |
} | |
err = showSearchPath(ctx, pgxTx, "BEFORE") | |
if err != nil { | |
return err | |
} | |
err = setSearchPath(ctx, pgxTx) | |
if err != nil { | |
return err | |
} | |
// NOTE: `setSearchPath()` is a **SMALL**; this is how e.g. | |
// `RuntimeParams` are set on a connection via `pgx`: | |
// - https://github.com/jackc/pgx/blob/v5.6.0/pgxpool/pool.go#L227 | |
// - https://github.com/jackc/pgx/blob/v5.6.0/conn.go#L135 | |
// - https://github.com/jackc/pgx/blob/v5.6.0/conn.go#L256-L259 | |
// - https://github.com/jackc/pgx/blob/v5.6.0/pgconn/pgconn.go#L156 | |
// - https://github.com/jackc/pgx/blob/v5.6.0/pgconn/pgconn.go#L261 | |
// - https://github.com/jackc/pgx/blob/v5.6.0/pgconn/pgconn.go#L357-L359 | |
err = showSearchPath(ctx, pgxTx, "AFTER") | |
if err != nil { | |
return err | |
} | |
return nil | |
} | |
func main() { | |
err := run() | |
if err != nil { | |
fmt.Fprintf(os.Stderr, "%v\n", err) | |
os.Exit(1) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment