Skip to content

Commit

Permalink
Fix lit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Mar 20, 2023
1 parent 4326bb8 commit cad9967
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 28 deletions.
6 changes: 5 additions & 1 deletion test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ func.func @alloc(%A : !tt.ptr<f16>) {
func.func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: scratch offset = 0, size = 512
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
%b = "tt.reduce" (%cst0) ({
^bb0(%arg0: f16, %arg1: f16):
%add = arith.addf %arg0, %arg1 : f16
tt.reduce.return %add : f16
}) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
return
// CHECK-NEXT: size = 512
}
Expand Down
12 changes: 8 additions & 4 deletions test/Analysis/test-membar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ func.func @scratch() {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
%2 = "tt.reduce" (%1) ({
^bb0(%arg1: f16, %arg2: f16):
%add = arith.addf %arg1, %arg2 : f16
tt.reduce.return %add : f16
}) {axis = 0 : i32} : (tensor<32x16xf16, #AL>) -> tensor<16xf16, #sliceAd0>
return
}

Expand Down Expand Up @@ -417,7 +421,7 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK: gpu.barrier
%c_blocked = triton_gpu.convert_layout %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
%c_blocked = triton_gpu.convert_layout %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>

%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> {
Expand All @@ -429,13 +433,13 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
%c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%c_blocked_next = triton_gpu.convert_layout %c_shared_next : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
%c_blocked_next = triton_gpu.convert_layout %c_shared_next : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
scf.yield %c_shared : tensor<128x32xf16, #A_SHARED>
}
scf.yield %c_shared_ : tensor<128x32xf16, #A_SHARED>
}
// CHECK-NOT: gpu.barrier
%b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
%b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
scf.yield %a_shared, %b_shared, %c_shared_next_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
Expand Down
48 changes: 36 additions & 12 deletions test/Conversion/triton_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,42 @@ func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}
func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
// Test if reduce ops infer types correctly

// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
%a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32>
%b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32>
%c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32>
%e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32>
%f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32>
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32
%g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32
// CHECK: }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
%a = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
// CHECK: }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
%b = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32>
// CHECK: }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
%c = "tt.reduce" (%v) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32>
// CHECK: }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
%e = "tt.reduce" (%b) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32>
// CHECK: }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
%f = "tt.reduce" (%a) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32>
// CHECK: }) {axis = 0 : i32} : (tensor<4xf32>) -> f32
%g = "tt.reduce" (%f) ({
^bb0(%arg0: f32, %arg1: f32):
%add = arith.addf %arg0, %arg1 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<4xf32>) -> f32

// Avoid optimizations for c, e, and g
%ptr1x2 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x2x!tt.ptr<f32>>
Expand Down
32 changes: 24 additions & 8 deletions test/Conversion/triton_to_tritongpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,30 @@ func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
// CHECK: tensor<4x4xf32, #[[blocked0]]> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>>
%c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32>
// CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}>
%c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32>
// CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>>
%c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32>
// CHECK: tensor<16x16xf32, #[[blocked2]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>>
%c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32>
// CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>>
%c0_ = "tt.reduce" (%c0) ({
^bb0(%arg1: f32, %arg2: f32):
%add = arith.addf %arg1, %arg2 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<4x4xf32>) -> tensor<4xf32>
// CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}>
%c1_ = "tt.reduce" (%c1) ({
^bb0(%arg3: f32, %arg4: f32):
%add = arith.addf %arg3, %arg4 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<8x2xf32>) -> tensor<2xf32>
// CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>>
%c2_ = "tt.reduce" (%c1) ({
^bb0(%arg5: f32, %arg6: f32):
%add = arith.addf %arg5, %arg6 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<8x2xf32>) -> tensor<8xf32>
// CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>>
%c3_ = "tt.reduce" (%c2) ({
^bb0(%arg7: f32, %arg8: f32):
%add = arith.addf %arg7, %arg8 : f32
tt.reduce.return %add : f32
}) {axis = 0 : i32} : (tensor<16x16xf32>) -> tensor<16xf32>

return
}
18 changes: 15 additions & 3 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,11 @@ func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1:
%27 = "triton_gpu.cmpf"(%cst_2, %26) {predicate = 4 : i64} : (tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xi1, #blocked2>
%28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2>
%29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%30 = tt.reduce %29 {axis = 1 : i32, redOp = 12 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%30 = "tt.reduce" (%29) ({
^bb0(%arg4: f32, %arg5: f32):
%max = arith.maxf %arg4, %arg5 : f32
tt.reduce.return %max : f32
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
%32 = triton_gpu.convert_layout %31 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1>
Expand All @@ -788,7 +792,11 @@ func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1:
%43 = math.exp %42 : tensor<16x16xf32, #blocked2>
%44 = arith.addf %36, %43 : tensor<16x16xf32, #blocked2>
%45 = "triton_gpu.select"(%22, %44, %36) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2>
%46 = tt.reduce %45 {axis = 1 : i32, redOp = 2 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%46 = "tt.reduce" (%45) ({
^bb0(%arg4: f32, %arg5: f32):
%add = arith.addf %arg4, %arg5 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%47 = triton_gpu.convert_layout %46 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0>
%48 = triton_gpu.convert_layout %47 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%49 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1>
Expand Down Expand Up @@ -892,7 +900,11 @@ func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !
%74 = "triton_gpu.select"(%54, %73, %arg7) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2>
scf.yield %74 : tensor<64x64xf32, #blocked2>
}
%26 = tt.reduce %25 {axis = 1 : i32, redOp = 2 : i32} : tensor<64x64xf32, #blocked2> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%26 = "tt.reduce" (%25) ({
^bb0(%arg8: f32, %arg9: f32):
%add = arith.addf %arg8, %arg9 : f32
tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%27 = triton_gpu.convert_layout %26 : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64xf32, #blocked0>
%28 = triton_gpu.convert_layout %27 : (tensor<64xf32, #blocked0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%29 = tt.expand_dims %28 {axis = 1 : i32} : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xf32, #blocked1>
Expand Down

0 comments on commit cad9967

Please sign in to comment.