1// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf,expand-strided-metadata,lower-affine,convert-arith-to-llvm,convert-scf-to-cf),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" | \ 2// RUN: mlir-runner -e entry -entry-point-result=void \ 3// RUN: -shared-libs=%mlir_c_runner_utils | \ 4// RUN: FileCheck %s 5 6// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true},expand-strided-metadata,lower-affine,convert-arith-to-llvm,convert-scf-to-cf),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" | \ 7// RUN: mlir-runner -e entry -entry-point-result=void \ 8// RUN: -shared-libs=%mlir_c_runner_utils | \ 9// RUN: FileCheck %s 10 11// Test for special cases of 1D vector transfer ops. 12 13memref.global "private" @gv : memref<5x6xf32> = 14 dense<[[0. , 1. , 2. , 3. , 4. , 5. ], 15 [10., 11., 12., 13., 14., 15.], 16 [20., 21., 22., 23., 24., 25.], 17 [30., 31., 32., 33., 34., 35.], 18 [40., 41., 42., 43., 44., 45.]]> 19 20// Non-contiguous, strided load. 21func.func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) { 22 %fm42 = arith.constant -42.0: f32 23 %f = vector.transfer_read %A[%base1, %base2], %fm42 24 {permutation_map = affine_map<(d0, d1) -> (d0)>} 25 : memref<?x?xf32>, vector<9xf32> 26 vector.print %f: vector<9xf32> 27 return 28} 29 30// Vector load with unit stride only on last dim. 31func.func @transfer_read_1d_unit_stride(%A : memref<?x?xf32>) { 32 %c0 = arith.constant 0 : index 33 %c1 = arith.constant 1 : index 34 %c2 = arith.constant 2 : index 35 %c3 = arith.constant 3 : index 36 %c4 = arith.constant 4 : index 37 %c5 = arith.constant 5 : index 38 %c6 = arith.constant 6 : index 39 %fm42 = arith.constant -42.0: f32 40 scf.for %arg2 = %c1 to %c5 step %c2 { 41 scf.for %arg3 = %c0 to %c6 step %c3 { 42 %0 = memref.subview %A[%arg2, %arg3] [1, 2] [1, 1] 43 : memref<?x?xf32> to memref<1x2xf32, strided<[?, 1], offset: ?>> 44 %1 = vector.transfer_read %0[%c0, %c0], %fm42 {in_bounds=[true]} 45 : memref<1x2xf32, strided<[?, 1], offset: ?>>, vector<2xf32> 46 vector.print %1 : vector<2xf32> 47 } 48 } 49 return 50} 51 52// Vector load with unit stride only on last dim. Strides are not static, so 53// codegen must go through VectorToSCF 1D lowering. 54func.func @transfer_read_1d_non_static_unit_stride(%A : memref<?x?xf32>) { 55 %c1 = arith.constant 1 : index 56 %c2 = arith.constant 2 : index 57 %c4 = arith.constant 4 : index 58 %c6 = arith.constant 6 : index 59 %fm42 = arith.constant -42.0: f32 60 %1 = memref.reinterpret_cast %A to offset: [%c6], sizes: [%c4, %c6], strides: [%c6, %c1] 61 : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 62 %2 = vector.transfer_read %1[%c2, %c1], %fm42 {in_bounds=[true]} 63 : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32> 64 vector.print %2 : vector<4xf32> 65 return 66} 67 68// Vector load where last dim has non-unit stride. 69func.func @transfer_read_1d_non_unit_stride(%A : memref<?x?xf32>) { 70 %B = memref.reinterpret_cast %A to offset: [0], sizes: [4, 3], strides: [6, 2] 71 : memref<?x?xf32> to memref<4x3xf32, strided<[6, 2]>> 72 %c1 = arith.constant 1 : index 73 %c2 = arith.constant 2 : index 74 %fm42 = arith.constant -42.0: f32 75 %vec = vector.transfer_read %B[%c2, %c1], %fm42 {in_bounds=[false]} : memref<4x3xf32, strided<[6, 2]>>, vector<3xf32> 76 vector.print %vec : vector<3xf32> 77 return 78} 79 80// Broadcast. 81func.func @transfer_read_1d_broadcast( 82 %A : memref<?x?xf32>, %base1 : index, %base2 : index) { 83 %fm42 = arith.constant -42.0: f32 84 %f = vector.transfer_read %A[%base1, %base2], %fm42 85 {permutation_map = affine_map<(d0, d1) -> (0)>} 86 : memref<?x?xf32>, vector<9xf32> 87 vector.print %f: vector<9xf32> 88 return 89} 90 91// Non-contiguous, strided load. 92func.func @transfer_read_1d_in_bounds( 93 %A : memref<?x?xf32>, %base1 : index, %base2 : index) { 94 %fm42 = arith.constant -42.0: f32 95 %f = vector.transfer_read %A[%base1, %base2], %fm42 96 {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 97 : memref<?x?xf32>, vector<3xf32> 98 vector.print %f: vector<3xf32> 99 return 100} 101 102// Non-contiguous, strided load. 103func.func @transfer_read_1d_mask( 104 %A : memref<?x?xf32>, %base1 : index, %base2 : index) { 105 %fm42 = arith.constant -42.0: f32 106 %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1> 107 %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask 108 {permutation_map = affine_map<(d0, d1) -> (d0)>} 109 : memref<?x?xf32>, vector<9xf32> 110 vector.print %f: vector<9xf32> 111 return 112} 113 114// Non-contiguous, out-of-bounds, strided load. 115func.func @transfer_read_1d_out_of_bounds( 116 %A : memref<?x?xf32>, %base1 : index, %base2 : index) { 117 %fm42 = arith.constant -42.0: f32 118 %f = vector.transfer_read %A[%base1, %base2], %fm42 119 {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]} 120 : memref<?x?xf32>, vector<3xf32> 121 vector.print %f: vector<3xf32> 122 return 123} 124 125// Non-contiguous, strided load. 126func.func @transfer_read_1d_mask_in_bounds( 127 %A : memref<?x?xf32>, %base1 : index, %base2 : index) { 128 %fm42 = arith.constant -42.0: f32 129 %mask = arith.constant dense<[1, 0, 1]> : vector<3xi1> 130 %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask 131 {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 132 : memref<?x?xf32>, vector<3xf32> 133 vector.print %f: vector<3xf32> 134 return 135} 136 137// Non-contiguous, strided store. 138func.func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) { 139 %fn1 = arith.constant -1.0 : f32 140 %vf0 = vector.splat %fn1 : vector<7xf32> 141 vector.transfer_write %vf0, %A[%base1, %base2] 142 {permutation_map = affine_map<(d0, d1) -> (d0)>} 143 : vector<7xf32>, memref<?x?xf32> 144 return 145} 146 147// Non-contiguous, strided store. 148func.func @transfer_write_1d_mask(%A : memref<?x?xf32>, %base1 : index, %base2 : index) { 149 %fn1 = arith.constant -2.0 : f32 150 %vf0 = vector.splat %fn1 : vector<7xf32> 151 %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1> 152 vector.transfer_write %vf0, %A[%base1, %base2], %mask 153 {permutation_map = affine_map<(d0, d1) -> (d0)>} 154 : vector<7xf32>, memref<?x?xf32> 155 return 156} 157 158func.func @entry() { 159 %c0 = arith.constant 0: index 160 %c1 = arith.constant 1: index 161 %c2 = arith.constant 2: index 162 %c3 = arith.constant 3: index 163 %c10 = arith.constant 10 : index 164 %0 = memref.get_global @gv : memref<5x6xf32> 165 %A = memref.cast %0 : memref<5x6xf32> to memref<?x?xf32> 166 167 // 1. Read from 2D memref on first dimension. Cannot be lowered to an LLVM 168 // vector load. Instead, generates scalar loads. 169 call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> () 170 // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 ) 171 172 // 2.a. Read 1D vector from 2D memref with non-unit stride on first dim. 173 call @transfer_read_1d_unit_stride(%A) : (memref<?x?xf32>) -> () 174 // CHECK: ( 10, 11 ) 175 // CHECK: ( 13, 14 ) 176 // CHECK: ( 30, 31 ) 177 // CHECK: ( 33, 34 ) 178 179 // 2.b. Read 1D vector from 2D memref with non-unit stride on first dim. 180 // Strides are non-static. 181 call @transfer_read_1d_non_static_unit_stride(%A) : (memref<?x?xf32>) -> () 182 // CHECK: ( 31, 32, 33, 34 ) 183 184 // 2.c. Read 1D vector from 2D memref with out-of-bounds transfer dim starting 185 // point. 186 call @transfer_read_1d_out_of_bounds(%A, %c10, %c1) 187 : (memref<?x?xf32>, index, index) -> () 188 // CHECK: ( -42, -42, -42 ) 189 190 // 3. Read 1D vector from 2D memref with non-unit stride on second dim. 191 call @transfer_read_1d_non_unit_stride(%A) : (memref<?x?xf32>) -> () 192 // CHECK: ( 22, 24, -42 ) 193 194 // 4. Write to 2D memref on first dimension. Cannot be lowered to an LLVM 195 // vector store. Instead, generates scalar stores. 196 call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> () 197 198 // 5. (Same as 1. To check if 4 works correctly.) 199 call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> () 200 // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) 201 202 // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector. 203 // Generates a loop with vector.insertelement. 204 call @transfer_read_1d_broadcast(%A, %c1, %c2) 205 : (memref<?x?xf32>, index, index) -> () 206 // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) 207 208 // 7. Read from 2D memref on first dimension. Accesses are in-bounds, so no 209 // if-check is generated inside the generated loop. 210 call @transfer_read_1d_in_bounds(%A, %c1, %c2) 211 : (memref<?x?xf32>, index, index) -> () 212 // CHECK: ( 12, 22, -1 ) 213 214 // 8. Optional mask attribute is specified and, in addition, there may be 215 // out-of-bounds accesses. 216 call @transfer_read_1d_mask(%A, %c1, %c2) 217 : (memref<?x?xf32>, index, index) -> () 218 // CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 ) 219 220 // 9. Same as 8, but accesses are in-bounds. 221 call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2) 222 : (memref<?x?xf32>, index, index) -> () 223 // CHECK: ( 12, -42, -1 ) 224 225 // 10. Write to 2D memref on first dimension with a mask. 226 call @transfer_write_1d_mask(%A, %c1, %c0) 227 : (memref<?x?xf32>, index, index) -> () 228 229 // 11. (Same as 1. To check if 10 works correctly.) 230 call @transfer_read_1d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> () 231 // CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 ) 232 233 return 234} 235