Skip to content
joejiong edited this page Dec 21, 2021 · 5 revisions

Welcome to the Paddle_AST_Infrastructure wiki!

#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (d0 ceildiv 256)>
module  {
  func @conv_2d(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c256 = arith.constant 256 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = splat %cst : vector<256xf32>
    %1 = memref.dim %arg1, %c0 : memref<?x?xf32>
    %2 = memref.dim %arg1, %c1 : memref<?x?xf32>
    %3 = memref.dim %arg2, %c0 : memref<?x?xf32>
    %4 = memref.dim %arg2, %c1 : memref<?x?xf32>
    affine.for %arg3 = #map0(%c0) to #map0(%3) {
      affine.for %arg4 = #map0(%c0) to #map0(%1) {
        affine.for %arg5 = #map0(%c0) to #map0(%2) {
          affine.for %arg6 = #map0(%c0) to #map1(%4) {
            %5 = affine.vector_load %arg1[%arg4, %arg5] : memref<?x?xf32>, vector<1xf32>
            %6 = vector.broadcast %5 : vector<1xf32> to vector<256xf32>
            %7 = arith.muli %arg6, %c256 : index
            %8 = arith.subi %4, %7 : index
            %9 = arith.cmpi sge, %8, %c256 : index
            scf.if %9 {
              %10 = affine.vector_load %arg0[%arg3 + %arg4, %arg5 + %arg6 * 256] : memref<?x?xf32>, vector<256xf32>
              %11 = affine.vector_load %arg2[%arg3, %arg6 * 256] : memref<?x?xf32>, vector<256xf32>
              %12 = vector.fma %10, %6, %11 : vector<256xf32>
              affine.vector_store %12, %arg2[%arg3, %arg6 * 256] : memref<?x?xf32>, vector<256xf32>
            } else {
              %10 = vector.create_mask %8 : vector<256xi1>
              %11 = arith.addi %arg3, %arg4 : index
              %12 = arith.muli %arg6, %c256 : index
              %13 = arith.addi %arg5, %12 : index
              %14 = vector.maskedload %arg0[%11, %13], %10, %0 : memref<?x?xf32>, vector<256xi1>, vector<256xf32> into vector<256xf32>
              %15 = vector.maskedload %arg2[%arg3, %12], %10, %0 : memref<?x?xf32>, vector<256xi1>, vector<256xf32> into vector<256xf32>
              %16 = vector.fma %14, %6, %15 : vector<256xf32>
              vector.maskedstore %arg2[%arg3, %12], %10, %16 : memref<?x?xf32>, vector<256xi1>, vector<256xf32>
            }
          }
        }
      }
    }
    return
  }
}

annotated

./conv-opt ../../examples/conv-opt/conv2d.mlir -conv-vectorization="strip-mining=256" --print-ir-before-all --print-ir-after-change --color
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (d0 ceildiv 256)>
module  {
  func @conv_2d(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    
    %0 = memref.dim %arg1, %c0 : memref<?x?xf32> // FH
    %1 = memref.dim %arg1, %c1 : memref<?x?xf32> // FW
    
    %2 = memref.dim %arg2, %c0 : memref<?x?xf32> //OH
    %3 = memref.dim %arg2, %c1 : memref<?x?xf32> //OW
    
    affine.for %arg3 = #map0(%c0) to #map0(%2) {       // a3 : 0-oh
      affine.for %arg4 = #map0(%c0) to #map0(%0) {     // a4 : 0-fh
        affine.for %arg5 = #map0(%c0) to #map0(%1) {   // a5 : 0-fw
          affine.for %arg6 = #map0(%c0) to #map1(%3) { // a6 : 0-up[ow/256]
          	
            // f4 = vector.load(filter[fh,fw])
            %4 = affine.vector_load %arg1[%arg4, %arg5] : memref<?x?xf32>, vector<1xf32> 	
            
            // vec.bcast(vector.load(filter[fh,fw])) 1-256
            %5 = vector.broadcast %4 : vector<1xf32> to vector<256xf32>
            
            // %6=vec.load256(img[oh+fh, fw+ow*256])
            %6 = affine.vector_load %arg0[%arg3 + %arg4, %arg5 + %arg6 * 256] : memref<?x?xf32>, vector<256xf32>
            
            // fi = vector.load(out[ow, up(ow/256)*256])
            %7 = affine.vector_load %arg2[%arg3, %arg6 * 256] : memref<?x?xf32>, vector<256xf32>
            
            // vec.fma(vec.load256(img[fh+oh, fw+ow*256])*vec.bcast(vector.load(filter[fh,fw]))
            //                                         +vec.load(out[ow, up(ow/256)*256]))
            %8 = vector.fma %6, %5, %7 : vector<256xf32>
            
            // out[oh,up(ow/256)*256]
            affine.vector_store %8, %arg2[%arg3, %arg6 * 256] : memref<?x?xf32>, vector<256xf32>
            
          }
        }
      }
    }
    return
  }
}

changes made for pointwise_conv_2d_nhwc_hwcf_1x1 kernel:

#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (d0 ceildiv 256)>
module  {
  func @pointwise_conv_2d_nhwc_hwcf(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    
    %0 = memref.dim %arg1, %c0 : memref<?x?xf32> // FH
    %1 = memref.dim %arg1, %c1 : memref<?x?xf32> // FW
    
    %2 = memref.dim %arg2, %c0 : memref<?x?xf32> //OH
    %3 = memref.dim %arg2, %c1 : memref<?x?xf32> //OW

    %4 = memref.dim %arg2, %c2 : memref<?x?xf32> //OF

    affine.for %fc = #map0(%c0) to #map0(%4) {          // of0 : 0-of
        affine.for %arg3 = #map0(%c0) to #map0(%2) {       // a3 : 0-oh
            affine.for %arg4 = #map0(%c0) to #map0(%0) {     // a4 : 0-fh
                affine.for %arg5 = #map0(%c0) to #map0(%1) {   // a5 : 0-fw
                    affine.for %arg6 = #map0(%c0) to #map1(%3) { // a6 : 0-up[ow/256]
                        // f4 = vector.load(filter[fh,fw])
                        %4 = affine.vector_load %arg1[%arg4, %arg5] : memref<?x?xf32>, vector<1xf32> 	
  
                        // vec.bcast(vector.load(filter[fh,fw])) 1-256
                        %5 = vector.broadcast %4 : vector<1xf32> to vector<256xf32>
                        
                        // %6=vec.load256(img[fh+oh, ow+fw*256])
                        %6 = affine.vector_load %arg0[%arg3 + %arg4, %arg5 + %arg6 * 256] : memref<?x?xf32>, vector<256xf32>
                        
                        // fi = vector.load(out[oh, up(ow/256)*256])
                        %7 = affine.vector_load %arg2[%arg3, %arg6 * 256] : memref<?x?xf32>, vector<256xf32>
                        
                        // vec.fma(vec.load256(img[fh+oh, fw+ow*256])*vec.bcast(vector.load(filter[fh,fw]))
                        //                                         +vec.load(out[ow, up(ow/256)*256]))
                        %8 = vector.fma %6, %5, %7 : vector<256xf32>
                        
                        // out[oh,up(ow/256)*256]
                        affine.vector_store %8, %arg2[%arg3, %arg6 * 256] : memref<?x?xf32>, vector<256xf32>    
                    }
                }
            }
        }
    }  
    return
  }
}
Clone this wiki locally