From 9b95dab3d836259e880f181715640fc23a65fc06 Mon Sep 17 00:00:00 2001 From: Bryan Van de Ven Date: Wed, 6 Sep 2023 12:39:55 -0700 Subject: [PATCH 1/2] auto-detect multi-node based on env vars --- legate/driver/args.py | 50 +++++++++++++++++++++- legate/util/shared_args.py | 3 +- tests/unit/legate/driver/test_args.py | 60 +++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 4 deletions(-) diff --git a/legate/driver/args.py b/legate/driver/args.py index a42e47b7a..73e1b36c4 100644 --- a/legate/driver/args.py +++ b/legate/driver/args.py @@ -17,6 +17,10 @@ from __future__ import annotations from argparse import REMAINDER, ArgumentDefaultsHelpFormatter, ArgumentParser +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any from .. import __version__ from ..util.args import InfoAction @@ -42,6 +46,45 @@ __all__ = ("parser",) +def detect_multi_node_defaults() -> tuple[dict[str, Any], dict[str, Any]]: + from os import getenv + + nodes_kw = dict(NODES.kwargs) + ranks_per_node_kw = dict(RANKS_PER_NODE.kwargs) + where = None + + if nodes_env := getenv("OMPI_COMM_WORLD_SIZE"): + if ranks_per_node_env := getenv("OMPI_COMM_WORLD_LOCAL_SIZE"): + nodes, ranks_per_node = int(nodes_env), int(ranks_per_node_env) + where = "OMPI" + + elif nodes_env := getenv("MV2_COMM_WORLD_SIZE"): + if ranks_per_node_env := getenv("MV2_COMM_WORLD_LOCAL_SIZE"): + nodes, ranks_per_node = int(nodes_env), int(ranks_per_node_env) + where = "MV2" + + elif nodes_env := getenv("SLURM_JOB_NUM_NODES"): + if ranks_env := getenv("SLURM_NTASKS"): + nodes, ranks = int(nodes_env), int(ranks_env) + assert ranks % nodes == 0 + ranks_per_node = ranks // nodes + where = "SLURM" + + else: + nodes = defaults.LEGATE_NODES + ranks_per_node = defaults.LEGATE_RANKS_PER_NODE + + nodes_kw["default"] = nodes + ranks_per_node_kw["default"] = ranks_per_node + + if where: + extra = f" [default auto-detected from {where}]" + nodes_kw["help"] += extra + ranks_per_node_kw["help"] += extra + + return nodes_kw, ranks_per_node_kw + + parser = ArgumentParser( description="Legate Driver", allow_abbrev=False, @@ -56,9 +99,12 @@ "NOT used as arguments to legate itself.", ) +nodes_kw, ranks_per_node_kw = detect_multi_node_defaults() + + multi_node = parser.add_argument_group("Multi-node configuration") -multi_node.add_argument(NODES.name, **NODES.kwargs) -multi_node.add_argument(RANKS_PER_NODE.name, **RANKS_PER_NODE.kwargs) +multi_node.add_argument(NODES.name, **nodes_kw) +multi_node.add_argument(RANKS_PER_NODE.name, **ranks_per_node_kw) multi_node.add_argument(NOCR.name, **NOCR.kwargs) multi_node.add_argument(LAUNCHER.name, **LAUNCHER.kwargs) multi_node.add_argument(LAUNCHER_EXTRA.name, **LAUNCHER_EXTRA.kwargs) diff --git a/legate/util/shared_args.py b/legate/util/shared_args.py index ac3a561f6..8194c04b7 100644 --- a/legate/util/shared_args.py +++ b/legate/util/shared_args.py @@ -57,8 +57,7 @@ default=defaults.LEGATE_RANKS_PER_NODE, dest="ranks_per_node", help="Number of ranks (processes running copies of the program) to " - "launch per node. The default (1 rank per node) will typically result " - "in the best performance.", + "launch per node.", ), ) diff --git a/tests/unit/legate/driver/test_args.py b/tests/unit/legate/driver/test_args.py index 20627b247..795d0d012 100644 --- a/tests/unit/legate/driver/test_args.py +++ b/tests/unit/legate/driver/test_args.py @@ -16,6 +16,8 @@ from argparse import SUPPRESS +import pytest + import legate.driver.args as m import legate.driver.defaults as defaults @@ -198,3 +200,61 @@ def test_parser_epilog(self) -> None: def test_parser_description(self) -> None: assert m.parser.description == "Legate Driver" + + +class TestMultiNodeDefaults: + def test_with_no_env(self) -> None: + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == defaults.LEGATE_NODES + assert "auto-detected" not in node_kw["help"] + + assert ranks_per_node_kw["default"] == defaults.LEGATE_RANKS_PER_NODE + assert "auto-detected" not in ranks_per_node_kw["help"] + + def test_with_OMPI(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OMPI_COMM_WORLD_SIZE", "6") + monkeypatch.setenv("OMPI_COMM_WORLD_LOCAL_SIZE", "2") + + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == 6 + assert "OMPI" in node_kw["help"] + + assert ranks_per_node_kw["default"] == 2 + assert "OMPI" in ranks_per_node_kw["help"] + + def test_with_MV2(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MV2_COMM_WORLD_SIZE", "6") + monkeypatch.setenv("MV2_COMM_WORLD_LOCAL_SIZE", "2") + + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == 6 + assert "MV2" in node_kw["help"] + + assert ranks_per_node_kw["default"] == 2 + assert "MV2" in ranks_per_node_kw["help"] + + def test_with_SLURM(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SLURM_NTASKS", "6") + monkeypatch.setenv("SLURM_JOB_NUM_NODES", "2") + + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == 2 + assert "SLURM" in node_kw["help"] + + assert ranks_per_node_kw["default"] == 3 + assert "SLURM" in ranks_per_node_kw["help"] + + # test same as no_env -- auto-detect for PMI is unsupported + def test_with_PMI(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("PMI_SIZE", "6") + node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() + + assert node_kw["default"] == defaults.LEGATE_NODES + assert "auto-detected" not in node_kw["help"] + + assert ranks_per_node_kw["default"] == defaults.LEGATE_RANKS_PER_NODE + assert "auto-detected" not in ranks_per_node_kw["help"] From 2975cc7ab7304e8f068a2707096ced9c12e4486e Mon Sep 17 00:00:00 2001 From: Bryan Van de Ven Date: Wed, 6 Sep 2023 17:05:01 -0700 Subject: [PATCH 2/2] review comments --- legate/driver/args.py | 26 +++++++++++++++++++++----- legate/util/shared_args.py | 3 ++- tests/unit/legate/driver/test_args.py | 25 +++++++++++++++++++++++-- 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/legate/driver/args.py b/legate/driver/args.py index 73e1b36c4..573a9061d 100644 --- a/legate/driver/args.py +++ b/legate/driver/args.py @@ -53,20 +53,36 @@ def detect_multi_node_defaults() -> tuple[dict[str, Any], dict[str, Any]]: ranks_per_node_kw = dict(RANKS_PER_NODE.kwargs) where = None - if nodes_env := getenv("OMPI_COMM_WORLD_SIZE"): + if ranks_env := getenv("OMPI_COMM_WORLD_SIZE"): if ranks_per_node_env := getenv("OMPI_COMM_WORLD_LOCAL_SIZE"): - nodes, ranks_per_node = int(nodes_env), int(ranks_per_node_env) + ranks, ranks_per_node = int(ranks_env), int(ranks_per_node_env) + if ranks % ranks_per_node != 0: + raise ValueError( + "Detected incompatible ranks and ranks-per-node from " + "the environment" + ) + nodes = ranks // ranks_per_node where = "OMPI" - elif nodes_env := getenv("MV2_COMM_WORLD_SIZE"): + elif ranks_env := getenv("MV2_COMM_WORLD_SIZE"): if ranks_per_node_env := getenv("MV2_COMM_WORLD_LOCAL_SIZE"): - nodes, ranks_per_node = int(nodes_env), int(ranks_per_node_env) + ranks, ranks_per_node = int(ranks_env), int(ranks_per_node_env) + if ranks % ranks_per_node != 0: + raise ValueError( + "Detected incompatible ranks and ranks-per-node from " + "the environment" + ) + nodes = ranks // ranks_per_node where = "MV2" elif nodes_env := getenv("SLURM_JOB_NUM_NODES"): if ranks_env := getenv("SLURM_NTASKS"): nodes, ranks = int(nodes_env), int(ranks_env) - assert ranks % nodes == 0 + if ranks % nodes != 0: + raise ValueError( + "Detected incompatible nodes and ranks from the " + "environment" + ) ranks_per_node = ranks // nodes where = "SLURM" diff --git a/legate/util/shared_args.py b/legate/util/shared_args.py index 8194c04b7..84b6d3275 100644 --- a/legate/util/shared_args.py +++ b/legate/util/shared_args.py @@ -57,7 +57,8 @@ default=defaults.LEGATE_RANKS_PER_NODE, dest="ranks_per_node", help="Number of ranks (processes running copies of the program) to " - "launch per node.", + "launch per node. 1 rank per node will typically result in the best " + "performance.", ), ) diff --git a/tests/unit/legate/driver/test_args.py b/tests/unit/legate/driver/test_args.py index 795d0d012..999920e12 100644 --- a/tests/unit/legate/driver/test_args.py +++ b/tests/unit/legate/driver/test_args.py @@ -218,24 +218,38 @@ def test_with_OMPI(self, monkeypatch: pytest.MonkeyPatch) -> None: node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() - assert node_kw["default"] == 6 + assert node_kw["default"] == 3 assert "OMPI" in node_kw["help"] assert ranks_per_node_kw["default"] == 2 assert "OMPI" in ranks_per_node_kw["help"] + def test_with_OMPI_bad(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OMPI_COMM_WORLD_SIZE", "5") + monkeypatch.setenv("OMPI_COMM_WORLD_LOCAL_SIZE", "3") + + with pytest.raises(ValueError): + m.detect_multi_node_defaults() + def test_with_MV2(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("MV2_COMM_WORLD_SIZE", "6") monkeypatch.setenv("MV2_COMM_WORLD_LOCAL_SIZE", "2") node_kw, ranks_per_node_kw = m.detect_multi_node_defaults() - assert node_kw["default"] == 6 + assert node_kw["default"] == 3 assert "MV2" in node_kw["help"] assert ranks_per_node_kw["default"] == 2 assert "MV2" in ranks_per_node_kw["help"] + def test_with_MV2_bad(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MV2_COMM_WORLD_SIZE", "5") + monkeypatch.setenv("MV2_COMM_WORLD_LOCAL_SIZE", "3") + + with pytest.raises(ValueError): + m.detect_multi_node_defaults() + def test_with_SLURM(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("SLURM_NTASKS", "6") monkeypatch.setenv("SLURM_JOB_NUM_NODES", "2") @@ -248,6 +262,13 @@ def test_with_SLURM(self, monkeypatch: pytest.MonkeyPatch) -> None: assert ranks_per_node_kw["default"] == 3 assert "SLURM" in ranks_per_node_kw["help"] + def test_with_SLURM_bad(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SLURM_NTASKS", "5") + monkeypatch.setenv("SLURM_JOB_NUM_NODES", "3") + + with pytest.raises(ValueError): + m.detect_multi_node_defaults() + # test same as no_env -- auto-detect for PMI is unsupported def test_with_PMI(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("PMI_SIZE", "6")