Skip to content

Commit bb861cb

Browse files
belericoawaelchlicarmoccaBorda
authored
Let TorchCollective works on the torch.distributed WORLD process group by default (#16995)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
1 parent c886317 commit bb861cb

File tree

2 files changed

+58
-11
lines changed

2 files changed

+58
-11
lines changed

src/lightning/fabric/plugins/collectives/torch_collective.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def __init__(self) -> None:
3030
raise RuntimeError("Torch distributed is not available.")
3131
super().__init__()
3232

33+
@property
34+
def group(self) -> CollectibleGroup:
35+
if self._group is None:
36+
self._group = dist.GroupMember.WORLD
37+
return super().group
38+
3339
@property
3440
def rank(self) -> int:
3541
# local rank
@@ -138,17 +144,20 @@ def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = N
138144
return self
139145

140146
def teardown(self) -> Self:
141-
non_group_member = self.group == dist.GroupMember.NON_GROUP_MEMBER
147+
group_member = self.group != dist.GroupMember.NON_GROUP_MEMBER
142148
super().teardown() # will destroy its own group
143149
# try to destroy the default group. this should only be done by a group member to avoid race conditions,
144150
# and only if the class is managing it
145-
if not non_group_member and TorchCollective.manages_default_group:
146-
default_group = dist.GroupMember.WORLD
147-
if default_group is not None: # not destroyed already
148-
group_map = dist.distributed_c10d._pg_map
149-
if len(group_map) == 1 and default_group in group_map: # only the default group is left
150-
self.destroy_group(default_group)
151-
TorchCollective.manages_default_group = False
151+
if (
152+
group_member
153+
and TorchCollective.manages_default_group
154+
and (default_group := dist.GroupMember.WORLD) is not None # not destroyed already
155+
and len(dist.distributed_c10d._pg_map) == 1 # only the default group is left
156+
):
157+
self.destroy_group(default_group)
158+
TorchCollective.manages_default_group = False
159+
elif TorchCollective.manages_default_group and dist.GroupMember.WORLD is None:
160+
TorchCollective.manages_default_group = False
152161
return self
153162

154163
@classmethod
@@ -171,7 +180,8 @@ def new_group(cls, **kwargs: Any) -> CollectibleGroup:
171180
def destroy_group(cls, group: CollectibleGroup) -> None:
172181
# can be called by all processes in the default group, group will be `object()` if they are not part of the
173182
# current group
174-
dist.destroy_process_group(group) # type: ignore[arg-type]
183+
if group in dist.distributed_c10d._pg_map:
184+
dist.destroy_process_group(group) # type: ignore[arg-type]
175185

176186
@classmethod
177187
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:

tests/tests_fabric/plugins/collectives/test_torch_collective.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_convert_ops():
146146
def test_repeated_create_and_destroy():
147147
collective = TorchCollective()
148148
with mock.patch("torch.distributed.init_process_group"):
149-
collective.setup(main_address="foo", main_port=123)
149+
collective.setup(main_address="foo", main_port="123")
150150

151151
assert not os.environ
152152

@@ -157,7 +157,9 @@ def test_repeated_create_and_destroy():
157157
with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"):
158158
collective.create_group()
159159

160-
with mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
160+
with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch(
161+
"torch.distributed.destroy_process_group"
162+
) as destroy_mock:
161163
collective.teardown()
162164
# this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
163165
# group
@@ -269,3 +271,38 @@ def _test_two_groups(strategy, left_collective, right_collective):
269271
@pytest.mark.skip(reason="TODO(carmocca): causing hangs in CI")
270272
def test_two_groups():
271273
collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)
274+
275+
276+
def _test_default_process_group(strategy, *collectives):
277+
for collective in collectives:
278+
assert collective.group == torch.distributed.group.WORLD
279+
world_size = strategy.world_size
280+
for c in collectives:
281+
tensor = torch.tensor(world_size)
282+
r = c.all_reduce(tensor)
283+
assert world_size**2 == r
284+
285+
286+
@skip_distributed_unavailable
287+
@RunIf(skip_windows=True)
288+
@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13
289+
def test_default_process_group():
290+
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
291+
292+
293+
@skip_distributed_unavailable
294+
@mock.patch.dict(os.environ, {}, clear=True)
295+
def test_collective_manages_default_group():
296+
collective = TorchCollective()
297+
with mock.patch("torch.distributed.init_process_group"):
298+
collective.setup(main_address="foo", main_port="123")
299+
300+
assert TorchCollective.manages_default_group
301+
302+
with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict(
303+
"torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}
304+
), mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
305+
collective.teardown()
306+
destroy_mock.assert_called_once_with(mock_group)
307+
308+
assert not TorchCollective.manages_default_group

0 commit comments

Comments
 (0)