1// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s 2 3// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc( 4// CHECK-SAME: %[[LHS:.*]]: vector<[8]xf32>, 5// CHECK-SAME: %[[RHS:.*]]: vector<[8]xf32>) 6// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>) 7func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32> 8{ 9 // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32> 10 // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32> 11 // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32> 12 // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32> 13 // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> 14 // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32> 15 // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> 16 // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32> 17 // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32> 18 %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32> 19 return %0 : vector<[8]x[8]xf32> 20} 21 22// ----- 23 24// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc( 25// CHECK-SAME: %[[LHS:.*]]: vector<[4]xf32>, 26// CHECK-SAME: %[[RHS:.*]]: vector<[16]xf32>, 27// CHECK-SAME: %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>, 28// CHECK-SAME: %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>, 29// CHECK-SAME: %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>, 30// CHECK-SAME: %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>) 31// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>) 32func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: vector<[4]x[16]xf32>) -> vector<[4]x[16]xf32> 33{ 34 // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[4]xf32> 35 // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[16]xf32> 36 // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[16]xf32> 37 // CHECK-DAG: %[[RHS_2:.*]] = vector.scalable.extract %[[RHS]][8] : vector<[4]xf32> from vector<[16]xf32> 38 // CHECK-DAG: %[[RHS_3:.*]] = vector.scalable.extract %[[RHS]][12] : vector<[4]xf32> from vector<[16]xf32> 39 // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]], %[[ACC_0]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> 40 // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]], %[[ACC_1]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> 41 // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_2]], %[[ACC_2]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> 42 // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_3]], %[[ACC_3]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> 43 // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32> 44 %0 = vector.outerproduct %lhs, %rhs, %acc : vector<[4]xf32>, vector<[16]xf32> 45 return %0 : vector<[4]x[16]xf32> 46} 47 48// ----- 49 50// CHECK-LABEL: @outerproduct_f32_masked_scalable_16x4( 51// CHECK-SAME: %[[LHS:.*]]: vector<[16]xf32>, 52// CHECK-SAME: %[[RHS:.*]]: vector<[4]xf32>, 53// CHECK-SAME: %[[LHS_DIM:.*]]: index, 54// CHECK-SAME: %[[RHS_DIM:.*]]: index) 55// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>) 56func.func @outerproduct_f32_masked_scalable_16x4(%lhs: vector<[16]xf32>, %rhs: vector<[4]xf32>, %lhs_dim: index, %rhs_dim: index) -> vector<[16]x[4]xf32> 57{ 58 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 59 // CHECK-DAG: %[[MINUS_4:.*]] = arith.constant -4 : index 60 // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index 61 // CHECK-DAG: %[[MINUS_12:.*]] = arith.constant -12 : index 62 // CHECK-DAG: %[[MINUS_4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_4]] : index 63 // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index 64 // CHECK-DAG: %[[MINUS_12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_12]] : index 65 // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[16]xf32> 66 // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[16]xf32> 67 // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][8] : vector<[4]xf32> from vector<[16]xf32> 68 // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][12] : vector<[4]xf32> from vector<[16]xf32> 69 // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[4]xf32> 70 // CHECK-DAG: %[[MASK_0:.*]] = vector.create_mask %[[LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1> 71 // CHECK-DAG: %[[TILE_1_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_4_VSCALE]] : index 72 // CHECK-DAG: %[[MASK_1:.*]] = vector.create_mask %[[TILE_1_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1> 73 // CHECK-DAG: %[[TILE_2_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_8_VSCALE]] : index 74 // CHECK-DAG: %[[MASK_2:.*]] = vector.create_mask %[[TILE_2_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1> 75 // CHECK-DAG: %[[TILE_3_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_12_VSCALE]] : index 76 // CHECK-DAG: %[[MASK_3:.*]] = vector.create_mask %[[TILE_3_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1> 77 // CHECK-DAG: %[[RES_0:.*]] = vector.mask %[[MASK_0]] { vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> 78 // CHECK-DAG: %[[RES_1:.*]] = vector.mask %[[MASK_1]] { vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> 79 // CHECK-DAG: %[[RES_2:.*]] = vector.mask %[[MASK_2]] { vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> 80 // CHECK-DAG: %[[RES_3:.*]] = vector.mask %[[MASK_3]] { vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> 81 // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32> 82 %mask = vector.create_mask %lhs_dim, %rhs_dim : vector<[16]x[4]xi1> 83 %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs : vector<[16]xf32>, vector<[4]xf32> } : vector<[16]x[4]xi1> -> vector<[16]x[4]xf32> 84 return %0 : vector<[16]x[4]xf32> 85} 86 87// ----- 88 89/// This demonstrates a rectangular tiling that uses all f64 accumulators. 90 91// CHECK-LABEL: @outerproduct_f64_scalable_8x4_no_acc( 92// CHECK-SAME: %[[LHS:.*]]: vector<[8]xf64>, 93// CHECK-SAME: %[[RHS:.*]]: vector<[4]xf64>) 94// CHECK-SAME: -> (vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>) 95func.func @outerproduct_f64_scalable_8x4_no_acc(%lhs: vector<[8]xf64>, %rhs: vector<[4]xf64>) -> vector<[8]x[4]xf64> 96{ 97 // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[2]xf64> from vector<[8]xf64> 98 // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][2] : vector<[2]xf64> from vector<[8]xf64> 99 // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[2]xf64> from vector<[8]xf64> 100 // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][6] : vector<[2]xf64> from vector<[8]xf64> 101 // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[2]xf64> from vector<[4]xf64> 102 // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][2] : vector<[2]xf64> from vector<[4]xf64> 103 // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64> 104 // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64> 105 // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64> 106 // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64> 107 // CHECK-DAG: %[[RES_4:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64> 108 // CHECK-DAG: %[[RES_5:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64> 109 // CHECK-DAG: %[[RES_6:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64> 110 // CHECK-DAG: %[[RES_7:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64> 111 // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]], %[[RES_4]], %[[RES_5]], %[[RES_6]], %[[RES_7]] : vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64> 112 %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf64>, vector<[4]xf64> 113 return %0 : vector<[8]x[4]xf64> 114} 115 116// ----- 117 118// CHECK-LABEL: @transfer_read_f32_scalable_8x8( 119// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) 120// CHECK-SAME: -> (vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>) 121func.func @transfer_read_f32_scalable_8x8(%src: memref<?x?xi32>) -> vector<[8]x[8]xi32> 122{ 123 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 124 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 125 // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 126 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 127 // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index 128 // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32> 129 // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32> 130 // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32> 131 // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32> 132 // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32> 133 %c0 = arith.constant 0 : index 134 %pad = arith.constant 0 : i32 135 %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[8]x[8]xi32> 136 return %0 : vector<[8]x[8]xi32> 137} 138 139// ----- 140 141// CHECK-LABEL: @transfer_read_i16_scalable_8x16_masked( 142// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi16>, 143// CHECK-SAME: %[[DIM0:.*]]: index, 144// CHECK-SAME: %[[DIM1:.*]]: index) 145// CHECK-SAME: -> (vector<[8]x[8]xi16>, vector<[8]x[8]xi16>) 146func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0: index, %dim1: index) -> vector<[8]x[16]xi16> 147{ 148 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 149 // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 150 // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index 151 // CHECK-DAG: %[[C0_I16:.*]] = arith.constant 0 : i16 152 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 153 // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index 154 // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index 155 // CHECK-DAG: %[[RIGHT_DIM_1:.*]] = arith.addi %[[DIM1]], %[[MINUS_8_VSCALE]] : index 156 // CHECK-DAG: %[[LEFT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x[8]xi1> 157 // CHECK-DAG: %[[RIGHT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[RIGHT_DIM_1]] : vector<[8]x[8]xi1> 158 // CHECK-DAG: %[[LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I16]], %[[LEFT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16> 159 // CHECK-DAG: %[[RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[C0_I16]], %[[RIGHT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16> 160 // CHECK-NEXT: return %[[LEFT]], %[[RIGHT]] : vector<[8]x[8]xi16>, vector<[8]x[8]xi16> 161 %c0 = arith.constant 0 : index 162 %pad = arith.constant 0 : i16 163 %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[16]xi1> 164 %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[16]xi16> 165 return %0 : vector<[8]x[16]xi16> 166} 167 168// ----- 169 170// CHECK-LABEL: @transfer_write_f16_scalable_16x8( 171// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf16>, 172// CHECK-SAME: %[[TOP:.*]]: vector<[8]x[8]xf16>, 173// CHECK-SAME: %[[BOTTOM:.*]]: vector<[8]x[8]xf16>) 174func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>) 175{ 176 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 177 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 178 // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 179 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 180 // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index 181 // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] { 182 // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16> 183 // CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16> 184 // CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index 185 // CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16> 186 // CHECK-NEXT: vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16> 187 // CHECK-NEXT: } 188 // CHECK-NEXT: return 189 %c0 = arith.constant 0 : index 190 vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16> 191 return 192} 193 194// ----- 195 196/// This is already a legal type. It should be ignored. 197 198// CHECK-LABEL: @transfer_write_i8_scalable_16x16_masked 199func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec: vector<[16]x[16]xi8>, %dim0: index, %dim1: index) 200{ 201 // CHECK: vector.transfer_write {{.*}} : vector<[16]x[16]xi8>, memref<?x?xi8> 202 %c0 = arith.constant 0 : index 203 %mask = vector.create_mask %dim0, %dim0 : vector<[16]x[16]xi1> 204 vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> 205 return 206} 207 208// ----- 209 210// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked( 211// CHECK-SAME: %[[DEST:[a-z0-9]+]]: memref<?x?xf32>, 212// CHECK-SAME: %[[DIM_0:[a-z0-9]+]]: index, 213// CHECK-SAME: %[[DIM_1:[a-z0-9]+]]: index, 214// CHECK-SAME: %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>, 215// CHECK-SAME: %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>, 216// CHECK-SAME: %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>, 217// CHECK-SAME: %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>) 218func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>) 219{ 220 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 221 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 222 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 223 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 224 // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index 225 // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1> 226 // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { 227 // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1> 228 // CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1> 229 // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> 230 // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32> 231 // CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1> 232 // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> 233 // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32> 234 // CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index 235 // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1> 236 // CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1> 237 // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> 238 // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32> 239 // CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1> 240 // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> 241 // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32> 242 // CHECK-NEXT: } 243 %c0 = arith.constant 0 : index 244 %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1> 245 vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32> 246 return 247} 248 249// ----- 250 251// Tensor semantics are not supported for the store loop lowering. 252 253// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor 254// CHECK-NOT: scf.for 255func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %vec: vector<[8]x[8]xf32>) 256{ 257 %c0 = arith.constant 0 : index 258 vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32> 259 return 260} 261 262// ----- 263 264#transpose = affine_map<(d0, d1) -> (d1, d0)> 265 266// Transposes are not supported for the store loop lowering. 267 268// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor 269// CHECK-NOT: scf.for 270func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>) 271{ 272 %c0 = arith.constant 0 : index 273 %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1> 274 vector.transfer_write %vec, %dest[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32> 275 return 276} 277 278// ----- 279 280// Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering. 281 282// CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32 283// CHECK-NOT: scf.for 284func.func @negative_transfer_write_f32_scalable_32x32(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[32]x[32]xf32>) 285{ 286 %c0 = arith.constant 0 : index 287 %mask = vector.create_mask %dim0, %dim1 : vector<[32]x[32]xi1> 288 vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[32]x[32]xf32>, memref<?x?xf32> 289 return 290} 291 292// ----- 293 294#transpose = affine_map<(d0, d1) -> (d1, d0)> 295 296// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read( 297// CHECK-SAME: %[[SRC:.*]]: memref<?x?xf32>, 298// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>) 299func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>) 300{ 301 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 302 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 303 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 304 // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 305 // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index 306 // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 307 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 308 // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index 309 // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index 310 // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index 311 // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32> 312 // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32> 313 // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32> 314 // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32> 315 // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { 316 // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> 317 // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32> 318 // CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index 319 // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> 320 // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32> 321 // CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index 322 // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> 323 // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32> 324 // CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index 325 // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> 326 // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32> 327 // CHECK-NEXT: } 328 // CHECK-NEXT: return 329 %c0 = arith.constant 0 : index 330 %pad = arith.constant 0.0 : f32 331 %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32> 332 vector.transfer_write %0, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32> 333 return 334} 335 336// ----- 337 338#transpose = affine_map<(d0, d1) -> (d1, d0)> 339 340// CHECK-LABEL: @transpose_f32_scalable_4x16_via_write( 341// CHECK-SAME: %[[SRC:.*]]: memref<?x?xf32>, 342// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>) 343func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: memref<?x?xf32>) 344{ 345 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 346 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 347 // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 348 // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index 349 // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 350 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 351 // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index 352 // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index 353 // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index 354 // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32> 355 // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32> 356 // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32> 357 // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32> 358 // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32> 359 // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32> 360 // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32> 361 // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32> 362 // CHECK-NEXT: return 363 %c0 = arith.constant 0 : index 364 %pad = arith.constant 0.0 : f32 365 %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[16]xf32> 366 vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32> 367 return 368} 369 370// ----- 371 372// CHECK-LABEL: @extract_from_vector_create_mask_non_constant_dim( 373// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index, 374// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index, 375// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index) 376func.func @extract_from_vector_create_mask_non_constant_dim(%dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> { 377 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 378 // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 379 // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi sgt, %[[DIM0]], %[[C2]] : index 380 // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index 381 // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1> 382 // CHECK-NEXT: return %[[EXTRACT]] 383 %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1> 384 %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> 385 return %extract : vector<[4]x[4]xi1> 386} 387 388// ----- 389 390// CHECK-LABEL: @non_constant_extract_from_vector_create_mask_non_constant( 391// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: index, 392// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index, 393// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index, 394// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index) 395func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: index, %dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> { 396 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 397 // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi slt, %[[INDEX]], %[[DIM0]] : index 398 // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index 399 // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1> 400 // CHECK-NEXT: return %[[EXTRACT]] 401 %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1> 402 %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> 403 return %extract : vector<[4]x[4]xi1> 404} 405 406// ----- 407 408// CHECK-LABEL: @lift_illegal_transpose_to_memory( 409// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index, 410// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index, 411// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>) 412func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> { 413 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 414 // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 415 // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 416 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 417 // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index 418 // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>> 419 // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 420 // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>> 421 // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32> 422 // CHECK-NEXT: return %[[LEGAL_READ]] 423 %pad = arith.constant 0.0 : f32 424 %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32> 425 %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32> 426 return %legalType : vector<4x[8]xf32> 427} 428 429// ----- 430 431// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_mask( 432// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index, 433// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index, 434// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32> 435func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> { 436 // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]] 437 // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] 438 // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] 439 // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1> 440 // CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]] 441 // CHECK-SAME: %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32> 442 // CHECK-NEXT: return %[[LEGAL_READ]] 443 %pad = arith.constant 0.0 : f32 444 %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1> 445 %illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32> 446 %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32> 447 return %legalType : vector<4x[8]xf32> 448} 449 450// ----- 451 452// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop( 453// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8> 454func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> { 455 // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]] 456 // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] 457 // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] 458 // CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]] 459 // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32> 460 // CHECK-NEXT: return %[[EXT_TYPE]] 461 %pad = arith.constant 0 : i8 462 %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8> 463 %extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32> 464 %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32> 465 return %legalType : vector<4x[8]xi32> 466} 467 468// ----- 469 470// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_in_bounds_attr 471func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> { 472 // CHECK: vector.transfer_read 473 // CHECK-SAME: in_bounds = [true, false] 474 // CHECK-NOT: in_bounds = [false, true] 475 %pad = arith.constant 0.0 : f32 476 %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[8]x4xf32> 477 %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32> 478 return %legalType : vector<4x[8]xf32> 479} 480 481// ----- 482 483// The pass should do nothing (and not crash). 484// CHECK-LABEL: @illegal_transpose_no_defining_source_op 485func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> 486{ 487 // CHECK: vector.transpose 488 %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> 489 return %0 : vector<1x[4]xf32> 490} 491 492// ----- 493 494// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d( 495// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>) 496func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> { 497 // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> 498 %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32> 499 return %0 : vector<1x[4]xf32> 500} 501 502// ----- 503 504// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d( 505// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>) 506func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> { 507 // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> 508 // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32> 509 %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32> 510 return %0 : vector<[4]xf32> 511} 512 513// ----- 514 515// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory 516func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> { 517 // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32> 518 // CHECK-NOT: vector.shape_cast 519 %pad = arith.constant 0.0 : f32 520 %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32> 521 %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32> 522 return %cast : vector<1x[4]xf32> 523} 524 525// ----- 526 527// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory 528func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> { 529 // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32> 530 // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32> 531 %pad = arith.constant 0.0 : f32 532 %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32> 533 %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32> 534 return %cast : vector<[4]xf32> 535} 536 537// ----- 538 539// CHECK-LABEL: @multi_tile_splat 540func.func @multi_tile_splat() -> vector<[8]x[8]xi32> 541{ 542 // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32> 543 // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32> 544 %0 = arith.constant dense<42> : vector<[8]x[8]xi32> 545 return %0 : vector<[8]x[8]xi32> 546} 547 548// ----- 549 550// CHECK: #[[$TRANSPOSE_MAP_0:.*]] = affine_map<(d0, d1) -> (d1, d0)> 551 552// CHECK-LABEL: @transpose_store_scalable_via_za( 553// CHECK-SAME: %[[VEC:.*]]: vector<2x[4]xf32> 554// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>, 555// CHECK-SAME: %[[I:.*]]: index, 556// CHECK-SAME: %[[J:.*]]: index) 557func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) { 558 // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 559 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 560 // CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32> 561 // CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32> 562 // CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32> 563 // CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32> 564 // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32> 565 // CHECK-NEXT: %[[VSCALE:.*]] = vector.vscale 566 // CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index 567 // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1> 568 // CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32> 569 %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32> 570 vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32> 571 return 572} 573 574// ----- 575 576// CHECK-LABEL: @transpose_store_scalable_via_za_masked( 577// CHECK-SAME: %[[A:[a-z0-9]+]]: index, 578// CHECK-SAME: %[[B:[a-z0-9]+]]: index) 579func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) { 580 // CHECK: %[[C2:.*]] = arith.constant 2 : index 581 // CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]] 582 // CHECK: %[[MASK:.*]] = vector.create_mask %[[A]], %[[MIN]] : vector<[4]x[4]xi1> 583 // CHECK: vector.transfer_write {{.*}} %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32> 584 %c0 = arith.constant 0 : index 585 %mask = vector.create_mask %a, %b : vector<[4]x2xi1> 586 %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32> 587 vector.transfer_write %tr, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32> 588 return 589} 590 591// ----- 592 593// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile( 594// CHECK-SAME: %[[VEC:.*]]: vector<8x[4]xf32> 595// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>, 596// CHECK-SAME: %[[I:.*]]: index, 597// CHECK-SAME: %[[J:.*]]: index) 598func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) { 599 // CHECK: %[[C4:.*]] = arith.constant 4 : index 600 601 // <skip 3x other extract+insert chain> 602 // CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32> 603 // CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32> 604 // CHECK: %[[VSCALE:.*]] = vector.vscale 605 // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index 606 // CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1> 607 // CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32> 608 609 // <skip 3x other extract+insert chain> 610 // CHECK: %[[V7:.*]] = vector.extract %arg0[7] : vector<[4]xf32> from vector<8x[4]xf32> 611 // CHECK: %[[TILE_1:.*]] = vector.insert %[[V7]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32> 612 // CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[C4]] : index 613 // CHECK: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[I]], %[[J_OFFSET]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32> 614 %tr = vector.transpose %vec, [1, 0] : vector<8x[4]xf32> to vector<[4]x8xf32> 615 vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>, memref<?x?xf32> 616 return 617} 618 619// ----- 620 621// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile_wide 622func.func @transpose_store_scalable_via_za_multi_tile_wide(%vec: vector<2x[8]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) { 623 // <check extracts from lower 4 x vscale of %vec> 624 // CHECK: vector.scalable.extract 625 // CHECK: %[[ROW_2_LOWER:.*]] = vector.scalable.extract %{{.*}}[0] : vector<[4]xf32> from vector<[8]xf32> 626 // CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_LOWER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32> 627 // CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I:.[a-z0-9]+]], %[[J:[a-z0-9]+]]] 628 629 // <check extracts from upper 4 x vscale of %vec> 630 // CHECK: vector.scalable.extract 631 // CHECK: %[[ROW_2_UPPER:.*]] = vector.scalable.extract %{{.*}}[4] : vector<[4]xf32> from vector<[8]xf32> 632 // CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_UPPER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32> 633 // CHECK: %[[I_OFFSET:.*]] = arith.addi %c4_vscale, %[[I]] : index 634 // CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I_OFFSET]], %[[J]]] 635 %tr = vector.transpose %vec, [1, 0] : vector<2x[8]xf32> to vector<[8]x2xf32> 636 vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[8]x2xf32>, memref<?x?xf32> 637 return 638} 639 640// ----- 641 642// CHECK-LABEL: @negative_transpose_store_scalable_via_za__bad_source_shape 643// CHECK-NOT: arm_sme.get_tile 644func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vector<2x[7]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) { 645 %tr = vector.transpose %vec, [1, 0] : vector<2x[7]xf32> to vector<[7]x2xf32> 646 vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32> 647 return 648} 649 650// ----- 651 652// From: https://github.com/llvm/llvm-project/issues/118449. 653// Check -arm-sme-vector-legalization does not crash when it encounters a `vector.mask` that 654// does not contain a maskable op. 655func.func @vector_mask_without_maskable_op(%mask: vector<16x2xi1>, %vec: vector<16x16xf32>) -> vector<16x16xf32> { 656 %0 = vector.mask %mask { vector.yield %vec : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32> 657 return %0 : vector<16x16xf32> 658} 659