1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s 2 3/// This tests that shape casts of scalable vectors (with one trailing scalable dim) 4/// can be correctly lowered to vector.scalable.insert/extract. 5 6// CHECK-LABEL: i32_3d_to_1d_last_dim_scalable 7// CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32> 8func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32> 9{ 10 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32> 11 // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> 12 // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[4]xi32> into vector<[8]xi32> 13 // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> 14 // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][4] : vector<[4]xi32> into vector<[8]xi32> 15 %flat = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> 16 // CHECK-NEXT: return %[[res1]] : vector<[8]xi32> 17 return %flat : vector<[8]xi32> 18} 19 20// ----- 21 22// CHECK-LABEL: i32_1d_to_3d_last_dim_scalable 23// CHECK-SAME: %[[arg0:.*]]: vector<[8]xi32> 24func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x[4]xi32> { 25 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<2x1x[4]xi32> 26 // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[4]xi32> from vector<[8]xi32> 27 // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> 28 // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][4] : vector<[4]xi32> from vector<[8]xi32> 29 // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> 30 %unflat = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> 31 // CHECK-NEXT: return %[[res1]] : vector<2x1x[4]xi32> 32 return %unflat : vector<2x1x[4]xi32> 33} 34 35// ----- 36 37// CHECK-LABEL: i8_2d_to_1d_last_dim_scalable 38// CHECK-SAME: %[[arg0:.*]]: vector<4x[8]xi8> 39func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]xi8> { 40 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[32]xi8> 41 // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[8]xi8> from vector<4x[8]xi8> 42 // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[8]xi8> into vector<[32]xi8> 43 // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[8]xi8> from vector<4x[8]xi8> 44 // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][8] : vector<[8]xi8> into vector<[32]xi8> 45 // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][2] : vector<[8]xi8> from vector<4x[8]xi8> 46 // CHECK-NEXT: %[[res2:.*]] = vector.scalable.insert %[[subvec2]], %[[res1]][16] : vector<[8]xi8> into vector<[32]xi8> 47 // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][3] : vector<[8]xi8> from vector<4x[8]xi8> 48 // CHECK-NEXT: %[[res3:.*]] = vector.scalable.insert %[[subvec3]], %[[res2]][24] : vector<[8]xi8> into vector<[32]xi8> 49 %flat = vector.shape_cast %arg0 : vector<4x[8]xi8> to vector<[32]xi8> 50 // CHECK-NEXT: return %[[res3]] : vector<[32]xi8> 51 return %flat : vector<[32]xi8> 52} 53 54// ----- 55 56// CHECK-LABEL: i8_1d_to_2d_last_dim_scalable 57// CHECK-SAME: %[[arg0:.*]]: vector<[32]xi8> 58func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]xi8> { 59 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<4x[8]xi8> 60 // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[8]xi8> from vector<[32]xi8> 61 // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[8]xi8> into vector<4x[8]xi8> 62 // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][8] : vector<[8]xi8> from vector<[32]xi8> 63 // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[8]xi8> into vector<4x[8]xi8> 64 // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[arg0]][16] : vector<[8]xi8> from vector<[32]xi8> 65 // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[8]xi8> into vector<4x[8]xi8> 66 // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[arg0]][24] : vector<[8]xi8> from vector<[32]xi8> 67 // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[8]xi8> into vector<4x[8]xi8> 68 %unflat = vector.shape_cast %arg0 : vector<[32]xi8> to vector<4x[8]xi8> 69 // CHECK-NEXT: return %[[res3]] : vector<4x[8]xi8> 70 return %unflat : vector<4x[8]xi8> 71} 72 73// ----- 74 75// CHECK-LABEL: f32_permute_leading_non_scalable_dims 76// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32> 77func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> { 78 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x[4]xf32> 79 // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32> 80 // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> 81 // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32> 82 // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> 83 // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32> 84 // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> 85 // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32> 86 // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> 87 // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32> 88 // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> 89 // CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32> 90 // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> 91 %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32> 92 // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32> 93 return %res : vector<3x2x[4]xf32> 94} 95 96// ----- 97 98// CHECK-LABEL: f64_flatten_leading_non_scalable_dims 99// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf64> 100func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) -> vector<4x[2]xf64> 101{ 102 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4x[2]xf64> 103 // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf64> from vector<2x2x[2]xf64> 104 // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf64> into vector<4x[2]xf64> 105 // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf64> from vector<2x2x[2]xf64> 106 // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf64> into vector<4x[2]xf64> 107 // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][1, 0] : vector<[2]xf64> from vector<2x2x[2]xf64> 108 // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf64> into vector<4x[2]xf64> 109 // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf64> from vector<2x2x[2]xf64> 110 // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf64> into vector<4x[2]xf64> 111 %res = vector.shape_cast %arg0: vector<2x2x[2]xf64> to vector<4x[2]xf64> 112 // CHECK-NEXT: return %7 : vector<4x[2]xf64> 113 return %res : vector<4x[2]xf64> 114} 115 116// ----- 117 118// CHECK-LABEL: f32_reduce_trailing_scalable_dim 119// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32> 120func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32> 121{ 122 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<6x[2]xf32> 123 // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32> 124 // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32> 125 // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf32> into vector<6x[2]xf32> 126 // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32> 127 // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32> 128 // CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32> 129 // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32> 130 // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32> 131 // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32> 132 // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32> 133 // CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32> 134 // CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32> 135 // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32> 136 // CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32> 137 // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32> 138 %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32> 139 // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32> 140 return %res: vector<6x[2]xf32> 141} 142 143// ----- 144 145// CHECK-LABEL: f32_increase_trailing_scalable_dim 146// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32> 147func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32> 148{ 149 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<2x[4]xf32> 150 // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32> 151 // CHECK-NEXT: %[[resvec0:.*]] = vector.extract %[[cst]][0] : vector<[4]xf32> from vector<2x[4]xf32> 152 // CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[resvec0]][0] : vector<[2]xf32> into vector<[4]xf32> 153 // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[2]xf32> from vector<4x[2]xf32> 154 // CHECK-NEXT: %[[resvec2:.*]] = vector.scalable.insert %[[subvec1]], %[[resvec1]][2] : vector<[2]xf32> into vector<[4]xf32> 155 // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[resvec2]], %[[cst]] [0] : vector<[4]xf32> into vector<2x[4]xf32> 156 // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][2] : vector<[2]xf32> from vector<4x[2]xf32> 157 // CHECK-NEXT: %[[resvec3:.*]] = vector.extract %[[cst]][1] : vector<[4]xf32> from vector<2x[4]xf32> 158 // CHECK-NEXT: %[[resvec4:.*]] = vector.scalable.insert %[[subvec3]], %[[resvec3]][0] : vector<[2]xf32> into vector<[4]xf32> 159 // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][3] : vector<[2]xf32> from vector<4x[2]xf32> 160 // CHECK-NEXT: %[[resvec5:.*]] = vector.scalable.insert %[[subvec4]], %[[resvec4]][2] : vector<[2]xf32> into vector<[4]xf32> 161 // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[resvec5]], %[[res0]] [1] : vector<[4]xf32> into vector<2x[4]xf32> 162 %res = vector.shape_cast %arg0: vector<4x[2]xf32> to vector<2x[4]xf32> 163 // CHECK-NEXT: return %[[res1]] : vector<2x[4]xf32> 164 return %res: vector<2x[4]xf32> 165} 166 167// ----- 168 169/// The following shape_casts are not supported as the types cannot be 170/// represented in LLVM (and likely won't be supported soon), and currently 171/// there's no ops that could do the extracts/inserts required. 172 173// ----- 174 175// CHECK-LABEL: cannot_cast_to_non_trailing_scalable_dim 176// CHECK-SAME: %[[arg0:.*]]: vector<[4]xf32> 177func.func @cannot_cast_to_non_trailing_scalable_dim(%arg0: vector<[4]xf32>) -> vector<[2]x2xf32> { 178 // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[4]xf32> to vector<[2]x2xf32> 179 %res = vector.shape_cast %arg0 : vector<[4]xf32> to vector<[2]x2xf32> 180 // CHECK-NEXT: return %[[res]] : vector<[2]x2xf32> 181 return %res: vector<[2]x2xf32> 182} 183 184// ----- 185 186// CHECK-LABEL: cannot_shape_cast_from_non_trailing_scalable_dim 187// CHECK-SAME: %[[arg0:.*]]: vector<[2]x2xf32> 188func.func @cannot_shape_cast_from_non_trailing_scalable_dim(%arg0: vector<[2]x2xf32>) -> vector<[4]xf32> { 189 // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[2]x2xf32> to vector<[4]xf32> 190 %res = vector.shape_cast %arg0 : vector<[2]x2xf32> to vector<[4]xf32> 191 // CHECK-NEXT: return %[[res]] : vector<[4]xf32> 192 return %res: vector<[4]xf32> 193} 194 195// ----- 196 197// CHECK-LABEL: cannot_shape_cast_more_than_one_scalable_dim 198// CHECK-SAME: %[[arg0:.*]]: vector<[4]x[4]xf32> 199func.func @cannot_shape_cast_more_than_one_scalable_dim(%arg0: vector<[4]x[4]xf32>) -> vector<2x[2]x[4]xf32> { 200 // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[4]x[4]xf32> to vector<2x[2]x[4]xf32> 201 %res = vector.shape_cast %arg0 : vector<[4]x[4]xf32> to vector<2x[2]x[4]xf32> 202 // CHECK-NEXT: return %[[res]] : vector<2x[2]x[4]xf32> 203 return %res: vector<2x[2]x[4]xf32> 204} 205 206module attributes {transform.with_named_sequence} { 207 transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { 208 %f = transform.structured.match ops{["func.func"]} in %module_op 209 : (!transform.any_op) -> !transform.any_op 210 211 transform.apply_patterns to %f { 212 transform.apply_patterns.vector.lower_shape_cast 213 } : !transform.any_op 214 transform.yield 215 } 216} 217