@@ -11,6 +11,7 @@ import (
11
11
"github.com/compose-spec/compose-go/v2/types"
12
12
"github.com/docker/cli/cli/command"
13
13
"github.com/docker/compose/v2/pkg/api"
14
+ dockertypes "github.com/docker/docker/api/types"
14
15
"github.com/docker/docker/api/types/container"
15
16
"github.com/docker/docker/api/types/filters"
16
17
"github.com/docker/docker/client"
@@ -134,6 +135,9 @@ type dockerCompose struct {
134
135
// used in ServiceContainer(...) function to avoid calls to the Docker API
135
136
containers map [string ]* testcontainers.DockerContainer
136
137
138
+ // cache for networks in the compose stack
139
+ networks map [string ]* testcontainers.DockerNetwork
140
+
137
141
// docker/compose API service instance used to control the compose stack
138
142
composeService api.Service
139
143
@@ -147,6 +151,12 @@ type dockerCompose struct {
147
151
// compiled compose project
148
152
// can be nil if the stack wasn't started yet
149
153
project * types.Project
154
+
155
+ // sessionID is used to identify the reaper session
156
+ sessionID string
157
+
158
+ // reaper is used to clean up containers after the stack is stopped
159
+ reaper * testcontainers.Reaper
150
160
}
151
161
152
162
func (d * dockerCompose ) ServiceContainer (ctx context.Context , svcName string ) (* testcontainers.DockerContainer , error ) {
@@ -235,26 +245,89 @@ func (d *dockerCompose) Up(ctx context.Context, opts ...StackUpOption) error {
235
245
return err
236
246
}
237
247
248
+ err = d .lookupNetworks (ctx )
249
+ if err != nil {
250
+ return err
251
+ }
252
+
253
+ if d .reaper != nil {
254
+ for _ , n := range d .networks {
255
+ termSignal , err := d .reaper .Connect ()
256
+ if err != nil {
257
+ return fmt .Errorf ("failed to connect to reaper: %w" , err )
258
+ }
259
+ n .SetTerminationSignal (termSignal )
260
+
261
+ // Cleanup on error, otherwise set termSignal to nil before successful return.
262
+ defer func () {
263
+ if termSignal != nil {
264
+ termSignal <- true
265
+ }
266
+ }()
267
+ }
268
+ }
269
+
270
+ errGrpContainers , errGrpCtx := errgroup .WithContext (ctx )
271
+
272
+ for _ , srv := range d .project .Services {
273
+ // we are going to connect each container to the reaper
274
+ srv := srv
275
+ errGrpContainers .Go (func () error {
276
+ dc , err := d .lookupContainer (errGrpCtx , srv .Name )
277
+ if err != nil {
278
+ return err
279
+ }
280
+
281
+ if d .reaper != nil {
282
+ termSignal , err := d .reaper .Connect ()
283
+ if err != nil {
284
+ return fmt .Errorf ("failed to connect to reaper: %w" , err )
285
+ }
286
+ dc .SetTerminationSignal (termSignal )
287
+
288
+ // Cleanup on error, otherwise set termSignal to nil before successful return.
289
+ defer func () {
290
+ if termSignal != nil {
291
+ termSignal <- true
292
+ }
293
+ }()
294
+ }
295
+
296
+ d .containers [srv .Name ] = dc
297
+
298
+ return nil
299
+ })
300
+ }
301
+
302
+ // wait here for the containers lookup to finish
303
+ if err := errGrpContainers .Wait (); err != nil {
304
+ return err
305
+ }
306
+
238
307
if len (d .waitStrategies ) == 0 {
239
308
return nil
240
309
}
241
310
242
- errGrp , errGrpCtx := errgroup .WithContext (ctx )
311
+ errGrpWait , errGrpCtx := errgroup .WithContext (ctx )
243
312
244
313
for svc , strategy := range d .waitStrategies { // pinning the variables
245
314
svc := svc
246
315
strategy := strategy
247
316
248
- errGrp .Go (func () error {
317
+ errGrpWait .Go (func () error {
249
318
target , err := d .lookupContainer (errGrpCtx , svc )
250
319
if err != nil {
251
320
return err
252
321
}
322
+
323
+ // cache all the containers on compose.up
324
+ d .containers [svc ] = target
325
+
253
326
return strategy .WaitUntilReady (errGrpCtx , target )
254
327
})
255
328
}
256
329
257
- return errGrp .Wait ()
330
+ return errGrpWait .Wait ()
258
331
}
259
332
260
333
func (d * dockerCompose ) WaitForService (s string , strategy wait.Strategy ) ComposeStack {
@@ -327,6 +400,34 @@ func (d *dockerCompose) lookupContainer(ctx context.Context, svcName string) (*t
327
400
return container , nil
328
401
}
329
402
403
+ func (d * dockerCompose ) lookupNetworks (ctx context.Context ) error {
404
+ d .containersLock .Lock ()
405
+ defer d .containersLock .Unlock ()
406
+
407
+ listOptions := dockertypes.NetworkListOptions {
408
+ Filters : filters .NewArgs (
409
+ filters .Arg ("label" , fmt .Sprintf ("%s=%s" , api .ProjectLabel , d .name )),
410
+ ),
411
+ }
412
+
413
+ networks , err := d .dockerClient .NetworkList (ctx , listOptions )
414
+ if err != nil {
415
+ return err
416
+ }
417
+
418
+ for _ , n := range networks {
419
+ dn := & testcontainers.DockerNetwork {
420
+ ID : n .ID ,
421
+ Name : n .Name ,
422
+ Driver : n .Driver ,
423
+ }
424
+
425
+ d .networks [n .ID ] = dn
426
+ }
427
+
428
+ return nil
429
+ }
430
+
330
431
func (d * dockerCompose ) compileProject (ctx context.Context ) (* types.Project , error ) {
331
432
const nameAndDefaultConfigPath = 2
332
433
projectOptions := make ([]cli.ProjectOptionsFn , len (d .projectOptions ), len (d .projectOptions )+ nameAndDefaultConfigPath )
@@ -353,6 +454,11 @@ func (d *dockerCompose) compileProject(ctx context.Context) (*types.Project, err
353
454
api .ConfigFilesLabel : strings .Join (proj .ComposeFiles , "," ),
354
455
api .OneoffLabel : "False" , // default, will be overridden by `run` command
355
456
}
457
+
458
+ for k , label := range testcontainers .GenericLabels () {
459
+ s .CustomLabels [k ] = label
460
+ }
461
+
356
462
for i , envFile := range compiledOptions .EnvFiles {
357
463
// add a label for each env file, indexed by its position
358
464
s .CustomLabels [fmt .Sprintf ("%s.%d" , api .EnvironmentFileLabel , i )] = envFile
@@ -361,6 +467,20 @@ func (d *dockerCompose) compileProject(ctx context.Context) (*types.Project, err
361
467
proj .Services [i ] = s
362
468
}
363
469
470
+ for key , n := range proj .Networks {
471
+ n .Labels = map [string ]string {
472
+ api .ProjectLabel : proj .Name ,
473
+ api .NetworkLabel : n .Name ,
474
+ api .VersionLabel : api .ComposeVersion ,
475
+ }
476
+
477
+ for k , label := range testcontainers .GenericLabels () {
478
+ n .Labels [k ] = label
479
+ }
480
+
481
+ proj .Networks [key ] = n
482
+ }
483
+
364
484
return proj , nil
365
485
}
366
486
0 commit comments