1// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s 2 3//===----------------------------------------------------------------------===// 4// vector.transfer_read 5//===----------------------------------------------------------------------===// 6 7// CHECK-LABEL: @transfer_read_2d__bad_type 8// CHECK-NOT: arm_sme.tile_load 9// CHECK: vector.transfer_read 10func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) { 11 %c0 = arith.constant 0 : index 12 %pad = arith.constant 0.0 : f64 13 %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64> 14 "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> () 15 return 16} 17 18// ----- 19 20// CHECK-LABEL: @transfer_read_2d__non_memref_type 21// CHECK-NOT: arm_sme.tile_load 22// CHECK: vector.transfer_read 23func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) { 24 %c0 = arith.constant 0 : index 25 %pad = arith.constant 0.0 : f64 26 %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64> 27 "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () 28 return 29} 30 31// ----- 32 33// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank 34// CHECK-NOT: arm_sme.tile_load 35// CHECK: vector.transfer_read 36func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) { 37 %c0 = arith.constant 0 : index 38 %pad = arith.constant 0.0 : f64 39 %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64> 40 "prevent.dce"(%0) : (vector<[2]xf64>) -> () 41 return 42} 43 44// ----- 45 46// CHECK-LABEL: @transfer_read_2d__non_transpose 47// CHECK-NOT: arm_sme.tile_load 48// CHECK: vector.transfer_read 49func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) { 50 %c0 = arith.constant 0 : index 51 %pad = arith.constant 0.0 : f64 52 %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64> 53 "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () 54 return 55} 56 57// ----- 58 59// CHECK-LABEL: @transfer_read_2d__out_of_bounds 60// CHECK-NOT: arm_sme.tile_load 61// CHECK: vector.transfer_read 62func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) { 63 %c0 = arith.constant 0 : index 64 %pad = arith.constant 0.0 : f64 65 %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64> 66 "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () 67 return 68} 69 70//===----------------------------------------------------------------------===// 71// vector.transfer_write 72//===----------------------------------------------------------------------===// 73 74// ----- 75 76// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' 77// lowering only occurs for vector types of correct rank, shape, element size 78// and number of scalable dims. 79 80// CHECK-LABEL: @transfer_write_2d_zero__bad_type 81// CHECK: vector.transfer_write 82// CHECK-NOT: arm_sme.intr.zero 83func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) { 84 %c0 = arith.constant 0 : index 85 %cst = arith.constant dense<0> : vector<[16]x[16]xi4> 86 vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4> 87 return 88} 89 90// ----- 91 92// CHECK-LABEL: @transfer_write_2d_zero__bad_shape 93// CHECK: vector.transfer_write 94// CHECK-NOT: arm_sme.tile_store 95func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) { 96 %c0 = arith.constant 0 : index 97 %cst = arith.constant dense<0> : vector<[8]x[8]xi8> 98 vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8> 99 return 100} 101 102// ----- 103 104// CHECK-LABEL: @transfer_write_2d_zero__bad_rank 105// CHECK: vector.transfer_write 106// CHECK-NOT: arm_sme.tile_store 107func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) { 108 %c0 = arith.constant 0 : index 109 %cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8> 110 vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8> 111 return 112} 113 114// ----- 115 116// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type 117// CHECK: vector.transfer_write 118// CHECK-NOT: arm_sme.tile_store 119func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> { 120 %c0 = arith.constant 0 : index 121 %cst = arith.constant dense<0> : vector<[16]x[16]xi8> 122 %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8> 123 return %0 : tensor<?x?xi8> 124} 125 126// ----- 127 128// CHECK-LABEL: @transfer_write_2d__fixed 129// CHECK: vector.transfer_write 130// CHECK-NOT: arm_sme.tile_store 131func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) { 132 %c0 = arith.constant 0 : index 133 vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8> 134 return 135} 136 137// ----- 138 139// CHECK-LABEL: @transfer_write_2d__out_of_bounds 140// CHECK: vector.transfer_write 141// CHECK-NOT: arm_sme.tile_store 142func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) { 143 %c0 = arith.constant 0 : index 144 vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32> 145 return 146} 147 148// ----- 149 150// CHECK-LABEL: func.func @transfer_write_slice_unsupported_permutation 151// CHECK-NOT: arm_sme.store_tile_slice 152func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) { 153 %c0 = arith.constant 0 : index 154 %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> 155 vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32> 156 return 157} 158 159 160//===----------------------------------------------------------------------===// 161// vector.outerproduct 162//===----------------------------------------------------------------------===// 163 164// ----- 165 166// CHECK-LABEL: @vector_outerproduct_unsupported_axpy 167// CHECK-NOT: arm_sme.outerproduct 168// CHECK: vector.outerproduct 169func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> { 170 %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64 171 return %0 : vector<[2]xf64> 172} 173 174// ----- 175 176// CHECK-LABEL: @vector_outerproduct_unsupported_kind 177// CHECK-NOT: arm_sme.outerproduct 178// CHECK: vector.outerproduct 179func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) { 180 %acc = arm_sme.get_tile : vector<[2]x[2]xf64> 181 %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64> 182 "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () 183} 184 185// ----- 186 187// CHECK-LABEL: @vector_outerproduct_unknown_mask 188// CHECK-NOT: arm_sme.outerproduct 189// CHECK: vector.outerproduct 190func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) { 191 %acc = arm_sme.get_tile : vector<[4]x[4]xf32> 192 %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> 193 "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () 194} 195 196// ----- 197 198/// Not SVE predicate-sized. 199 200// CHECK-LABEL: @negative_vector_extract_to_psel_0 201func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1> 202{ 203 // CHECK-NOT: arm_sve.psel 204 %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1> 205 %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1> 206 return %slice : vector<[32]xi1> 207} 208 209// ----- 210 211/// Source not 2-D scalable mask. 212 213// CHECK-LABEL: @negative_vector_extract_to_psel_1 214func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1> 215{ 216 // CHECK-NOT: arm_sve.psel 217 %mask = vector.create_mask %a, %b : vector<4x[8]xi1> 218 %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1> 219 return %slice : vector<[8]xi1> 220} 221 222// ----- 223 224/// Source not vector.create_mask. 225 226// CHECK-LABEL: @negative_vector_extract_to_psel_2 227func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1> 228{ 229 // CHECK-NOT: arm_sve.psel 230 %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> 231 return %slice : vector<[8]xi1> 232} 233 234// ----- 235 236/// Not psel-like extract. 237 238// CHECK-LABEL: @negative_vector_extract_to_psel_3 239func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1 240{ 241 // CHECK-NOT: arm_sve.psel 242 %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> 243 %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1> 244 return %el : i1 245} 246