1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous transformation utilities for the Affine 10 // dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/Utils.h" 15 16 #include "mlir/Dialect/Affine/Analysis/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Arith/Utils/Utils.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/IR/AffineExprVisitor.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/IR/IRMapping.h" 27 #include "mlir/IR/ImplicitLocOpBuilder.h" 28 #include "mlir/IR/IntegerSet.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "llvm/Support/LogicalResult.h" 31 #include <optional> 32 33 #define DEBUG_TYPE "affine-utils" 34 35 using namespace mlir; 36 using namespace affine; 37 using namespace presburger; 38 39 namespace { 40 /// Visit affine expressions recursively and build the sequence of operations 41 /// that correspond to it. Visitation functions return an Value of the 42 /// expression subtree they visited or `nullptr` on error. 43 class AffineApplyExpander 44 : public AffineExprVisitor<AffineApplyExpander, Value> { 45 public: 46 /// This internal class expects arguments to be non-null, checks must be 47 /// performed at the call site. 48 AffineApplyExpander(OpBuilder &builder, ValueRange dimValues, 49 ValueRange symbolValues, Location loc) 50 : builder(builder), dimValues(dimValues), symbolValues(symbolValues), 51 loc(loc) {} 52 53 template <typename OpTy> 54 Value buildBinaryExpr(AffineBinaryOpExpr expr, 55 arith::IntegerOverflowFlags overflowFlags = 56 arith::IntegerOverflowFlags::none) { 57 auto lhs = visit(expr.getLHS()); 58 auto rhs = visit(expr.getRHS()); 59 if (!lhs || !rhs) 60 return nullptr; 61 auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags); 62 return op.getResult(); 63 } 64 65 Value visitAddExpr(AffineBinaryOpExpr expr) { 66 return buildBinaryExpr<arith::AddIOp>(expr); 67 } 68 69 Value visitMulExpr(AffineBinaryOpExpr expr) { 70 return buildBinaryExpr<arith::MulIOp>(expr, 71 arith::IntegerOverflowFlags::nsw); 72 } 73 74 /// Euclidean modulo operation: negative RHS is not allowed. 75 /// Remainder of the euclidean integer division is always non-negative. 76 /// 77 /// Implemented as 78 /// 79 /// a mod b = 80 /// let remainder = srem a, b; 81 /// negative = a < 0 in 82 /// select negative, remainder + b, remainder. 83 Value visitModExpr(AffineBinaryOpExpr expr) { 84 if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) { 85 if (rhsConst.getValue() <= 0) { 86 emitError(loc, "modulo by non-positive value is not supported"); 87 return nullptr; 88 } 89 } 90 91 auto lhs = visit(expr.getLHS()); 92 auto rhs = visit(expr.getRHS()); 93 assert(lhs && rhs && "unexpected affine expr lowering failure"); 94 95 Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs); 96 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 97 Value isRemainderNegative = builder.create<arith::CmpIOp>( 98 loc, arith::CmpIPredicate::slt, remainder, zeroCst); 99 Value correctedRemainder = 100 builder.create<arith::AddIOp>(loc, remainder, rhs); 101 Value result = builder.create<arith::SelectOp>( 102 loc, isRemainderNegative, correctedRemainder, remainder); 103 return result; 104 } 105 106 /// Floor division operation (rounds towards negative infinity). 107 /// 108 /// For positive divisors, it can be implemented without branching and with a 109 /// single division operation as 110 /// 111 /// a floordiv b = 112 /// let negative = a < 0 in 113 /// let absolute = negative ? -a - 1 : a in 114 /// let quotient = absolute / b in 115 /// negative ? -quotient - 1 : quotient 116 /// 117 /// Note: this lowering does not use arith.floordivsi because the lowering of 118 /// that to arith.divsi (see populateCeilFloorDivExpandOpsPatterns) generates 119 /// not one but two arith.divsi. That could be changed to one divsi, but one 120 /// way or another, going through arith.floordivsi will result in more complex 121 /// IR because arith.floordivsi is more general than affine floordiv in that 122 /// it supports negative RHS. 123 Value visitFloorDivExpr(AffineBinaryOpExpr expr) { 124 if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) { 125 if (rhsConst.getValue() <= 0) { 126 emitError(loc, "division by non-positive value is not supported"); 127 return nullptr; 128 } 129 } 130 auto lhs = visit(expr.getLHS()); 131 auto rhs = visit(expr.getRHS()); 132 assert(lhs && rhs && "unexpected affine expr lowering failure"); 133 134 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 135 Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1); 136 Value negative = builder.create<arith::CmpIOp>( 137 loc, arith::CmpIPredicate::slt, lhs, zeroCst); 138 Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs); 139 Value dividend = 140 builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs); 141 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); 142 Value correctedQuotient = 143 builder.create<arith::SubIOp>(loc, noneCst, quotient); 144 Value result = builder.create<arith::SelectOp>(loc, negative, 145 correctedQuotient, quotient); 146 return result; 147 } 148 149 /// Ceiling division operation (rounds towards positive infinity). 150 /// 151 /// For positive divisors, it can be implemented without branching and with a 152 /// single division operation as 153 /// 154 /// a ceildiv b = 155 /// let negative = a <= 0 in 156 /// let absolute = negative ? -a : a - 1 in 157 /// let quotient = absolute / b in 158 /// negative ? -quotient : quotient + 1 159 /// 160 /// Note: not using arith.ceildivsi for the same reason as explained in the 161 /// visitFloorDivExpr comment. 162 Value visitCeilDivExpr(AffineBinaryOpExpr expr) { 163 if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) { 164 if (rhsConst.getValue() <= 0) { 165 emitError(loc, "division by non-positive value is not supported"); 166 return nullptr; 167 } 168 } 169 auto lhs = visit(expr.getLHS()); 170 auto rhs = visit(expr.getRHS()); 171 assert(lhs && rhs && "unexpected affine expr lowering failure"); 172 173 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 174 Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1); 175 Value nonPositive = builder.create<arith::CmpIOp>( 176 loc, arith::CmpIPredicate::sle, lhs, zeroCst); 177 Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs); 178 Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst); 179 Value dividend = 180 builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented); 181 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); 182 Value negatedQuotient = 183 builder.create<arith::SubIOp>(loc, zeroCst, quotient); 184 Value incrementedQuotient = 185 builder.create<arith::AddIOp>(loc, quotient, oneCst); 186 Value result = builder.create<arith::SelectOp>( 187 loc, nonPositive, negatedQuotient, incrementedQuotient); 188 return result; 189 } 190 191 Value visitConstantExpr(AffineConstantExpr expr) { 192 auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue()); 193 return op.getResult(); 194 } 195 196 Value visitDimExpr(AffineDimExpr expr) { 197 assert(expr.getPosition() < dimValues.size() && 198 "affine dim position out of range"); 199 return dimValues[expr.getPosition()]; 200 } 201 202 Value visitSymbolExpr(AffineSymbolExpr expr) { 203 assert(expr.getPosition() < symbolValues.size() && 204 "symbol dim position out of range"); 205 return symbolValues[expr.getPosition()]; 206 } 207 208 private: 209 OpBuilder &builder; 210 ValueRange dimValues; 211 ValueRange symbolValues; 212 213 Location loc; 214 }; 215 } // namespace 216 217 /// Create a sequence of operations that implement the `expr` applied to the 218 /// given dimension and symbol values. 219 mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc, 220 AffineExpr expr, 221 ValueRange dimValues, 222 ValueRange symbolValues) { 223 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); 224 } 225 226 /// Create a sequence of operations that implement the `affineMap` applied to 227 /// the given `operands` (as it it were an AffineApplyOp). 228 std::optional<SmallVector<Value, 8>> 229 mlir::affine::expandAffineMap(OpBuilder &builder, Location loc, 230 AffineMap affineMap, ValueRange operands) { 231 auto numDims = affineMap.getNumDims(); 232 auto expanded = llvm::to_vector<8>( 233 llvm::map_range(affineMap.getResults(), 234 [numDims, &builder, loc, operands](AffineExpr expr) { 235 return expandAffineExpr(builder, loc, expr, 236 operands.take_front(numDims), 237 operands.drop_front(numDims)); 238 })); 239 if (llvm::all_of(expanded, [](Value v) { return v; })) 240 return expanded; 241 return std::nullopt; 242 } 243 244 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether 245 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards 246 /// the rest of the op. 247 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) { 248 if (elseBlock) 249 assert(ifOp.hasElse() && "else block expected"); 250 251 Block *destBlock = ifOp->getBlock(); 252 Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock(); 253 destBlock->getOperations().splice( 254 Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(), 255 std::prev(srcBlock->end())); 256 ifOp.erase(); 257 } 258 259 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant 260 /// on. The `ifOp` could be hoisted and placed right before such an operation. 261 /// This method assumes that the ifOp has been canonicalized (to be correct and 262 /// effective). 263 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) { 264 // Walk up the parents past all for op that this conditional is invariant on. 265 auto ifOperands = ifOp.getOperands(); 266 auto *res = ifOp.getOperation(); 267 while (!isa<func::FuncOp>(res->getParentOp())) { 268 auto *parentOp = res->getParentOp(); 269 if (auto forOp = dyn_cast<AffineForOp>(parentOp)) { 270 if (llvm::is_contained(ifOperands, forOp.getInductionVar())) 271 break; 272 } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) { 273 for (auto iv : parallelOp.getIVs()) 274 if (llvm::is_contained(ifOperands, iv)) 275 break; 276 } else if (!isa<AffineIfOp>(parentOp)) { 277 // Won't walk up past anything other than affine.for/if ops. 278 break; 279 } 280 // You can always hoist up past any affine.if ops. 281 res = parentOp; 282 } 283 return res; 284 } 285 286 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over 287 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened, 288 /// otherwise the same `ifOp`. 289 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) { 290 // No hoisting to do. 291 if (hoistOverOp == ifOp) 292 return ifOp; 293 294 // Create the hoisted 'if' first. Then, clone the op we are hoisting over for 295 // the else block. Then drop the else block of the original 'if' in the 'then' 296 // branch while promoting its then block, and analogously drop the 'then' 297 // block of the original 'if' from the 'else' branch while promoting its else 298 // block. 299 IRMapping operandMap; 300 OpBuilder b(hoistOverOp); 301 auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(), 302 ifOp.getOperands(), 303 /*elseBlock=*/true); 304 305 // Create a clone of hoistOverOp to use for the else branch of the hoisted 306 // conditional. The else block may get optimized away if empty. 307 Operation *hoistOverOpClone = nullptr; 308 // We use this unique name to identify/find `ifOp`'s clone in the else 309 // version. 310 StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting"); 311 operandMap.clear(); 312 b.setInsertionPointAfter(hoistOverOp); 313 // We'll set an attribute to identify this op in a clone of this sub-tree. 314 ifOp->setAttr(idForIfOp, b.getBoolAttr(true)); 315 hoistOverOpClone = b.clone(*hoistOverOp, operandMap); 316 317 // Promote the 'then' block of the original affine.if in the then version. 318 promoteIfBlock(ifOp, /*elseBlock=*/false); 319 320 // Move the then version to the hoisted if op's 'then' block. 321 auto *thenBlock = hoistedIfOp.getThenBlock(); 322 thenBlock->getOperations().splice(thenBlock->begin(), 323 hoistOverOp->getBlock()->getOperations(), 324 Block::iterator(hoistOverOp)); 325 326 // Find the clone of the original affine.if op in the else version. 327 AffineIfOp ifCloneInElse; 328 hoistOverOpClone->walk([&](AffineIfOp ifClone) { 329 if (!ifClone->getAttr(idForIfOp)) 330 return WalkResult::advance(); 331 ifCloneInElse = ifClone; 332 return WalkResult::interrupt(); 333 }); 334 assert(ifCloneInElse && "if op clone should exist"); 335 // For the else block, promote the else block of the original 'if' if it had 336 // one; otherwise, the op itself is to be erased. 337 if (!ifCloneInElse.hasElse()) 338 ifCloneInElse.erase(); 339 else 340 promoteIfBlock(ifCloneInElse, /*elseBlock=*/true); 341 342 // Move the else version into the else block of the hoisted if op. 343 auto *elseBlock = hoistedIfOp.getElseBlock(); 344 elseBlock->getOperations().splice( 345 elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(), 346 Block::iterator(hoistOverOpClone)); 347 348 return hoistedIfOp; 349 } 350 351 LogicalResult 352 mlir::affine::affineParallelize(AffineForOp forOp, 353 ArrayRef<LoopReduction> parallelReductions, 354 AffineParallelOp *resOp) { 355 // Fail early if there are iter arguments that are not reductions. 356 unsigned numReductions = parallelReductions.size(); 357 if (numReductions != forOp.getNumIterOperands()) 358 return failure(); 359 360 Location loc = forOp.getLoc(); 361 OpBuilder outsideBuilder(forOp); 362 AffineMap lowerBoundMap = forOp.getLowerBoundMap(); 363 ValueRange lowerBoundOperands = forOp.getLowerBoundOperands(); 364 AffineMap upperBoundMap = forOp.getUpperBoundMap(); 365 ValueRange upperBoundOperands = forOp.getUpperBoundOperands(); 366 367 // Creating empty 1-D affine.parallel op. 368 auto reducedValues = llvm::to_vector<4>(llvm::map_range( 369 parallelReductions, [](const LoopReduction &red) { return red.value; })); 370 auto reductionKinds = llvm::to_vector<4>(llvm::map_range( 371 parallelReductions, [](const LoopReduction &red) { return red.kind; })); 372 AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>( 373 loc, ValueRange(reducedValues).getTypes(), reductionKinds, 374 llvm::ArrayRef(lowerBoundMap), lowerBoundOperands, 375 llvm::ArrayRef(upperBoundMap), upperBoundOperands, 376 llvm::ArrayRef(forOp.getStepAsInt())); 377 // Steal the body of the old affine for op. 378 newPloop.getRegion().takeBody(forOp.getRegion()); 379 Operation *yieldOp = &newPloop.getBody()->back(); 380 381 // Handle the initial values of reductions because the parallel loop always 382 // starts from the neutral value. 383 SmallVector<Value> newResults; 384 newResults.reserve(numReductions); 385 for (unsigned i = 0; i < numReductions; ++i) { 386 Value init = forOp.getInits()[i]; 387 // This works because we are only handling single-op reductions at the 388 // moment. A switch on reduction kind or a mechanism to collect operations 389 // participating in the reduction will be necessary for multi-op reductions. 390 Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp(); 391 assert(reductionOp && "yielded value is expected to be produced by an op"); 392 outsideBuilder.getInsertionBlock()->getOperations().splice( 393 outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(), 394 reductionOp); 395 reductionOp->setOperands({init, newPloop->getResult(i)}); 396 forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0)); 397 } 398 399 // Update the loop terminator to yield reduced values bypassing the reduction 400 // operation itself (now moved outside of the loop) and erase the block 401 // arguments that correspond to reductions. Note that the loop always has one 402 // "main" induction variable whenc coming from a non-parallel for. 403 unsigned numIVs = 1; 404 yieldOp->setOperands(reducedValues); 405 newPloop.getBody()->eraseArguments(numIVs, numReductions); 406 407 forOp.erase(); 408 if (resOp) 409 *resOp = newPloop; 410 return success(); 411 } 412 413 // Returns success if any hoisting happened. 414 LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { 415 // Bail out early if the ifOp returns a result. TODO: Consider how to 416 // properly support this case. 417 if (ifOp.getNumResults() != 0) 418 return failure(); 419 420 // Apply canonicalization patterns and folding - this is necessary for the 421 // hoisting check to be correct (operands should be composed), and to be more 422 // effective (no unused operands). Since the pattern rewriter's folding is 423 // entangled with application of patterns, we may fold/end up erasing the op, 424 // in which case we return with `folded` being set. 425 RewritePatternSet patterns(ifOp.getContext()); 426 AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); 427 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 428 GreedyRewriteConfig config; 429 config.strictMode = GreedyRewriteStrictness::ExistingOps; 430 bool erased; 431 (void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config, 432 /*changed=*/nullptr, &erased); 433 if (erased) { 434 if (folded) 435 *folded = true; 436 return failure(); 437 } 438 if (folded) 439 *folded = false; 440 441 // The folding above should have ensured this, but the affine.if's 442 // canonicalization is missing composition of affine.applys into it. 443 assert(llvm::all_of(ifOp.getOperands(), 444 [](Value v) { 445 return isTopLevelValue(v) || isAffineForInductionVar(v); 446 }) && 447 "operands not composed"); 448 449 // We are going hoist as high as possible. 450 // TODO: this could be customized in the future. 451 auto *hoistOverOp = getOutermostInvariantForOp(ifOp); 452 453 AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp); 454 // Nothing to hoist over. 455 if (hoistedIfOp == ifOp) 456 return failure(); 457 458 // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up 459 // a sequence of affine.fors that are all perfectly nested). 460 (void)applyPatternsGreedily( 461 hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(), 462 frozenPatterns); 463 464 return success(); 465 } 466 467 // Return the min expr after replacing the given dim. 468 AffineExpr mlir::affine::substWithMin(AffineExpr e, AffineExpr dim, 469 AffineExpr min, AffineExpr max, 470 bool positivePath) { 471 if (e == dim) 472 return positivePath ? min : max; 473 if (auto bin = dyn_cast<AffineBinaryOpExpr>(e)) { 474 AffineExpr lhs = bin.getLHS(); 475 AffineExpr rhs = bin.getRHS(); 476 if (bin.getKind() == mlir::AffineExprKind::Add) 477 return substWithMin(lhs, dim, min, max, positivePath) + 478 substWithMin(rhs, dim, min, max, positivePath); 479 480 auto c1 = dyn_cast<AffineConstantExpr>(bin.getLHS()); 481 auto c2 = dyn_cast<AffineConstantExpr>(bin.getRHS()); 482 if (c1 && c1.getValue() < 0) 483 return getAffineBinaryOpExpr( 484 bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath)); 485 if (c2 && c2.getValue() < 0) 486 return getAffineBinaryOpExpr( 487 bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2); 488 return getAffineBinaryOpExpr( 489 bin.getKind(), substWithMin(lhs, dim, min, max, positivePath), 490 substWithMin(rhs, dim, min, max, positivePath)); 491 } 492 return e; 493 } 494 495 void mlir::affine::normalizeAffineParallel(AffineParallelOp op) { 496 // Loops with min/max in bounds are not normalized at the moment. 497 if (op.hasMinMaxBounds()) 498 return; 499 500 AffineMap lbMap = op.getLowerBoundsMap(); 501 SmallVector<int64_t, 8> steps = op.getSteps(); 502 // No need to do any work if the parallel op is already normalized. 503 bool isAlreadyNormalized = 504 llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) { 505 int64_t step = std::get<0>(tuple); 506 auto lbExpr = dyn_cast<AffineConstantExpr>(std::get<1>(tuple)); 507 return lbExpr && lbExpr.getValue() == 0 && step == 1; 508 }); 509 if (isAlreadyNormalized) 510 return; 511 512 AffineValueMap ranges; 513 AffineValueMap::difference(op.getUpperBoundsValueMap(), 514 op.getLowerBoundsValueMap(), &ranges); 515 auto builder = OpBuilder::atBlockBegin(op.getBody()); 516 auto zeroExpr = builder.getAffineConstantExpr(0); 517 SmallVector<AffineExpr, 8> lbExprs; 518 SmallVector<AffineExpr, 8> ubExprs; 519 for (unsigned i = 0, e = steps.size(); i < e; ++i) { 520 int64_t step = steps[i]; 521 522 // Adjust the lower bound to be 0. 523 lbExprs.push_back(zeroExpr); 524 525 // Adjust the upper bound expression: 'range / step'. 526 AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step); 527 ubExprs.push_back(ubExpr); 528 529 // Adjust the corresponding IV: 'lb + i * step'. 530 BlockArgument iv = op.getBody()->getArgument(i); 531 AffineExpr lbExpr = lbMap.getResult(i); 532 unsigned nDims = lbMap.getNumDims(); 533 auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step; 534 auto map = AffineMap::get(/*dimCount=*/nDims + 1, 535 /*symbolCount=*/lbMap.getNumSymbols(), expr); 536 537 // Use an 'affine.apply' op that will be simplified later in subsequent 538 // canonicalizations. 539 OperandRange lbOperands = op.getLowerBoundsOperands(); 540 OperandRange dimOperands = lbOperands.take_front(nDims); 541 OperandRange symbolOperands = lbOperands.drop_front(nDims); 542 SmallVector<Value, 8> applyOperands{dimOperands}; 543 applyOperands.push_back(iv); 544 applyOperands.append(symbolOperands.begin(), symbolOperands.end()); 545 auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands); 546 iv.replaceAllUsesExcept(apply, apply); 547 } 548 549 SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1); 550 op.setSteps(newSteps); 551 auto newLowerMap = AffineMap::get( 552 /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext()); 553 op.setLowerBounds({}, newLowerMap); 554 auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(), 555 ubExprs, op.getContext()); 556 op.setUpperBounds(ranges.getOperands(), newUpperMap); 557 } 558 559 LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op, 560 bool promoteSingleIter) { 561 if (promoteSingleIter && succeeded(promoteIfSingleIteration(op))) 562 return success(); 563 564 // Check if the forop is already normalized. 565 if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) && 566 (op.getStep() == 1)) 567 return success(); 568 569 // Check if the lower bound has a single result only. Loops with a max lower 570 // bound can't be normalized without additional support like 571 // affine.execute_region's. If the lower bound does not have a single result 572 // then skip this op. 573 if (op.getLowerBoundMap().getNumResults() != 1) 574 return failure(); 575 576 Location loc = op.getLoc(); 577 OpBuilder opBuilder(op); 578 int64_t origLoopStep = op.getStepAsInt(); 579 580 // Construct the new upper bound value map. 581 AffineMap oldLbMap = op.getLowerBoundMap(); 582 // The upper bound can have multiple results. To use 583 // AffineValueMap::difference, we need to have the same number of results in 584 // both lower and upper bound maps. So, we just create a value map for the 585 // lower bound with the only available lower bound result repeated to pad up 586 // to the number of upper bound results. 587 SmallVector<AffineExpr> lbExprs(op.getUpperBoundMap().getNumResults(), 588 op.getLowerBoundMap().getResult(0)); 589 AffineValueMap lbMap(oldLbMap, op.getLowerBoundOperands()); 590 AffineMap paddedLbMap = 591 AffineMap::get(oldLbMap.getNumDims(), oldLbMap.getNumSymbols(), lbExprs, 592 op.getContext()); 593 AffineValueMap paddedLbValueMap(paddedLbMap, op.getLowerBoundOperands()); 594 AffineValueMap ubValueMap(op.getUpperBoundMap(), op.getUpperBoundOperands()); 595 AffineValueMap newUbValueMap; 596 // Compute the `upper bound - lower bound`. 597 AffineValueMap::difference(ubValueMap, paddedLbValueMap, &newUbValueMap); 598 (void)newUbValueMap.canonicalize(); 599 600 // Scale down the upper bound value map by the loop step. 601 unsigned numResult = newUbValueMap.getNumResults(); 602 SmallVector<AffineExpr> scaleDownExprs(numResult); 603 for (unsigned i = 0; i < numResult; ++i) 604 scaleDownExprs[i] = opBuilder.getAffineDimExpr(i).ceilDiv(origLoopStep); 605 // `scaleDownMap` is (d0, d1, ..., d_n) -> (d0 / step, d1 / step, ..., d_n / 606 // step). Where `n` is the number of results in the upper bound map. 607 AffineMap scaleDownMap = 608 AffineMap::get(numResult, 0, scaleDownExprs, op.getContext()); 609 AffineMap newUbMap = scaleDownMap.compose(newUbValueMap.getAffineMap()); 610 611 // Set the newly create upper bound map and operands. 612 op.setUpperBound(newUbValueMap.getOperands(), newUbMap); 613 op.setLowerBound({}, opBuilder.getConstantAffineMap(0)); 614 op.setStep(1); 615 616 // Calculate the Value of new loopIV. Create affine.apply for the value of 617 // the loopIV in normalized loop. 618 opBuilder.setInsertionPointToStart(op.getBody()); 619 // Construct an affine.apply op mapping the new IV to the old IV. 620 AffineMap scaleIvMap = 621 AffineMap::get(1, 0, -opBuilder.getAffineDimExpr(0) * origLoopStep); 622 AffineValueMap scaleIvValueMap(scaleIvMap, ValueRange{op.getInductionVar()}); 623 AffineValueMap newIvToOldIvMap; 624 AffineValueMap::difference(lbMap, scaleIvValueMap, &newIvToOldIvMap); 625 (void)newIvToOldIvMap.canonicalize(); 626 auto newIV = opBuilder.create<AffineApplyOp>( 627 loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands()); 628 op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); 629 return success(); 630 } 631 632 /// Returns true if the memory operation of `destAccess` depends on `srcAccess` 633 /// inside of the innermost common surrounding affine loop between the two 634 /// accesses. 635 static bool mustReachAtInnermost(const MemRefAccess &srcAccess, 636 const MemRefAccess &destAccess) { 637 // Affine dependence analysis is possible only if both ops in the same 638 // AffineScope. 639 if (getAffineScope(srcAccess.opInst) != getAffineScope(destAccess.opInst)) 640 return false; 641 642 unsigned nsLoops = 643 getNumCommonSurroundingLoops(*srcAccess.opInst, *destAccess.opInst); 644 DependenceResult result = 645 checkMemrefAccessDependence(srcAccess, destAccess, nsLoops + 1); 646 return hasDependence(result); 647 } 648 649 /// Returns true if `srcMemOp` may have an effect on `destMemOp` within the 650 /// scope of the outermost `minSurroundingLoops` loops that surround them. 651 /// `srcMemOp` and `destMemOp` are expected to be affine read/write ops. 652 static bool mayHaveEffect(Operation *srcMemOp, Operation *destMemOp, 653 unsigned minSurroundingLoops) { 654 MemRefAccess srcAccess(srcMemOp); 655 MemRefAccess destAccess(destMemOp); 656 657 // Affine dependence analysis here is applicable only if both ops operate on 658 // the same memref and if `srcMemOp` and `destMemOp` are in the same 659 // AffineScope. Also, we can only check if our affine scope is isolated from 660 // above; otherwise, values can from outside of the affine scope that the 661 // check below cannot analyze. 662 Region *srcScope = getAffineScope(srcMemOp); 663 if (srcAccess.memref == destAccess.memref && 664 srcScope == getAffineScope(destMemOp)) { 665 unsigned nsLoops = getNumCommonSurroundingLoops(*srcMemOp, *destMemOp); 666 FlatAffineValueConstraints dependenceConstraints; 667 for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) { 668 DependenceResult result = checkMemrefAccessDependence( 669 srcAccess, destAccess, d, &dependenceConstraints, 670 /*dependenceComponents=*/nullptr); 671 // A dependence failure or the presence of a dependence implies a 672 // side effect. 673 if (!noDependence(result)) 674 return true; 675 } 676 // No side effect was seen. 677 return false; 678 } 679 // TODO: Check here if the memrefs alias: there is no side effect if 680 // `srcAccess.memref` and `destAccess.memref` don't alias. 681 return true; 682 } 683 684 template <typename EffectType, typename T> 685 bool mlir::affine::hasNoInterveningEffect( 686 Operation *start, T memOp, 687 llvm::function_ref<bool(Value, Value)> mayAlias) { 688 // A boolean representing whether an intervening operation could have impacted 689 // memOp. 690 bool hasSideEffect = false; 691 692 // Check whether the effect on memOp can be caused by a given operation op. 693 Value memref = memOp.getMemRef(); 694 std::function<void(Operation *)> checkOperation = [&](Operation *op) { 695 // If the effect has alreay been found, early exit, 696 if (hasSideEffect) 697 return; 698 699 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) { 700 SmallVector<MemoryEffects::EffectInstance, 1> effects; 701 memEffect.getEffects(effects); 702 703 bool opMayHaveEffect = false; 704 for (auto effect : effects) { 705 // If op causes EffectType on a potentially aliasing location for 706 // memOp, mark as having the effect. 707 if (isa<EffectType>(effect.getEffect())) { 708 if (effect.getValue() && effect.getValue() != memref && 709 !mayAlias(effect.getValue(), memref)) 710 continue; 711 opMayHaveEffect = true; 712 break; 713 } 714 } 715 716 if (!opMayHaveEffect) 717 return; 718 719 // If the side effect comes from an affine read or write, try to 720 // prove the side effecting `op` cannot reach `memOp`. 721 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 722 // For ease, let's consider the case that `op` is a store and 723 // we're looking for other potential stores that overwrite memory after 724 // `start`, and before being read in `memOp`. In this case, we only 725 // need to consider other potential stores with depth > 726 // minSurroundingLoops since `start` would overwrite any store with a 727 // smaller number of surrounding loops before. 728 unsigned minSurroundingLoops = 729 getNumCommonSurroundingLoops(*start, *memOp); 730 if (mayHaveEffect(op, memOp, minSurroundingLoops)) 731 hasSideEffect = true; 732 return; 733 } 734 735 // We have an op with a memory effect and we cannot prove if it 736 // intervenes. 737 hasSideEffect = true; 738 return; 739 } 740 741 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) { 742 // Recurse into the regions for this op and check whether the internal 743 // operations may have the side effect `EffectType` on memOp. 744 for (Region ®ion : op->getRegions()) 745 for (Block &block : region) 746 for (Operation &op : block) 747 checkOperation(&op); 748 return; 749 } 750 751 // Otherwise, conservatively assume generic operations have the effect 752 // on the operation 753 hasSideEffect = true; 754 }; 755 756 // Check all paths from ancestor op `parent` to the operation `to` for the 757 // effect. It is known that `to` must be contained within `parent`. 758 auto until = [&](Operation *parent, Operation *to) { 759 // TODO check only the paths from `parent` to `to`. 760 // Currently we fallback and check the entire parent op, rather than 761 // just the paths from the parent path, stopping after reaching `to`. 762 // This is conservatively correct, but could be made more aggressive. 763 assert(parent->isAncestor(to)); 764 checkOperation(parent); 765 }; 766 767 // Check for all paths from operation `from` to operation `untilOp` for the 768 // given memory effect. 769 std::function<void(Operation *, Operation *)> recur = 770 [&](Operation *from, Operation *untilOp) { 771 assert( 772 from->getParentRegion()->isAncestor(untilOp->getParentRegion()) && 773 "Checking for side effect between two operations without a common " 774 "ancestor"); 775 776 // If the operations are in different regions, recursively consider all 777 // path from `from` to the parent of `to` and all paths from the parent 778 // of `to` to `to`. 779 if (from->getParentRegion() != untilOp->getParentRegion()) { 780 recur(from, untilOp->getParentOp()); 781 until(untilOp->getParentOp(), untilOp); 782 return; 783 } 784 785 // Now, assuming that `from` and `to` exist in the same region, perform 786 // a CFG traversal to check all the relevant operations. 787 788 // Additional blocks to consider. 789 SmallVector<Block *, 2> todoBlocks; 790 { 791 // First consider the parent block of `from` an check all operations 792 // after `from`. 793 for (auto iter = ++from->getIterator(), end = from->getBlock()->end(); 794 iter != end && &*iter != untilOp; ++iter) { 795 checkOperation(&*iter); 796 } 797 798 // If the parent of `from` doesn't contain `to`, add the successors 799 // to the list of blocks to check. 800 if (untilOp->getBlock() != from->getBlock()) 801 for (Block *succ : from->getBlock()->getSuccessors()) 802 todoBlocks.push_back(succ); 803 } 804 805 SmallPtrSet<Block *, 4> done; 806 // Traverse the CFG until hitting `to`. 807 while (!todoBlocks.empty()) { 808 Block *blk = todoBlocks.pop_back_val(); 809 if (done.count(blk)) 810 continue; 811 done.insert(blk); 812 for (auto &op : *blk) { 813 if (&op == untilOp) 814 break; 815 checkOperation(&op); 816 if (&op == blk->getTerminator()) 817 for (Block *succ : blk->getSuccessors()) 818 todoBlocks.push_back(succ); 819 } 820 } 821 }; 822 recur(start, memOp); 823 return !hasSideEffect; 824 } 825 826 /// Attempt to eliminate loadOp by replacing it with a value stored into memory 827 /// which the load is guaranteed to retrieve. This check involves three 828 /// components: 1) The store and load must be on the same location 2) The store 829 /// must dominate (and therefore must always occur prior to) the load 3) No 830 /// other operations will overwrite the memory loaded between the given load 831 /// and store. If such a value exists, the replaced `loadOp` will be added to 832 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`. 833 static void forwardStoreToLoad( 834 AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase, 835 SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo, 836 llvm::function_ref<bool(Value, Value)> mayAlias) { 837 838 // The store op candidate for forwarding that satisfies all conditions 839 // to replace the load, if any. 840 Operation *lastWriteStoreOp = nullptr; 841 842 for (auto *user : loadOp.getMemRef().getUsers()) { 843 auto storeOp = dyn_cast<AffineWriteOpInterface>(user); 844 if (!storeOp) 845 continue; 846 MemRefAccess srcAccess(storeOp); 847 MemRefAccess destAccess(loadOp); 848 849 // 1. Check if the store and the load have mathematically equivalent 850 // affine access functions; this implies that they statically refer to the 851 // same single memref element. As an example this filters out cases like: 852 // store %A[%i0 + 1] 853 // load %A[%i0] 854 // store %A[%M] 855 // load %A[%N] 856 // Use the AffineValueMap difference based memref access equality checking. 857 if (srcAccess != destAccess) 858 continue; 859 860 // 2. The store has to dominate the load op to be candidate. 861 if (!domInfo.dominates(storeOp, loadOp)) 862 continue; 863 864 // 3. The store must reach the load. Access function equivalence only 865 // guarantees this for accesses in the same block. The load could be in a 866 // nested block that is unreachable. 867 if (!mustReachAtInnermost(srcAccess, destAccess)) 868 continue; 869 870 // 4. Ensure there is no intermediate operation which could replace the 871 // value in memory. 872 if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp, 873 mayAlias)) 874 continue; 875 876 // We now have a candidate for forwarding. 877 assert(lastWriteStoreOp == nullptr && 878 "multiple simultaneous replacement stores"); 879 lastWriteStoreOp = storeOp; 880 } 881 882 if (!lastWriteStoreOp) 883 return; 884 885 // Perform the actual store to load forwarding. 886 Value storeVal = 887 cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore(); 888 // Check if 2 values have the same shape. This is needed for affine vector 889 // loads and stores. 890 if (storeVal.getType() != loadOp.getValue().getType()) 891 return; 892 loadOp.getValue().replaceAllUsesWith(storeVal); 893 // Record the memref for a later sweep to optimize away. 894 memrefsToErase.insert(loadOp.getMemRef()); 895 // Record this to erase later. 896 loadOpsToErase.push_back(loadOp); 897 } 898 899 template bool 900 mlir::affine::hasNoInterveningEffect<mlir::MemoryEffects::Read, 901 affine::AffineReadOpInterface>( 902 mlir::Operation *, affine::AffineReadOpInterface, 903 llvm::function_ref<bool(Value, Value)>); 904 905 // This attempts to find stores which have no impact on the final result. 906 // A writing op writeA will be eliminated if there exists an op writeB if 907 // 1) writeA and writeB have mathematically equivalent affine access functions. 908 // 2) writeB postdominates writeA. 909 // 3) There is no potential read between writeA and writeB. 910 static void findUnusedStore(AffineWriteOpInterface writeA, 911 SmallVectorImpl<Operation *> &opsToErase, 912 PostDominanceInfo &postDominanceInfo, 913 llvm::function_ref<bool(Value, Value)> mayAlias) { 914 915 for (Operation *user : writeA.getMemRef().getUsers()) { 916 // Only consider writing operations. 917 auto writeB = dyn_cast<AffineWriteOpInterface>(user); 918 if (!writeB) 919 continue; 920 921 // The operations must be distinct. 922 if (writeB == writeA) 923 continue; 924 925 // Both operations must lie in the same region. 926 if (writeB->getParentRegion() != writeA->getParentRegion()) 927 continue; 928 929 // Both operations must write to the same memory. 930 MemRefAccess srcAccess(writeB); 931 MemRefAccess destAccess(writeA); 932 933 if (srcAccess != destAccess) 934 continue; 935 936 // writeB must postdominate writeA. 937 if (!postDominanceInfo.postDominates(writeB, writeA)) 938 continue; 939 940 // There cannot be an operation which reads from memory between 941 // the two writes. 942 if (!affine::hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB, 943 mayAlias)) 944 continue; 945 946 opsToErase.push_back(writeA); 947 break; 948 } 949 } 950 951 // The load to load forwarding / redundant load elimination is similar to the 952 // store to load forwarding. 953 // loadA will be be replaced with loadB if: 954 // 1) loadA and loadB have mathematically equivalent affine access functions. 955 // 2) loadB dominates loadA. 956 // 3) There is no write between loadA and loadB. 957 static void loadCSE(AffineReadOpInterface loadA, 958 SmallVectorImpl<Operation *> &loadOpsToErase, 959 DominanceInfo &domInfo, 960 llvm::function_ref<bool(Value, Value)> mayAlias) { 961 SmallVector<AffineReadOpInterface, 4> loadCandidates; 962 for (auto *user : loadA.getMemRef().getUsers()) { 963 auto loadB = dyn_cast<AffineReadOpInterface>(user); 964 if (!loadB || loadB == loadA) 965 continue; 966 967 MemRefAccess srcAccess(loadB); 968 MemRefAccess destAccess(loadA); 969 970 // 1. The accesses should be to be to the same location. 971 if (srcAccess != destAccess) { 972 continue; 973 } 974 975 // 2. loadB should dominate loadA. 976 if (!domInfo.dominates(loadB, loadA)) 977 continue; 978 979 // 3. There should not be a write between loadA and loadB. 980 if (!affine::hasNoInterveningEffect<MemoryEffects::Write>( 981 loadB.getOperation(), loadA, mayAlias)) 982 continue; 983 984 // Check if two values have the same shape. This is needed for affine vector 985 // loads. 986 if (loadB.getValue().getType() != loadA.getValue().getType()) 987 continue; 988 989 loadCandidates.push_back(loadB); 990 } 991 992 // Of the legal load candidates, use the one that dominates all others 993 // to minimize the subsequent need to loadCSE 994 Value loadB; 995 for (AffineReadOpInterface option : loadCandidates) { 996 if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) { 997 return depStore == option || 998 domInfo.dominates(option.getOperation(), 999 depStore.getOperation()); 1000 })) { 1001 loadB = option.getValue(); 1002 break; 1003 } 1004 } 1005 1006 if (loadB) { 1007 loadA.getValue().replaceAllUsesWith(loadB); 1008 // Record this to erase later. 1009 loadOpsToErase.push_back(loadA); 1010 } 1011 } 1012 1013 // The store to load forwarding and load CSE rely on three conditions: 1014 // 1015 // 1) store/load providing a replacement value and load being replaced need to 1016 // have mathematically equivalent affine access functions (checked after full 1017 // composition of load/store operands); this implies that they access the same 1018 // single memref element for all iterations of the common surrounding loop, 1019 // 1020 // 2) the store/load op should dominate the load op, 1021 // 1022 // 3) no operation that may write to memory read by the load being replaced can 1023 // occur after executing the instruction (load or store) providing the 1024 // replacement value and before the load being replaced (thus potentially 1025 // allowing overwriting the memory read by the load). 1026 // 1027 // The above conditions are simple to check, sufficient, and powerful for most 1028 // cases in practice - they are sufficient, but not necessary --- since they 1029 // don't reason about loops that are guaranteed to execute at least once or 1030 // multiple sources to forward from. 1031 // 1032 // TODO: more forwarding can be done when support for 1033 // loop/conditional live-out SSA values is available. 1034 // TODO: do general dead store elimination for memref's. This pass 1035 // currently only eliminates the stores only if no other loads/uses (other 1036 // than dealloc) remain. 1037 // 1038 void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, 1039 PostDominanceInfo &postDomInfo, 1040 AliasAnalysis &aliasAnalysis) { 1041 // Load op's whose results were replaced by those forwarded from stores. 1042 SmallVector<Operation *, 8> opsToErase; 1043 1044 // A list of memref's that are potentially dead / could be eliminated. 1045 SmallPtrSet<Value, 4> memrefsToErase; 1046 1047 auto mayAlias = [&](Value val1, Value val2) -> bool { 1048 return !aliasAnalysis.alias(val1, val2).isNo(); 1049 }; 1050 1051 // Walk all load's and perform store to load forwarding. 1052 f.walk([&](AffineReadOpInterface loadOp) { 1053 forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo, mayAlias); 1054 }); 1055 for (auto *op : opsToErase) 1056 op->erase(); 1057 opsToErase.clear(); 1058 1059 // Walk all store's and perform unused store elimination 1060 f.walk([&](AffineWriteOpInterface storeOp) { 1061 findUnusedStore(storeOp, opsToErase, postDomInfo, mayAlias); 1062 }); 1063 for (auto *op : opsToErase) 1064 op->erase(); 1065 opsToErase.clear(); 1066 1067 // Check if the store fwd'ed memrefs are now left with only stores and 1068 // deallocs and can thus be completely deleted. Note: the canonicalize pass 1069 // should be able to do this as well, but we'll do it here since we collected 1070 // these anyway. 1071 for (auto memref : memrefsToErase) { 1072 // If the memref hasn't been locally alloc'ed, skip. 1073 Operation *defOp = memref.getDefiningOp(); 1074 if (!defOp || !hasSingleEffect<MemoryEffects::Allocate>(defOp, memref)) 1075 // TODO: if the memref was returned by a 'call' operation, we 1076 // could still erase it if the call had no side-effects. 1077 continue; 1078 if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { 1079 return !isa<AffineWriteOpInterface>(ownerOp) && 1080 !hasSingleEffect<MemoryEffects::Free>(ownerOp, memref); 1081 })) 1082 continue; 1083 1084 // Erase all stores, the dealloc, and the alloc on the memref. 1085 for (auto *user : llvm::make_early_inc_range(memref.getUsers())) 1086 user->erase(); 1087 defOp->erase(); 1088 } 1089 1090 // To eliminate as many loads as possible, run load CSE after eliminating 1091 // stores. Otherwise, some stores are wrongly seen as having an intervening 1092 // effect. 1093 f.walk([&](AffineReadOpInterface loadOp) { 1094 loadCSE(loadOp, opsToErase, domInfo, mayAlias); 1095 }); 1096 for (auto *op : opsToErase) 1097 op->erase(); 1098 } 1099 1100 // Private helper function to transform memref.load with reduced rank. 1101 // This function will modify the indices of the memref.load to match the 1102 // newMemRef. 1103 LogicalResult transformMemRefLoadWithReducedRank( 1104 Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos, 1105 ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands, 1106 ArrayRef<Value> symbolOperands, AffineMap indexRemap) { 1107 unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank(); 1108 unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank(); 1109 unsigned oldMapNumInputs = oldMemRefRank; 1110 SmallVector<Value, 4> oldMapOperands( 1111 op->operand_begin() + memRefOperandPos + 1, 1112 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); 1113 SmallVector<Value, 4> oldMemRefOperands; 1114 oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); 1115 SmallVector<Value, 4> remapOperands; 1116 remapOperands.reserve(extraOperands.size() + oldMemRefRank + 1117 symbolOperands.size()); 1118 remapOperands.append(extraOperands.begin(), extraOperands.end()); 1119 remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); 1120 remapOperands.append(symbolOperands.begin(), symbolOperands.end()); 1121 1122 SmallVector<Value, 4> remapOutputs; 1123 remapOutputs.reserve(oldMemRefRank); 1124 SmallVector<Value, 4> affineApplyOps; 1125 1126 OpBuilder builder(op); 1127 1128 if (indexRemap && 1129 indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { 1130 // Remapped indices. 1131 for (auto resultExpr : indexRemap.getResults()) { 1132 auto singleResMap = AffineMap::get( 1133 indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); 1134 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 1135 remapOperands); 1136 remapOutputs.push_back(afOp); 1137 affineApplyOps.push_back(afOp); 1138 } 1139 } else { 1140 // No remapping specified. 1141 remapOutputs.assign(remapOperands.begin(), remapOperands.end()); 1142 } 1143 1144 SmallVector<Value, 4> newMapOperands; 1145 newMapOperands.reserve(newMemRefRank); 1146 1147 // Prepend 'extraIndices' in 'newMapOperands'. 1148 for (Value extraIndex : extraIndices) { 1149 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && 1150 "invalid memory op index"); 1151 newMapOperands.push_back(extraIndex); 1152 } 1153 1154 // Append 'remapOutputs' to 'newMapOperands'. 1155 newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); 1156 1157 // Create new fully composed AffineMap for new op to be created. 1158 assert(newMapOperands.size() == newMemRefRank); 1159 1160 OperationState state(op->getLoc(), op->getName()); 1161 // Construct the new operation using this memref. 1162 state.operands.reserve(newMapOperands.size() + extraIndices.size()); 1163 state.operands.push_back(newMemRef); 1164 1165 // Insert the new memref map operands. 1166 state.operands.append(newMapOperands.begin(), newMapOperands.end()); 1167 1168 state.types.reserve(op->getNumResults()); 1169 for (auto result : op->getResults()) 1170 state.types.push_back(result.getType()); 1171 1172 // Copy over the attributes from the old operation to the new operation. 1173 for (auto namedAttr : op->getAttrs()) { 1174 state.attributes.push_back(namedAttr); 1175 } 1176 1177 // Create the new operation. 1178 auto *repOp = builder.create(state); 1179 op->replaceAllUsesWith(repOp); 1180 op->erase(); 1181 1182 return success(); 1183 } 1184 // Perform the replacement in `op`. 1185 LogicalResult mlir::affine::replaceAllMemRefUsesWith( 1186 Value oldMemRef, Value newMemRef, Operation *op, 1187 ArrayRef<Value> extraIndices, AffineMap indexRemap, 1188 ArrayRef<Value> extraOperands, ArrayRef<Value> symbolOperands, 1189 bool allowNonDereferencingOps) { 1190 unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank(); 1191 (void)newMemRefRank; // unused in opt mode 1192 unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank(); 1193 (void)oldMemRefRank; // unused in opt mode 1194 if (indexRemap) { 1195 assert(indexRemap.getNumSymbols() == symbolOperands.size() && 1196 "symbolic operand count mismatch"); 1197 assert(indexRemap.getNumInputs() == 1198 extraOperands.size() + oldMemRefRank + symbolOperands.size()); 1199 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); 1200 } else { 1201 assert(oldMemRefRank + extraIndices.size() == newMemRefRank); 1202 } 1203 1204 // Assert same elemental type. 1205 assert(cast<MemRefType>(oldMemRef.getType()).getElementType() == 1206 cast<MemRefType>(newMemRef.getType()).getElementType()); 1207 1208 SmallVector<unsigned, 2> usePositions; 1209 for (const auto &opEntry : llvm::enumerate(op->getOperands())) { 1210 if (opEntry.value() == oldMemRef) 1211 usePositions.push_back(opEntry.index()); 1212 } 1213 1214 // If memref doesn't appear, nothing to do. 1215 if (usePositions.empty()) 1216 return success(); 1217 1218 if (usePositions.size() > 1) { 1219 // TODO: extend it for this case when needed (rare). 1220 assert(false && "multiple dereferencing uses in a single op not supported"); 1221 return failure(); 1222 } 1223 1224 unsigned memRefOperandPos = usePositions.front(); 1225 1226 OpBuilder builder(op); 1227 // The following checks if op is dereferencing memref and performs the access 1228 // index rewrites. 1229 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op); 1230 if (!affMapAccInterface) { 1231 if (!allowNonDereferencingOps) { 1232 // Failure: memref used in a non-dereferencing context (potentially 1233 // escapes); no replacement in these cases unless allowNonDereferencingOps 1234 // is set. 1235 return failure(); 1236 } 1237 1238 // Check if it is a memref.load 1239 auto memrefLoad = dyn_cast<memref::LoadOp>(op); 1240 bool isReductionLike = 1241 indexRemap.getNumResults() < indexRemap.getNumInputs(); 1242 if (!memrefLoad || !isReductionLike) { 1243 op->setOperand(memRefOperandPos, newMemRef); 1244 return success(); 1245 } 1246 1247 return transformMemRefLoadWithReducedRank( 1248 op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands, 1249 symbolOperands, indexRemap); 1250 } 1251 // Perform index rewrites for the dereferencing op and then replace the op 1252 NamedAttribute oldMapAttrPair = 1253 affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); 1254 AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue(); 1255 unsigned oldMapNumInputs = oldMap.getNumInputs(); 1256 SmallVector<Value, 4> oldMapOperands( 1257 op->operand_begin() + memRefOperandPos + 1, 1258 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); 1259 1260 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. 1261 SmallVector<Value, 4> oldMemRefOperands; 1262 SmallVector<Value, 4> affineApplyOps; 1263 oldMemRefOperands.reserve(oldMemRefRank); 1264 if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { 1265 for (auto resultExpr : oldMap.getResults()) { 1266 auto singleResMap = AffineMap::get(oldMap.getNumDims(), 1267 oldMap.getNumSymbols(), resultExpr); 1268 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 1269 oldMapOperands); 1270 oldMemRefOperands.push_back(afOp); 1271 affineApplyOps.push_back(afOp); 1272 } 1273 } else { 1274 oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); 1275 } 1276 1277 // Construct new indices as a remap of the old ones if a remapping has been 1278 // provided. The indices of a memref come right after it, i.e., 1279 // at position memRefOperandPos + 1. 1280 SmallVector<Value, 4> remapOperands; 1281 remapOperands.reserve(extraOperands.size() + oldMemRefRank + 1282 symbolOperands.size()); 1283 remapOperands.append(extraOperands.begin(), extraOperands.end()); 1284 remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); 1285 remapOperands.append(symbolOperands.begin(), symbolOperands.end()); 1286 1287 SmallVector<Value, 4> remapOutputs; 1288 remapOutputs.reserve(oldMemRefRank); 1289 1290 if (indexRemap && 1291 indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { 1292 // Remapped indices. 1293 for (auto resultExpr : indexRemap.getResults()) { 1294 auto singleResMap = AffineMap::get( 1295 indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); 1296 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 1297 remapOperands); 1298 remapOutputs.push_back(afOp); 1299 affineApplyOps.push_back(afOp); 1300 } 1301 } else { 1302 // No remapping specified. 1303 remapOutputs.assign(remapOperands.begin(), remapOperands.end()); 1304 } 1305 1306 SmallVector<Value, 4> newMapOperands; 1307 newMapOperands.reserve(newMemRefRank); 1308 1309 // Prepend 'extraIndices' in 'newMapOperands'. 1310 for (Value extraIndex : extraIndices) { 1311 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && 1312 "invalid memory op index"); 1313 newMapOperands.push_back(extraIndex); 1314 } 1315 1316 // Append 'remapOutputs' to 'newMapOperands'. 1317 newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); 1318 1319 // Create new fully composed AffineMap for new op to be created. 1320 assert(newMapOperands.size() == newMemRefRank); 1321 auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); 1322 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); 1323 newMap = simplifyAffineMap(newMap); 1324 canonicalizeMapAndOperands(&newMap, &newMapOperands); 1325 // Remove any affine.apply's that became dead as a result of composition. 1326 for (Value value : affineApplyOps) 1327 if (value.use_empty()) 1328 value.getDefiningOp()->erase(); 1329 1330 OperationState state(op->getLoc(), op->getName()); 1331 // Construct the new operation using this memref. 1332 state.operands.reserve(op->getNumOperands() + extraIndices.size()); 1333 // Insert the non-memref operands. 1334 state.operands.append(op->operand_begin(), 1335 op->operand_begin() + memRefOperandPos); 1336 // Insert the new memref value. 1337 state.operands.push_back(newMemRef); 1338 1339 // Insert the new memref map operands. 1340 state.operands.append(newMapOperands.begin(), newMapOperands.end()); 1341 1342 // Insert the remaining operands unmodified. 1343 state.operands.append(op->operand_begin() + memRefOperandPos + 1 + 1344 oldMapNumInputs, 1345 op->operand_end()); 1346 1347 // Result types don't change. Both memref's are of the same elemental type. 1348 state.types.reserve(op->getNumResults()); 1349 for (auto result : op->getResults()) 1350 state.types.push_back(result.getType()); 1351 1352 // Add attribute for 'newMap', other Attributes do not change. 1353 auto newMapAttr = AffineMapAttr::get(newMap); 1354 for (auto namedAttr : op->getAttrs()) { 1355 if (namedAttr.getName() == oldMapAttrPair.getName()) 1356 state.attributes.push_back({namedAttr.getName(), newMapAttr}); 1357 else 1358 state.attributes.push_back(namedAttr); 1359 } 1360 1361 // Create the new operation. 1362 auto *repOp = builder.create(state); 1363 op->replaceAllUsesWith(repOp); 1364 op->erase(); 1365 1366 return success(); 1367 } 1368 1369 LogicalResult mlir::affine::replaceAllMemRefUsesWith( 1370 Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices, 1371 AffineMap indexRemap, ArrayRef<Value> extraOperands, 1372 ArrayRef<Value> symbolOperands, Operation *domOpFilter, 1373 Operation *postDomOpFilter, bool allowNonDereferencingOps, 1374 bool replaceInDeallocOp) { 1375 unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank(); 1376 (void)newMemRefRank; // unused in opt mode 1377 unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank(); 1378 (void)oldMemRefRank; 1379 if (indexRemap) { 1380 assert(indexRemap.getNumSymbols() == symbolOperands.size() && 1381 "symbol operand count mismatch"); 1382 assert(indexRemap.getNumInputs() == 1383 extraOperands.size() + oldMemRefRank + symbolOperands.size()); 1384 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); 1385 } else { 1386 assert(oldMemRefRank + extraIndices.size() == newMemRefRank); 1387 } 1388 1389 // Assert same elemental type. 1390 assert(cast<MemRefType>(oldMemRef.getType()).getElementType() == 1391 cast<MemRefType>(newMemRef.getType()).getElementType()); 1392 1393 std::unique_ptr<DominanceInfo> domInfo; 1394 std::unique_ptr<PostDominanceInfo> postDomInfo; 1395 if (domOpFilter) 1396 domInfo = std::make_unique<DominanceInfo>( 1397 domOpFilter->getParentOfType<FunctionOpInterface>()); 1398 1399 if (postDomOpFilter) 1400 postDomInfo = std::make_unique<PostDominanceInfo>( 1401 postDomOpFilter->getParentOfType<FunctionOpInterface>()); 1402 1403 // Walk all uses of old memref; collect ops to perform replacement. We use a 1404 // DenseSet since an operation could potentially have multiple uses of a 1405 // memref (although rare), and the replacement later is going to erase ops. 1406 DenseSet<Operation *> opsToReplace; 1407 for (auto *op : oldMemRef.getUsers()) { 1408 // Skip this use if it's not dominated by domOpFilter. 1409 if (domOpFilter && !domInfo->dominates(domOpFilter, op)) 1410 continue; 1411 1412 // Skip this use if it's not post-dominated by postDomOpFilter. 1413 if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op)) 1414 continue; 1415 1416 // Skip dealloc's - no replacement is necessary, and a memref replacement 1417 // at other uses doesn't hurt these dealloc's. 1418 if (hasSingleEffect<MemoryEffects::Free>(op, oldMemRef) && 1419 !replaceInDeallocOp) 1420 continue; 1421 1422 // Check if the memref was used in a non-dereferencing context. It is fine 1423 // for the memref to be used in a non-dereferencing way outside of the 1424 // region where this replacement is happening. 1425 if (!isa<AffineMapAccessInterface>(*op)) { 1426 if (!allowNonDereferencingOps) { 1427 LLVM_DEBUG(llvm::dbgs() 1428 << "Memref replacement failed: non-deferencing memref op: \n" 1429 << *op << '\n'); 1430 return failure(); 1431 } 1432 // Non-dereferencing ops with the MemRefsNormalizable trait are 1433 // supported for replacement. 1434 if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) { 1435 LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a " 1436 "memrefs normalizable trait: \n" 1437 << *op << '\n'); 1438 return failure(); 1439 } 1440 } 1441 1442 // We'll first collect and then replace --- since replacement erases the op 1443 // that has the use, and that op could be postDomFilter or domFilter itself! 1444 opsToReplace.insert(op); 1445 } 1446 1447 for (auto *op : opsToReplace) { 1448 if (failed(replaceAllMemRefUsesWith( 1449 oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands, 1450 symbolOperands, allowNonDereferencingOps))) 1451 llvm_unreachable("memref replacement guaranteed to succeed here"); 1452 } 1453 1454 return success(); 1455 } 1456 1457 /// Given an operation, inserts one or more single result affine 1458 /// apply operations, results of which are exclusively used by this operation 1459 /// operation. The operands of these newly created affine apply ops are 1460 /// guaranteed to be loop iterators or terminal symbols of a function. 1461 /// 1462 /// Before 1463 /// 1464 /// affine.for %i = 0 to #map(%N) 1465 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) 1466 /// "send"(%idx, %A, ...) 1467 /// "compute"(%idx) 1468 /// 1469 /// After 1470 /// 1471 /// affine.for %i = 0 to #map(%N) 1472 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) 1473 /// "send"(%idx, %A, ...) 1474 /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) 1475 /// "compute"(%idx_) 1476 /// 1477 /// This allows applying different transformations on send and compute (for eg. 1478 /// different shifts/delays). 1479 /// 1480 /// Returns nullptr either if none of opInst's operands were the result of an 1481 /// affine.apply and thus there was no affine computation slice to create, or if 1482 /// all the affine.apply op's supplying operands to this opInst did not have any 1483 /// uses besides this opInst; otherwise returns the list of affine.apply 1484 /// operations created in output argument `sliceOps`. 1485 void mlir::affine::createAffineComputationSlice( 1486 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) { 1487 // Collect all operands that are results of affine apply ops. 1488 SmallVector<Value, 4> subOperands; 1489 subOperands.reserve(opInst->getNumOperands()); 1490 for (auto operand : opInst->getOperands()) 1491 if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp())) 1492 subOperands.push_back(operand); 1493 1494 // Gather sequence of AffineApplyOps reachable from 'subOperands'. 1495 SmallVector<Operation *, 4> affineApplyOps; 1496 getReachableAffineApplyOps(subOperands, affineApplyOps); 1497 // Skip transforming if there are no affine maps to compose. 1498 if (affineApplyOps.empty()) 1499 return; 1500 1501 // Check if all uses of the affine apply op's lie only in this op op, in 1502 // which case there would be nothing to do. 1503 bool localized = true; 1504 for (auto *op : affineApplyOps) { 1505 for (auto result : op->getResults()) { 1506 for (auto *user : result.getUsers()) { 1507 if (user != opInst) { 1508 localized = false; 1509 break; 1510 } 1511 } 1512 } 1513 } 1514 if (localized) 1515 return; 1516 1517 OpBuilder builder(opInst); 1518 SmallVector<Value, 4> composedOpOperands(subOperands); 1519 auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); 1520 fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); 1521 1522 // Create an affine.apply for each of the map results. 1523 sliceOps->reserve(composedMap.getNumResults()); 1524 for (auto resultExpr : composedMap.getResults()) { 1525 auto singleResMap = AffineMap::get(composedMap.getNumDims(), 1526 composedMap.getNumSymbols(), resultExpr); 1527 sliceOps->push_back(builder.create<AffineApplyOp>( 1528 opInst->getLoc(), singleResMap, composedOpOperands)); 1529 } 1530 1531 // Construct the new operands that include the results from the composed 1532 // affine apply op above instead of existing ones (subOperands). So, they 1533 // differ from opInst's operands only for those operands in 'subOperands', for 1534 // which they will be replaced by the corresponding one from 'sliceOps'. 1535 SmallVector<Value, 4> newOperands(opInst->getOperands()); 1536 for (Value &operand : newOperands) { 1537 // Replace the subOperands from among the new operands. 1538 unsigned j, f; 1539 for (j = 0, f = subOperands.size(); j < f; j++) { 1540 if (operand == subOperands[j]) 1541 break; 1542 } 1543 if (j < subOperands.size()) 1544 operand = (*sliceOps)[j]; 1545 } 1546 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) 1547 opInst->setOperand(idx, newOperands[idx]); 1548 } 1549 1550 /// Enum to set patterns of affine expr in tiled-layout map. 1551 /// TileFloorDiv: <dim expr> div <tile size> 1552 /// TileMod: <dim expr> mod <tile size> 1553 /// TileNone: None of the above 1554 /// Example: 1555 /// #tiled_2d_128x256 = affine_map<(d0, d1) 1556 /// -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)> 1557 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv 1558 /// "d0 mod 128" and "d1 mod 256" ==> TileMod 1559 enum TileExprPattern { TileFloorDiv, TileMod, TileNone }; 1560 1561 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions 1562 /// being floordiv'ed by respective tile sizes appeare in a mod with the same 1563 /// tile sizes, and no other expression involves those k dimensions. This 1564 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for 1565 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a 1566 /// tiled layout, an empty vector is returned. 1567 static LogicalResult getTileSizePos( 1568 AffineMap map, 1569 SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) { 1570 // Create `floordivExprs` which is a vector of tuples including LHS and RHS of 1571 // `floordiv` and its position in `map` output. 1572 // Example: #tiled_2d_128x256 = affine_map<(d0, d1) 1573 // -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)> 1574 // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}. 1575 SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs; 1576 unsigned pos = 0; 1577 for (AffineExpr expr : map.getResults()) { 1578 if (expr.getKind() == AffineExprKind::FloorDiv) { 1579 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 1580 if (isa<AffineConstantExpr>(binaryExpr.getRHS())) 1581 floordivExprs.emplace_back( 1582 std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos)); 1583 } 1584 pos++; 1585 } 1586 // Not tiled layout if `floordivExprs` is empty. 1587 if (floordivExprs.empty()) { 1588 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{}; 1589 return success(); 1590 } 1591 1592 // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is 1593 // not tiled layout. 1594 for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) { 1595 AffineExpr floordivExprLHS = std::get<0>(fexpr); 1596 AffineExpr floordivExprRHS = std::get<1>(fexpr); 1597 unsigned floordivPos = std::get<2>(fexpr); 1598 1599 // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS 1600 // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used 1601 // other expr, the map is not tiled layout. Example of non tiled layout: 1602 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)> 1603 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)> 1604 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod 1605 // 256)> 1606 bool found = false; 1607 pos = 0; 1608 for (AffineExpr expr : map.getResults()) { 1609 bool notTiled = false; 1610 if (pos != floordivPos) { 1611 expr.walk([&](AffineExpr e) { 1612 if (e == floordivExprLHS) { 1613 if (expr.getKind() == AffineExprKind::Mod) { 1614 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr); 1615 // If LHS and RHS of `mod` are the same with those of floordiv. 1616 if (floordivExprLHS == binaryExpr.getLHS() && 1617 floordivExprRHS == binaryExpr.getRHS()) { 1618 // Save tile size (RHS of `mod`), and position of `floordiv` and 1619 // `mod` if same expr with `mod` is not found yet. 1620 if (!found) { 1621 tileSizePos.emplace_back( 1622 std::make_tuple(binaryExpr.getRHS(), floordivPos, pos)); 1623 found = true; 1624 } else { 1625 // Non tiled layout: Have multilpe `mod` with the same LHS. 1626 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1627 // mod 256, d2 mod 256)> 1628 notTiled = true; 1629 } 1630 } else { 1631 // Non tiled layout: RHS of `mod` is different from `floordiv`. 1632 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1633 // mod 128)> 1634 notTiled = true; 1635 } 1636 } else { 1637 // Non tiled layout: LHS is the same, but not `mod`. 1638 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1639 // floordiv 256)> 1640 notTiled = true; 1641 } 1642 } 1643 }); 1644 } 1645 if (notTiled) { 1646 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{}; 1647 return success(); 1648 } 1649 pos++; 1650 } 1651 } 1652 return success(); 1653 } 1654 1655 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic 1656 /// after normalization. Dimensions that include dynamic dimensions in the map 1657 /// output will become dynamic dimensions. Return true if `dim` is dynamic 1658 /// dimension. 1659 /// 1660 /// Example: 1661 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)> 1662 /// 1663 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic. 1664 /// memref<4x?xf32, #map0> ==> memref<4x?x?xf32> 1665 static bool 1666 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap, 1667 SmallVectorImpl<unsigned> &inMemrefTypeDynDims) { 1668 AffineExpr expr = layoutMap.getResults()[dim]; 1669 // Check if affine expr of the dimension includes dynamic dimension of input 1670 // memrefType. 1671 MLIRContext *context = layoutMap.getContext(); 1672 return expr 1673 .walk([&](AffineExpr e) { 1674 if (isa<AffineDimExpr>(e) && 1675 llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) { 1676 return e == getAffineDimExpr(dim, context); 1677 })) 1678 return WalkResult::interrupt(); 1679 return WalkResult::advance(); 1680 }) 1681 .wasInterrupted(); 1682 } 1683 1684 /// Create affine expr to calculate dimension size for a tiled-layout map. 1685 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput, 1686 TileExprPattern pat) { 1687 // Create map output for the patterns. 1688 // "floordiv <tile size>" ==> "ceildiv <tile size>" 1689 // "mod <tile size>" ==> "<tile size>" 1690 AffineExpr newMapOutput; 1691 AffineBinaryOpExpr binaryExpr = nullptr; 1692 switch (pat) { 1693 case TileExprPattern::TileMod: 1694 binaryExpr = cast<AffineBinaryOpExpr>(oldMapOutput); 1695 newMapOutput = binaryExpr.getRHS(); 1696 break; 1697 case TileExprPattern::TileFloorDiv: 1698 binaryExpr = cast<AffineBinaryOpExpr>(oldMapOutput); 1699 newMapOutput = getAffineBinaryOpExpr( 1700 AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS()); 1701 break; 1702 default: 1703 newMapOutput = oldMapOutput; 1704 } 1705 return newMapOutput; 1706 } 1707 1708 /// Create new maps to calculate each dimension size of `newMemRefType`, and 1709 /// create `newDynamicSizes` from them by using AffineApplyOp. 1710 /// 1711 /// Steps for normalizing dynamic memrefs for a tiled layout map 1712 /// Example: 1713 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)> 1714 /// %0 = dim %arg0, %c1 :memref<4x?xf32> 1715 /// %1 = alloc(%0) : memref<4x?xf32, #map0> 1716 /// 1717 /// (Before this function) 1718 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only 1719 /// single layout map is supported. 1720 /// 1721 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It 1722 /// is memref<4x?x?xf32> in the above example. 1723 /// 1724 /// (In this function) 1725 /// 3. Create new maps to calculate each dimension of the normalized memrefType 1726 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the 1727 /// dimension size can be calculated by replacing "floordiv <tile size>" with 1728 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>". 1729 /// - New map in the above example 1730 /// #map0 = affine_map<(d0, d1) -> (d0)> 1731 /// #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)> 1732 /// #map2 = affine_map<(d0, d1) -> (32)> 1733 /// 1734 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp 1735 /// is used in dynamicSizes of new AllocOp. 1736 /// %0 = dim %arg0, %c1 : memref<4x?xf32> 1737 /// %c4 = arith.constant 4 : index 1738 /// %1 = affine.apply #map1(%c4, %0) 1739 /// %2 = affine.apply #map2(%c4, %0) 1740 template <typename AllocLikeOp> 1741 static void createNewDynamicSizes(MemRefType oldMemRefType, 1742 MemRefType newMemRefType, AffineMap map, 1743 AllocLikeOp *allocOp, OpBuilder b, 1744 SmallVectorImpl<Value> &newDynamicSizes) { 1745 // Create new input for AffineApplyOp. 1746 SmallVector<Value, 4> inAffineApply; 1747 ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape(); 1748 unsigned dynIdx = 0; 1749 for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) { 1750 if (oldMemRefShape[d] < 0) { 1751 // Use dynamicSizes of allocOp for dynamic dimension. 1752 inAffineApply.emplace_back(allocOp->getDynamicSizes()[dynIdx]); 1753 dynIdx++; 1754 } else { 1755 // Create ConstantOp for static dimension. 1756 auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]); 1757 inAffineApply.emplace_back( 1758 b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr)); 1759 } 1760 } 1761 1762 // Create new map to calculate each dimension size of new memref for each 1763 // original map output. Only for dynamic dimesion of `newMemRefType`. 1764 unsigned newDimIdx = 0; 1765 ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape(); 1766 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1767 (void)getTileSizePos(map, tileSizePos); 1768 for (AffineExpr expr : map.getResults()) { 1769 if (newMemRefShape[newDimIdx] < 0) { 1770 // Create new maps to calculate each dimension size of new memref. 1771 enum TileExprPattern pat = TileExprPattern::TileNone; 1772 for (auto pos : tileSizePos) { 1773 if (newDimIdx == std::get<1>(pos)) 1774 pat = TileExprPattern::TileFloorDiv; 1775 else if (newDimIdx == std::get<2>(pos)) 1776 pat = TileExprPattern::TileMod; 1777 } 1778 AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat); 1779 AffineMap newMap = 1780 AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput); 1781 Value affineApp = 1782 b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply); 1783 newDynamicSizes.emplace_back(affineApp); 1784 } 1785 newDimIdx++; 1786 } 1787 } 1788 1789 // TODO: Currently works for static memrefs with a single layout map. 1790 template <typename AllocLikeOp> 1791 LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) { 1792 MemRefType memrefType = allocOp->getType(); 1793 OpBuilder b(*allocOp); 1794 1795 // Fetch a new memref type after normalizing the old memref to have an 1796 // identity map layout. 1797 MemRefType newMemRefType = normalizeMemRefType(memrefType); 1798 if (newMemRefType == memrefType) 1799 // Either memrefType already had an identity map or the map couldn't be 1800 // transformed to an identity map. 1801 return failure(); 1802 1803 Value oldMemRef = allocOp->getResult(); 1804 1805 SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands()); 1806 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 1807 AllocLikeOp newAlloc; 1808 // Check if `layoutMap` is a tiled layout. Only single layout map is 1809 // supported for normalizing dynamic memrefs. 1810 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1811 (void)getTileSizePos(layoutMap, tileSizePos); 1812 if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) { 1813 MemRefType oldMemRefType = cast<MemRefType>(oldMemRef.getType()); 1814 SmallVector<Value, 4> newDynamicSizes; 1815 createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b, 1816 newDynamicSizes); 1817 // Add the new dynamic sizes in new AllocOp. 1818 newAlloc = 1819 b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes, 1820 allocOp->getAlignmentAttr()); 1821 } else { 1822 newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, 1823 allocOp->getAlignmentAttr()); 1824 } 1825 // Replace all uses of the old memref. 1826 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, 1827 /*extraIndices=*/{}, 1828 /*indexRemap=*/layoutMap, 1829 /*extraOperands=*/{}, 1830 /*symbolOperands=*/symbolOperands, 1831 /*domOpFilter=*/nullptr, 1832 /*postDomOpFilter=*/nullptr, 1833 /*allowNonDereferencingOps=*/true))) { 1834 // If it failed (due to escapes for example), bail out. 1835 newAlloc.erase(); 1836 return failure(); 1837 } 1838 // Replace any uses of the original alloc op and erase it. All remaining uses 1839 // have to be dealloc's; RAMUW above would've failed otherwise. 1840 assert(llvm::all_of(oldMemRef.getUsers(), [&](Operation *op) { 1841 return hasSingleEffect<MemoryEffects::Free>(op, oldMemRef); 1842 })); 1843 oldMemRef.replaceAllUsesWith(newAlloc); 1844 allocOp->erase(); 1845 return success(); 1846 } 1847 1848 template LogicalResult 1849 mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op); 1850 template LogicalResult 1851 mlir::affine::normalizeMemRef<memref::AllocOp>(memref::AllocOp *op); 1852 1853 MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) { 1854 unsigned rank = memrefType.getRank(); 1855 if (rank == 0) 1856 return memrefType; 1857 1858 if (memrefType.getLayout().isIdentity()) { 1859 // Either no maps is associated with this memref or this memref has 1860 // a trivial (identity) map. 1861 return memrefType; 1862 } 1863 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 1864 unsigned numSymbolicOperands = layoutMap.getNumSymbols(); 1865 1866 // We don't do any checks for one-to-one'ness; we assume that it is 1867 // one-to-one. 1868 1869 // Normalize only static memrefs and dynamic memrefs with a tiled-layout map 1870 // for now. 1871 // TODO: Normalize the other types of dynamic memrefs. 1872 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1873 (void)getTileSizePos(layoutMap, tileSizePos); 1874 if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty()) 1875 return memrefType; 1876 1877 // We have a single map that is not an identity map. Create a new memref 1878 // with the right shape and an identity layout map. 1879 ArrayRef<int64_t> shape = memrefType.getShape(); 1880 // FlatAffineValueConstraint may later on use symbolicOperands. 1881 FlatAffineValueConstraints fac(rank, numSymbolicOperands); 1882 SmallVector<unsigned, 4> memrefTypeDynDims; 1883 for (unsigned d = 0; d < rank; ++d) { 1884 // Use constraint system only in static dimensions. 1885 if (shape[d] > 0) { 1886 fac.addBound(BoundType::LB, d, 0); 1887 fac.addBound(BoundType::UB, d, shape[d] - 1); 1888 } else { 1889 memrefTypeDynDims.emplace_back(d); 1890 } 1891 } 1892 // We compose this map with the original index (logical) space to derive 1893 // the upper bounds for the new index space. 1894 unsigned newRank = layoutMap.getNumResults(); 1895 if (failed(fac.composeMatchingMap(layoutMap))) 1896 return memrefType; 1897 // TODO: Handle semi-affine maps. 1898 // Project out the old data dimensions. 1899 fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars()); 1900 SmallVector<int64_t, 4> newShape(newRank); 1901 MLIRContext *context = memrefType.getContext(); 1902 for (unsigned d = 0; d < newRank; ++d) { 1903 // Check if this dimension is dynamic. 1904 if (isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) { 1905 newShape[d] = ShapedType::kDynamic; 1906 continue; 1907 } 1908 // The lower bound for the shape is always zero. 1909 std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d); 1910 // For a static memref and an affine map with no symbols, this is 1911 // always bounded. However, when we have symbols, we may not be able to 1912 // obtain a constant upper bound. Also, mapping to a negative space is 1913 // invalid for normalization. 1914 if (!ubConst.has_value() || *ubConst < 0) { 1915 LLVM_DEBUG(llvm::dbgs() 1916 << "can't normalize map due to unknown/invalid upper bound"); 1917 return memrefType; 1918 } 1919 // If dimension of new memrefType is dynamic, the value is -1. 1920 newShape[d] = *ubConst + 1; 1921 } 1922 1923 // Create the new memref type after trivializing the old layout map. 1924 auto newMemRefType = 1925 MemRefType::Builder(memrefType) 1926 .setShape(newShape) 1927 .setLayout(AffineMapAttr::get( 1928 AffineMap::getMultiDimIdentityMap(newRank, context))); 1929 return newMemRefType; 1930 } 1931 1932 DivModValue mlir::affine::getDivMod(OpBuilder &b, Location loc, Value lhs, 1933 Value rhs) { 1934 DivModValue result; 1935 AffineExpr d0, d1; 1936 bindDims(b.getContext(), d0, d1); 1937 result.quotient = 1938 affine::makeComposedAffineApply(b, loc, d0.floorDiv(d1), {lhs, rhs}); 1939 result.remainder = 1940 affine::makeComposedAffineApply(b, loc, d0 % d1, {lhs, rhs}); 1941 return result; 1942 } 1943 1944 /// Create an affine map that computes `lhs` * `rhs`, composing in any other 1945 /// affine maps. 1946 static FailureOr<OpFoldResult> composedAffineMultiply(OpBuilder &b, 1947 Location loc, 1948 OpFoldResult lhs, 1949 OpFoldResult rhs) { 1950 AffineExpr s0, s1; 1951 bindSymbols(b.getContext(), s0, s1); 1952 return makeComposedFoldedAffineApply(b, loc, s0 * s1, {lhs, rhs}); 1953 } 1954 1955 FailureOr<SmallVector<Value>> 1956 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, 1957 ArrayRef<Value> basis, bool hasOuterBound) { 1958 if (hasOuterBound) 1959 basis = basis.drop_front(); 1960 1961 // Note: the divisors are backwards due to the scan. 1962 SmallVector<Value> divisors; 1963 OpFoldResult basisProd = b.getIndexAttr(1); 1964 for (OpFoldResult basisElem : llvm::reverse(basis)) { 1965 FailureOr<OpFoldResult> nextProd = 1966 composedAffineMultiply(b, loc, basisElem, basisProd); 1967 if (failed(nextProd)) 1968 return failure(); 1969 basisProd = *nextProd; 1970 divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, basisProd)); 1971 } 1972 1973 SmallVector<Value> results; 1974 results.reserve(divisors.size() + 1); 1975 Value residual = linearIndex; 1976 for (Value divisor : llvm::reverse(divisors)) { 1977 DivModValue divMod = getDivMod(b, loc, residual, divisor); 1978 results.push_back(divMod.quotient); 1979 residual = divMod.remainder; 1980 } 1981 results.push_back(residual); 1982 return results; 1983 } 1984 1985 FailureOr<SmallVector<Value>> 1986 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, 1987 ArrayRef<OpFoldResult> basis, 1988 bool hasOuterBound) { 1989 if (hasOuterBound) 1990 basis = basis.drop_front(); 1991 1992 // Note: the divisors are backwards due to the scan. 1993 SmallVector<Value> divisors; 1994 OpFoldResult basisProd = b.getIndexAttr(1); 1995 for (OpFoldResult basisElem : llvm::reverse(basis)) { 1996 FailureOr<OpFoldResult> nextProd = 1997 composedAffineMultiply(b, loc, basisElem, basisProd); 1998 if (failed(nextProd)) 1999 return failure(); 2000 basisProd = *nextProd; 2001 divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, basisProd)); 2002 } 2003 2004 SmallVector<Value> results; 2005 results.reserve(divisors.size() + 1); 2006 Value residual = linearIndex; 2007 for (Value divisor : llvm::reverse(divisors)) { 2008 DivModValue divMod = getDivMod(b, loc, residual, divisor); 2009 results.push_back(divMod.quotient); 2010 residual = divMod.remainder; 2011 } 2012 results.push_back(residual); 2013 return results; 2014 } 2015 2016 OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex, 2017 ArrayRef<OpFoldResult> basis, 2018 ImplicitLocOpBuilder &builder) { 2019 return linearizeIndex(builder, builder.getLoc(), multiIndex, basis); 2020 } 2021 2022 OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc, 2023 ArrayRef<OpFoldResult> multiIndex, 2024 ArrayRef<OpFoldResult> basis) { 2025 assert(multiIndex.size() == basis.size() || 2026 multiIndex.size() == basis.size() + 1); 2027 SmallVector<AffineExpr> basisAffine; 2028 2029 // Add a fake initial size in order to make the later index linearization 2030 // computations line up if an outer bound is not provided. 2031 if (multiIndex.size() == basis.size() + 1) 2032 basisAffine.push_back(getAffineConstantExpr(1, builder.getContext())); 2033 2034 for (size_t i = 0; i < basis.size(); ++i) { 2035 basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext())); 2036 } 2037 2038 SmallVector<AffineExpr> stridesAffine = computeStrides(basisAffine); 2039 SmallVector<OpFoldResult> strides; 2040 strides.reserve(stridesAffine.size()); 2041 llvm::transform(stridesAffine, std::back_inserter(strides), 2042 [&builder, &basis, loc](AffineExpr strideExpr) { 2043 return affine::makeComposedFoldedAffineApply( 2044 builder, loc, strideExpr, basis); 2045 }); 2046 2047 auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex( 2048 OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex); 2049 return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr, 2050 multiIndexAndStrides); 2051 } 2052