diff --git a/legate/driver/args.py b/legate/driver/args.py index a42e47b7a..573a9061d 100644 --- a/legate/driver/args.py +++ b/legate/driver/args.py @@ -17,6 +17,10 @@ from __future__ import annotations from argparse import REMAINDER, ArgumentDefaultsHelpFormatter, ArgumentParser +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any from .. import __version__ from ..util.args import InfoAction @@ -42,6 +46,61 @@ __all__ = ("parser",) +def detect_multi_node_defaults() -> tuple[dict[str, Any], dict[str, Any]]: + from os import getenv + + nodes_kw = dict(NODES.kwargs) + ranks_per_node_kw = dict(RANKS_PER_NODE.kwargs) + where = None + + if ranks_env := getenv("OMPI_COMM_WORLD_SIZE"): + if ranks_per_node_env := getenv("OMPI_COMM_WORLD_LOCAL_SIZE"): + ranks, ranks_per_node = int(ranks_env), int(ranks_per_node_env) + if ranks % ranks_per_node != 0: + raise ValueError( + "Detected incompatible ranks and ranks-per-node from " + "the environment" + ) + nodes = ranks // ranks_per_node + where = "OMPI" + + elif ranks_env := getenv("MV2_COMM_WORLD_SIZE"): + if ranks_per_node_env := getenv("MV2_COMM_WORLD_LOCAL_SIZE"): + ranks, ranks_per_node = int(ranks_env), int(ranks_per_node_env) + if ranks % ranks_per_node != 0: + raise ValueError( + "Detected incompatible ranks and ranks-per-node from " + "the environment" + ) + nodes = ranks // ranks_per_node + where = "MV2" + + elif nodes_env := getenv("SLURM_JOB_NUM_NODES"): + if ranks_env := getenv("SLURM_NTASKS"): + nodes, ranks = int(nodes_env), int(ranks_env) + if ranks % nodes != 0: + raise ValueError( + "Detected incompatible nodes and ranks from the " + "environment" + ) + ranks_per_node = ranks // nodes + where = "SLURM" + + else: + nodes = defaults.LEGATE_NODES + ranks_per_node = defaults.LEGATE_RANKS_PER_NODE + + nodes_kw["default"] = nodes + ranks_per_node_kw["default"] = ranks_per_node + + if where: + extra = f" [default auto-detected from {where}]" + nodes_kw["help"] += extra + ranks_per_node_kw["help"] += extra + + return nodes_kw, ranks_per_node_kw + + parser = ArgumentParser( description="Legate Driver", allow_abbrev=False, @@ -56,9 +115,12 @@ "NOT used as arguments to legate itself.", ) +nodes_kw, ranks_per_node_kw = detect_multi_node_defaults() + + multi_node = parser.add_argument_group("Multi-node configuration") -multi_node.add_argument(NODES.name, **NODES.kwargs) -multi_node.add_argument(RANKS_PER_NODE.name, **RANKS_PER_NODE.kwargs) +multi_node.add_argument(NODES.name, **nodes_kw) +multi_node.add_argument(RANKS_PER_NODE.name, **ranks_per_node_kw) multi_node.add_argument(NOCR.name, **NOCR.kwargs) multi_node.add_argument(LAUNCHER.name, **LAUNCHER.kwargs) multi_node.add_argument(LAUNCHER_EXTRA.name, **LAUNCHER_EXTRA.kwargs) diff --git a/legate/util/shared_args.py b/legate/util/shared_args.py index ac3a561f6..84b6d3275 100644 --- a/legate/util/shared_args.py +++ b/legate/util/shared_args.py @@ -57,8 +57,8 @@ default=defaults.LEGATE_RANKS_PER_NODE, dest="ranks_per_node", help="Number of ranks (processes running copies of the program) to " - "launch per node. The default (1 rank per node) will typically result " - "in the best performance.", + "launch per node. 1 rank per node will typically result in the best " + "performance.", ), ) diff --git a/tests/unit/legate/driver/test_args.py b/tests/unit/legate/driver/test_args.py index 20627b247..999920e12 100644 --- a/tests/unit/legate/driver/test_args.py +++ b/tests/unit/legate/driver/test_args.py @@ -16,6 +16,8 @@ from argparse import SUPPRESS +import pytest + import legate.driver.args as m import legate.driver.defaults as defaults @@ -198,3 +200,82 @@ def test_parser_epilog(self) -> None: def test_parser_description(self) -> None: assert m.parser.description == "Legate Driver" + + +class TestMultiNodeDefaults: + def test_with_no_env(self) -> None: + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == defaults.LEGATE_NODES + assert "auto-detected" not in node_kw["help"] + + assert ranks_per_node_kw["default"] == defaults.LEGATE_RANKS_PER_NODE + assert "auto-detected" not in ranks_per_node_kw["help"] + + def test_with_OMPI(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OMPI_COMM_WORLD_SIZE", "6") + monkeypatch.setenv("OMPI_COMM_WORLD_LOCAL_SIZE", "2") + + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == 3 + assert "OMPI" in node_kw["help"] + + assert ranks_per_node_kw["default"] == 2 + assert "OMPI" in ranks_per_node_kw["help"] + + def test_with_OMPI_bad(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OMPI_COMM_WORLD_SIZE", "5") + monkeypatch.setenv("OMPI_COMM_WORLD_LOCAL_SIZE", "3") + + with pytest.raises(ValueError): + m.detect_multi_node_defaults() + + def test_with_MV2(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MV2_COMM_WORLD_SIZE", "6") + monkeypatch.setenv("MV2_COMM_WORLD_LOCAL_SIZE", "2") + + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == 3 + assert "MV2" in node_kw["help"] + + assert ranks_per_node_kw["default"] == 2 + assert "MV2" in ranks_per_node_kw["help"] + + def test_with_MV2_bad(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MV2_COMM_WORLD_SIZE", "5") + monkeypatch.setenv("MV2_COMM_WORLD_LOCAL_SIZE", "3") + + with pytest.raises(ValueError): + m.detect_multi_node_defaults() + + def test_with_SLURM(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SLURM_NTASKS", "6") + monkeypatch.setenv("SLURM_JOB_NUM_NODES", "2") + + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == 2 + assert "SLURM" in node_kw["help"] + + assert ranks_per_node_kw["default"] == 3 + assert "SLURM" in ranks_per_node_kw["help"] + + def test_with_SLURM_bad(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SLURM_NTASKS", "5") + monkeypatch.setenv("SLURM_JOB_NUM_NODES", "3") + + with pytest.raises(ValueError): + m.detect_multi_node_defaults() + + # test same as no_env -- auto-detect for PMI is unsupported + def test_with_PMI(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("PMI_SIZE", "6") + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == defaults.LEGATE_NODES + assert "auto-detected" not in node_kw["help"] + + assert ranks_per_node_kw["default"] == defaults.LEGATE_RANKS_PER_NODE + assert "auto-detected" not in ranks_per_node_kw["help"]