1// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck %s 2 3// This file tests whether the memref type having non-trivial map layouts 4// are normalized to trivial (identity) layouts. 5 6// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)> 7// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)> 8// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)> 9 10// CHECK-LABEL: func @permute() 11func.func @permute() { 12 %A = memref.alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> 13 affine.for %i = 0 to 64 { 14 affine.for %j = 0 to 256 { 15 %1 = affine.load %A[%i, %j] : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> 16 "prevent.dce"(%1) : (f32) -> () 17 } 18 } 19 memref.dealloc %A : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> 20 return 21} 22// The old memref alloc should disappear. 23// CHECK-NOT: memref<64x256xf32> 24// CHECK: [[MEM:%[0-9a-zA-Z_]+]] = memref.alloc() : memref<256x64xf32> 25// CHECK-NEXT: affine.for %[[I:arg[0-9a-zA-Z_]+]] = 0 to 64 { 26// CHECK-NEXT: affine.for %[[J:arg[0-9a-zA-Z_]+]] = 0 to 256 { 27// CHECK-NEXT: affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32> 28// CHECK-NEXT: "prevent.dce" 29// CHECK-NEXT: } 30// CHECK-NEXT: } 31// CHECK-NEXT: memref.dealloc [[MEM]] 32// CHECK-NEXT: return 33 34// CHECK-LABEL: func @alloca 35func.func @alloca(%idx : index) { 36 // CHECK-NEXT: memref.alloca() : memref<65xf32> 37 %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 38 // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32> 39 affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 40 affine.for %i = 0 to 64 { 41 %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 42 "prevent.dce"(%1) : (f32) -> () 43 // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32> 44 } 45 return 46} 47 48// CHECK-LABEL: func @shift 49func.func @shift(%idx : index) { 50 // CHECK-NEXT: memref.alloc() : memref<65xf32> 51 %A = memref.alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 52 // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32> 53 affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 54 affine.for %i = 0 to 64 { 55 %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 56 "prevent.dce"(%1) : (f32) -> () 57 // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32> 58 } 59 return 60} 61 62// CHECK-LABEL: func @high_dim_permute() 63func.func @high_dim_permute() { 64 // CHECK-NOT: memref<64x128x256xf32, 65 %A = memref.alloc() : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> 66 // CHECK: %[[I:arg[0-9a-zA-Z_]+]] 67 affine.for %i = 0 to 64 { 68 // CHECK: %[[J:arg[0-9a-zA-Z_]+]] 69 affine.for %j = 0 to 128 { 70 // CHECK: %[[K:arg[0-9a-zA-Z_]+]] 71 affine.for %k = 0 to 256 { 72 %1 = affine.load %A[%i, %j, %k] : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> 73 // CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : memref<256x64x128xf32> 74 "prevent.dce"(%1) : (f32) -> () 75 } 76 } 77 } 78 return 79} 80 81// CHECK-LABEL: func @invalid_map 82func.func @invalid_map() { 83 %A = memref.alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (d0, -d1 - 10)>> 84 // CHECK: %{{.*}} = memref.alloc() : memref<64x128xf32, 85 return 86} 87 88// A tiled layout. 89// CHECK-LABEL: func @data_tiling 90func.func @data_tiling(%idx : index) { 91 // CHECK: memref.alloc() : memref<8x32x8x16xf32> 92 %A = memref.alloc() : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>> 93 // CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16] 94 %1 = affine.load %A[%idx, %idx] : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>> 95 "prevent.dce"(%1) : (f32) -> () 96 return 97} 98 99// Strides 2 and 4 along respective dimensions. 100// CHECK-LABEL: func @strided 101func.func @strided() { 102 %A = memref.alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>> 103 // CHECK: affine.for %[[IV0:.*]] = 104 affine.for %i = 0 to 64 { 105 // CHECK: affine.for %[[IV1:.*]] = 106 affine.for %j = 0 to 128 { 107 // CHECK: affine.load %{{.*}}[%[[IV0]] * 2, %[[IV1]] * 4] : memref<127x509xf32> 108 %1 = affine.load %A[%i, %j] : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>> 109 "prevent.dce"(%1) : (f32) -> () 110 } 111 } 112 return 113} 114 115// Strided, but the strides are in the linearized space. 116// CHECK-LABEL: func @strided_cumulative 117func.func @strided_cumulative() { 118 %A = memref.alloc() : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>> 119 // CHECK: affine.for %[[IV0:.*]] = 120 affine.for %i = 0 to 2 { 121 // CHECK: affine.for %[[IV1:.*]] = 122 affine.for %j = 0 to 5 { 123 // CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : memref<72xf32> 124 %1 = affine.load %A[%i, %j] : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>> 125 "prevent.dce"(%1) : (f32) -> () 126 } 127 } 128 return 129} 130 131// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith 132// when the index remap has symbols. 133// CHECK-LABEL: func @symbolic_operands 134func.func @symbolic_operands(%s : index) { 135 // CHECK: memref.alloc() : memref<100xf32> 136 %A = memref.alloc()[%s] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>> 137 affine.for %i = 0 to 10 { 138 affine.for %j = 0 to 10 { 139 // CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32> 140 %1 = affine.load %A[%i, %j] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>> 141 "prevent.dce"(%1) : (f32) -> () 142 } 143 } 144 return 145} 146 147// Semi-affine maps, normalization not implemented yet. 148// CHECK-LABEL: func @semi_affine_layout_map 149func.func @semi_affine_layout_map(%s0: index, %s1: index) { 150 %A = memref.alloc()[%s0, %s1] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>> 151 affine.for %i = 0 to 256 { 152 affine.for %j = 0 to 1024 { 153 // CHECK: memref<256x1024xf32, #map{{[0-9a-zA-Z_]+}}> 154 affine.load %A[%i, %j] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>> 155 } 156 } 157 return 158} 159 160// CHECK-LABEL: func @alignment 161func.func @alignment() { 162 %A = memref.alloc() {alignment = 32 : i64}: memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> 163 // CHECK-NEXT: memref.alloc() {alignment = 32 : i64} : memref<256x64x128xf32> 164 return 165} 166 167#tile = affine_map < (i)->(i floordiv 4, i mod 4) > 168 169// Following test cases check the inter-procedural memref normalization. 170 171// Test case 1: Check normalization for multiple memrefs in a function argument list. 172// CHECK-LABEL: func @multiple_argument_type 173// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<4x4xf64>, %[[B:arg[0-9a-zA-Z_]+]]: f64, %[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>, %[[D:arg[0-9a-zA-Z_]+]]: memref<24xf64>) -> f64 174func.func @multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>, %D: memref<24xf64>) -> f64 { 175 %a = affine.load %A[0] : memref<16xf64, #tile> 176 %p = arith.mulf %a, %a : f64 177 affine.store %p, %A[10] : memref<16xf64, #tile> 178 call @single_argument_type(%C): (memref<8xf64, #tile>) -> () 179 return %B : f64 180} 181 182// CHECK: %[[a:[0-9a-zA-Z_]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> 183// CHECK: %[[p:[0-9a-zA-Z_]+]] = arith.mulf %[[a]], %[[a]] : f64 184// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64> 185// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> () 186// CHECK: return %[[B]] : f64 187 188// Test case 2: Check normalization for single memref argument in a function. 189// CHECK-LABEL: func @single_argument_type 190// CHECK-SAME: (%[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) 191func.func @single_argument_type(%C : memref<8xf64, #tile>) { 192 %a = memref.alloc(): memref<8xf64, #tile> 193 %b = memref.alloc(): memref<16xf64, #tile> 194 %d = arith.constant 23.0 : f64 195 %e = memref.alloc(): memref<24xf64> 196 call @single_argument_type(%a): (memref<8xf64, #tile>) -> () 197 call @single_argument_type(%C): (memref<8xf64, #tile>) -> () 198 call @multiple_argument_type(%b, %d, %a, %e): (memref<16xf64, #tile>, f64, memref<8xf64, #tile>, memref<24xf64>) -> f64 199 return 200} 201 202// CHECK: %[[a:[0-9a-zA-Z_]+]] = memref.alloc() : memref<2x4xf64> 203// CHECK: %[[b:[0-9a-zA-Z_]+]] = memref.alloc() : memref<4x4xf64> 204// CHECK: %cst = arith.constant 2.300000e+01 : f64 205// CHECK: %[[e:[0-9a-zA-Z_]+]] = memref.alloc() : memref<24xf64> 206// CHECK: call @single_argument_type(%[[a]]) : (memref<2x4xf64>) -> () 207// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> () 208// CHECK: call @multiple_argument_type(%[[b]], %cst, %[[a]], %[[e]]) : (memref<4x4xf64>, f64, memref<2x4xf64>, memref<24xf64>) -> f64 209 210// Test case 3: Check function returning any other type except memref. 211// CHECK-LABEL: func @non_memref_ret 212// CHECK-SAME: (%[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) -> i1 213func.func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 { 214 %d = arith.constant 1 : i1 215 return %d : i1 216} 217 218// Test cases here onwards deal with normalization of memref in function signature, caller site. 219 220// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls. 221// CHECK-LABEL: func @ret_multiple_argument_type 222// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<4x4xf64>, %[[B:arg[0-9a-zA-Z_]+]]: f64, %[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64) 223func.func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) { 224 %a = affine.load %A[0] : memref<16xf64, #tile> 225 %p = arith.mulf %a, %a : f64 226 %cond = arith.constant 1 : i1 227 cf.cond_br %cond, ^bb1, ^bb2 228 ^bb1: 229 %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) 230 return %res2, %p: memref<8xf64, #tile>, f64 231 ^bb2: 232 return %C, %p: memref<8xf64, #tile>, f64 233} 234 235// CHECK: %[[a:[0-9a-zA-Z_]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> 236// CHECK: %[[p:[0-9a-zA-Z_]+]] = arith.mulf %[[a]], %[[a]] : f64 237// CHECK: %true = arith.constant true 238// CHECK: cf.cond_br %true, ^bb1, ^bb2 239// CHECK: ^bb1: // pred: ^bb0 240// CHECK: %[[res:[0-9a-zA-Z_]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 241// CHECK: return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64 242// CHECK: ^bb2: // pred: ^bb0 243// CHECK: return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64 244 245// CHECK-LABEL: func @ret_single_argument_type 246// CHECK-SAME: (%[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 247func.func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){ 248 %a = memref.alloc() : memref<8xf64, #tile> 249 %b = memref.alloc() : memref<16xf64, #tile> 250 %d = arith.constant 23.0 : f64 251 call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) 252 call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) 253 %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) 254 %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) 255 return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile> 256} 257 258// CHECK: %[[a:[0-9a-zA-Z_]+]] = memref.alloc() : memref<2x4xf64> 259// CHECK: %[[b:[0-9a-zA-Z_]+]] = memref.alloc() : memref<4x4xf64> 260// CHECK: %cst = arith.constant 2.300000e+01 : f64 261// CHECK: %[[resA:[0-9a-zA-Z_]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 262// CHECK: %[[resB:[0-9a-zA-Z_]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 263// CHECK: %[[resC:[0-9a-zA-Z_]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64) 264// CHECK: %[[resD:[0-9a-zA-Z_]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 265// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64> 266 267// Test case set #5: To check normalization in a chain of interconnected functions. 268// CHECK-LABEL: func @func_A 269// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) 270func.func @func_A(%A: memref<8xf64, #tile>) { 271 call @func_B(%A) : (memref<8xf64, #tile>) -> () 272 return 273} 274// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> () 275 276// CHECK-LABEL: func @func_B 277// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) 278func.func @func_B(%A: memref<8xf64, #tile>) { 279 call @func_C(%A) : (memref<8xf64, #tile>) -> () 280 return 281} 282// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> () 283 284// CHECK-LABEL: func @func_C 285// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) 286func.func @func_C(%A: memref<8xf64, #tile>) { 287 return 288} 289 290// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type. 291// CHECK-LABEL: func @some_func_A 292// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<8xf64, #map{{[0-9a-zA-Z_]+}}>) 293func.func @some_func_A(%A: memref<8xf64, #tile>) { 294 call @some_func_B(%A) : (memref<8xf64, #tile>) -> () 295 return 296} 297// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9a-zA-Z_]+}}>) -> () 298 299// CHECK-LABEL: func @some_func_B 300// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<8xf64, #map{{[0-9a-zA-Z_]+}}>) 301func.func @some_func_B(%A: memref<8xf64, #tile>) { 302 "test.test"(%A) : (memref<8xf64, #tile>) -> () 303 call @some_func_C(%A) : (memref<8xf64, #tile>) -> () 304 return 305} 306// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9a-zA-Z_]+}}>) -> () 307 308// CHECK-LABEL: func @some_func_C 309// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<8xf64, #map{{[0-9a-zA-Z_]+}}>) 310func.func @some_func_C(%A: memref<8xf64, #tile>) { 311 return 312} 313 314// Test case set #7: Check normalization in case of external functions. 315// CHECK-LABEL: func private @external_func_A 316// CHECK-SAME: (memref<4x4xf64>) 317func.func private @external_func_A(memref<16xf64, #tile>) -> () 318 319// CHECK-LABEL: func private @external_func_B 320// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64> 321func.func private @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>) 322 323// CHECK-LABEL: func @simply_call_external() 324func.func @simply_call_external() { 325 %a = memref.alloc() : memref<16xf64, #tile> 326 call @external_func_A(%a) : (memref<16xf64, #tile>) -> () 327 return 328} 329// CHECK: %[[a:[0-9a-zA-Z_]+]] = memref.alloc() : memref<4x4xf64> 330// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> () 331 332// CHECK-LABEL: func @use_value_of_external 333// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<4x4xf64>, %[[B:arg[0-9a-zA-Z_]+]]: f64) -> memref<2x4xf64> 334func.func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) { 335 %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>) 336 return %res : memref<8xf64, #tile> 337} 338// CHECK: %[[res:[0-9a-zA-Z_]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64> 339// CHECK: return %{{.*}} : memref<2x4xf64> 340 341// CHECK-LABEL: func @affine_parallel_norm 342func.func @affine_parallel_norm() -> memref<8xf32, #tile> { 343 %c = arith.constant 23.0 : f32 344 %a = memref.alloc() : memref<8xf32, #tile> 345 // CHECK: affine.parallel (%{{.*}}) = (0) to (8) reduce ("assign") -> (memref<2x4xf32>) 346 %1 = affine.parallel (%i) = (0) to (8) reduce ("assign") -> memref<8xf32, #tile> { 347 affine.store %c, %a[%i] : memref<8xf32, #tile> 348 // CHECK: affine.yield %{{.*}} : memref<2x4xf32> 349 affine.yield %a : memref<8xf32, #tile> 350 } 351 return %1 : memref<8xf32, #tile> 352} 353 354#map = affine_map<(d0, d1)[s0] -> (d0 * 3 + s0 + d1)> 355// CHECK-LABEL: func.func @map_symbol 356func.func @map_symbol() -> memref<2x3xf32, #map> { 357 %c1 = arith.constant 1 : index 358 // The constant isn't propagated here and the utility can't compute a constant 359 // upper bound for the memref dimension in the absence of that. 360 // CHECK: memref.alloc()[%{{.*}}] 361 %0 = memref.alloc()[%c1] : memref<2x3xf32, #map> 362 return %0 : memref<2x3xf32, #map> 363} 364 365#neg = affine_map<(d0, d1) -> (d0, d1 - 100)> 366// CHECK-LABEL: func.func @neg_map 367func.func @neg_map() -> memref<2x3xf32, #neg> { 368 // This isn't a valid map for normalization. 369 // CHECK: memref.alloc() : memref<2x3xf32, #{{.*}}> 370 %0 = memref.alloc() : memref<2x3xf32, #neg> 371 return %0 : memref<2x3xf32, #neg> 372} 373 374// CHECK-LABEL: func @memref_with_strided_offset 375func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %arg2: index) -> tensor<16x512xf32> { 376 %c0 = arith.constant 0 : index 377 %0 = bufferization.to_memref %arg0 : tensor<128x512xf32> to memref<128x512xf32, strided<[?, ?], offset: ?>> 378 %subview = memref.subview %0[%arg2, 0] [%arg1, 512] [1, 1] : memref<128x512xf32, strided<[?, ?], offset: ?>> to memref<?x512xf32, strided<[?, ?], offset: ?>> 379 // CHECK: %{{.*}} = memref.cast %{{.*}} : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>> 380 %cast = memref.cast %subview : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>> 381 %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>> to tensor<16x512xf32> 382 return %1 : tensor<16x512xf32> 383} 384 385#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))> 386#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))> 387#map2 = affine_map<(i,j) -> (4 * i + j)> 388// CHECK-LABEL: func @memref_load_with_reduction_map 389func.func @memref_load_with_reduction_map(%arg0 : memref<4x4xf32,#map2>) -> () { 390 %0 = memref.alloc() : memref<4x8xf32,#map0> 391 %1 = memref.alloc() : memref<8x4xf32,#map1> 392 %2 = memref.alloc() : memref<4x4xf32,#map2> 393 // CHECK-NOT: memref<4x8xf32> 394 // CHECK-NOT: memref<8x4xf32> 395 // CHECK-NOT: memref<4x4xf32> 396 %cst = arith.constant 3.0 : f32 397 %cst0 = arith.constant 0 : index 398 affine.for %i = 0 to 4 { 399 affine.for %j = 0 to 8 { 400 affine.for %k = 0 to 8 { 401 // CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}}) 402 // CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32> 403 %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0> 404 // CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}}) 405 // CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32> 406 %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1> 407 // CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}}) 408 // CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32> 409 %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2> 410 %3 = arith.mulf %a, %b : f32 411 %4 = arith.addf %3, %c : f32 412 affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2> 413 } 414 } 415 } 416 return 417}