@@ -146,7 +146,7 @@ def test_convert_ops():
146
146
def test_repeated_create_and_destroy ():
147
147
collective = TorchCollective ()
148
148
with mock .patch ("torch.distributed.init_process_group" ):
149
- collective .setup (main_address = "foo" , main_port = 123 )
149
+ collective .setup (main_address = "foo" , main_port = " 123" )
150
150
151
151
assert not os .environ
152
152
@@ -157,7 +157,9 @@ def test_repeated_create_and_destroy():
157
157
with pytest .raises (RuntimeError , match = "TorchCollective` already owns a group" ):
158
158
collective .create_group ()
159
159
160
- with mock .patch ("torch.distributed.destroy_process_group" ) as destroy_mock :
160
+ with mock .patch .dict ("torch.distributed.distributed_c10d._pg_map" , {collective .group : ("" , None )}), mock .patch (
161
+ "torch.distributed.destroy_process_group"
162
+ ) as destroy_mock :
161
163
collective .teardown ()
162
164
# this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
163
165
# group
@@ -269,3 +271,38 @@ def _test_two_groups(strategy, left_collective, right_collective):
269
271
@pytest .mark .skip (reason = "TODO(carmocca): causing hangs in CI" )
270
272
def test_two_groups ():
271
273
collective_launch (_test_two_groups , [torch .device ("cpu" )] * 3 , num_groups = 2 )
274
+
275
+
276
+ def _test_default_process_group (strategy , * collectives ):
277
+ for collective in collectives :
278
+ assert collective .group == torch .distributed .group .WORLD
279
+ world_size = strategy .world_size
280
+ for c in collectives :
281
+ tensor = torch .tensor (world_size )
282
+ r = c .all_reduce (tensor )
283
+ assert world_size ** 2 == r
284
+
285
+
286
+ @skip_distributed_unavailable
287
+ @RunIf (skip_windows = True )
288
+ @mock .patch .dict (os .environ , os .environ .copy (), clear = True ) # sets CUDA_MODULE_LOADING in torch==1.13
289
+ def test_default_process_group ():
290
+ collective_launch (_test_default_process_group , [torch .device ("cpu" )] * 3 , num_groups = 2 )
291
+
292
+
293
+ @skip_distributed_unavailable
294
+ @mock .patch .dict (os .environ , {}, clear = True )
295
+ def test_collective_manages_default_group ():
296
+ collective = TorchCollective ()
297
+ with mock .patch ("torch.distributed.init_process_group" ):
298
+ collective .setup (main_address = "foo" , main_port = "123" )
299
+
300
+ assert TorchCollective .manages_default_group
301
+
302
+ with mock .patch .object (collective , "_group" ) as mock_group , mock .patch .dict (
303
+ "torch.distributed.distributed_c10d._pg_map" , {mock_group : ("" , None )}
304
+ ), mock .patch ("torch.distributed.destroy_process_group" ) as destroy_mock :
305
+ collective .teardown ()
306
+ destroy_mock .assert_called_once_with (mock_group )
307
+
308
+ assert not TorchCollective .manages_default_group
0 commit comments