Skip to content

Commit 060734b

Browse files
authored
fix(postgres): Apply default snapshot name if no name specified (#2783)
1 parent b60497e commit 060734b

File tree

3 files changed

+83
-58
lines changed

3 files changed

+83
-58
lines changed

modules/postgres/options.go

+2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import (
77
type options struct {
88
// SQLDriverName is the name of the SQL driver to use.
99
SQLDriverName string
10+
Snapshot string
1011
}
1112

1213
func defaultOptions() options {
1314
return options{
1415
SQLDriverName: "postgres",
16+
Snapshot: defaultSnapshotName,
1517
}
1618
}
1719

modules/postgres/postgres.go

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom
177177
password: req.Env["POSTGRES_PASSWORD"],
178178
user: req.Env["POSTGRES_USER"],
179179
sqlDriverName: settings.SQLDriverName,
180+
snapshotName: settings.Snapshot,
180181
}
181182
}
182183

modules/postgres/postgres_test.go

+80-58
Original file line numberDiff line numberDiff line change
@@ -203,73 +203,95 @@ func TestWithInitScript(t *testing.T) {
203203
}
204204

205205
func TestSnapshot(t *testing.T) {
206-
// snapshotAndReset {
207-
ctx := context.Background()
208-
209-
// 1. Start the postgres ctr and run any migrations on it
210-
ctr, err := postgres.Run(
211-
ctx,
212-
"docker.io/postgres:16-alpine",
213-
postgres.WithDatabase(dbname),
214-
postgres.WithUsername(user),
215-
postgres.WithPassword(password),
216-
postgres.BasicWaitStrategies(),
217-
postgres.WithSQLDriver("pgx"),
218-
)
219-
testcontainers.CleanupContainer(t, ctr)
220-
require.NoError(t, err)
221-
222-
// Run any migrations on the database
223-
_, _, err = ctr.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "CREATE TABLE users (id SERIAL, name TEXT NOT NULL, age INT NOT NULL)"})
224-
require.NoError(t, err)
206+
tests := []struct {
207+
name string
208+
options []postgres.SnapshotOption
209+
}{
210+
{
211+
name: "snapshot/default",
212+
options: nil,
213+
},
225214

226-
// 2. Create a snapshot of the database to restore later
227-
err = ctr.Snapshot(ctx, postgres.WithSnapshotName("test-snapshot"))
228-
require.NoError(t, err)
215+
{
216+
name: "snapshot/custom",
217+
options: []postgres.SnapshotOption{
218+
postgres.WithSnapshotName("custom-snapshot"),
219+
},
220+
},
221+
}
229222

230-
dbURL, err := ctr.ConnectionString(ctx)
231-
require.NoError(t, err)
223+
for _, tt := range tests {
224+
t.Run(tt.name, func(t *testing.T) {
225+
// snapshotAndReset {
226+
ctx := context.Background()
232227

233-
t.Run("Test inserting a user", func(t *testing.T) {
234-
t.Cleanup(func() {
235-
// 3. In each test, reset the DB to its snapshot state.
236-
err = ctr.Restore(ctx)
228+
// 1. Start the postgres ctr and run any migrations on it
229+
ctr, err := postgres.Run(
230+
ctx,
231+
"docker.io/postgres:16-alpine",
232+
postgres.WithDatabase(dbname),
233+
postgres.WithUsername(user),
234+
postgres.WithPassword(password),
235+
postgres.BasicWaitStrategies(),
236+
postgres.WithSQLDriver("pgx"),
237+
)
238+
testcontainers.CleanupContainer(t, ctr)
237239
require.NoError(t, err)
238-
})
239240

240-
conn, err := pgx.Connect(context.Background(), dbURL)
241-
require.NoError(t, err)
242-
defer conn.Close(context.Background())
243-
244-
_, err = conn.Exec(ctx, "INSERT INTO users(name, age) VALUES ($1, $2)", "test", 42)
245-
require.NoError(t, err)
246-
247-
var name string
248-
var age int64
249-
err = conn.QueryRow(context.Background(), "SELECT name, age FROM users LIMIT 1").Scan(&name, &age)
250-
require.NoError(t, err)
251-
252-
require.Equal(t, "test", name)
253-
require.EqualValues(t, 42, age)
254-
})
241+
// Run any migrations on the database
242+
_, _, err = ctr.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "CREATE TABLE users (id SERIAL, name TEXT NOT NULL, age INT NOT NULL)"})
243+
require.NoError(t, err)
255244

256-
// 4. Run as many tests as you need, they will each get a clean database
257-
t.Run("Test querying empty DB", func(t *testing.T) {
258-
t.Cleanup(func() {
259-
err = ctr.Restore(ctx)
245+
// 2. Create a snapshot of the database to restore later
246+
// tt.options comes the test case, it can be specified as e.g. `postgres.WithSnapshotName("custom-snapshot")` or omitted, to use default name
247+
err = ctr.Snapshot(ctx, tt.options...)
260248
require.NoError(t, err)
261-
})
262249

263-
conn, err := pgx.Connect(context.Background(), dbURL)
264-
require.NoError(t, err)
265-
defer conn.Close(context.Background())
250+
dbURL, err := ctr.ConnectionString(ctx)
251+
require.NoError(t, err)
266252

267-
var name string
268-
var age int64
269-
err = conn.QueryRow(context.Background(), "SELECT name, age FROM users LIMIT 1").Scan(&name, &age)
270-
require.ErrorIs(t, err, pgx.ErrNoRows)
271-
})
272-
// }
253+
t.Run("Test inserting a user", func(t *testing.T) {
254+
t.Cleanup(func() {
255+
// 3. In each test, reset the DB to its snapshot state.
256+
err = ctr.Restore(ctx)
257+
require.NoError(t, err)
258+
})
259+
260+
conn, err := pgx.Connect(context.Background(), dbURL)
261+
require.NoError(t, err)
262+
defer conn.Close(context.Background())
263+
264+
_, err = conn.Exec(ctx, "INSERT INTO users(name, age) VALUES ($1, $2)", "test", 42)
265+
require.NoError(t, err)
266+
267+
var name string
268+
var age int64
269+
err = conn.QueryRow(context.Background(), "SELECT name, age FROM users LIMIT 1").Scan(&name, &age)
270+
require.NoError(t, err)
271+
272+
require.Equal(t, "test", name)
273+
require.EqualValues(t, 42, age)
274+
})
275+
276+
// 4. Run as many tests as you need, they will each get a clean database
277+
t.Run("Test querying empty DB", func(t *testing.T) {
278+
t.Cleanup(func() {
279+
err = ctr.Restore(ctx)
280+
require.NoError(t, err)
281+
})
282+
283+
conn, err := pgx.Connect(context.Background(), dbURL)
284+
require.NoError(t, err)
285+
defer conn.Close(context.Background())
286+
287+
var name string
288+
var age int64
289+
err = conn.QueryRow(context.Background(), "SELECT name, age FROM users LIMIT 1").Scan(&name, &age)
290+
require.ErrorIs(t, err, pgx.ErrNoRows)
291+
})
292+
// }
293+
})
294+
}
273295
}
274296

275297
func TestSnapshotWithOverrides(t *testing.T) {

0 commit comments

Comments
 (0)