File tree 8 files changed +49
-8
lines changed
8 files changed +49
-8
lines changed Original file line number Diff line number Diff line change @@ -21,14 +21,14 @@ jobs:
21
21
fail-fast : false
22
22
matrix :
23
23
include :
24
- - name : " python3.11-pytorch2.5.1 -gpus1"
24
+ - name : " python3.11-pytorch2.6.0 -gpus1"
25
25
gpu_num : 1
26
26
python_version : 3.11
27
- container : mosaicml/pytorch:2.5.1_cu124 -python3.11-ubuntu20 .04
28
- - name : " python3.11-pytorch2.5.1 -gpus2"
27
+ container : mosaicml/pytorch:2.6.0_cu124 -python3.11-ubuntu22 .04
28
+ - name : " python3.11-pytorch2.6.0 -gpus2"
29
29
gpu_num : 2
30
30
python_version : 3.11
31
- container : mosaicml/pytorch:2.5.1_cu124 -python3.11-ubuntu20 .04
31
+ container : mosaicml/pytorch:2.6.0_cu124 -python3.11-ubuntu22 .04
32
32
steps :
33
33
- name : Run PR GPU tests
34
34
uses : mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2
Original file line number Diff line number Diff line change 32
32
additional_dependencies :
33
33
- toml
34
34
- repo : https://github.com/hadialqattan/pycln
35
- rev : v2.1.2
35
+ rev : v2.5.0
36
36
hooks :
37
37
- id : pycln
38
38
args : [. --all]
Original file line number Diff line number Diff line change @@ -73,6 +73,18 @@ class Arguments:
73
73
moe_zloss_in_fp32 : bool = False
74
74
75
75
def __post_init__ (self ):
76
+ # Sparse MLP is not supported with triton >=3.2.0
77
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
78
+ if self .__getattribute__ ('mlp_impl' ) == 'sparse' :
79
+ try :
80
+ import triton
81
+ if triton .__version__ >= '3.2.0' :
82
+ raise ValueError (
83
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.' ,
84
+ )
85
+ except ImportError :
86
+ raise ImportError ('Triton is required for sparse MLP implementation' )
87
+
76
88
if self .__getattribute__ ('mlp_impl' ) == 'grouped' :
77
89
grouped_gemm .assert_grouped_gemm_is_available ()
78
90
Original file line number Diff line number Diff line change 3
3
4
4
# build requirements
5
5
[build-system ]
6
- requires = [" setuptools < 70.0.0" , " torch >= 2.5.1 , < 2.5.2 " ]
6
+ requires = [" setuptools < 70.0.0" , " torch >= 2.6.0 , < 2.6.1 " ]
7
7
build-backend = " setuptools.build_meta"
8
8
9
9
# Pytest
Original file line number Diff line number Diff line change 62
62
install_requires = [
63
63
'numpy>=1.21.5,<2.1.0' ,
64
64
'packaging>=21.3.0,<24.2' ,
65
- 'torch>=2.5.1 ,<2.5.2 ' ,
66
- 'triton>=2.1 .0' ,
65
+ 'torch>=2.6.0 ,<2.6.1 ' ,
66
+ 'triton>=3.2.0,<3.3 .0' ,
67
67
'stanford-stk==0.7.1' ,
68
68
]
69
69
Original file line number Diff line number Diff line change @@ -53,6 +53,16 @@ def construct_moes(
53
53
mlp_impl : str = 'sparse' ,
54
54
moe_zloss_weight : float = 0 ,
55
55
):
56
+ # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
57
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
58
+ if mlp_impl == 'sparse' :
59
+ try :
60
+ import triton
61
+ if triton .__version__ >= '3.2.0' :
62
+ pytest .skip ('Sparse MLP is not supported with triton >=3.2.0' )
63
+ except ImportError :
64
+ pass
65
+
56
66
init_method = partial (torch .nn .init .normal_ , mean = 0.0 , std = 0.1 )
57
67
args = Arguments (
58
68
hidden_size = hidden_size ,
Original file line number Diff line number Diff line change @@ -23,6 +23,16 @@ def construct_dmoe_glu(
23
23
mlp_impl : str = 'sparse' ,
24
24
memory_optimized_mlp : bool = False ,
25
25
):
26
+ # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
27
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
28
+ if mlp_impl == 'sparse' :
29
+ try :
30
+ import triton
31
+ if triton .__version__ >= '3.2.0' :
32
+ pytest .skip ('Sparse MLP is not supported with triton >=3.2.0' )
33
+ except ImportError :
34
+ pass
35
+
26
36
init_method = partial (torch .nn .init .normal_ , mean = 0.0 , std = 0.1 )
27
37
args = Arguments (
28
38
hidden_size = hidden_size ,
Original file line number Diff line number Diff line change @@ -41,6 +41,15 @@ def construct_moe(
41
41
moe_top_k : int = 1 ,
42
42
moe_zloss_weight : float = 0 ,
43
43
):
44
+ # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
45
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
46
+ try :
47
+ import triton
48
+ if triton .__version__ >= '3.2.0' :
49
+ pytest .skip ('Sparse MLP is not supported with triton >=3.2.0' )
50
+ except ImportError :
51
+ pass
52
+
44
53
init_method = partial (torch .nn .init .normal_ , mean = 0.0 , std = 0.1 )
45
54
args = Arguments (
46
55
hidden_size = hidden_size ,
You can’t perform that action at this time.
0 commit comments