1// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize | FileCheck %s 2// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize='max-nested=1' | FileCheck --check-prefix=MAX-NESTED %s 3// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize='parallel-reductions=1' | FileCheck --check-prefix=REDUCE %s 4 5// CHECK-LABEL: func @reduce_window_max() { 6func.func @reduce_window_max() { 7 %cst = arith.constant 0.000000e+00 : f32 8 %0 = memref.alloc() : memref<1x8x8x64xf32> 9 %1 = memref.alloc() : memref<1x18x18x64xf32> 10 affine.for %arg0 = 0 to 1 { 11 affine.for %arg1 = 0 to 8 { 12 affine.for %arg2 = 0 to 8 { 13 affine.for %arg3 = 0 to 64 { 14 affine.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> 15 } 16 } 17 } 18 } 19 affine.for %arg0 = 0 to 1 { 20 affine.for %arg1 = 0 to 8 { 21 affine.for %arg2 = 0 to 8 { 22 affine.for %arg3 = 0 to 64 { 23 affine.for %arg4 = 0 to 1 { 24 affine.for %arg5 = 0 to 3 { 25 affine.for %arg6 = 0 to 3 { 26 affine.for %arg7 = 0 to 1 { 27 %2 = affine.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> 28 %3 = affine.load %1[%arg0 + %arg4, %arg1 * 2 + %arg5, %arg2 * 2 + %arg6, %arg3 + %arg7] : memref<1x18x18x64xf32> 29 %4 = arith.cmpf ogt, %2, %3 : f32 30 %5 = arith.select %4, %2, %3 : f32 31 affine.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> 32 } 33 } 34 } 35 } 36 } 37 } 38 } 39 } 40 return 41} 42 43// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 44// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32> 45// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32> 46// CHECK: affine.parallel (%[[arg0:.*]]) = (0) to (1) { 47// CHECK: affine.parallel (%[[arg1:.*]]) = (0) to (8) { 48// CHECK: affine.parallel (%[[arg2:.*]]) = (0) to (8) { 49// CHECK: affine.parallel (%[[arg3:.*]]) = (0) to (64) { 50// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32> 51// CHECK: } 52// CHECK: } 53// CHECK: } 54// CHECK: } 55// CHECK: affine.parallel (%[[a0:.*]]) = (0) to (1) { 56// CHECK: affine.parallel (%[[a1:.*]]) = (0) to (8) { 57// CHECK: affine.parallel (%[[a2:.*]]) = (0) to (8) { 58// CHECK: affine.parallel (%[[a3:.*]]) = (0) to (64) { 59// CHECK: affine.parallel (%[[a4:.*]]) = (0) to (1) { 60// CHECK: affine.for %[[a5:.*]] = 0 to 3 { 61// CHECK: affine.for %[[a6:.*]] = 0 to 3 { 62// CHECK: affine.parallel (%[[a7:.*]]) = (0) to (1) { 63// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> 64// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32> 65// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 66// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32 67// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> 68// CHECK: } 69// CHECK: } 70// CHECK: } 71// CHECK: } 72// CHECK: } 73// CHECK: } 74// CHECK: } 75// CHECK: } 76// CHECK: } 77 78func.func @loop_nest_3d_outer_two_parallel(%N : index) { 79 %0 = memref.alloc() : memref<1024 x 1024 x vector<64xf32>> 80 %1 = memref.alloc() : memref<1024 x 1024 x vector<64xf32>> 81 %2 = memref.alloc() : memref<1024 x 1024 x vector<64xf32>> 82 affine.for %i = 0 to %N { 83 affine.for %j = 0 to %N { 84 %7 = affine.load %2[%i, %j] : memref<1024x1024xvector<64xf32>> 85 affine.for %k = 0 to %N { 86 %5 = affine.load %0[%i, %k] : memref<1024x1024xvector<64xf32>> 87 %6 = affine.load %1[%k, %j] : memref<1024x1024xvector<64xf32>> 88 %8 = arith.mulf %5, %6 : vector<64xf32> 89 %9 = arith.addf %7, %8 : vector<64xf32> 90 affine.store %9, %2[%i, %j] : memref<1024x1024xvector<64xf32>> 91 } 92 } 93 } 94 return 95} 96 97// CHECK: affine.parallel (%[[arg1:.*]]) = (0) to (symbol(%arg0)) { 98// CHECK-NEXT: affine.parallel (%[[arg2:.*]]) = (0) to (symbol(%arg0)) { 99// CHECK: affine.for %[[arg3:.*]] = 0 to %arg0 { 100 101// CHECK-LABEL: unknown_op_conservative 102func.func @unknown_op_conservative() { 103 affine.for %i = 0 to 10 { 104// CHECK: affine.for %[[arg1:.*]] = 0 to 10 { 105 "unknown"() : () -> () 106 } 107 return 108} 109 110// CHECK-LABEL: non_affine_load 111func.func @non_affine_load() { 112 %0 = memref.alloc() : memref<100 x f32> 113 affine.for %i = 0 to 100 { 114// CHECK: affine.for %{{.*}} = 0 to 100 { 115 memref.load %0[%i] : memref<100 x f32> 116 } 117 return 118} 119 120// CHECK-LABEL: for_with_minmax 121func.func @for_with_minmax(%m: memref<?xf32>, %lb0: index, %lb1: index, 122 %ub0: index, %ub1: index) { 123 // CHECK: affine.parallel (%{{.*}}) = (max(%{{.*}}, %{{.*}})) to (min(%{{.*}}, %{{.*}})) 124 affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %lb1) 125 to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) { 126 affine.load %m[%i] : memref<?xf32> 127 } 128 return 129} 130 131// CHECK-LABEL: nested_for_with_minmax 132func.func @nested_for_with_minmax(%m: memref<?xf32>, %lb0: index, 133 %ub0: index, %ub1: index) { 134 // CHECK: affine.parallel (%[[I:.*]]) = 135 affine.for %j = 0 to 10 { 136 // CHECK: affine.parallel (%{{.*}}) = (max(%{{.*}}, %[[I]])) to (min(%{{.*}}, %{{.*}})) 137 affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j) 138 to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) { 139 affine.load %m[%i] : memref<?xf32> 140 } 141 } 142 return 143} 144 145// MAX-NESTED-LABEL: @max_nested 146func.func @max_nested(%m: memref<?x?xf32>, %lb0: index, %lb1: index, 147 %ub0: index, %ub1: index) { 148 // MAX-NESTED: affine.parallel 149 affine.for %i = affine_map<(d0) -> (d0)>(%lb0) to affine_map<(d0) -> (d0)>(%ub0) { 150 // MAX-NESTED: affine.for 151 affine.for %j = affine_map<(d0) -> (d0)>(%lb1) to affine_map<(d0) -> (d0)>(%ub1) { 152 affine.load %m[%i, %j] : memref<?x?xf32> 153 } 154 } 155 return 156} 157 158// MAX-NESTED-LABEL: @max_nested_1 159func.func @max_nested_1(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf32>, %arg2: memref<4096x4096xf32>) { 160 %0 = memref.alloc() : memref<4096x4096xf32> 161 // MAX-NESTED: affine.parallel 162 affine.for %arg3 = 0 to 4096 { 163 // MAX-NESTED-NEXT: affine.for 164 affine.for %arg4 = 0 to 4096 { 165 // MAX-NESTED-NEXT: affine.for 166 affine.for %arg5 = 0 to 4096 { 167 %1 = affine.load %arg0[%arg3, %arg5] : memref<4096x4096xf32> 168 %2 = affine.load %arg1[%arg5, %arg4] : memref<4096x4096xf32> 169 %3 = affine.load %0[%arg3, %arg4] : memref<4096x4096xf32> 170 %4 = arith.mulf %1, %2 : f32 171 %5 = arith.addf %3, %4 : f32 172 affine.store %5, %0[%arg3, %arg4] : memref<4096x4096xf32> 173 } 174 } 175 } 176 return 177} 178 179// CHECK-LABEL: @iter_args 180// REDUCE-LABEL: @iter_args 181func.func @iter_args(%in: memref<10xf32>) { 182 // REDUCE: %[[init:.*]] = arith.constant 183 %cst = arith.constant 0.000000e+00 : f32 184 // CHECK-NOT: affine.parallel 185 // REDUCE: %[[reduced:.*]] = affine.parallel (%{{.*}}) = (0) to (10) reduce ("addf") 186 %final_red = affine.for %i = 0 to 10 iter_args(%red_iter = %cst) -> (f32) { 187 // REDUCE: %[[red_value:.*]] = affine.load 188 %ld = affine.load %in[%i] : memref<10xf32> 189 // REDUCE-NOT: arith.addf 190 %add = arith.addf %red_iter, %ld : f32 191 // REDUCE: affine.yield %[[red_value]] 192 affine.yield %add : f32 193 } 194 // REDUCE: arith.addf %[[init]], %[[reduced]] 195 return 196} 197 198// CHECK-LABEL: @nested_iter_args 199// REDUCE-LABEL: @nested_iter_args 200func.func @nested_iter_args(%in: memref<20x10xf32>) { 201 %cst = arith.constant 0.000000e+00 : f32 202 // CHECK: affine.parallel 203 affine.for %i = 0 to 20 { 204 // CHECK-NOT: affine.parallel 205 // REDUCE: affine.parallel 206 // REDUCE: reduce ("addf") 207 %final_red = affine.for %j = 0 to 10 iter_args(%red_iter = %cst) -> (f32) { 208 %ld = affine.load %in[%i, %j] : memref<20x10xf32> 209 %add = arith.addf %red_iter, %ld : f32 210 affine.yield %add : f32 211 } 212 } 213 return 214} 215 216// REDUCE-LABEL: @strange_butterfly 217func.func @strange_butterfly() { 218 %cst1 = arith.constant 0.0 : f32 219 %cst2 = arith.constant 1.0 : f32 220 // REDUCE-NOT: affine.parallel 221 affine.for %i = 0 to 10 iter_args(%it1 = %cst1, %it2 = %cst2) -> (f32, f32) { 222 %0 = arith.addf %it1, %it2 : f32 223 affine.yield %0, %0 : f32, f32 224 } 225 return 226} 227 228// An iter arg is used more than once. This is not a simple reduction and 229// should not be parallelized. 230// REDUCE-LABEL: @repeated_use 231func.func @repeated_use() { 232 %cst1 = arith.constant 0.0 : f32 233 // REDUCE-NOT: affine.parallel 234 affine.for %i = 0 to 10 iter_args(%it1 = %cst1) -> (f32) { 235 %0 = arith.addf %it1, %it1 : f32 236 affine.yield %0 : f32 237 } 238 return 239} 240 241// An iter arg is used in the chain of operations defining the value being 242// reduced, this is not a simple reduction and should not be parallelized. 243// REDUCE-LABEL: @use_in_backward_slice 244func.func @use_in_backward_slice() { 245 %cst1 = arith.constant 0.0 : f32 246 %cst2 = arith.constant 1.0 : f32 247 // REDUCE-NOT: affine.parallel 248 affine.for %i = 0 to 10 iter_args(%it1 = %cst1, %it2 = %cst2) -> (f32, f32) { 249 %0 = "test.some_modification"(%it2) : (f32) -> f32 250 %1 = arith.addf %it1, %0 : f32 251 affine.yield %1, %1 : f32, f32 252 } 253 return 254} 255 256// REDUCE-LABEL: @nested_min_max 257// CHECK-LABEL: @nested_min_max 258// CHECK: (%{{.*}}, %[[LB0:.*]]: index, %[[UB0:.*]]: index, %[[UB1:.*]]: index) 259func.func @nested_min_max(%m: memref<?xf32>, %lb0: index, 260 %ub0: index, %ub1: index) { 261 // CHECK: affine.parallel (%[[J:.*]]) = 262 affine.for %j = 0 to 10 { 263 // CHECK: affine.parallel (%{{.*}}) = (max(%[[LB0]], %[[J]])) 264 // CHECK: to (min(%[[UB0]], %[[UB1]])) 265 affine.for %i = max affine_map<(d0, d1) -> (d0, d1)>(%lb0, %j) 266 to min affine_map<(d0, d1) -> (d0, d1)>(%ub0, %ub1) { 267 affine.load %m[%i] : memref<?xf32> 268 } 269 } 270 return 271} 272 273// Test in the presence of locally allocated memrefs. 274 275// CHECK: func @local_alloc 276func.func @local_alloc() { 277 %cst = arith.constant 0.0 : f32 278 affine.for %i = 0 to 100 { 279 %m = memref.alloc() : memref<1xf32> 280 %ma = memref.alloca() : memref<1xf32> 281 affine.store %cst, %m[0] : memref<1xf32> 282 } 283 // CHECK: affine.parallel 284 return 285} 286 287// CHECK: func @local_alloc_cast 288func.func @local_alloc_cast() { 289 %cst = arith.constant 0.0 : f32 290 affine.for %i = 0 to 100 { 291 %m = memref.alloc() : memref<128xf32> 292 affine.for %j = 0 to 128 { 293 affine.store %cst, %m[%j] : memref<128xf32> 294 } 295 affine.for %j = 0 to 128 { 296 affine.store %cst, %m[0] : memref<128xf32> 297 } 298 %r = memref.reinterpret_cast %m to offset: [0], sizes: [8, 16], 299 strides: [16, 1] : memref<128xf32> to memref<8x16xf32> 300 affine.for %j = 0 to 8 { 301 affine.store %cst, %r[%j, %j] : memref<8x16xf32> 302 } 303 } 304 // CHECK: affine.parallel 305 // CHECK: affine.parallel 306 // CHECK: } 307 // CHECK: affine.for 308 // CHECK: } 309 // CHECK: affine.parallel 310 // CHECK: } 311 // CHECK: } 312 313 return 314} 315 316// CHECK-LABEL: @iter_arg_memrefs 317func.func @iter_arg_memrefs(%in: memref<10xf32>) { 318 %mi = memref.alloc() : memref<f32> 319 // Loop-carried memrefs are treated as serializing the loop. 320 // CHECK: affine.for 321 %mo = affine.for %i = 0 to 10 iter_args(%m_arg = %mi) -> (memref<f32>) { 322 affine.yield %m_arg : memref<f32> 323 } 324 return 325} 326