1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s 2 3func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { 4 %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> 5 return %0 : vector<2xf32> 6} 7// CHECK-LABEL: func @vector_multi_reduction 8// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>) 9// CHECK-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32> 10// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] 11// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0] 12// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32 13// CHECK: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32> 14// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] 15// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1] 16// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32 17// CHECK: %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32> 18// CHECK: return %[[RESULT_VEC]] 19 20func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 { 21 %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32 22 return %0 : f32 23} 24// CHECK-LABEL: func @vector_multi_reduction_to_scalar 25// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32) 26// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32> 27// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32 28// CHECK: return %[[REDUCED]] 29 30func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> { 31 %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> 32 return %0 : vector<2x3xi32> 33} 34// CHECK-LABEL: func @vector_reduction_inner 35// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32> 36// CHECK-DAG: %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32> 37// CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32> 38// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<20xi32> from vector<6x20xi32> 39// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : i32 from vector<2x3xi32> 40// CHECK: %[[V0R:.+]] = vector.reduction <add>, %[[V0]], %[[ACC0]] : vector<20xi32> into i32 41// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insert %[[V0R]], %[[FLAT_RESULT_VEC_0]] [0] : i32 into vector<6xi32> 42// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<20xi32> from vector<6x20xi32> 43// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : i32 from vector<2x3xi32> 44// CHECK: %[[V1R:.+]] = vector.reduction <add>, %[[V1]], %[[ACC1]] : vector<20xi32> into i32 45// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insert %[[V1R]], %[[FLAT_RESULT_VEC_1]] [1] : i32 into vector<6xi32> 46// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<20xi32> from vector<6x20xi32> 47// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : i32 from vector<2x3xi32> 48// CHECK: %[[V2R:.+]] = vector.reduction <add>, %[[V2]], %[[ACC2]] : vector<20xi32> into i32 49// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insert %[[V2R]], %[[FLAT_RESULT_VEC_2]] [2] : i32 into vector<6xi32> 50// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<20xi32> from vector<6x20xi32> 51// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][1, 0] : i32 from vector<2x3xi32> 52// CHECK: %[[V3R:.+]] = vector.reduction <add>, %[[V3]], %[[ACC3]] : vector<20xi32> into i32 53// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insert %[[V3R]], %[[FLAT_RESULT_VEC_3]] [3] : i32 into vector<6xi32> 54// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<20xi32> from vector<6x20xi32> 55// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 1] : i32 from vector<2x3xi32> 56// CHECK: %[[V4R:.+]] = vector.reduction <add>, %[[V4]], %[[ACC4]] : vector<20xi32> into i32 57// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insert %[[V4R]], %[[FLAT_RESULT_VEC_4]] [4] : i32 into vector<6xi32> 58// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<20xi32> from vector<6x20xi32> 59// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 2] : i32 from vector<2x3xi32> 60// CHECK: %[[V5R:.+]] = vector.reduction <add>, %[[V5]], %[[ACC5]] : vector<20xi32> into i32 61// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insert %[[V5R]], %[[FLAT_RESULT_VEC_5]] [5] : i32 into vector<6xi32> 62// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32> 63// CHECK: return %[[RESULT]] 64 65func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> { 66 %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> 67 return %0 : vector<2x5xf32> 68} 69 70// CHECK-LABEL: func @vector_multi_reduction_transposed 71// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32> 72// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32> 73// CHECK: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> 74// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> 75// CHECK: return %[[RESULT]] 76 77func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> { 78 %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32> 79 return %0 : vector<2x4xf32> 80} 81// CHECK-LABEL: func @vector_multi_reduction_ordering 82// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32>, %[[ACC:.*]]: vector<2x4xf32>) 83// CHECK-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32> 84// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32> 85// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0] 86// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : f32 from vector<2x4xf32> 87// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<3xf32> into f32 88// CHECK: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<8xf32> 89// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1] 90// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : f32 from vector<2x4xf32> 91// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<3xf32> into f32 92// CHECK: %[[RESULT_VEC_2:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<8xf32> 93// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2] 94// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : f32 from vector<2x4xf32> 95// CHECK: %[[RV2:.+]] = vector.reduction <mul>, %[[V2]], %[[ACC2]] : vector<3xf32> into f32 96// CHECK: %[[RESULT_VEC_3:.+]] = vector.insert %[[RV2:.+]], %[[RESULT_VEC_2]] [2] : f32 into vector<8xf32> 97// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3] 98// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][0, 3] : f32 from vector<2x4xf32> 99// CHECK: %[[RV3:.+]] = vector.reduction <mul>, %[[V3]], %[[ACC3]] : vector<3xf32> into f32 100// CHECK: %[[RESULT_VEC_4:.+]] = vector.insert %[[RV3:.+]], %[[RESULT_VEC_3]] [3] : f32 into vector<8xf32> 101// CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0] 102// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 0] : f32 from vector<2x4xf32> 103// CHECK: %[[RV4:.+]] = vector.reduction <mul>, %[[V4]], %[[ACC4]] : vector<3xf32> into f32 104// CHECK: %[[RESULT_VEC_5:.+]] = vector.insert %[[RV4:.+]], %[[RESULT_VEC_4]] [4] : f32 into vector<8xf32> 105// CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1] 106// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 1] : f32 from vector<2x4xf32> 107// CHECK: %[[RV5:.+]] = vector.reduction <mul>, %[[V5]], %[[ACC5]] : vector<3xf32> into f32 108// CHECK: %[[RESULT_VEC_6:.+]] = vector.insert %[[RV5:.+]], %[[RESULT_VEC_5]] [5] : f32 into vector<8xf32> 109// CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2] 110// CHECK: %[[ACC6:.+]] = vector.extract %[[ACC]][1, 2] : f32 from vector<2x4xf32> 111// CHECK: %[[RV6:.+]] = vector.reduction <mul>, %[[V6]], %[[ACC6]] : vector<3xf32> into f32 112// CHECK: %[[RESULT_VEC_7:.+]] = vector.insert %[[RV6:.+]], %[[RESULT_VEC_6]] [6] : f32 into vector<8xf32> 113// CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3] 114// CHECK: %[[ACC7:.+]] = vector.extract %[[ACC]][1, 3] : f32 from vector<2x4xf32> 115// CHECK: %[[RV7:.+]] = vector.reduction <mul>, %[[V7]], %[[ACC7]] : vector<3xf32> into f32 116// CHECK: %[[RESULT_VEC:.+]] = vector.insert %[[RV7:.+]], %[[RESULT_VEC_7]] [7] : f32 into vector<8xf32> 117// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32> 118// CHECK: return %[[RESHAPED_VEC]] 119 120func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { 121 %c0 = arith.constant 0 : index 122 %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 123 %c1 = arith.constant 1 : index 124 %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 125 %c0_1 = arith.constant 0 : index 126 %cst = arith.constant 0.000000e+00 : f32 127 %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1> 128 %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> 129 %cst_2 = arith.constant 0.000000e+00 : f32 130 %2 = vector.create_mask %dim : vector<4xi1> 131 %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> 132 %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32> 133 %c0_3 = arith.constant 0 : index 134 %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32> 135 return %5 : tensor<?xf32> 136} 137 138// Verify that the original 2-D mask is sliced and propagated properly to the 139// vector.reduction instances. 140 141// CHECK-LABEL: func.func @vectorize_dynamic_reduction 142// CHECK: %[[VAL_8:.*]] = tensor.dim 143// CHECK: %[[VAL_9:.*]] = tensor.dim 144// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]] : vector<4x8xi1> 145 146// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<8xi1> from vector<4x8xi1> 147// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 148// CHECK: %[[VAL_18:.*]] = vector.insert 149 150// CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<8xi1> from vector<4x8xi1> 151// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 152// CHECK: %[[VAL_23:.*]] = vector.insert 153 154// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<8xi1> from vector<4x8xi1> 155// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 156// CHECK: %[[VAL_28:.*]] = vector.insert 157 158// CHECK: %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<8xi1> from vector<4x8xi1> 159// CHECK: %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 160// CHECK: %[[VAL_33:.*]] = vector.insert 161 162func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 { 163 %c0 = arith.constant 0 : index 164 %dim = tensor.dim %arg0, %c0 : tensor<?xf32> 165 %c0_1 = arith.constant 0 : index 166 %cst = arith.constant 0.000000e+00 : f32 167 %0 = vector.create_mask %dim : vector<8xi1> 168 %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1], %cst {in_bounds = [true]} : tensor<?xf32>, vector<8xf32> } : vector<8xi1> -> vector<8xf32> 169 %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %cst [0] : vector<8xf32> to f32 } : vector<8xi1> -> f32 170 return %4 : f32 171} 172 173// Verify that a 1-D vector.multi_reduction is transformed into a vector.reduction. 174// This transform expands 1-D vectors into 2-D. 175 176// CHECK-LABEL: func.func @vectorize_1d_dynamic_reduction( 177// CHECK: %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1> 178// CHECK: %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 179 180func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 181 %c0 = arith.constant 0 : index 182 %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> 183 %c1 = arith.constant 1 : index 184 %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> 185 %c2 = arith.constant 2 : index 186 %dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> 187 %c0_2 = arith.constant 0 : index 188 %cst = arith.constant 0.000000e+00 : f32 189 %0 = vector.create_mask %dim, %dim_0, %dim_1 : vector<4x8x16xi1> 190 %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %cst {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32> 191 %cst_3 = arith.constant 0.000000e+00 : f32 192 %2 = vector.create_mask %dim_1, %dim_0 : vector<16x8xi1> 193 %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32> 194 %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32> 195 %c0_4 = arith.constant 0 : index 196 %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_4, %c0_4] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32> 197 return %5 : tensor<?x?xf32> 198} 199 200// CHECK-LABEL: func.func @vectorize_dynamic_transpose_reduction 201// CHECK: %[[VAL_6:.*]] = tensor.dim 202// CHECK: %[[VAL_7:.*]] = tensor.dim 203// CHECK: %[[VAL_8:.*]] = tensor.dim 204// CHECK: %[[VAL_135:.*]] = vector.create_mask %{{.*}}, %{{.*}}, %{{.*}} : vector<4x8x16xi1> 205// CHECK: %[[VAL_139:.*]] = vector.transpose %[[VAL_135]], [1, 2, 0] : vector<4x8x16xi1> to vector<8x16x4xi1> 206 207// Just checking a few instances to make sure the vector mask is properly propagated: 208 209// CHECK: %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<4xi1> from vector<8x16x4xi1> 210// CHECK: %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction <add> 211// CHECK: %[[VAL_145:.*]] = vector.insert %[[VAL_144]] 212 213// CHECK: %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<4xi1> from vector<8x16x4xi1> 214// CHECK: %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction <add> 215// CHECK: %[[VAL_150:.*]] = vector.insert %[[VAL_149]] 216 217// CHECK: %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<4xi1> from vector<8x16x4xi1> 218// CHECK: %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction <add> 219// CHECK: %[[VAL_155:.*]] = vector.insert %[[VAL_154]] 220 221// CHECK: %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<4xi1> from vector<8x16x4xi1> 222// CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add> 223// CHECK: %[[VAL_160:.*]] = vector.insert %[[VAL_159]] 224 225func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> { 226 %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32> 227 return %0 : vector<4xf32> 228} 229 230// CHECK-LABEL: func @vector_multi_reduction_parallel_middle 231// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32> 232// CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32> 233 234func.func private @vector_multi_reduction_non_scalable_dim(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>) -> vector<8x[4]xf32> { 235 %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32> 236 return %0 : vector<8x[4]xf32> 237} 238// CHECK-LABEL: func.func private @vector_multi_reduction_non_scalable_dim( 239// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[4]x2xf32>, 240// CHECK-SAME: %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> { 241// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32> 242 243// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2xf32> from vector<8x[4]x2xf32> 244// CHECK: %[[VAL_36:.*]] = vector.extract %[[VAL_1]][0, 0] : f32 from vector<8x[4]xf32> 245// CHECK: %[[VAL_37:.*]] = vector.reduction <add>, %[[VAL_35]], %[[VAL_36]] : vector<2xf32> into f32 246// CHECK: %[[VAL_38:.*]] = vector.insert %[[VAL_37]], %[[VAL_2]] [0] : f32 into vector<[32]xf32> 247 248// CHECK: %[[VAL_39:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2xf32> from vector<8x[4]x2xf32> 249// CHECK: %[[VAL_40:.*]] = vector.extract %[[VAL_1]][0, 1] : f32 from vector<8x[4]xf32> 250// CHECK: %[[VAL_41:.*]] = vector.reduction <add>, %[[VAL_39]], %[[VAL_40]] : vector<2xf32> into f32 251// CHECK: %[[VAL_42:.*]] = vector.insert %[[VAL_41]], %[[VAL_38]] [1] : f32 into vector<[32]xf32> 252 253// (...) 254 255// CHECK: %[[VAL_159:.*]] = vector.extract %[[VAL_0]][7, 3] : vector<2xf32> from vector<8x[4]x2xf32> 256// CHECK: %[[VAL_160:.*]] = vector.extract %[[VAL_1]][7, 3] : f32 from vector<8x[4]xf32> 257// CHECK: %[[VAL_161:.*]] = vector.reduction <add>, %[[VAL_159]], %[[VAL_160]] : vector<2xf32> into f32 258// CHECK: %[[VAL_162:.*]] = vector.insert %[[VAL_161]], %{{.*}} [31] : f32 into vector<[32]xf32> 259 260// CHECK: %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32> 261// CHECK: return %[[VAL_163]] : vector<8x[4]xf32> 262 263// Check that OneDimMultiReductionToTwoDim handles scalable dim 264func.func @vector_multi_reduction_scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 { 265 %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32 266 return %0 : f32 267} 268 269// CHECK-LABEL: func.func @vector_multi_reduction_scalable_dim_1d( 270// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>, 271// CHECK-SAME: %[[ARG_1:.*]]: f32, 272// CHECK-SAME: %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 { 273// CHECK: %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32 274// CHECK: return %[[VAL_2]] : f32 275 276func.func @vector_multi_reduction_scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> { 277 %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32> 278 return %0 : vector<2xf32> 279} 280 281// CHECK-LABEL: func.func @vector_multi_reduction_scalable_dim_2d( 282// CHECK-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>, 283// CHECK-SAME: %[[ARG_1:.*]]: vector<2xf32>, 284// CHECK-SAME: %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> { 285// CHECK-DAG: %[[C0_2xf32:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> 286// CHECK: %[[ARG0_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32> 287// CHECK: %[[ARG1_0:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32> 288// CHECK: %[[ARG2_0:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1> 289// CHECK: %[[REDUCE_0:.*]] = vector.mask %[[ARG2_0]] { vector.reduction <add>, %[[ARG0_0]], %[[ARG1_0]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32 290// CHECK: %[[INSERT_0:.*]] = vector.insert %[[REDUCE_0]], %[[C0_2xf32]] [0] : f32 into vector<2xf32> 291// CHECK: %[[ARG0_1:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32> 292// CHECK: %[[ARG1_1:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32> 293// CHECK: %[[ARG2_1:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1> 294// CHECK: %[[REDUCE_1:.*]] = vector.mask %[[ARG2_1]] { vector.reduction <add>, %[[ARG0_1]], %[[ARG1_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32 295// CHECK: %[[INSERT_1:.*]] = vector.insert %[[REDUCE_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32> 296// CHECK: return %[[INSERT_1]] : vector<2xf32> 297 298module attributes {transform.with_named_sequence} { 299 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 300 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 301 transform.apply_patterns to %func_op { 302 transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" 303 } : !transform.op<"func.func"> 304 transform.yield 305 } 306} 307