Skip to content

Commit cf059d9

Browse files
authored
Updated pytorch and disabled sparse tests (#168)
* Updated pytorch and disabled sparse tests * added pull request target * udpated pull_request_target * made triton import inline and added try/catch * formatted * added note to use grouped instead
1 parent 75a2560 commit cf059d9

File tree

8 files changed

+49
-8
lines changed

8 files changed

+49
-8
lines changed

.github/workflows/pr-gpu.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ jobs:
2121
fail-fast: false
2222
matrix:
2323
include:
24-
- name: "python3.11-pytorch2.5.1-gpus1"
24+
- name: "python3.11-pytorch2.6.0-gpus1"
2525
gpu_num: 1
2626
python_version: 3.11
27-
container: mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu20.04
28-
- name: "python3.11-pytorch2.5.1-gpus2"
27+
container: mosaicml/pytorch:2.6.0_cu124-python3.11-ubuntu22.04
28+
- name: "python3.11-pytorch2.6.0-gpus2"
2929
gpu_num: 2
3030
python_version: 3.11
31-
container: mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu20.04
31+
container: mosaicml/pytorch:2.6.0_cu124-python3.11-ubuntu22.04
3232
steps:
3333
- name: Run PR GPU tests
3434
uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ repos:
3232
additional_dependencies:
3333
- toml
3434
- repo: https://github.com/hadialqattan/pycln
35-
rev: v2.1.2
35+
rev: v2.5.0
3636
hooks:
3737
- id: pycln
3838
args: [. --all]

megablocks/layers/arguments.py

+12
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ class Arguments:
7373
moe_zloss_in_fp32: bool = False
7474

7575
def __post_init__(self):
76+
# Sparse MLP is not supported with triton >=3.2.0
77+
# TODO: Remove this once sparse is supported with triton >=3.2.0
78+
if self.__getattribute__('mlp_impl') == 'sparse':
79+
try:
80+
import triton
81+
if triton.__version__ >= '3.2.0':
82+
raise ValueError(
83+
'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
84+
)
85+
except ImportError:
86+
raise ImportError('Triton is required for sparse MLP implementation')
87+
7688
if self.__getattribute__('mlp_impl') == 'grouped':
7789
grouped_gemm.assert_grouped_gemm_is_available()
7890

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# build requirements
55
[build-system]
6-
requires = ["setuptools < 70.0.0", "torch >= 2.5.1, < 2.5.2"]
6+
requires = ["setuptools < 70.0.0", "torch >= 2.6.0, < 2.6.1"]
77
build-backend = "setuptools.build_meta"
88

99
# Pytest

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@
6262
install_requires = [
6363
'numpy>=1.21.5,<2.1.0',
6464
'packaging>=21.3.0,<24.2',
65-
'torch>=2.5.1,<2.5.2',
66-
'triton>=2.1.0',
65+
'torch>=2.6.0,<2.6.1',
66+
'triton>=3.2.0,<3.3.0',
6767
'stanford-stk==0.7.1',
6868
]
6969

tests/layers/dmoe_test.py

+10
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ def construct_moes(
5353
mlp_impl: str = 'sparse',
5454
moe_zloss_weight: float = 0,
5555
):
56+
# All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
57+
# TODO: Remove this once sparse is supported with triton >=3.2.0
58+
if mlp_impl == 'sparse':
59+
try:
60+
import triton
61+
if triton.__version__ >= '3.2.0':
62+
pytest.skip('Sparse MLP is not supported with triton >=3.2.0')
63+
except ImportError:
64+
pass
65+
5666
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
5767
args = Arguments(
5868
hidden_size=hidden_size,

tests/layers/glu_test.py

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ def construct_dmoe_glu(
2323
mlp_impl: str = 'sparse',
2424
memory_optimized_mlp: bool = False,
2525
):
26+
# All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
27+
# TODO: Remove this once sparse is supported with triton >=3.2.0
28+
if mlp_impl == 'sparse':
29+
try:
30+
import triton
31+
if triton.__version__ >= '3.2.0':
32+
pytest.skip('Sparse MLP is not supported with triton >=3.2.0')
33+
except ImportError:
34+
pass
35+
2636
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
2737
args = Arguments(
2838
hidden_size=hidden_size,

tests/layers/moe_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def construct_moe(
4141
moe_top_k: int = 1,
4242
moe_zloss_weight: float = 0,
4343
):
44+
# All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
45+
# TODO: Remove this once sparse is supported with triton >=3.2.0
46+
try:
47+
import triton
48+
if triton.__version__ >= '3.2.0':
49+
pytest.skip('Sparse MLP is not supported with triton >=3.2.0')
50+
except ImportError:
51+
pass
52+
4453
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
4554
args = Arguments(
4655
hidden_size=hidden_size,

0 commit comments

Comments
 (0)