1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s 2 3func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { 4 %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> 5 return %0 : vector<2xf32> 6} 7 8// CHECK-LABEL: func @vector_multi_reduction 9// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32> 10// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> 11// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32> 12// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32> 13// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32> 14// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32> 15// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32> 16// CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32> 17// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32> 18// CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32> 19// CHECK: return %[[RESULT_VEC]] : vector<2xf32> 20 21func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { 22 %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> 23 return %0 : vector<2xf32> 24} 25 26// CHECK-LABEL: func @vector_multi_reduction_min 27// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32> 28// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> 29// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32> 30// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32> 31// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32> 32// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32> 33// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32> 34// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32> 35// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32> 36// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32> 37// CHECK: return %[[RESULT_VEC]] : vector<2xf32> 38 39func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { 40 %0 = vector.multi_reduction <maxnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> 41 return %0 : vector<2xf32> 42} 43 44// CHECK-LABEL: func @vector_multi_reduction_max 45// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32> 46// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> 47// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32> 48// CHECK: %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32> 49// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32> 50// CHECK: %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32> 51// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32> 52// CHECK: %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32> 53// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32> 54// CHECK: %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32> 55// CHECK: return %[[RESULT_VEC]] : vector<2xf32> 56 57func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> { 58 %0 = vector.multi_reduction <and>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32> 59 return %0 : vector<2xi32> 60} 61 62// CHECK-LABEL: func @vector_multi_reduction_and 63// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32> 64// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> 65// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32> 66// CHECK: %[[RV0:.+]] = arith.andi %[[V0]], %[[ACC]] : vector<2xi32> 67// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32> 68// CHECK: %[[RV01:.+]] = arith.andi %[[V1]], %[[RV0]] : vector<2xi32> 69// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32> 70// CHECK: %[[RV012:.+]] = arith.andi %[[V2]], %[[RV01]] : vector<2xi32> 71// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32> 72// CHECK: %[[RESULT_VEC:.+]] = arith.andi %[[V3]], %[[RV012]] : vector<2xi32> 73// CHECK: return %[[RESULT_VEC]] : vector<2xi32> 74 75func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> { 76 %0 = vector.multi_reduction <or>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32> 77 return %0 : vector<2xi32> 78} 79 80// CHECK-LABEL: func @vector_multi_reduction_or 81// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32> 82// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> 83// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32> 84// CHECK: %[[RV0:.+]] = arith.ori %[[V0]], %[[ACC]] : vector<2xi32> 85// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32> 86// CHECK: %[[RV01:.+]] = arith.ori %[[V1]], %[[RV0]] : vector<2xi32> 87// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32> 88// CHECK: %[[RV012:.+]] = arith.ori %[[V2]], %[[RV01]] : vector<2xi32> 89// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32> 90// CHECK: %[[RESULT_VEC:.+]] = arith.ori %[[V3]], %[[RV012]] : vector<2xi32> 91// CHECK: return %[[RESULT_VEC]] : vector<2xi32> 92 93func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> { 94 %0 = vector.multi_reduction <xor>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32> 95 return %0 : vector<2xi32> 96} 97 98// CHECK-LABEL: func @vector_multi_reduction_xor 99// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32> 100// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> 101// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32> 102// CHECK: %[[RV0:.+]] = arith.xori %[[V0]], %[[ACC]] : vector<2xi32> 103// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32> 104// CHECK: %[[RV01:.+]] = arith.xori %[[V1]], %[[RV0]] : vector<2xi32> 105// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32> 106// CHECK: %[[RV012:.+]] = arith.xori %[[V2]], %[[RV01]] : vector<2xi32> 107// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32> 108// CHECK: %[[RESULT_VEC:.+]] = arith.xori %[[V3]], %[[RV012]] : vector<2xi32> 109// CHECK: return %[[RESULT_VEC]] : vector<2xi32> 110 111 112func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> { 113 %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> 114 return %0 : vector<2x3xi32> 115} 116 117// CHECK-LABEL: func @vector_reduction_outer 118// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32> 119// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32> 120// CHECK: %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32> 121// CHECK: %[[FACC:.+]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32> 122// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<6xi32> from vector<20x6xi32> 123// CHECK: %[[R:.+]] = arith.addi %[[V0]], %[[FACC]] : vector<6xi32> 124// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<6xi32> from vector<20x6xi32> 125// CHECK: %[[R0:.+]] = arith.addi %[[V1]], %[[R]] : vector<6xi32> 126// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<6xi32> from vector<20x6xi32> 127// CHECK: %[[R1:.+]] = arith.addi %[[V2]], %[[R0]] : vector<6xi32> 128// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<6xi32> from vector<20x6xi32> 129// CHECK: %[[R2:.+]] = arith.addi %[[V3]], %[[R1]] : vector<6xi32> 130// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED]][4] : vector<6xi32> from vector<20x6xi32> 131// CHECK: %[[R3:.+]] = arith.addi %[[V4]], %[[R2]] : vector<6xi32> 132// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED]][5] : vector<6xi32> from vector<20x6xi32> 133// CHECK: %[[R4:.+]] = arith.addi %[[V5]], %[[R3]] : vector<6xi32> 134// CHECK: %[[V6:.+]] = vector.extract %[[RESHAPED]][6] : vector<6xi32> from vector<20x6xi32> 135// CHECK: %[[R5:.+]] = arith.addi %[[V6]], %[[R4]] : vector<6xi32> 136// CHECK: %[[V7:.+]] = vector.extract %[[RESHAPED]][7] : vector<6xi32> from vector<20x6xi32> 137// CHECK: %[[R6:.+]] = arith.addi %[[V7]], %[[R5]] : vector<6xi32> 138// CHECK: %[[V8:.+]] = vector.extract %[[RESHAPED]][8] : vector<6xi32> from vector<20x6xi32> 139// CHECK: %[[R7:.+]] = arith.addi %[[V8]], %[[R6]] : vector<6xi32> 140// CHECK: %[[V9:.+]] = vector.extract %[[RESHAPED]][9] : vector<6xi32> from vector<20x6xi32> 141// CHECK: %[[R8:.+]] = arith.addi %[[V9]], %[[R7]] : vector<6xi32> 142// CHECK: %[[V10:.+]] = vector.extract %[[RESHAPED]][10] : vector<6xi32> from vector<20x6xi32> 143// CHECK: %[[R9:.+]] = arith.addi %[[V10]], %[[R8]] : vector<6xi32> 144// CHECK: %[[V11:.+]] = vector.extract %[[RESHAPED]][11] : vector<6xi32> from vector<20x6xi32> 145// CHECK: %[[R10:.+]] = arith.addi %[[V11]], %[[R9]] : vector<6xi32> 146// CHECK: %[[V12:.+]] = vector.extract %[[RESHAPED]][12] : vector<6xi32> from vector<20x6xi32> 147// CHECK: %[[R11:.+]] = arith.addi %[[V12]], %[[R10]] : vector<6xi32> 148// CHECK: %[[V13:.+]] = vector.extract %[[RESHAPED]][13] : vector<6xi32> from vector<20x6xi32> 149// CHECK: %[[R12:.+]] = arith.addi %[[V13]], %[[R11]] : vector<6xi32> 150// CHECK: %[[V14:.+]] = vector.extract %[[RESHAPED]][14] : vector<6xi32> from vector<20x6xi32> 151// CHECK: %[[R13:.+]] = arith.addi %[[V14]], %[[R12]] : vector<6xi32> 152// CHECK: %[[V15:.+]] = vector.extract %[[RESHAPED]][15] : vector<6xi32> from vector<20x6xi32> 153// CHECK: %[[R14:.+]] = arith.addi %[[V15]], %[[R13]] : vector<6xi32> 154// CHECK: %[[V16:.+]] = vector.extract %[[RESHAPED]][16] : vector<6xi32> from vector<20x6xi32> 155// CHECK: %[[R15:.+]] = arith.addi %[[V16]], %[[R14]] : vector<6xi32> 156// CHECK: %[[V17:.+]] = vector.extract %[[RESHAPED]][17] : vector<6xi32> from vector<20x6xi32> 157// CHECK: %[[R16:.+]] = arith.addi %[[V17]], %[[R15]] : vector<6xi32> 158// CHECK: %[[V18:.+]] = vector.extract %[[RESHAPED]][18] : vector<6xi32> from vector<20x6xi32> 159// CHECK: %[[R17:.+]] = arith.addi %[[V18]], %[[R16]] : vector<6xi32> 160// CHECK: %[[V19:.+]] = vector.extract %[[RESHAPED]][19] : vector<6xi32> from vector<20x6xi32> 161// CHECK: %[[R18:.+]] = arith.addi %[[V19]], %[[R17]] : vector<6xi32> 162// CHECK: %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32> 163// CHECK: return %[[RESULT_VEC]] : vector<2x3xi32> 164 165func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> { 166 %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32> 167 return %0 : vector<4xf32> 168} 169 170// CHECK-LABEL: func @vector_multi_reduction_parallel_middle 171// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32> 172// CHECK: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32> 173 174// This test is mainly to catch a bug that running 175// `InnerOuterDimReductionConversion` on this function results in an 176// infinite loop. So just check that some value is returned. 177func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 { 178 %0 = vector.multi_reduction #vector.kind<maxnumf>, %arg0, %acc [0] : vector<2xf32> to f32 179 return %0 : f32 180} 181// CHECK-LABEL: func @vector_reduction_1D 182// CHECK: return %{{.+}} 183 184func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 { 185 %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32 186 return %0 : f32 187} 188// CHECK-LABEL: func @vector_multi_reduction_to_scalar 189// CHECK: return %{{.+}} 190 191module attributes {transform.with_named_sequence} { 192 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 193 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 194 transform.apply_patterns to %func_op { 195 transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" 196 } : !transform.op<"func.func"> 197 transform.yield 198 } 199} 200