1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/UB/IR/UBOps.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/AffineExprVisitor.h" 15 #include "mlir/IR/IRMapping.h" 16 #include "mlir/IR/IntegerSet.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/OpDefinition.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Interfaces/ShapedOpInterfaces.h" 21 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include "llvm/ADT/SmallBitVector.h" 26 #include "llvm/ADT/SmallVectorExtras.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/MathExtras.h" 30 #include <numeric> 31 #include <optional> 32 33 using namespace mlir; 34 using namespace mlir::affine; 35 36 using llvm::divideCeilSigned; 37 using llvm::divideFloorSigned; 38 using llvm::mod; 39 40 #define DEBUG_TYPE "affine-ops" 41 42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc" 43 44 /// A utility function to check if a value is defined at the top level of 45 /// `region` or is an argument of `region`. A value of index type defined at the 46 /// top level of a `AffineScope` region is always a valid symbol for all 47 /// uses in that region. 48 bool mlir::affine::isTopLevelValue(Value value, Region *region) { 49 if (auto arg = llvm::dyn_cast<BlockArgument>(value)) 50 return arg.getParentRegion() == region; 51 return value.getDefiningOp()->getParentRegion() == region; 52 } 53 54 /// Checks if `value` known to be a legal affine dimension or symbol in `src` 55 /// region remains legal if the operation that uses it is inlined into `dest` 56 /// with the given value mapping. `legalityCheck` is either `isValidDim` or 57 /// `isValidSymbol`, depending on the value being required to remain a valid 58 /// dimension or symbol. 59 static bool 60 remainsLegalAfterInline(Value value, Region *src, Region *dest, 61 const IRMapping &mapping, 62 function_ref<bool(Value, Region *)> legalityCheck) { 63 // If the value is a valid dimension for any other reason than being 64 // a top-level value, it will remain valid: constants get inlined 65 // with the function, transitive affine applies also get inlined and 66 // will be checked themselves, etc. 67 if (!isTopLevelValue(value, src)) 68 return true; 69 70 // If it's a top-level value because it's a block operand, i.e. a 71 // function argument, check whether the value replacing it after 72 // inlining is a valid dimension in the new region. 73 if (llvm::isa<BlockArgument>(value)) 74 return legalityCheck(mapping.lookup(value), dest); 75 76 // If it's a top-level value because it's defined in the region, 77 // it can only be inlined if the defining op is a constant or a 78 // `dim`, which can appear anywhere and be valid, since the defining 79 // op won't be top-level anymore after inlining. 80 Attribute operandCst; 81 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp()); 82 return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) || 83 isDimLikeOp; 84 } 85 86 /// Checks if all values known to be legal affine dimensions or symbols in `src` 87 /// remain so if their respective users are inlined into `dest`. 88 static bool 89 remainsLegalAfterInline(ValueRange values, Region *src, Region *dest, 90 const IRMapping &mapping, 91 function_ref<bool(Value, Region *)> legalityCheck) { 92 return llvm::all_of(values, [&](Value v) { 93 return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck); 94 }); 95 } 96 97 /// Checks if an affine read or write operation remains legal after inlining 98 /// from `src` to `dest`. 99 template <typename OpTy> 100 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest, 101 const IRMapping &mapping) { 102 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface, 103 AffineWriteOpInterface>::value, 104 "only ops with affine read/write interface are supported"); 105 106 AffineMap map = op.getAffineMap(); 107 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims()); 108 ValueRange symbolOperands = 109 op.getMapOperands().take_back(map.getNumSymbols()); 110 if (!remainsLegalAfterInline( 111 dimOperands, src, dest, mapping, 112 static_cast<bool (*)(Value, Region *)>(isValidDim))) 113 return false; 114 if (!remainsLegalAfterInline( 115 symbolOperands, src, dest, mapping, 116 static_cast<bool (*)(Value, Region *)>(isValidSymbol))) 117 return false; 118 return true; 119 } 120 121 /// Checks if an affine apply operation remains legal after inlining from `src` 122 /// to `dest`. 123 // Use "unused attribute" marker to silence clang-tidy warning stemming from 124 // the inability to see through "llvm::TypeSwitch". 125 template <> 126 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op, 127 Region *src, Region *dest, 128 const IRMapping &mapping) { 129 // If it's a valid dimension, we need to check that it remains so. 130 if (isValidDim(op.getResult(), src)) 131 return remainsLegalAfterInline( 132 op.getMapOperands(), src, dest, mapping, 133 static_cast<bool (*)(Value, Region *)>(isValidDim)); 134 135 // Otherwise it must be a valid symbol, check that it remains so. 136 return remainsLegalAfterInline( 137 op.getMapOperands(), src, dest, mapping, 138 static_cast<bool (*)(Value, Region *)>(isValidSymbol)); 139 } 140 141 //===----------------------------------------------------------------------===// 142 // AffineDialect Interfaces 143 //===----------------------------------------------------------------------===// 144 145 namespace { 146 /// This class defines the interface for handling inlining with affine 147 /// operations. 148 struct AffineInlinerInterface : public DialectInlinerInterface { 149 using DialectInlinerInterface::DialectInlinerInterface; 150 151 //===--------------------------------------------------------------------===// 152 // Analysis Hooks 153 //===--------------------------------------------------------------------===// 154 155 /// Returns true if the given region 'src' can be inlined into the region 156 /// 'dest' that is attached to an operation registered to the current dialect. 157 /// 'wouldBeCloned' is set if the region is cloned into its new location 158 /// rather than moved, indicating there may be other users. 159 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 160 IRMapping &valueMapping) const final { 161 // We can inline into affine loops and conditionals if this doesn't break 162 // affine value categorization rules. 163 Operation *destOp = dest->getParentOp(); 164 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp)) 165 return false; 166 167 // Multi-block regions cannot be inlined into affine constructs, all of 168 // which require single-block regions. 169 if (!llvm::hasSingleElement(*src)) 170 return false; 171 172 // Side-effecting operations that the affine dialect cannot understand 173 // should not be inlined. 174 Block &srcBlock = src->front(); 175 for (Operation &op : srcBlock) { 176 // Ops with no side effects are fine, 177 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) { 178 if (iface.hasNoEffect()) 179 continue; 180 } 181 182 // Assuming the inlined region is valid, we only need to check if the 183 // inlining would change it. 184 bool remainsValid = 185 llvm::TypeSwitch<Operation *, bool>(&op) 186 .Case<AffineApplyOp, AffineReadOpInterface, 187 AffineWriteOpInterface>([&](auto op) { 188 return remainsLegalAfterInline(op, src, dest, valueMapping); 189 }) 190 .Default([](Operation *) { 191 // Conservatively disallow inlining ops we cannot reason about. 192 return false; 193 }); 194 195 if (!remainsValid) 196 return false; 197 } 198 199 return true; 200 } 201 202 /// Returns true if the given operation 'op', that is registered to this 203 /// dialect, can be inlined into the given region, false otherwise. 204 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, 205 IRMapping &valueMapping) const final { 206 // Always allow inlining affine operations into a region that is marked as 207 // affine scope, or into affine loops and conditionals. There are some edge 208 // cases when inlining *into* affine structures, but that is handled in the 209 // other 'isLegalToInline' hook above. 210 Operation *parentOp = region->getParentOp(); 211 return parentOp->hasTrait<OpTrait::AffineScope>() || 212 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp); 213 } 214 215 /// Affine regions should be analyzed recursively. 216 bool shouldAnalyzeRecursively(Operation *op) const final { return true; } 217 }; 218 } // namespace 219 220 //===----------------------------------------------------------------------===// 221 // AffineDialect 222 //===----------------------------------------------------------------------===// 223 224 void AffineDialect::initialize() { 225 addOperations<AffineDmaStartOp, AffineDmaWaitOp, 226 #define GET_OP_LIST 227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" 228 >(); 229 addInterfaces<AffineInlinerInterface>(); 230 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp, 231 AffineMinOp>(); 232 } 233 234 /// Materialize a single constant operation from a given attribute value with 235 /// the desired resultant type. 236 Operation *AffineDialect::materializeConstant(OpBuilder &builder, 237 Attribute value, Type type, 238 Location loc) { 239 if (auto poison = dyn_cast<ub::PoisonAttr>(value)) 240 return builder.create<ub::PoisonOp>(loc, type, poison); 241 return arith::ConstantOp::materialize(builder, value, type, loc); 242 } 243 244 /// A utility function to check if a value is defined at the top level of an 245 /// op with trait `AffineScope`. If the value is defined in an unlinked region, 246 /// conservatively assume it is not top-level. A value of index type defined at 247 /// the top level is always a valid symbol. 248 bool mlir::affine::isTopLevelValue(Value value) { 249 if (auto arg = llvm::dyn_cast<BlockArgument>(value)) { 250 // The block owning the argument may be unlinked, e.g. when the surrounding 251 // region has not yet been attached to an Op, at which point the parent Op 252 // is null. 253 Operation *parentOp = arg.getOwner()->getParentOp(); 254 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>(); 255 } 256 // The defining Op may live in an unlinked block so its parent Op may be null. 257 Operation *parentOp = value.getDefiningOp()->getParentOp(); 258 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>(); 259 } 260 261 /// Returns the closest region enclosing `op` that is held by an operation with 262 /// trait `AffineScope`; `nullptr` if there is no such region. 263 Region *mlir::affine::getAffineScope(Operation *op) { 264 auto *curOp = op; 265 while (auto *parentOp = curOp->getParentOp()) { 266 if (parentOp->hasTrait<OpTrait::AffineScope>()) 267 return curOp->getParentRegion(); 268 curOp = parentOp; 269 } 270 return nullptr; 271 } 272 273 // A Value can be used as a dimension id iff it meets one of the following 274 // conditions: 275 // *) It is valid as a symbol. 276 // *) It is an induction variable. 277 // *) It is the result of affine apply operation with dimension id arguments. 278 bool mlir::affine::isValidDim(Value value) { 279 // The value must be an index type. 280 if (!value.getType().isIndex()) 281 return false; 282 283 if (auto *defOp = value.getDefiningOp()) 284 return isValidDim(value, getAffineScope(defOp)); 285 286 // This value has to be a block argument for an op that has the 287 // `AffineScope` trait or for an affine.for or affine.parallel. 288 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp(); 289 return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() || 290 isa<AffineForOp, AffineParallelOp>(parentOp)); 291 } 292 293 // Value can be used as a dimension id iff it meets one of the following 294 // conditions: 295 // *) It is valid as a symbol. 296 // *) It is an induction variable. 297 // *) It is the result of an affine apply operation with dimension id operands. 298 bool mlir::affine::isValidDim(Value value, Region *region) { 299 // The value must be an index type. 300 if (!value.getType().isIndex()) 301 return false; 302 303 // All valid symbols are okay. 304 if (isValidSymbol(value, region)) 305 return true; 306 307 auto *op = value.getDefiningOp(); 308 if (!op) { 309 // This value has to be a block argument for an affine.for or an 310 // affine.parallel. 311 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp(); 312 return isa<AffineForOp, AffineParallelOp>(parentOp); 313 } 314 315 // Affine apply operation is ok if all of its operands are ok. 316 if (auto applyOp = dyn_cast<AffineApplyOp>(op)) 317 return applyOp.isValidDim(region); 318 // The dim op is okay if its operand memref/tensor is defined at the top 319 // level. 320 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op)) 321 return isTopLevelValue(dimOp.getShapedValue()); 322 return false; 323 } 324 325 /// Returns true if the 'index' dimension of the `memref` defined by 326 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol 327 /// for `region`. 328 template <typename AnyMemRefDefOp> 329 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, 330 Region *region) { 331 MemRefType memRefType = memrefDefOp.getType(); 332 333 // Dimension index is out of bounds. 334 if (index >= memRefType.getRank()) { 335 return false; 336 } 337 338 // Statically shaped. 339 if (!memRefType.isDynamicDim(index)) 340 return true; 341 // Get the position of the dimension among dynamic dimensions; 342 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); 343 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), 344 region); 345 } 346 347 /// Returns true if the result of the dim op is a valid symbol for `region`. 348 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) { 349 // The dim op is okay if its source is defined at the top level. 350 if (isTopLevelValue(dimOp.getShapedValue())) 351 return true; 352 353 // Conservatively handle remaining BlockArguments as non-valid symbols. 354 // E.g. scf.for iterArgs. 355 if (llvm::isa<BlockArgument>(dimOp.getShapedValue())) 356 return false; 357 358 // The dim op is also okay if its operand memref is a view/subview whose 359 // corresponding size is a valid symbol. 360 std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension()); 361 362 // Be conservative if we can't understand the dimension. 363 if (!index.has_value()) 364 return false; 365 366 // Skip over all memref.cast ops (if any). 367 Operation *op = dimOp.getShapedValue().getDefiningOp(); 368 while (auto castOp = dyn_cast<memref::CastOp>(op)) { 369 // Bail on unranked memrefs. 370 if (isa<UnrankedMemRefType>(castOp.getSource().getType())) 371 return false; 372 op = castOp.getSource().getDefiningOp(); 373 if (!op) 374 return false; 375 } 376 377 int64_t i = index.value(); 378 return TypeSwitch<Operation *, bool>(op) 379 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>( 380 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); }) 381 .Default([](Operation *) { return false; }); 382 } 383 384 // A value can be used as a symbol (at all its use sites) iff it meets one of 385 // the following conditions: 386 // *) It is a constant. 387 // *) Its defining op or block arg appearance is immediately enclosed by an op 388 // with `AffineScope` trait. 389 // *) It is the result of an affine.apply operation with symbol operands. 390 // *) It is a result of the dim op on a memref whose corresponding size is a 391 // valid symbol. 392 bool mlir::affine::isValidSymbol(Value value) { 393 if (!value) 394 return false; 395 396 // The value must be an index type. 397 if (!value.getType().isIndex()) 398 return false; 399 400 // Check that the value is a top level value. 401 if (isTopLevelValue(value)) 402 return true; 403 404 if (auto *defOp = value.getDefiningOp()) 405 return isValidSymbol(value, getAffineScope(defOp)); 406 407 return false; 408 } 409 410 /// A value can be used as a symbol for `region` iff it meets one of the 411 /// following conditions: 412 /// *) It is a constant. 413 /// *) It is a result of a `Pure` operation whose operands are valid symbolic 414 /// *) identifiers. 415 /// *) It is a result of the dim op on a memref whose corresponding size is 416 /// a valid symbol. 417 /// *) It is defined at the top level of 'region' or is its argument. 418 /// *) It dominates `region`'s parent op. 419 /// If `region` is null, conservatively assume the symbol definition scope does 420 /// not exist and only accept the values that would be symbols regardless of 421 /// the surrounding region structure, i.e. the first three cases above. 422 bool mlir::affine::isValidSymbol(Value value, Region *region) { 423 // The value must be an index type. 424 if (!value.getType().isIndex()) 425 return false; 426 427 // A top-level value is a valid symbol. 428 if (region && ::isTopLevelValue(value, region)) 429 return true; 430 431 auto *defOp = value.getDefiningOp(); 432 if (!defOp) { 433 // A block argument that is not a top-level value is a valid symbol if it 434 // dominates region's parent op. 435 Operation *regionOp = region ? region->getParentOp() : nullptr; 436 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) 437 if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) 438 return isValidSymbol(value, parentOpRegion); 439 return false; 440 } 441 442 // Constant operation is ok. 443 Attribute operandCst; 444 if (matchPattern(defOp, m_Constant(&operandCst))) 445 return true; 446 447 // `Pure` operation that whose operands are valid symbolic identifiers. 448 if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) { 449 return affine::isValidSymbol(operand, region); 450 })) { 451 return true; 452 } 453 454 // Dim op results could be valid symbols at any level. 455 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp)) 456 return isDimOpValidSymbol(dimOp, region); 457 458 // Check for values dominating `region`'s parent op. 459 Operation *regionOp = region ? region->getParentOp() : nullptr; 460 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) 461 if (auto *parentRegion = region->getParentOp()->getParentRegion()) 462 return isValidSymbol(value, parentRegion); 463 464 return false; 465 } 466 467 // Returns true if 'value' is a valid index to an affine operation (e.g. 468 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where 469 // `region` provides the polyhedral symbol scope. Returns false otherwise. 470 static bool isValidAffineIndexOperand(Value value, Region *region) { 471 return isValidDim(value, region) || isValidSymbol(value, region); 472 } 473 474 /// Prints dimension and symbol list. 475 static void printDimAndSymbolList(Operation::operand_iterator begin, 476 Operation::operand_iterator end, 477 unsigned numDims, OpAsmPrinter &printer) { 478 OperandRange operands(begin, end); 479 printer << '(' << operands.take_front(numDims) << ')'; 480 if (operands.size() > numDims) 481 printer << '[' << operands.drop_front(numDims) << ']'; 482 } 483 484 /// Parses dimension and symbol list and returns true if parsing failed. 485 ParseResult mlir::affine::parseDimAndSymbolList( 486 OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) { 487 SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos; 488 if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) 489 return failure(); 490 // Store number of dimensions for validation by caller. 491 numDims = opInfos.size(); 492 493 // Parse the optional symbol operands. 494 auto indexTy = parser.getBuilder().getIndexType(); 495 return failure(parser.parseOperandList( 496 opInfos, OpAsmParser::Delimiter::OptionalSquare) || 497 parser.resolveOperands(opInfos, indexTy, operands)); 498 } 499 500 /// Utility function to verify that a set of operands are valid dimension and 501 /// symbol identifiers. The operands should be laid out such that the dimension 502 /// operands are before the symbol operands. This function returns failure if 503 /// there was an invalid operand. An operation is provided to emit any necessary 504 /// errors. 505 template <typename OpTy> 506 static LogicalResult 507 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, 508 unsigned numDims) { 509 unsigned opIt = 0; 510 for (auto operand : operands) { 511 if (opIt++ < numDims) { 512 if (!isValidDim(operand, getAffineScope(op))) 513 return op.emitOpError("operand cannot be used as a dimension id"); 514 } else if (!isValidSymbol(operand, getAffineScope(op))) { 515 return op.emitOpError("operand cannot be used as a symbol"); 516 } 517 } 518 return success(); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // AffineApplyOp 523 //===----------------------------------------------------------------------===// 524 525 AffineValueMap AffineApplyOp::getAffineValueMap() { 526 return AffineValueMap(getAffineMap(), getOperands(), getResult()); 527 } 528 529 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) { 530 auto &builder = parser.getBuilder(); 531 auto indexTy = builder.getIndexType(); 532 533 AffineMapAttr mapAttr; 534 unsigned numDims; 535 if (parser.parseAttribute(mapAttr, "map", result.attributes) || 536 parseDimAndSymbolList(parser, result.operands, numDims) || 537 parser.parseOptionalAttrDict(result.attributes)) 538 return failure(); 539 auto map = mapAttr.getValue(); 540 541 if (map.getNumDims() != numDims || 542 numDims + map.getNumSymbols() != result.operands.size()) { 543 return parser.emitError(parser.getNameLoc(), 544 "dimension or symbol index mismatch"); 545 } 546 547 result.types.append(map.getNumResults(), indexTy); 548 return success(); 549 } 550 551 void AffineApplyOp::print(OpAsmPrinter &p) { 552 p << " " << getMapAttr(); 553 printDimAndSymbolList(operand_begin(), operand_end(), 554 getAffineMap().getNumDims(), p); 555 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"}); 556 } 557 558 LogicalResult AffineApplyOp::verify() { 559 // Check input and output dimensions match. 560 AffineMap affineMap = getMap(); 561 562 // Verify that operand count matches affine map dimension and symbol count. 563 if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols()) 564 return emitOpError( 565 "operand count and affine map dimension and symbol count must match"); 566 567 // Verify that the map only produces one result. 568 if (affineMap.getNumResults() != 1) 569 return emitOpError("mapping must produce one value"); 570 571 return success(); 572 } 573 574 // The result of the affine apply operation can be used as a dimension id if all 575 // its operands are valid dimension ids. 576 bool AffineApplyOp::isValidDim() { 577 return llvm::all_of(getOperands(), 578 [](Value op) { return affine::isValidDim(op); }); 579 } 580 581 // The result of the affine apply operation can be used as a dimension id if all 582 // its operands are valid dimension ids with the parent operation of `region` 583 // defining the polyhedral scope for symbols. 584 bool AffineApplyOp::isValidDim(Region *region) { 585 return llvm::all_of(getOperands(), 586 [&](Value op) { return ::isValidDim(op, region); }); 587 } 588 589 // The result of the affine apply operation can be used as a symbol if all its 590 // operands are symbols. 591 bool AffineApplyOp::isValidSymbol() { 592 return llvm::all_of(getOperands(), 593 [](Value op) { return affine::isValidSymbol(op); }); 594 } 595 596 // The result of the affine apply operation can be used as a symbol in `region` 597 // if all its operands are symbols in `region`. 598 bool AffineApplyOp::isValidSymbol(Region *region) { 599 return llvm::all_of(getOperands(), [&](Value operand) { 600 return affine::isValidSymbol(operand, region); 601 }); 602 } 603 604 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) { 605 auto map = getAffineMap(); 606 607 // Fold dims and symbols to existing values. 608 auto expr = map.getResult(0); 609 if (auto dim = dyn_cast<AffineDimExpr>(expr)) 610 return getOperand(dim.getPosition()); 611 if (auto sym = dyn_cast<AffineSymbolExpr>(expr)) 612 return getOperand(map.getNumDims() + sym.getPosition()); 613 614 // Otherwise, default to folding the map. 615 SmallVector<Attribute, 1> result; 616 bool hasPoison = false; 617 auto foldResult = 618 map.constantFold(adaptor.getMapOperands(), result, &hasPoison); 619 if (hasPoison) 620 return ub::PoisonAttr::get(getContext()); 621 if (failed(foldResult)) 622 return {}; 623 return result[0]; 624 } 625 626 /// Returns the largest known divisor of `e`. Exploits information from the 627 /// values in `operands`. 628 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) { 629 // This method isn't aware of `operands`. 630 int64_t div = e.getLargestKnownDivisor(); 631 632 // We now make use of operands for the case `e` is a dim expression. 633 // TODO: More powerful simplification would have to modify 634 // getLargestKnownDivisor to take `operands` and exploit that information as 635 // well for dim/sym expressions, but in that case, getLargestKnownDivisor 636 // can't be part of the IR library but of the `Analysis` library. The IR 637 // library can only really depend on simple O(1) checks. 638 auto dimExpr = dyn_cast<AffineDimExpr>(e); 639 // If it's not a dim expr, `div` is the best we have. 640 if (!dimExpr) 641 return div; 642 643 // We simply exploit information from loop IVs. 644 // We don't need to use mlir::getLargestKnownDivisorOfValue since the other 645 // desired simplifications are expected to be part of other 646 // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the 647 // LoopAnalysis library. 648 Value operand = operands[dimExpr.getPosition()]; 649 int64_t operandDivisor = 1; 650 // TODO: With the right accessors, this can be extended to 651 // LoopLikeOpInterface. 652 if (AffineForOp forOp = getForInductionVarOwner(operand)) { 653 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) { 654 operandDivisor = forOp.getStepAsInt(); 655 } else { 656 uint64_t lbLargestKnownDivisor = 657 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs(); 658 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt()); 659 } 660 } 661 return operandDivisor; 662 } 663 664 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e` 665 /// being an affine dim expression or a constant. 666 static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands, 667 int64_t k) { 668 if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) { 669 int64_t constVal = constExpr.getValue(); 670 return constVal >= 0 && constVal < k; 671 } 672 auto dimExpr = dyn_cast<AffineDimExpr>(e); 673 if (!dimExpr) 674 return false; 675 Value operand = operands[dimExpr.getPosition()]; 676 // TODO: With the right accessors, this can be extended to 677 // LoopLikeOpInterface. 678 if (AffineForOp forOp = getForInductionVarOwner(operand)) { 679 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 && 680 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) { 681 return true; 682 } 683 } 684 685 // We don't consider other cases like `operand` being defined by a constant or 686 // an affine.apply op since such cases will already be handled by other 687 // patterns and propagation of loop IVs or constant would happen. 688 return false; 689 } 690 691 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d. 692 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the 693 /// expression is in that form. 694 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div, 695 AffineExpr "ientTimesDiv, AffineExpr &rem) { 696 auto bin = dyn_cast<AffineBinaryOpExpr>(e); 697 if (!bin || bin.getKind() != AffineExprKind::Add) 698 return false; 699 700 AffineExpr llhs = bin.getLHS(); 701 AffineExpr rlhs = bin.getRHS(); 702 div = getLargestKnownDivisor(llhs, operands); 703 if (isNonNegativeBoundedBy(rlhs, operands, div)) { 704 quotientTimesDiv = llhs; 705 rem = rlhs; 706 return true; 707 } 708 div = getLargestKnownDivisor(rlhs, operands); 709 if (isNonNegativeBoundedBy(llhs, operands, div)) { 710 quotientTimesDiv = rlhs; 711 rem = llhs; 712 return true; 713 } 714 return false; 715 } 716 717 /// Gets the constant lower bound on an `iv`. 718 static std::optional<int64_t> getLowerBound(Value iv) { 719 AffineForOp forOp = getForInductionVarOwner(iv); 720 if (forOp && forOp.hasConstantLowerBound()) 721 return forOp.getConstantLowerBound(); 722 return std::nullopt; 723 } 724 725 /// Gets the constant upper bound on an affine.for `iv`. 726 static std::optional<int64_t> getUpperBound(Value iv) { 727 AffineForOp forOp = getForInductionVarOwner(iv); 728 if (!forOp || !forOp.hasConstantUpperBound()) 729 return std::nullopt; 730 731 // If its lower bound is also known, we can get a more precise bound 732 // whenever the step is not one. 733 if (forOp.hasConstantLowerBound()) { 734 return forOp.getConstantUpperBound() - 1 - 735 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) % 736 forOp.getStepAsInt(); 737 } 738 return forOp.getConstantUpperBound() - 1; 739 } 740 741 /// Determine a constant upper bound for `expr` if one exists while exploiting 742 /// values in `operands`. Note that the upper bound is an inclusive one. `expr` 743 /// is guaranteed to be less than or equal to it. 744 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims, 745 unsigned numSymbols, 746 ArrayRef<Value> operands) { 747 // Get the constant lower or upper bounds on the operands. 748 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds; 749 constLowerBounds.reserve(operands.size()); 750 constUpperBounds.reserve(operands.size()); 751 for (Value operand : operands) { 752 constLowerBounds.push_back(getLowerBound(operand)); 753 constUpperBounds.push_back(getUpperBound(operand)); 754 } 755 756 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) 757 return constExpr.getValue(); 758 759 return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds, 760 constUpperBounds, 761 /*isUpper=*/true); 762 } 763 764 /// Determine a constant lower bound for `expr` if one exists while exploiting 765 /// values in `operands`. Note that the upper bound is an inclusive one. `expr` 766 /// is guaranteed to be less than or equal to it. 767 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims, 768 unsigned numSymbols, 769 ArrayRef<Value> operands) { 770 // Get the constant lower or upper bounds on the operands. 771 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds; 772 constLowerBounds.reserve(operands.size()); 773 constUpperBounds.reserve(operands.size()); 774 for (Value operand : operands) { 775 constLowerBounds.push_back(getLowerBound(operand)); 776 constUpperBounds.push_back(getUpperBound(operand)); 777 } 778 779 std::optional<int64_t> lowerBound; 780 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) { 781 lowerBound = constExpr.getValue(); 782 } else { 783 lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols, 784 constLowerBounds, constUpperBounds, 785 /*isUpper=*/false); 786 } 787 return lowerBound; 788 } 789 790 /// Simplify `expr` while exploiting information from the values in `operands`. 791 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, 792 unsigned numSymbols, 793 ArrayRef<Value> operands) { 794 // We do this only for certain floordiv/mod expressions. 795 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr); 796 if (!binExpr) 797 return; 798 799 // Simplify the child expressions first. 800 AffineExpr lhs = binExpr.getLHS(); 801 AffineExpr rhs = binExpr.getRHS(); 802 simplifyExprAndOperands(lhs, numDims, numSymbols, operands); 803 simplifyExprAndOperands(rhs, numDims, numSymbols, operands); 804 expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs); 805 806 binExpr = dyn_cast<AffineBinaryOpExpr>(expr); 807 if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv && 808 expr.getKind() != AffineExprKind::CeilDiv && 809 expr.getKind() != AffineExprKind::Mod)) { 810 return; 811 } 812 813 // The `lhs` and `rhs` may be different post construction of simplified expr. 814 lhs = binExpr.getLHS(); 815 rhs = binExpr.getRHS(); 816 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs); 817 if (!rhsConst) 818 return; 819 820 int64_t rhsConstVal = rhsConst.getValue(); 821 // Undefined exprsessions aren't touched; IR can still be valid with them. 822 if (rhsConstVal <= 0) 823 return; 824 825 // Exploit constant lower/upper bounds to simplify a floordiv or mod. 826 MLIRContext *context = expr.getContext(); 827 std::optional<int64_t> lhsLbConst = 828 getLowerBound(lhs, numDims, numSymbols, operands); 829 std::optional<int64_t> lhsUbConst = 830 getUpperBound(lhs, numDims, numSymbols, operands); 831 if (lhsLbConst && lhsUbConst) { 832 int64_t lhsLbConstVal = *lhsLbConst; 833 int64_t lhsUbConstVal = *lhsUbConst; 834 // lhs floordiv c is a single value lhs is bounded in a range `c` that has 835 // the same quotient. 836 if (binExpr.getKind() == AffineExprKind::FloorDiv && 837 divideFloorSigned(lhsLbConstVal, rhsConstVal) == 838 divideFloorSigned(lhsUbConstVal, rhsConstVal)) { 839 expr = getAffineConstantExpr( 840 divideFloorSigned(lhsLbConstVal, rhsConstVal), context); 841 return; 842 } 843 // lhs ceildiv c is a single value if the entire range has the same ceil 844 // quotient. 845 if (binExpr.getKind() == AffineExprKind::CeilDiv && 846 divideCeilSigned(lhsLbConstVal, rhsConstVal) == 847 divideCeilSigned(lhsUbConstVal, rhsConstVal)) { 848 expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal), 849 context); 850 return; 851 } 852 // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs. 853 if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 && 854 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) { 855 expr = lhs; 856 return; 857 } 858 } 859 860 // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2) 861 // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if 862 // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c. 863 // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c. 864 AffineExpr quotientTimesDiv, rem; 865 int64_t divisor; 866 if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) { 867 if (rhsConstVal % divisor == 0 && 868 binExpr.getKind() == AffineExprKind::FloorDiv) { 869 expr = quotientTimesDiv.floorDiv(rhsConst); 870 } else if (divisor % rhsConstVal == 0 && 871 binExpr.getKind() == AffineExprKind::Mod) { 872 expr = rem % rhsConst; 873 } 874 return; 875 } 876 877 // Handle the simple case when the LHS expression can be either upper 878 // bounded or is a known multiple of RHS constant. 879 // lhs floordiv c -> 0 if 0 <= lhs < c, 880 // lhs mod c -> 0 if lhs % c = 0. 881 if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) && 882 binExpr.getKind() == AffineExprKind::FloorDiv) || 883 (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 && 884 binExpr.getKind() == AffineExprKind::Mod)) { 885 expr = getAffineConstantExpr(0, expr.getContext()); 886 } 887 } 888 889 /// Simplify the expressions in `map` while making use of lower or upper bounds 890 /// of its operands. If `isMax` is true, the map is to be treated as a max of 891 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 + 892 /// d1) can be simplified to (8) if the operands are respectively lower bounded 893 /// by 2 and 0 (the second expression can't be lower than 8). 894 static void simplifyMinOrMaxExprWithOperands(AffineMap &map, 895 ArrayRef<Value> operands, 896 bool isMax) { 897 // Can't simplify. 898 if (operands.empty()) 899 return; 900 901 // Get the upper or lower bound on an affine.for op IV using its range. 902 // Get the constant lower or upper bounds on the operands. 903 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds; 904 constLowerBounds.reserve(operands.size()); 905 constUpperBounds.reserve(operands.size()); 906 for (Value operand : operands) { 907 constLowerBounds.push_back(getLowerBound(operand)); 908 constUpperBounds.push_back(getUpperBound(operand)); 909 } 910 911 // We will compute the lower and upper bounds on each of the expressions 912 // Then, we will check (depending on max or min) as to whether a specific 913 // bound is redundant by checking if its highest (in case of max) and its 914 // lowest (in the case of min) value is already lower than (or higher than) 915 // the lower bound (or upper bound in the case of min) of another bound. 916 SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds; 917 lowerBounds.reserve(map.getNumResults()); 918 upperBounds.reserve(map.getNumResults()); 919 for (AffineExpr e : map.getResults()) { 920 if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) { 921 lowerBounds.push_back(constExpr.getValue()); 922 upperBounds.push_back(constExpr.getValue()); 923 } else { 924 lowerBounds.push_back( 925 getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(), 926 constLowerBounds, constUpperBounds, 927 /*isUpper=*/false)); 928 upperBounds.push_back( 929 getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(), 930 constLowerBounds, constUpperBounds, 931 /*isUpper=*/true)); 932 } 933 } 934 935 // Collect expressions that are not redundant. 936 SmallVector<AffineExpr, 4> irredundantExprs; 937 for (auto exprEn : llvm::enumerate(map.getResults())) { 938 AffineExpr e = exprEn.value(); 939 unsigned i = exprEn.index(); 940 // Some expressions can be turned into constants. 941 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i]) 942 e = getAffineConstantExpr(*lowerBounds[i], e.getContext()); 943 944 // Check if the expression is redundant. 945 if (isMax) { 946 if (!upperBounds[i]) { 947 irredundantExprs.push_back(e); 948 continue; 949 } 950 // If there exists another expression such that its lower bound is greater 951 // than this expression's upper bound, it's redundant. 952 if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) { 953 auto otherLowerBound = en.value(); 954 unsigned pos = en.index(); 955 if (pos == i || !otherLowerBound) 956 return false; 957 if (*otherLowerBound > *upperBounds[i]) 958 return true; 959 if (*otherLowerBound < *upperBounds[i]) 960 return false; 961 // Equality case. When both expressions are considered redundant, we 962 // don't want to get both of them. We keep the one that appears 963 // first. 964 if (upperBounds[pos] && lowerBounds[i] && 965 lowerBounds[i] == upperBounds[i] && 966 otherLowerBound == *upperBounds[pos] && i < pos) 967 return false; 968 return true; 969 })) 970 irredundantExprs.push_back(e); 971 } else { 972 if (!lowerBounds[i]) { 973 irredundantExprs.push_back(e); 974 continue; 975 } 976 // Likewise for the `min` case. Use the complement of the condition above. 977 if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) { 978 auto otherUpperBound = en.value(); 979 unsigned pos = en.index(); 980 if (pos == i || !otherUpperBound) 981 return false; 982 if (*otherUpperBound < *lowerBounds[i]) 983 return true; 984 if (*otherUpperBound > *lowerBounds[i]) 985 return false; 986 if (lowerBounds[pos] && upperBounds[i] && 987 lowerBounds[i] == upperBounds[i] && 988 otherUpperBound == lowerBounds[pos] && i < pos) 989 return false; 990 return true; 991 })) 992 irredundantExprs.push_back(e); 993 } 994 } 995 996 // Create the map without the redundant expressions. 997 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs, 998 map.getContext()); 999 } 1000 1001 /// Simplify the map while exploiting information on the values in `operands`. 1002 // Use "unused attribute" marker to silence warning stemming from the inability 1003 // to see through the template expansion. 1004 static void LLVM_ATTRIBUTE_UNUSED 1005 simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) { 1006 assert(map.getNumInputs() == operands.size() && "invalid operands for map"); 1007 SmallVector<AffineExpr> newResults; 1008 newResults.reserve(map.getNumResults()); 1009 for (AffineExpr expr : map.getResults()) { 1010 simplifyExprAndOperands(expr, map.getNumDims(), map.getNumSymbols(), 1011 operands); 1012 newResults.push_back(expr); 1013 } 1014 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, 1015 map.getContext()); 1016 } 1017 1018 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the 1019 /// defining AffineApplyOp expression and operands. 1020 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. 1021 /// When `dimOrSymbolPosition >= dims.size()`, 1022 /// AffineSymbolExpr@[pos - dims.size()] is replaced. 1023 /// Mutate `map`,`dims` and `syms` in place as follows: 1024 /// 1. `dims` and `syms` are only appended to. 1025 /// 2. `map` dim and symbols are gradually shifted to higher positions. 1026 /// 3. Old `dim` and `sym` entries are replaced by nullptr 1027 /// This avoids the need for any bookkeeping. 1028 static LogicalResult replaceDimOrSym(AffineMap *map, 1029 unsigned dimOrSymbolPosition, 1030 SmallVectorImpl<Value> &dims, 1031 SmallVectorImpl<Value> &syms) { 1032 MLIRContext *ctx = map->getContext(); 1033 bool isDimReplacement = (dimOrSymbolPosition < dims.size()); 1034 unsigned pos = isDimReplacement ? dimOrSymbolPosition 1035 : dimOrSymbolPosition - dims.size(); 1036 Value &v = isDimReplacement ? dims[pos] : syms[pos]; 1037 if (!v) 1038 return failure(); 1039 1040 auto affineApply = v.getDefiningOp<AffineApplyOp>(); 1041 if (!affineApply) 1042 return failure(); 1043 1044 // At this point we will perform a replacement of `v`, set the entry in `dim` 1045 // or `sym` to nullptr immediately. 1046 v = nullptr; 1047 1048 // Compute the map, dims and symbols coming from the AffineApplyOp. 1049 AffineMap composeMap = affineApply.getAffineMap(); 1050 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results"); 1051 SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(), 1052 affineApply.getMapOperands().end()); 1053 // Canonicalize the map to promote dims to symbols when possible. This is to 1054 // avoid generating invalid maps. 1055 canonicalizeMapAndOperands(&composeMap, &composeOperands); 1056 AffineExpr replacementExpr = 1057 composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0); 1058 ValueRange composeDims = 1059 ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims()); 1060 ValueRange composeSyms = 1061 ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols()); 1062 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx) 1063 : getAffineSymbolExpr(pos, ctx); 1064 1065 // Append the dims and symbols where relevant and perform the replacement. 1066 dims.append(composeDims.begin(), composeDims.end()); 1067 syms.append(composeSyms.begin(), composeSyms.end()); 1068 *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size()); 1069 1070 return success(); 1071 } 1072 1073 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp 1074 /// iteratively. Perform canonicalization of map and operands as well as 1075 /// AffineMap simplification. `map` and `operands` are mutated in place. 1076 static void composeAffineMapAndOperands(AffineMap *map, 1077 SmallVectorImpl<Value> *operands) { 1078 if (map->getNumResults() == 0) { 1079 canonicalizeMapAndOperands(map, operands); 1080 *map = simplifyAffineMap(*map); 1081 return; 1082 } 1083 1084 MLIRContext *ctx = map->getContext(); 1085 SmallVector<Value, 4> dims(operands->begin(), 1086 operands->begin() + map->getNumDims()); 1087 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(), 1088 operands->end()); 1089 1090 // Iterate over dims and symbols coming from AffineApplyOp and replace until 1091 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims` 1092 // and `syms` can only increase by construction. 1093 // The implementation uses a `while` loop to support the case of symbols 1094 // that may be constructed from dims ;this may be overkill. 1095 while (true) { 1096 bool changed = false; 1097 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos) 1098 if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms)))) 1099 break; 1100 if (!changed) 1101 break; 1102 } 1103 1104 // Clear operands so we can fill them anew. 1105 operands->clear(); 1106 1107 // At this point we may have introduced null operands, prune them out before 1108 // canonicalizing map and operands. 1109 unsigned nDims = 0, nSyms = 0; 1110 SmallVector<AffineExpr, 4> dimReplacements, symReplacements; 1111 dimReplacements.reserve(dims.size()); 1112 symReplacements.reserve(syms.size()); 1113 for (auto *container : {&dims, &syms}) { 1114 bool isDim = (container == &dims); 1115 auto &repls = isDim ? dimReplacements : symReplacements; 1116 for (const auto &en : llvm::enumerate(*container)) { 1117 Value v = en.value(); 1118 if (!v) { 1119 assert(isDim ? !map->isFunctionOfDim(en.index()) 1120 : !map->isFunctionOfSymbol(en.index()) && 1121 "map is function of unexpected expr@pos"); 1122 repls.push_back(getAffineConstantExpr(0, ctx)); 1123 continue; 1124 } 1125 repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx) 1126 : getAffineSymbolExpr(nSyms++, ctx)); 1127 operands->push_back(v); 1128 } 1129 } 1130 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims, 1131 nSyms); 1132 1133 // Canonicalize and simplify before returning. 1134 canonicalizeMapAndOperands(map, operands); 1135 *map = simplifyAffineMap(*map); 1136 } 1137 1138 void mlir::affine::fullyComposeAffineMapAndOperands( 1139 AffineMap *map, SmallVectorImpl<Value> *operands) { 1140 while (llvm::any_of(*operands, [](Value v) { 1141 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp()); 1142 })) { 1143 composeAffineMapAndOperands(map, operands); 1144 } 1145 } 1146 1147 AffineApplyOp 1148 mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, 1149 ArrayRef<OpFoldResult> operands) { 1150 SmallVector<Value> valueOperands; 1151 map = foldAttributesIntoMap(b, map, operands, valueOperands); 1152 composeAffineMapAndOperands(&map, &valueOperands); 1153 assert(map); 1154 return b.create<AffineApplyOp>(loc, map, valueOperands); 1155 } 1156 1157 AffineApplyOp 1158 mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e, 1159 ArrayRef<OpFoldResult> operands) { 1160 return makeComposedAffineApply( 1161 b, loc, 1162 AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}, b.getContext()) 1163 .front(), 1164 operands); 1165 } 1166 1167 /// Composes the given affine map with the given list of operands, pulling in 1168 /// the maps from any affine.apply operations that supply the operands. 1169 static void composeMultiResultAffineMap(AffineMap &map, 1170 SmallVectorImpl<Value> &operands) { 1171 // Compose and canonicalize each expression in the map individually because 1172 // composition only applies to single-result maps, collecting potentially 1173 // duplicate operands in a single list with shifted dimensions and symbols. 1174 SmallVector<Value> dims, symbols; 1175 SmallVector<AffineExpr> exprs; 1176 for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) { 1177 SmallVector<Value> submapOperands(operands.begin(), operands.end()); 1178 AffineMap submap = map.getSubMap({i}); 1179 fullyComposeAffineMapAndOperands(&submap, &submapOperands); 1180 canonicalizeMapAndOperands(&submap, &submapOperands); 1181 unsigned numNewDims = submap.getNumDims(); 1182 submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size()); 1183 llvm::append_range(dims, 1184 ArrayRef<Value>(submapOperands).take_front(numNewDims)); 1185 llvm::append_range(symbols, 1186 ArrayRef<Value>(submapOperands).drop_front(numNewDims)); 1187 exprs.push_back(submap.getResult(0)); 1188 } 1189 1190 // Canonicalize the map created from composed expressions to deduplicate the 1191 // dimension and symbol operands. 1192 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols)); 1193 map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 1194 canonicalizeMapAndOperands(&map, &operands); 1195 } 1196 1197 OpFoldResult 1198 mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, 1199 AffineMap map, 1200 ArrayRef<OpFoldResult> operands) { 1201 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result"); 1202 1203 // Create new builder without a listener, so that no notification is 1204 // triggered if the op is folded. 1205 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this 1206 // workaround is no longer needed. 1207 OpBuilder newBuilder(b.getContext()); 1208 newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint()); 1209 1210 // Create op. 1211 AffineApplyOp applyOp = 1212 makeComposedAffineApply(newBuilder, loc, map, operands); 1213 1214 // Get constant operands. 1215 SmallVector<Attribute> constOperands(applyOp->getNumOperands()); 1216 for (unsigned i = 0, e = constOperands.size(); i != e; ++i) 1217 matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i])); 1218 1219 // Try to fold the operation. 1220 SmallVector<OpFoldResult> foldResults; 1221 if (failed(applyOp->fold(constOperands, foldResults)) || 1222 foldResults.empty()) { 1223 if (OpBuilder::Listener *listener = b.getListener()) 1224 listener->notifyOperationInserted(applyOp, /*previous=*/{}); 1225 return applyOp.getResult(); 1226 } 1227 1228 applyOp->erase(); 1229 assert(foldResults.size() == 1 && "expected 1 folded result"); 1230 return foldResults.front(); 1231 } 1232 1233 OpFoldResult 1234 mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, 1235 AffineExpr expr, 1236 ArrayRef<OpFoldResult> operands) { 1237 return makeComposedFoldedAffineApply( 1238 b, loc, 1239 AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}, b.getContext()) 1240 .front(), 1241 operands); 1242 } 1243 1244 SmallVector<OpFoldResult> 1245 mlir::affine::makeComposedFoldedMultiResultAffineApply( 1246 OpBuilder &b, Location loc, AffineMap map, 1247 ArrayRef<OpFoldResult> operands) { 1248 return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()), 1249 [&](unsigned i) { 1250 return makeComposedFoldedAffineApply( 1251 b, loc, map.getSubMap({i}), operands); 1252 }); 1253 } 1254 1255 template <typename OpTy> 1256 static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, 1257 ArrayRef<OpFoldResult> operands) { 1258 SmallVector<Value> valueOperands; 1259 map = foldAttributesIntoMap(b, map, operands, valueOperands); 1260 composeMultiResultAffineMap(map, valueOperands); 1261 return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands); 1262 } 1263 1264 AffineMinOp 1265 mlir::affine::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, 1266 ArrayRef<OpFoldResult> operands) { 1267 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands); 1268 } 1269 1270 template <typename OpTy> 1271 static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, 1272 AffineMap map, 1273 ArrayRef<OpFoldResult> operands) { 1274 // Create new builder without a listener, so that no notification is 1275 // triggered if the op is folded. 1276 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this 1277 // workaround is no longer needed. 1278 OpBuilder newBuilder(b.getContext()); 1279 newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint()); 1280 1281 // Create op. 1282 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands); 1283 1284 // Get constant operands. 1285 SmallVector<Attribute> constOperands(minMaxOp->getNumOperands()); 1286 for (unsigned i = 0, e = constOperands.size(); i != e; ++i) 1287 matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i])); 1288 1289 // Try to fold the operation. 1290 SmallVector<OpFoldResult> foldResults; 1291 if (failed(minMaxOp->fold(constOperands, foldResults)) || 1292 foldResults.empty()) { 1293 if (OpBuilder::Listener *listener = b.getListener()) 1294 listener->notifyOperationInserted(minMaxOp, /*previous=*/{}); 1295 return minMaxOp.getResult(); 1296 } 1297 1298 minMaxOp->erase(); 1299 assert(foldResults.size() == 1 && "expected 1 folded result"); 1300 return foldResults.front(); 1301 } 1302 1303 OpFoldResult 1304 mlir::affine::makeComposedFoldedAffineMin(OpBuilder &b, Location loc, 1305 AffineMap map, 1306 ArrayRef<OpFoldResult> operands) { 1307 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands); 1308 } 1309 1310 OpFoldResult 1311 mlir::affine::makeComposedFoldedAffineMax(OpBuilder &b, Location loc, 1312 AffineMap map, 1313 ArrayRef<OpFoldResult> operands) { 1314 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands); 1315 } 1316 1317 // A symbol may appear as a dim in affine.apply operations. This function 1318 // canonicalizes dims that are valid symbols into actual symbols. 1319 template <class MapOrSet> 1320 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, 1321 SmallVectorImpl<Value> *operands) { 1322 if (!mapOrSet || operands->empty()) 1323 return; 1324 1325 assert(mapOrSet->getNumInputs() == operands->size() && 1326 "map/set inputs must match number of operands"); 1327 1328 auto *context = mapOrSet->getContext(); 1329 SmallVector<Value, 8> resultOperands; 1330 resultOperands.reserve(operands->size()); 1331 SmallVector<Value, 8> remappedSymbols; 1332 remappedSymbols.reserve(operands->size()); 1333 unsigned nextDim = 0; 1334 unsigned nextSym = 0; 1335 unsigned oldNumSyms = mapOrSet->getNumSymbols(); 1336 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims()); 1337 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) { 1338 if (i < mapOrSet->getNumDims()) { 1339 if (isValidSymbol((*operands)[i])) { 1340 // This is a valid symbol that appears as a dim, canonicalize it. 1341 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); 1342 remappedSymbols.push_back((*operands)[i]); 1343 } else { 1344 dimRemapping[i] = getAffineDimExpr(nextDim++, context); 1345 resultOperands.push_back((*operands)[i]); 1346 } 1347 } else { 1348 resultOperands.push_back((*operands)[i]); 1349 } 1350 } 1351 1352 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); 1353 *operands = resultOperands; 1354 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim, 1355 oldNumSyms + nextSym); 1356 1357 assert(mapOrSet->getNumInputs() == operands->size() && 1358 "map/set inputs must match number of operands"); 1359 } 1360 1361 // Works for either an affine map or an integer set. 1362 template <class MapOrSet> 1363 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, 1364 SmallVectorImpl<Value> *operands) { 1365 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value, 1366 "Argument must be either of AffineMap or IntegerSet type"); 1367 1368 if (!mapOrSet || operands->empty()) 1369 return; 1370 1371 assert(mapOrSet->getNumInputs() == operands->size() && 1372 "map/set inputs must match number of operands"); 1373 1374 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands); 1375 1376 // Check to see what dims are used. 1377 llvm::SmallBitVector usedDims(mapOrSet->getNumDims()); 1378 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols()); 1379 mapOrSet->walkExprs([&](AffineExpr expr) { 1380 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 1381 usedDims[dimExpr.getPosition()] = true; 1382 else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) 1383 usedSyms[symExpr.getPosition()] = true; 1384 }); 1385 1386 auto *context = mapOrSet->getContext(); 1387 1388 SmallVector<Value, 8> resultOperands; 1389 resultOperands.reserve(operands->size()); 1390 1391 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims; 1392 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims()); 1393 unsigned nextDim = 0; 1394 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { 1395 if (usedDims[i]) { 1396 // Remap dim positions for duplicate operands. 1397 auto it = seenDims.find((*operands)[i]); 1398 if (it == seenDims.end()) { 1399 dimRemapping[i] = getAffineDimExpr(nextDim++, context); 1400 resultOperands.push_back((*operands)[i]); 1401 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); 1402 } else { 1403 dimRemapping[i] = it->second; 1404 } 1405 } 1406 } 1407 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols; 1408 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols()); 1409 unsigned nextSym = 0; 1410 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { 1411 if (!usedSyms[i]) 1412 continue; 1413 // Handle constant operands (only needed for symbolic operands since 1414 // constant operands in dimensional positions would have already been 1415 // promoted to symbolic positions above). 1416 IntegerAttr operandCst; 1417 if (matchPattern((*operands)[i + mapOrSet->getNumDims()], 1418 m_Constant(&operandCst))) { 1419 symRemapping[i] = 1420 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context); 1421 continue; 1422 } 1423 // Remap symbol positions for duplicate operands. 1424 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]); 1425 if (it == seenSymbols.end()) { 1426 symRemapping[i] = getAffineSymbolExpr(nextSym++, context); 1427 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]); 1428 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()], 1429 symRemapping[i])); 1430 } else { 1431 symRemapping[i] = it->second; 1432 } 1433 } 1434 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping, 1435 nextDim, nextSym); 1436 *operands = resultOperands; 1437 } 1438 1439 void mlir::affine::canonicalizeMapAndOperands( 1440 AffineMap *map, SmallVectorImpl<Value> *operands) { 1441 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands); 1442 } 1443 1444 void mlir::affine::canonicalizeSetAndOperands( 1445 IntegerSet *set, SmallVectorImpl<Value> *operands) { 1446 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands); 1447 } 1448 1449 namespace { 1450 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing 1451 /// maps that supply results into them. 1452 /// 1453 template <typename AffineOpTy> 1454 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> { 1455 using OpRewritePattern<AffineOpTy>::OpRewritePattern; 1456 1457 /// Replace the affine op with another instance of it with the supplied 1458 /// map and mapOperands. 1459 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, 1460 AffineMap map, ArrayRef<Value> mapOperands) const; 1461 1462 LogicalResult matchAndRewrite(AffineOpTy affineOp, 1463 PatternRewriter &rewriter) const override { 1464 static_assert( 1465 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp, 1466 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp, 1467 AffineVectorStoreOp, AffineVectorLoadOp>::value, 1468 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op " 1469 "expected"); 1470 auto map = affineOp.getAffineMap(); 1471 AffineMap oldMap = map; 1472 auto oldOperands = affineOp.getMapOperands(); 1473 SmallVector<Value, 8> resultOperands(oldOperands); 1474 composeAffineMapAndOperands(&map, &resultOperands); 1475 canonicalizeMapAndOperands(&map, &resultOperands); 1476 simplifyMapWithOperands(map, resultOperands); 1477 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), 1478 resultOperands.begin())) 1479 return failure(); 1480 1481 replaceAffineOp(rewriter, affineOp, map, resultOperands); 1482 return success(); 1483 } 1484 }; 1485 1486 // Specialize the template to account for the different build signatures for 1487 // affine load, store, and apply ops. 1488 template <> 1489 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp( 1490 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, 1491 ArrayRef<Value> mapOperands) const { 1492 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map, 1493 mapOperands); 1494 } 1495 template <> 1496 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp( 1497 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, 1498 ArrayRef<Value> mapOperands) const { 1499 rewriter.replaceOpWithNewOp<AffinePrefetchOp>( 1500 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(), 1501 prefetch.getLocalityHint(), prefetch.getIsDataCache()); 1502 } 1503 template <> 1504 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp( 1505 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, 1506 ArrayRef<Value> mapOperands) const { 1507 rewriter.replaceOpWithNewOp<AffineStoreOp>( 1508 store, store.getValueToStore(), store.getMemRef(), map, mapOperands); 1509 } 1510 template <> 1511 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp( 1512 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map, 1513 ArrayRef<Value> mapOperands) const { 1514 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>( 1515 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map, 1516 mapOperands); 1517 } 1518 template <> 1519 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp( 1520 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map, 1521 ArrayRef<Value> mapOperands) const { 1522 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>( 1523 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map, 1524 mapOperands); 1525 } 1526 1527 // Generic version for ops that don't have extra operands. 1528 template <typename AffineOpTy> 1529 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp( 1530 PatternRewriter &rewriter, AffineOpTy op, AffineMap map, 1531 ArrayRef<Value> mapOperands) const { 1532 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands); 1533 } 1534 } // namespace 1535 1536 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results, 1537 MLIRContext *context) { 1538 results.add<SimplifyAffineOp<AffineApplyOp>>(context); 1539 } 1540 1541 //===----------------------------------------------------------------------===// 1542 // AffineDmaStartOp 1543 //===----------------------------------------------------------------------===// 1544 1545 // TODO: Check that map operands are loop IVs or symbols. 1546 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result, 1547 Value srcMemRef, AffineMap srcMap, 1548 ValueRange srcIndices, Value destMemRef, 1549 AffineMap dstMap, ValueRange destIndices, 1550 Value tagMemRef, AffineMap tagMap, 1551 ValueRange tagIndices, Value numElements, 1552 Value stride, Value elementsPerStride) { 1553 result.addOperands(srcMemRef); 1554 result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap)); 1555 result.addOperands(srcIndices); 1556 result.addOperands(destMemRef); 1557 result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap)); 1558 result.addOperands(destIndices); 1559 result.addOperands(tagMemRef); 1560 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap)); 1561 result.addOperands(tagIndices); 1562 result.addOperands(numElements); 1563 if (stride) { 1564 result.addOperands({stride, elementsPerStride}); 1565 } 1566 } 1567 1568 void AffineDmaStartOp::print(OpAsmPrinter &p) { 1569 p << " " << getSrcMemRef() << '['; 1570 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); 1571 p << "], " << getDstMemRef() << '['; 1572 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); 1573 p << "], " << getTagMemRef() << '['; 1574 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices()); 1575 p << "], " << getNumElements(); 1576 if (isStrided()) { 1577 p << ", " << getStride(); 1578 p << ", " << getNumElementsPerStride(); 1579 } 1580 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " 1581 << getTagMemRefType(); 1582 } 1583 1584 // Parse AffineDmaStartOp. 1585 // Ex: 1586 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, 1587 // %stride, %num_elt_per_stride 1588 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> 1589 // 1590 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser, 1591 OperationState &result) { 1592 OpAsmParser::UnresolvedOperand srcMemRefInfo; 1593 AffineMapAttr srcMapAttr; 1594 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcMapOperands; 1595 OpAsmParser::UnresolvedOperand dstMemRefInfo; 1596 AffineMapAttr dstMapAttr; 1597 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstMapOperands; 1598 OpAsmParser::UnresolvedOperand tagMemRefInfo; 1599 AffineMapAttr tagMapAttr; 1600 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagMapOperands; 1601 OpAsmParser::UnresolvedOperand numElementsInfo; 1602 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo; 1603 1604 SmallVector<Type, 3> types; 1605 auto indexType = parser.getBuilder().getIndexType(); 1606 1607 // Parse and resolve the following list of operands: 1608 // *) dst memref followed by its affine maps operands (in square brackets). 1609 // *) src memref followed by its affine map operands (in square brackets). 1610 // *) tag memref followed by its affine map operands (in square brackets). 1611 // *) number of elements transferred by DMA operation. 1612 if (parser.parseOperand(srcMemRefInfo) || 1613 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, 1614 getSrcMapAttrStrName(), 1615 result.attributes) || 1616 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 1617 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, 1618 getDstMapAttrStrName(), 1619 result.attributes) || 1620 parser.parseComma() || parser.parseOperand(tagMemRefInfo) || 1621 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, 1622 getTagMapAttrStrName(), 1623 result.attributes) || 1624 parser.parseComma() || parser.parseOperand(numElementsInfo)) 1625 return failure(); 1626 1627 // Parse optional stride and elements per stride. 1628 if (parser.parseTrailingOperandList(strideInfo)) 1629 return failure(); 1630 1631 if (!strideInfo.empty() && strideInfo.size() != 2) { 1632 return parser.emitError(parser.getNameLoc(), 1633 "expected two stride related operands"); 1634 } 1635 bool isStrided = strideInfo.size() == 2; 1636 1637 if (parser.parseColonTypeList(types)) 1638 return failure(); 1639 1640 if (types.size() != 3) 1641 return parser.emitError(parser.getNameLoc(), "expected three types"); 1642 1643 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 1644 parser.resolveOperands(srcMapOperands, indexType, result.operands) || 1645 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 1646 parser.resolveOperands(dstMapOperands, indexType, result.operands) || 1647 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) || 1648 parser.resolveOperands(tagMapOperands, indexType, result.operands) || 1649 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1650 return failure(); 1651 1652 if (isStrided) { 1653 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1654 return failure(); 1655 } 1656 1657 // Check that src/dst/tag operand counts match their map.numInputs. 1658 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || 1659 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || 1660 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) 1661 return parser.emitError(parser.getNameLoc(), 1662 "memref operand count not equal to map.numInputs"); 1663 return success(); 1664 } 1665 1666 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() { 1667 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType())) 1668 return emitOpError("expected DMA source to be of memref type"); 1669 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType())) 1670 return emitOpError("expected DMA destination to be of memref type"); 1671 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType())) 1672 return emitOpError("expected DMA tag to be of memref type"); 1673 1674 unsigned numInputsAllMaps = getSrcMap().getNumInputs() + 1675 getDstMap().getNumInputs() + 1676 getTagMap().getNumInputs(); 1677 if (getNumOperands() != numInputsAllMaps + 3 + 1 && 1678 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { 1679 return emitOpError("incorrect number of operands"); 1680 } 1681 1682 Region *scope = getAffineScope(*this); 1683 for (auto idx : getSrcIndices()) { 1684 if (!idx.getType().isIndex()) 1685 return emitOpError("src index to dma_start must have 'index' type"); 1686 if (!isValidAffineIndexOperand(idx, scope)) 1687 return emitOpError( 1688 "src index must be a valid dimension or symbol identifier"); 1689 } 1690 for (auto idx : getDstIndices()) { 1691 if (!idx.getType().isIndex()) 1692 return emitOpError("dst index to dma_start must have 'index' type"); 1693 if (!isValidAffineIndexOperand(idx, scope)) 1694 return emitOpError( 1695 "dst index must be a valid dimension or symbol identifier"); 1696 } 1697 for (auto idx : getTagIndices()) { 1698 if (!idx.getType().isIndex()) 1699 return emitOpError("tag index to dma_start must have 'index' type"); 1700 if (!isValidAffineIndexOperand(idx, scope)) 1701 return emitOpError( 1702 "tag index must be a valid dimension or symbol identifier"); 1703 } 1704 return success(); 1705 } 1706 1707 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1708 SmallVectorImpl<OpFoldResult> &results) { 1709 /// dma_start(memrefcast) -> dma_start 1710 return memref::foldMemRefCast(*this); 1711 } 1712 1713 void AffineDmaStartOp::getEffects( 1714 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1715 &effects) { 1716 effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(), 1717 SideEffects::DefaultResource::get()); 1718 effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(), 1719 SideEffects::DefaultResource::get()); 1720 effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(), 1721 SideEffects::DefaultResource::get()); 1722 } 1723 1724 //===----------------------------------------------------------------------===// 1725 // AffineDmaWaitOp 1726 //===----------------------------------------------------------------------===// 1727 1728 // TODO: Check that map operands are loop IVs or symbols. 1729 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result, 1730 Value tagMemRef, AffineMap tagMap, 1731 ValueRange tagIndices, Value numElements) { 1732 result.addOperands(tagMemRef); 1733 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap)); 1734 result.addOperands(tagIndices); 1735 result.addOperands(numElements); 1736 } 1737 1738 void AffineDmaWaitOp::print(OpAsmPrinter &p) { 1739 p << " " << getTagMemRef() << '['; 1740 SmallVector<Value, 2> operands(getTagIndices()); 1741 p.printAffineMapOfSSAIds(getTagMapAttr(), operands); 1742 p << "], "; 1743 p.printOperand(getNumElements()); 1744 p << " : " << getTagMemRef().getType(); 1745 } 1746 1747 // Parse AffineDmaWaitOp. 1748 // Eg: 1749 // affine.dma_wait %tag[%index], %num_elements 1750 // : memref<1 x i32, (d0) -> (d0), 4> 1751 // 1752 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser, 1753 OperationState &result) { 1754 OpAsmParser::UnresolvedOperand tagMemRefInfo; 1755 AffineMapAttr tagMapAttr; 1756 SmallVector<OpAsmParser::UnresolvedOperand, 2> tagMapOperands; 1757 Type type; 1758 auto indexType = parser.getBuilder().getIndexType(); 1759 OpAsmParser::UnresolvedOperand numElementsInfo; 1760 1761 // Parse tag memref, its map operands, and dma size. 1762 if (parser.parseOperand(tagMemRefInfo) || 1763 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, 1764 getTagMapAttrStrName(), 1765 result.attributes) || 1766 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1767 parser.parseColonType(type) || 1768 parser.resolveOperand(tagMemRefInfo, type, result.operands) || 1769 parser.resolveOperands(tagMapOperands, indexType, result.operands) || 1770 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1771 return failure(); 1772 1773 if (!llvm::isa<MemRefType>(type)) 1774 return parser.emitError(parser.getNameLoc(), 1775 "expected tag to be of memref type"); 1776 1777 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) 1778 return parser.emitError(parser.getNameLoc(), 1779 "tag memref operand count != to map.numInputs"); 1780 return success(); 1781 } 1782 1783 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() { 1784 if (!llvm::isa<MemRefType>(getOperand(0).getType())) 1785 return emitOpError("expected DMA tag to be of memref type"); 1786 Region *scope = getAffineScope(*this); 1787 for (auto idx : getTagIndices()) { 1788 if (!idx.getType().isIndex()) 1789 return emitOpError("index to dma_wait must have 'index' type"); 1790 if (!isValidAffineIndexOperand(idx, scope)) 1791 return emitOpError( 1792 "index must be a valid dimension or symbol identifier"); 1793 } 1794 return success(); 1795 } 1796 1797 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1798 SmallVectorImpl<OpFoldResult> &results) { 1799 /// dma_wait(memrefcast) -> dma_wait 1800 return memref::foldMemRefCast(*this); 1801 } 1802 1803 void AffineDmaWaitOp::getEffects( 1804 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1805 &effects) { 1806 effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(), 1807 SideEffects::DefaultResource::get()); 1808 } 1809 1810 //===----------------------------------------------------------------------===// 1811 // AffineForOp 1812 //===----------------------------------------------------------------------===// 1813 1814 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and 1815 /// bodyBuilder are empty/null, we include default terminator op. 1816 void AffineForOp::build(OpBuilder &builder, OperationState &result, 1817 ValueRange lbOperands, AffineMap lbMap, 1818 ValueRange ubOperands, AffineMap ubMap, int64_t step, 1819 ValueRange iterArgs, BodyBuilderFn bodyBuilder) { 1820 assert(((!lbMap && lbOperands.empty()) || 1821 lbOperands.size() == lbMap.getNumInputs()) && 1822 "lower bound operand count does not match the affine map"); 1823 assert(((!ubMap && ubOperands.empty()) || 1824 ubOperands.size() == ubMap.getNumInputs()) && 1825 "upper bound operand count does not match the affine map"); 1826 assert(step > 0 && "step has to be a positive integer constant"); 1827 1828 OpBuilder::InsertionGuard guard(builder); 1829 1830 // Set variadic segment sizes. 1831 result.addAttribute( 1832 getOperandSegmentSizeAttr(), 1833 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()), 1834 static_cast<int32_t>(ubOperands.size()), 1835 static_cast<int32_t>(iterArgs.size())})); 1836 1837 for (Value val : iterArgs) 1838 result.addTypes(val.getType()); 1839 1840 // Add an attribute for the step. 1841 result.addAttribute(getStepAttrName(result.name), 1842 builder.getIntegerAttr(builder.getIndexType(), step)); 1843 1844 // Add the lower bound. 1845 result.addAttribute(getLowerBoundMapAttrName(result.name), 1846 AffineMapAttr::get(lbMap)); 1847 result.addOperands(lbOperands); 1848 1849 // Add the upper bound. 1850 result.addAttribute(getUpperBoundMapAttrName(result.name), 1851 AffineMapAttr::get(ubMap)); 1852 result.addOperands(ubOperands); 1853 1854 result.addOperands(iterArgs); 1855 // Create a region and a block for the body. The argument of the region is 1856 // the loop induction variable. 1857 Region *bodyRegion = result.addRegion(); 1858 Block *bodyBlock = builder.createBlock(bodyRegion); 1859 Value inductionVar = 1860 bodyBlock->addArgument(builder.getIndexType(), result.location); 1861 for (Value val : iterArgs) 1862 bodyBlock->addArgument(val.getType(), val.getLoc()); 1863 1864 // Create the default terminator if the builder is not provided and if the 1865 // iteration arguments are not provided. Otherwise, leave this to the caller 1866 // because we don't know which values to return from the loop. 1867 if (iterArgs.empty() && !bodyBuilder) { 1868 ensureTerminator(*bodyRegion, builder, result.location); 1869 } else if (bodyBuilder) { 1870 OpBuilder::InsertionGuard guard(builder); 1871 builder.setInsertionPointToStart(bodyBlock); 1872 bodyBuilder(builder, result.location, inductionVar, 1873 bodyBlock->getArguments().drop_front()); 1874 } 1875 } 1876 1877 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb, 1878 int64_t ub, int64_t step, ValueRange iterArgs, 1879 BodyBuilderFn bodyBuilder) { 1880 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext()); 1881 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext()); 1882 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs, 1883 bodyBuilder); 1884 } 1885 1886 LogicalResult AffineForOp::verifyRegions() { 1887 // Check that the body defines as single block argument for the induction 1888 // variable. 1889 auto *body = getBody(); 1890 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) 1891 return emitOpError("expected body to have a single index argument for the " 1892 "induction variable"); 1893 1894 // Verify that the bound operands are valid dimension/symbols. 1895 /// Lower bound. 1896 if (getLowerBoundMap().getNumInputs() > 0) 1897 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(), 1898 getLowerBoundMap().getNumDims()))) 1899 return failure(); 1900 /// Upper bound. 1901 if (getUpperBoundMap().getNumInputs() > 0) 1902 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(), 1903 getUpperBoundMap().getNumDims()))) 1904 return failure(); 1905 1906 unsigned opNumResults = getNumResults(); 1907 if (opNumResults == 0) 1908 return success(); 1909 1910 // If ForOp defines values, check that the number and types of the defined 1911 // values match ForOp initial iter operands and backedge basic block 1912 // arguments. 1913 if (getNumIterOperands() != opNumResults) 1914 return emitOpError( 1915 "mismatch between the number of loop-carried values and results"); 1916 if (getNumRegionIterArgs() != opNumResults) 1917 return emitOpError( 1918 "mismatch between the number of basic block args and results"); 1919 1920 return success(); 1921 } 1922 1923 /// Parse a for operation loop bounds. 1924 static ParseResult parseBound(bool isLower, OperationState &result, 1925 OpAsmParser &p) { 1926 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if 1927 // the map has multiple results. 1928 bool failedToParsedMinMax = 1929 failed(p.parseOptionalKeyword(isLower ? "max" : "min")); 1930 1931 auto &builder = p.getBuilder(); 1932 auto boundAttrStrName = 1933 isLower ? AffineForOp::getLowerBoundMapAttrName(result.name) 1934 : AffineForOp::getUpperBoundMapAttrName(result.name); 1935 1936 // Parse ssa-id as identity map. 1937 SmallVector<OpAsmParser::UnresolvedOperand, 1> boundOpInfos; 1938 if (p.parseOperandList(boundOpInfos)) 1939 return failure(); 1940 1941 if (!boundOpInfos.empty()) { 1942 // Check that only one operand was parsed. 1943 if (boundOpInfos.size() > 1) 1944 return p.emitError(p.getNameLoc(), 1945 "expected only one loop bound operand"); 1946 1947 // TODO: improve error message when SSA value is not of index type. 1948 // Currently it is 'use of value ... expects different type than prior uses' 1949 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(), 1950 result.operands)) 1951 return failure(); 1952 1953 // Create an identity map using symbol id. This representation is optimized 1954 // for storage. Analysis passes may expand it into a multi-dimensional map 1955 // if desired. 1956 AffineMap map = builder.getSymbolIdentityMap(); 1957 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map)); 1958 return success(); 1959 } 1960 1961 // Get the attribute location. 1962 SMLoc attrLoc = p.getCurrentLocation(); 1963 1964 Attribute boundAttr; 1965 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName, 1966 result.attributes)) 1967 return failure(); 1968 1969 // Parse full form - affine map followed by dim and symbol list. 1970 if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) { 1971 unsigned currentNumOperands = result.operands.size(); 1972 unsigned numDims; 1973 if (parseDimAndSymbolList(p, result.operands, numDims)) 1974 return failure(); 1975 1976 auto map = affineMapAttr.getValue(); 1977 if (map.getNumDims() != numDims) 1978 return p.emitError( 1979 p.getNameLoc(), 1980 "dim operand count and affine map dim count must match"); 1981 1982 unsigned numDimAndSymbolOperands = 1983 result.operands.size() - currentNumOperands; 1984 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) 1985 return p.emitError( 1986 p.getNameLoc(), 1987 "symbol operand count and affine map symbol count must match"); 1988 1989 // If the map has multiple results, make sure that we parsed the min/max 1990 // prefix. 1991 if (map.getNumResults() > 1 && failedToParsedMinMax) { 1992 if (isLower) { 1993 return p.emitError(attrLoc, "lower loop bound affine map with " 1994 "multiple results requires 'max' prefix"); 1995 } 1996 return p.emitError(attrLoc, "upper loop bound affine map with multiple " 1997 "results requires 'min' prefix"); 1998 } 1999 return success(); 2000 } 2001 2002 // Parse custom assembly form. 2003 if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) { 2004 result.attributes.pop_back(); 2005 result.addAttribute( 2006 boundAttrStrName, 2007 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt()))); 2008 return success(); 2009 } 2010 2011 return p.emitError( 2012 p.getNameLoc(), 2013 "expected valid affine map representation for loop bounds"); 2014 } 2015 2016 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) { 2017 auto &builder = parser.getBuilder(); 2018 OpAsmParser::Argument inductionVariable; 2019 inductionVariable.type = builder.getIndexType(); 2020 // Parse the induction variable followed by '='. 2021 if (parser.parseArgument(inductionVariable) || parser.parseEqual()) 2022 return failure(); 2023 2024 // Parse loop bounds. 2025 int64_t numOperands = result.operands.size(); 2026 if (parseBound(/*isLower=*/true, result, parser)) 2027 return failure(); 2028 int64_t numLbOperands = result.operands.size() - numOperands; 2029 if (parser.parseKeyword("to", " between bounds")) 2030 return failure(); 2031 numOperands = result.operands.size(); 2032 if (parseBound(/*isLower=*/false, result, parser)) 2033 return failure(); 2034 int64_t numUbOperands = result.operands.size() - numOperands; 2035 2036 // Parse the optional loop step, we default to 1 if one is not present. 2037 if (parser.parseOptionalKeyword("step")) { 2038 result.addAttribute( 2039 getStepAttrName(result.name), 2040 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); 2041 } else { 2042 SMLoc stepLoc = parser.getCurrentLocation(); 2043 IntegerAttr stepAttr; 2044 if (parser.parseAttribute(stepAttr, builder.getIndexType(), 2045 getStepAttrName(result.name).data(), 2046 result.attributes)) 2047 return failure(); 2048 2049 if (stepAttr.getValue().isNegative()) 2050 return parser.emitError( 2051 stepLoc, 2052 "expected step to be representable as a positive signed integer"); 2053 } 2054 2055 // Parse the optional initial iteration arguments. 2056 SmallVector<OpAsmParser::Argument, 4> regionArgs; 2057 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; 2058 2059 // Induction variable. 2060 regionArgs.push_back(inductionVariable); 2061 2062 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 2063 // Parse assignment list and results type list. 2064 if (parser.parseAssignmentList(regionArgs, operands) || 2065 parser.parseArrowTypeList(result.types)) 2066 return failure(); 2067 // Resolve input operands. 2068 for (auto argOperandType : 2069 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) { 2070 Type type = std::get<2>(argOperandType); 2071 std::get<0>(argOperandType).type = type; 2072 if (parser.resolveOperand(std::get<1>(argOperandType), type, 2073 result.operands)) 2074 return failure(); 2075 } 2076 } 2077 2078 result.addAttribute( 2079 getOperandSegmentSizeAttr(), 2080 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands), 2081 static_cast<int32_t>(numUbOperands), 2082 static_cast<int32_t>(operands.size())})); 2083 2084 // Parse the body region. 2085 Region *body = result.addRegion(); 2086 if (regionArgs.size() != result.types.size() + 1) 2087 return parser.emitError( 2088 parser.getNameLoc(), 2089 "mismatch between the number of loop-carried values and results"); 2090 if (parser.parseRegion(*body, regionArgs)) 2091 return failure(); 2092 2093 AffineForOp::ensureTerminator(*body, builder, result.location); 2094 2095 // Parse the optional attribute list. 2096 return parser.parseOptionalAttrDict(result.attributes); 2097 } 2098 2099 static void printBound(AffineMapAttr boundMap, 2100 Operation::operand_range boundOperands, 2101 const char *prefix, OpAsmPrinter &p) { 2102 AffineMap map = boundMap.getValue(); 2103 2104 // Check if this bound should be printed using custom assembly form. 2105 // The decision to restrict printing custom assembly form to trivial cases 2106 // comes from the will to roundtrip MLIR binary -> text -> binary in a 2107 // lossless way. 2108 // Therefore, custom assembly form parsing and printing is only supported for 2109 // zero-operand constant maps and single symbol operand identity maps. 2110 if (map.getNumResults() == 1) { 2111 AffineExpr expr = map.getResult(0); 2112 2113 // Print constant bound. 2114 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { 2115 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) { 2116 p << constExpr.getValue(); 2117 return; 2118 } 2119 } 2120 2121 // Print bound that consists of a single SSA symbol if the map is over a 2122 // single symbol. 2123 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { 2124 if (dyn_cast<AffineSymbolExpr>(expr)) { 2125 p.printOperand(*boundOperands.begin()); 2126 return; 2127 } 2128 } 2129 } else { 2130 // Map has multiple results. Print 'min' or 'max' prefix. 2131 p << prefix << ' '; 2132 } 2133 2134 // Print the map and its operands. 2135 p << boundMap; 2136 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), 2137 map.getNumDims(), p); 2138 } 2139 2140 unsigned AffineForOp::getNumIterOperands() { 2141 AffineMap lbMap = getLowerBoundMapAttr().getValue(); 2142 AffineMap ubMap = getUpperBoundMapAttr().getValue(); 2143 2144 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs(); 2145 } 2146 2147 std::optional<MutableArrayRef<OpOperand>> 2148 AffineForOp::getYieldedValuesMutable() { 2149 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable(); 2150 } 2151 2152 void AffineForOp::print(OpAsmPrinter &p) { 2153 p << ' '; 2154 p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{}, 2155 /*omitType=*/true); 2156 p << " = "; 2157 printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p); 2158 p << " to "; 2159 printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p); 2160 2161 if (getStepAsInt() != 1) 2162 p << " step " << getStepAsInt(); 2163 2164 bool printBlockTerminators = false; 2165 if (getNumIterOperands() > 0) { 2166 p << " iter_args("; 2167 auto regionArgs = getRegionIterArgs(); 2168 auto operands = getInits(); 2169 2170 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 2171 p << std::get<0>(it) << " = " << std::get<1>(it); 2172 }); 2173 p << ") -> (" << getResultTypes() << ")"; 2174 printBlockTerminators = true; 2175 } 2176 2177 p << ' '; 2178 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2179 printBlockTerminators); 2180 p.printOptionalAttrDict( 2181 (*this)->getAttrs(), 2182 /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()), 2183 getUpperBoundMapAttrName(getOperation()->getName()), 2184 getStepAttrName(getOperation()->getName()), 2185 getOperandSegmentSizeAttr()}); 2186 } 2187 2188 /// Fold the constant bounds of a loop. 2189 static LogicalResult foldLoopBounds(AffineForOp forOp) { 2190 auto foldLowerOrUpperBound = [&forOp](bool lower) { 2191 // Check to see if each of the operands is the result of a constant. If 2192 // so, get the value. If not, ignore it. 2193 SmallVector<Attribute, 8> operandConstants; 2194 auto boundOperands = 2195 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); 2196 for (auto operand : boundOperands) { 2197 Attribute operandCst; 2198 matchPattern(operand, m_Constant(&operandCst)); 2199 operandConstants.push_back(operandCst); 2200 } 2201 2202 AffineMap boundMap = 2203 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); 2204 assert(boundMap.getNumResults() >= 1 && 2205 "bound maps should have at least one result"); 2206 SmallVector<Attribute, 4> foldedResults; 2207 if (failed(boundMap.constantFold(operandConstants, foldedResults))) 2208 return failure(); 2209 2210 // Compute the max or min as applicable over the results. 2211 assert(!foldedResults.empty() && "bounds should have at least one result"); 2212 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue(); 2213 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { 2214 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue(); 2215 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) 2216 : llvm::APIntOps::smin(maxOrMin, foldedResult); 2217 } 2218 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) 2219 : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); 2220 return success(); 2221 }; 2222 2223 // Try to fold the lower bound. 2224 bool folded = false; 2225 if (!forOp.hasConstantLowerBound()) 2226 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); 2227 2228 // Try to fold the upper bound. 2229 if (!forOp.hasConstantUpperBound()) 2230 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); 2231 return success(folded); 2232 } 2233 2234 /// Canonicalize the bounds of the given loop. 2235 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { 2236 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands()); 2237 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands()); 2238 2239 auto lbMap = forOp.getLowerBoundMap(); 2240 auto ubMap = forOp.getUpperBoundMap(); 2241 auto prevLbMap = lbMap; 2242 auto prevUbMap = ubMap; 2243 2244 composeAffineMapAndOperands(&lbMap, &lbOperands); 2245 canonicalizeMapAndOperands(&lbMap, &lbOperands); 2246 simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true); 2247 simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false); 2248 lbMap = removeDuplicateExprs(lbMap); 2249 2250 composeAffineMapAndOperands(&ubMap, &ubOperands); 2251 canonicalizeMapAndOperands(&ubMap, &ubOperands); 2252 ubMap = removeDuplicateExprs(ubMap); 2253 2254 // Any canonicalization change always leads to updated map(s). 2255 if (lbMap == prevLbMap && ubMap == prevUbMap) 2256 return failure(); 2257 2258 if (lbMap != prevLbMap) 2259 forOp.setLowerBound(lbOperands, lbMap); 2260 if (ubMap != prevUbMap) 2261 forOp.setUpperBound(ubOperands, ubMap); 2262 return success(); 2263 } 2264 2265 namespace { 2266 /// Returns constant trip count in trivial cases. 2267 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) { 2268 int64_t step = forOp.getStepAsInt(); 2269 if (!forOp.hasConstantBounds() || step <= 0) 2270 return std::nullopt; 2271 int64_t lb = forOp.getConstantLowerBound(); 2272 int64_t ub = forOp.getConstantUpperBound(); 2273 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step; 2274 } 2275 2276 /// This is a pattern to fold trivially empty loop bodies. 2277 /// TODO: This should be moved into the folding hook. 2278 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> { 2279 using OpRewritePattern<AffineForOp>::OpRewritePattern; 2280 2281 LogicalResult matchAndRewrite(AffineForOp forOp, 2282 PatternRewriter &rewriter) const override { 2283 // Check that the body only contains a yield. 2284 if (!llvm::hasSingleElement(*forOp.getBody())) 2285 return failure(); 2286 if (forOp.getNumResults() == 0) 2287 return success(); 2288 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp); 2289 if (tripCount && *tripCount == 0) { 2290 // The initial values of the iteration arguments would be the op's 2291 // results. 2292 rewriter.replaceOp(forOp, forOp.getInits()); 2293 return success(); 2294 } 2295 SmallVector<Value, 4> replacements; 2296 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator()); 2297 auto iterArgs = forOp.getRegionIterArgs(); 2298 bool hasValDefinedOutsideLoop = false; 2299 bool iterArgsNotInOrder = false; 2300 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) { 2301 Value val = yieldOp.getOperand(i); 2302 auto *iterArgIt = llvm::find(iterArgs, val); 2303 if (iterArgIt == iterArgs.end()) { 2304 // `val` is defined outside of the loop. 2305 assert(forOp.isDefinedOutsideOfLoop(val) && 2306 "must be defined outside of the loop"); 2307 hasValDefinedOutsideLoop = true; 2308 replacements.push_back(val); 2309 } else { 2310 unsigned pos = std::distance(iterArgs.begin(), iterArgIt); 2311 if (pos != i) 2312 iterArgsNotInOrder = true; 2313 replacements.push_back(forOp.getInits()[pos]); 2314 } 2315 } 2316 // Bail out when the trip count is unknown and the loop returns any value 2317 // defined outside of the loop or any iterArg out of order. 2318 if (!tripCount.has_value() && 2319 (hasValDefinedOutsideLoop || iterArgsNotInOrder)) 2320 return failure(); 2321 // Bail out when the loop iterates more than once and it returns any iterArg 2322 // out of order. 2323 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder) 2324 return failure(); 2325 rewriter.replaceOp(forOp, replacements); 2326 return success(); 2327 } 2328 }; 2329 } // namespace 2330 2331 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results, 2332 MLIRContext *context) { 2333 results.add<AffineForEmptyLoopFolder>(context); 2334 } 2335 2336 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { 2337 assert((point.isParent() || point == getRegion()) && "invalid region point"); 2338 2339 // The initial operands map to the loop arguments after the induction 2340 // variable or are forwarded to the results when the trip count is zero. 2341 return getInits(); 2342 } 2343 2344 void AffineForOp::getSuccessorRegions( 2345 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 2346 assert((point.isParent() || point == getRegion()) && "expected loop region"); 2347 // The loop may typically branch back to its body or to the parent operation. 2348 // If the predecessor is the parent op and the trip count is known to be at 2349 // least one, branch into the body using the iterator arguments. And in cases 2350 // we know the trip count is zero, it can only branch back to its parent. 2351 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this); 2352 if (point.isParent() && tripCount.has_value()) { 2353 if (tripCount.value() > 0) { 2354 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); 2355 return; 2356 } 2357 if (tripCount.value() == 0) { 2358 regions.push_back(RegionSuccessor(getResults())); 2359 return; 2360 } 2361 } 2362 2363 // From the loop body, if the trip count is one, we can only branch back to 2364 // the parent. 2365 if (!point.isParent() && tripCount && *tripCount == 1) { 2366 regions.push_back(RegionSuccessor(getResults())); 2367 return; 2368 } 2369 2370 // In all other cases, the loop may branch back to itself or the parent 2371 // operation. 2372 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); 2373 regions.push_back(RegionSuccessor(getResults())); 2374 } 2375 2376 /// Returns true if the affine.for has zero iterations in trivial cases. 2377 static bool hasTrivialZeroTripCount(AffineForOp op) { 2378 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op); 2379 return tripCount && *tripCount == 0; 2380 } 2381 2382 LogicalResult AffineForOp::fold(FoldAdaptor adaptor, 2383 SmallVectorImpl<OpFoldResult> &results) { 2384 bool folded = succeeded(foldLoopBounds(*this)); 2385 folded |= succeeded(canonicalizeLoopBounds(*this)); 2386 if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) { 2387 // The initial values of the loop-carried variables (iter_args) are the 2388 // results of the op. But this must be avoided for an affine.for op that 2389 // does not return any results. Since ops that do not return results cannot 2390 // be folded away, we would enter an infinite loop of folds on the same 2391 // affine.for op. 2392 results.assign(getInits().begin(), getInits().end()); 2393 folded = true; 2394 } 2395 return success(folded); 2396 } 2397 2398 AffineBound AffineForOp::getLowerBound() { 2399 return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap()); 2400 } 2401 2402 AffineBound AffineForOp::getUpperBound() { 2403 return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap()); 2404 } 2405 2406 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) { 2407 assert(lbOperands.size() == map.getNumInputs()); 2408 assert(map.getNumResults() >= 1 && "bound map has at least one result"); 2409 getLowerBoundOperandsMutable().assign(lbOperands); 2410 setLowerBoundMap(map); 2411 } 2412 2413 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) { 2414 assert(ubOperands.size() == map.getNumInputs()); 2415 assert(map.getNumResults() >= 1 && "bound map has at least one result"); 2416 getUpperBoundOperandsMutable().assign(ubOperands); 2417 setUpperBoundMap(map); 2418 } 2419 2420 bool AffineForOp::hasConstantLowerBound() { 2421 return getLowerBoundMap().isSingleConstant(); 2422 } 2423 2424 bool AffineForOp::hasConstantUpperBound() { 2425 return getUpperBoundMap().isSingleConstant(); 2426 } 2427 2428 int64_t AffineForOp::getConstantLowerBound() { 2429 return getLowerBoundMap().getSingleConstantResult(); 2430 } 2431 2432 int64_t AffineForOp::getConstantUpperBound() { 2433 return getUpperBoundMap().getSingleConstantResult(); 2434 } 2435 2436 void AffineForOp::setConstantLowerBound(int64_t value) { 2437 setLowerBound({}, AffineMap::getConstantMap(value, getContext())); 2438 } 2439 2440 void AffineForOp::setConstantUpperBound(int64_t value) { 2441 setUpperBound({}, AffineMap::getConstantMap(value, getContext())); 2442 } 2443 2444 AffineForOp::operand_range AffineForOp::getControlOperands() { 2445 return {operand_begin(), operand_begin() + getLowerBoundOperands().size() + 2446 getUpperBoundOperands().size()}; 2447 } 2448 2449 bool AffineForOp::matchingBoundOperandList() { 2450 auto lbMap = getLowerBoundMap(); 2451 auto ubMap = getUpperBoundMap(); 2452 if (lbMap.getNumDims() != ubMap.getNumDims() || 2453 lbMap.getNumSymbols() != ubMap.getNumSymbols()) 2454 return false; 2455 2456 unsigned numOperands = lbMap.getNumInputs(); 2457 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { 2458 // Compare Value 's. 2459 if (getOperand(i) != getOperand(numOperands + i)) 2460 return false; 2461 } 2462 return true; 2463 } 2464 2465 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; } 2466 2467 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() { 2468 return SmallVector<Value>{getInductionVar()}; 2469 } 2470 2471 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() { 2472 if (!hasConstantLowerBound()) 2473 return std::nullopt; 2474 OpBuilder b(getContext()); 2475 return SmallVector<OpFoldResult>{ 2476 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))}; 2477 } 2478 2479 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() { 2480 OpBuilder b(getContext()); 2481 return SmallVector<OpFoldResult>{ 2482 OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))}; 2483 } 2484 2485 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() { 2486 if (!hasConstantUpperBound()) 2487 return {}; 2488 OpBuilder b(getContext()); 2489 return SmallVector<OpFoldResult>{ 2490 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))}; 2491 } 2492 2493 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields( 2494 RewriterBase &rewriter, ValueRange newInitOperands, 2495 bool replaceInitOperandUsesInLoop, 2496 const NewYieldValuesFn &newYieldValuesFn) { 2497 // Create a new loop before the existing one, with the extra operands. 2498 OpBuilder::InsertionGuard g(rewriter); 2499 rewriter.setInsertionPoint(getOperation()); 2500 auto inits = llvm::to_vector(getInits()); 2501 inits.append(newInitOperands.begin(), newInitOperands.end()); 2502 AffineForOp newLoop = rewriter.create<AffineForOp>( 2503 getLoc(), getLowerBoundOperands(), getLowerBoundMap(), 2504 getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits); 2505 2506 // Generate the new yield values and append them to the scf.yield operation. 2507 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator()); 2508 ArrayRef<BlockArgument> newIterArgs = 2509 newLoop.getBody()->getArguments().take_back(newInitOperands.size()); 2510 { 2511 OpBuilder::InsertionGuard g(rewriter); 2512 rewriter.setInsertionPoint(yieldOp); 2513 SmallVector<Value> newYieldedValues = 2514 newYieldValuesFn(rewriter, getLoc(), newIterArgs); 2515 assert(newInitOperands.size() == newYieldedValues.size() && 2516 "expected as many new yield values as new iter operands"); 2517 rewriter.modifyOpInPlace(yieldOp, [&]() { 2518 yieldOp.getOperandsMutable().append(newYieldedValues); 2519 }); 2520 } 2521 2522 // Move the loop body to the new op. 2523 rewriter.mergeBlocks(getBody(), newLoop.getBody(), 2524 newLoop.getBody()->getArguments().take_front( 2525 getBody()->getNumArguments())); 2526 2527 if (replaceInitOperandUsesInLoop) { 2528 // Replace all uses of `newInitOperands` with the corresponding basic block 2529 // arguments. 2530 for (auto it : llvm::zip(newInitOperands, newIterArgs)) { 2531 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it), 2532 [&](OpOperand &use) { 2533 Operation *user = use.getOwner(); 2534 return newLoop->isProperAncestor(user); 2535 }); 2536 } 2537 } 2538 2539 // Replace the old loop. 2540 rewriter.replaceOp(getOperation(), 2541 newLoop->getResults().take_front(getNumResults())); 2542 return cast<LoopLikeOpInterface>(newLoop.getOperation()); 2543 } 2544 2545 Speculation::Speculatability AffineForOp::getSpeculatability() { 2546 // `affine.for (I = Start; I < End; I += 1)` terminates for all values of 2547 // Start and End. 2548 // 2549 // For Step != 1, the loop may not terminate. We can add more smarts here if 2550 // needed. 2551 return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable 2552 : Speculation::NotSpeculatable; 2553 } 2554 2555 /// Returns true if the provided value is the induction variable of a 2556 /// AffineForOp. 2557 bool mlir::affine::isAffineForInductionVar(Value val) { 2558 return getForInductionVarOwner(val) != AffineForOp(); 2559 } 2560 2561 bool mlir::affine::isAffineParallelInductionVar(Value val) { 2562 return getAffineParallelInductionVarOwner(val) != nullptr; 2563 } 2564 2565 bool mlir::affine::isAffineInductionVar(Value val) { 2566 return isAffineForInductionVar(val) || isAffineParallelInductionVar(val); 2567 } 2568 2569 AffineForOp mlir::affine::getForInductionVarOwner(Value val) { 2570 auto ivArg = llvm::dyn_cast<BlockArgument>(val); 2571 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent()) 2572 return AffineForOp(); 2573 if (auto forOp = 2574 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>()) 2575 // Check to make sure `val` is the induction variable, not an iter_arg. 2576 return forOp.getInductionVar() == val ? forOp : AffineForOp(); 2577 return AffineForOp(); 2578 } 2579 2580 AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) { 2581 auto ivArg = llvm::dyn_cast<BlockArgument>(val); 2582 if (!ivArg || !ivArg.getOwner()) 2583 return nullptr; 2584 Operation *containingOp = ivArg.getOwner()->getParentOp(); 2585 auto parallelOp = dyn_cast<AffineParallelOp>(containingOp); 2586 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val)) 2587 return parallelOp; 2588 return nullptr; 2589 } 2590 2591 /// Extracts the induction variables from a list of AffineForOps and returns 2592 /// them. 2593 void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts, 2594 SmallVectorImpl<Value> *ivs) { 2595 ivs->reserve(forInsts.size()); 2596 for (auto forInst : forInsts) 2597 ivs->push_back(forInst.getInductionVar()); 2598 } 2599 2600 void mlir::affine::extractInductionVars(ArrayRef<mlir::Operation *> affineOps, 2601 SmallVectorImpl<mlir::Value> &ivs) { 2602 ivs.reserve(affineOps.size()); 2603 for (Operation *op : affineOps) { 2604 // Add constraints from forOp's bounds. 2605 if (auto forOp = dyn_cast<AffineForOp>(op)) 2606 ivs.push_back(forOp.getInductionVar()); 2607 else if (auto parallelOp = dyn_cast<AffineParallelOp>(op)) 2608 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++) 2609 ivs.push_back(parallelOp.getBody()->getArgument(i)); 2610 } 2611 } 2612 2613 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop 2614 /// operations. 2615 template <typename BoundListTy, typename LoopCreatorTy> 2616 static void buildAffineLoopNestImpl( 2617 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, 2618 ArrayRef<int64_t> steps, 2619 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, 2620 LoopCreatorTy &&loopCreatorFn) { 2621 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments"); 2622 assert(lbs.size() == steps.size() && "Mismatch in number of arguments"); 2623 2624 // If there are no loops to be constructed, construct the body anyway. 2625 OpBuilder::InsertionGuard guard(builder); 2626 if (lbs.empty()) { 2627 if (bodyBuilderFn) 2628 bodyBuilderFn(builder, loc, ValueRange()); 2629 return; 2630 } 2631 2632 // Create the loops iteratively and store the induction variables. 2633 SmallVector<Value, 4> ivs; 2634 ivs.reserve(lbs.size()); 2635 for (unsigned i = 0, e = lbs.size(); i < e; ++i) { 2636 // Callback for creating the loop body, always creates the terminator. 2637 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, 2638 ValueRange iterArgs) { 2639 ivs.push_back(iv); 2640 // In the innermost loop, call the body builder. 2641 if (i == e - 1 && bodyBuilderFn) { 2642 OpBuilder::InsertionGuard nestedGuard(nestedBuilder); 2643 bodyBuilderFn(nestedBuilder, nestedLoc, ivs); 2644 } 2645 nestedBuilder.create<AffineYieldOp>(nestedLoc); 2646 }; 2647 2648 // Delegate actual loop creation to the callback in order to dispatch 2649 // between constant- and variable-bound loops. 2650 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody); 2651 builder.setInsertionPointToStart(loop.getBody()); 2652 } 2653 } 2654 2655 /// Creates an affine loop from the bounds known to be constants. 2656 static AffineForOp 2657 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, 2658 int64_t ub, int64_t step, 2659 AffineForOp::BodyBuilderFn bodyBuilderFn) { 2660 return builder.create<AffineForOp>(loc, lb, ub, step, 2661 /*iterArgs=*/std::nullopt, bodyBuilderFn); 2662 } 2663 2664 /// Creates an affine loop from the bounds that may or may not be constants. 2665 static AffineForOp 2666 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, 2667 int64_t step, 2668 AffineForOp::BodyBuilderFn bodyBuilderFn) { 2669 std::optional<int64_t> lbConst = getConstantIntValue(lb); 2670 std::optional<int64_t> ubConst = getConstantIntValue(ub); 2671 if (lbConst && ubConst) 2672 return buildAffineLoopFromConstants(builder, loc, lbConst.value(), 2673 ubConst.value(), step, bodyBuilderFn); 2674 return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub, 2675 builder.getDimIdentityMap(), step, 2676 /*iterArgs=*/std::nullopt, bodyBuilderFn); 2677 } 2678 2679 void mlir::affine::buildAffineLoopNest( 2680 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs, 2681 ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps, 2682 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { 2683 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, 2684 buildAffineLoopFromConstants); 2685 } 2686 2687 void mlir::affine::buildAffineLoopNest( 2688 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, 2689 ArrayRef<int64_t> steps, 2690 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { 2691 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, 2692 buildAffineLoopFromValues); 2693 } 2694 2695 //===----------------------------------------------------------------------===// 2696 // AffineIfOp 2697 //===----------------------------------------------------------------------===// 2698 2699 namespace { 2700 /// Remove else blocks that have nothing other than a zero value yield. 2701 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> { 2702 using OpRewritePattern<AffineIfOp>::OpRewritePattern; 2703 2704 LogicalResult matchAndRewrite(AffineIfOp ifOp, 2705 PatternRewriter &rewriter) const override { 2706 if (ifOp.getElseRegion().empty() || 2707 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults()) 2708 return failure(); 2709 2710 rewriter.startOpModification(ifOp); 2711 rewriter.eraseBlock(ifOp.getElseBlock()); 2712 rewriter.finalizeOpModification(ifOp); 2713 return success(); 2714 } 2715 }; 2716 2717 /// Removes affine.if cond if the condition is always true or false in certain 2718 /// trivial cases. Promotes the then/else block in the parent operation block. 2719 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> { 2720 using OpRewritePattern<AffineIfOp>::OpRewritePattern; 2721 2722 LogicalResult matchAndRewrite(AffineIfOp op, 2723 PatternRewriter &rewriter) const override { 2724 2725 auto isTriviallyFalse = [](IntegerSet iSet) { 2726 return iSet.isEmptyIntegerSet(); 2727 }; 2728 2729 auto isTriviallyTrue = [](IntegerSet iSet) { 2730 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 && 2731 iSet.getConstraint(0) == 0); 2732 }; 2733 2734 IntegerSet affineIfConditions = op.getIntegerSet(); 2735 Block *blockToMove; 2736 if (isTriviallyFalse(affineIfConditions)) { 2737 // The absence, or equivalently, the emptiness of the else region need not 2738 // be checked when affine.if is returning results because if an affine.if 2739 // operation is returning results, it always has a non-empty else region. 2740 if (op.getNumResults() == 0 && !op.hasElse()) { 2741 // If the else region is absent, or equivalently, empty, remove the 2742 // affine.if operation (which is not returning any results). 2743 rewriter.eraseOp(op); 2744 return success(); 2745 } 2746 blockToMove = op.getElseBlock(); 2747 } else if (isTriviallyTrue(affineIfConditions)) { 2748 blockToMove = op.getThenBlock(); 2749 } else { 2750 return failure(); 2751 } 2752 Operation *blockToMoveTerminator = blockToMove->getTerminator(); 2753 // Promote the "blockToMove" block to the parent operation block between the 2754 // prologue and epilogue of "op". 2755 rewriter.inlineBlockBefore(blockToMove, op); 2756 // Replace the "op" operation with the operands of the 2757 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is 2758 // the affine.yield operation present in the "blockToMove" block. It has no 2759 // operands when affine.if is not returning results and therefore, in that 2760 // case, replaceOp just erases "op". When affine.if is not returning 2761 // results, the affine.yield operation can be omitted. It gets inserted 2762 // implicitly. 2763 rewriter.replaceOp(op, blockToMoveTerminator->getOperands()); 2764 // Erase the "blockToMoveTerminator" operation since it is now in the parent 2765 // operation block, which already has its own terminator. 2766 rewriter.eraseOp(blockToMoveTerminator); 2767 return success(); 2768 } 2769 }; 2770 } // namespace 2771 2772 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be 2773 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp 2774 void AffineIfOp::getSuccessorRegions( 2775 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 2776 // If the predecessor is an AffineIfOp, then branching into both `then` and 2777 // `else` region is valid. 2778 if (point.isParent()) { 2779 regions.reserve(2); 2780 regions.push_back( 2781 RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); 2782 // If the "else" region is empty, branch bach into parent. 2783 if (getElseRegion().empty()) { 2784 regions.push_back(getResults()); 2785 } else { 2786 regions.push_back( 2787 RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); 2788 } 2789 return; 2790 } 2791 2792 // If the predecessor is the `else`/`then` region, then branching into parent 2793 // op is valid. 2794 regions.push_back(RegionSuccessor(getResults())); 2795 } 2796 2797 LogicalResult AffineIfOp::verify() { 2798 // Verify that we have a condition attribute. 2799 // FIXME: This should be specified in the arguments list in ODS. 2800 auto conditionAttr = 2801 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName()); 2802 if (!conditionAttr) 2803 return emitOpError("requires an integer set attribute named 'condition'"); 2804 2805 // Verify that there are enough operands for the condition. 2806 IntegerSet condition = conditionAttr.getValue(); 2807 if (getNumOperands() != condition.getNumInputs()) 2808 return emitOpError("operand count and condition integer set dimension and " 2809 "symbol count must match"); 2810 2811 // Verify that the operands are valid dimension/symbols. 2812 if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), 2813 condition.getNumDims()))) 2814 return failure(); 2815 2816 return success(); 2817 } 2818 2819 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) { 2820 // Parse the condition attribute set. 2821 IntegerSetAttr conditionAttr; 2822 unsigned numDims; 2823 if (parser.parseAttribute(conditionAttr, 2824 AffineIfOp::getConditionAttrStrName(), 2825 result.attributes) || 2826 parseDimAndSymbolList(parser, result.operands, numDims)) 2827 return failure(); 2828 2829 // Verify the condition operands. 2830 auto set = conditionAttr.getValue(); 2831 if (set.getNumDims() != numDims) 2832 return parser.emitError( 2833 parser.getNameLoc(), 2834 "dim operand count and integer set dim count must match"); 2835 if (numDims + set.getNumSymbols() != result.operands.size()) 2836 return parser.emitError( 2837 parser.getNameLoc(), 2838 "symbol operand count and integer set symbol count must match"); 2839 2840 if (parser.parseOptionalArrowTypeList(result.types)) 2841 return failure(); 2842 2843 // Create the regions for 'then' and 'else'. The latter must be created even 2844 // if it remains empty for the validity of the operation. 2845 result.regions.reserve(2); 2846 Region *thenRegion = result.addRegion(); 2847 Region *elseRegion = result.addRegion(); 2848 2849 // Parse the 'then' region. 2850 if (parser.parseRegion(*thenRegion, {}, {})) 2851 return failure(); 2852 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(), 2853 result.location); 2854 2855 // If we find an 'else' keyword then parse the 'else' region. 2856 if (!parser.parseOptionalKeyword("else")) { 2857 if (parser.parseRegion(*elseRegion, {}, {})) 2858 return failure(); 2859 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(), 2860 result.location); 2861 } 2862 2863 // Parse the optional attribute list. 2864 if (parser.parseOptionalAttrDict(result.attributes)) 2865 return failure(); 2866 2867 return success(); 2868 } 2869 2870 void AffineIfOp::print(OpAsmPrinter &p) { 2871 auto conditionAttr = 2872 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName()); 2873 p << " " << conditionAttr; 2874 printDimAndSymbolList(operand_begin(), operand_end(), 2875 conditionAttr.getValue().getNumDims(), p); 2876 p.printOptionalArrowTypeList(getResultTypes()); 2877 p << ' '; 2878 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, 2879 /*printBlockTerminators=*/getNumResults()); 2880 2881 // Print the 'else' regions if it has any blocks. 2882 auto &elseRegion = this->getElseRegion(); 2883 if (!elseRegion.empty()) { 2884 p << " else "; 2885 p.printRegion(elseRegion, 2886 /*printEntryBlockArgs=*/false, 2887 /*printBlockTerminators=*/getNumResults()); 2888 } 2889 2890 // Print the attribute list. 2891 p.printOptionalAttrDict((*this)->getAttrs(), 2892 /*elidedAttrs=*/getConditionAttrStrName()); 2893 } 2894 2895 IntegerSet AffineIfOp::getIntegerSet() { 2896 return (*this) 2897 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName()) 2898 .getValue(); 2899 } 2900 2901 void AffineIfOp::setIntegerSet(IntegerSet newSet) { 2902 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet)); 2903 } 2904 2905 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) { 2906 setIntegerSet(set); 2907 (*this)->setOperands(operands); 2908 } 2909 2910 void AffineIfOp::build(OpBuilder &builder, OperationState &result, 2911 TypeRange resultTypes, IntegerSet set, ValueRange args, 2912 bool withElseRegion) { 2913 assert(resultTypes.empty() || withElseRegion); 2914 OpBuilder::InsertionGuard guard(builder); 2915 2916 result.addTypes(resultTypes); 2917 result.addOperands(args); 2918 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set)); 2919 2920 Region *thenRegion = result.addRegion(); 2921 builder.createBlock(thenRegion); 2922 if (resultTypes.empty()) 2923 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); 2924 2925 Region *elseRegion = result.addRegion(); 2926 if (withElseRegion) { 2927 builder.createBlock(elseRegion); 2928 if (resultTypes.empty()) 2929 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); 2930 } 2931 } 2932 2933 void AffineIfOp::build(OpBuilder &builder, OperationState &result, 2934 IntegerSet set, ValueRange args, bool withElseRegion) { 2935 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args, 2936 withElseRegion); 2937 } 2938 2939 /// Compose any affine.apply ops feeding into `operands` of the integer set 2940 /// `set` by composing the maps of such affine.apply ops with the integer 2941 /// set constraints. 2942 static void composeSetAndOperands(IntegerSet &set, 2943 SmallVectorImpl<Value> &operands) { 2944 // We will simply reuse the API of the map composition by viewing the LHSs of 2945 // the equalities and inequalities of `set` as the affine exprs of an affine 2946 // map. Convert to equivalent map, compose, and convert back to set. 2947 auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(), 2948 set.getConstraints(), set.getContext()); 2949 // Check if any composition is possible. 2950 if (llvm::none_of(operands, 2951 [](Value v) { return v.getDefiningOp<AffineApplyOp>(); })) 2952 return; 2953 2954 composeAffineMapAndOperands(&map, &operands); 2955 set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(), 2956 set.getEqFlags()); 2957 } 2958 2959 /// Canonicalize an affine if op's conditional (integer set + operands). 2960 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { 2961 auto set = getIntegerSet(); 2962 SmallVector<Value, 4> operands(getOperands()); 2963 composeSetAndOperands(set, operands); 2964 canonicalizeSetAndOperands(&set, &operands); 2965 2966 // Check if the canonicalization or composition led to any change. 2967 if (getIntegerSet() == set && llvm::equal(operands, getOperands())) 2968 return failure(); 2969 2970 setConditional(set, operands); 2971 return success(); 2972 } 2973 2974 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results, 2975 MLIRContext *context) { 2976 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context); 2977 } 2978 2979 //===----------------------------------------------------------------------===// 2980 // AffineLoadOp 2981 //===----------------------------------------------------------------------===// 2982 2983 void AffineLoadOp::build(OpBuilder &builder, OperationState &result, 2984 AffineMap map, ValueRange operands) { 2985 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); 2986 result.addOperands(operands); 2987 if (map) 2988 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); 2989 auto memrefType = llvm::cast<MemRefType>(operands[0].getType()); 2990 result.types.push_back(memrefType.getElementType()); 2991 } 2992 2993 void AffineLoadOp::build(OpBuilder &builder, OperationState &result, 2994 Value memref, AffineMap map, ValueRange mapOperands) { 2995 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); 2996 result.addOperands(memref); 2997 result.addOperands(mapOperands); 2998 auto memrefType = llvm::cast<MemRefType>(memref.getType()); 2999 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); 3000 result.types.push_back(memrefType.getElementType()); 3001 } 3002 3003 void AffineLoadOp::build(OpBuilder &builder, OperationState &result, 3004 Value memref, ValueRange indices) { 3005 auto memrefType = llvm::cast<MemRefType>(memref.getType()); 3006 int64_t rank = memrefType.getRank(); 3007 // Create identity map for memrefs with at least one dimension or () -> () 3008 // for zero-dimensional memrefs. 3009 auto map = 3010 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); 3011 build(builder, result, memref, map, indices); 3012 } 3013 3014 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) { 3015 auto &builder = parser.getBuilder(); 3016 auto indexTy = builder.getIndexType(); 3017 3018 MemRefType type; 3019 OpAsmParser::UnresolvedOperand memrefInfo; 3020 AffineMapAttr mapAttr; 3021 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; 3022 return failure( 3023 parser.parseOperand(memrefInfo) || 3024 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, 3025 AffineLoadOp::getMapAttrStrName(), 3026 result.attributes) || 3027 parser.parseOptionalAttrDict(result.attributes) || 3028 parser.parseColonType(type) || 3029 parser.resolveOperand(memrefInfo, type, result.operands) || 3030 parser.resolveOperands(mapOperands, indexTy, result.operands) || 3031 parser.addTypeToList(type.getElementType(), result.types)); 3032 } 3033 3034 void AffineLoadOp::print(OpAsmPrinter &p) { 3035 p << " " << getMemRef() << '['; 3036 if (AffineMapAttr mapAttr = 3037 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName())) 3038 p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); 3039 p << ']'; 3040 p.printOptionalAttrDict((*this)->getAttrs(), 3041 /*elidedAttrs=*/{getMapAttrStrName()}); 3042 p << " : " << getMemRefType(); 3043 } 3044 3045 /// Verify common indexing invariants of affine.load, affine.store, 3046 /// affine.vector_load and affine.vector_store. 3047 static LogicalResult 3048 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, 3049 Operation::operand_range mapOperands, 3050 MemRefType memrefType, unsigned numIndexOperands) { 3051 AffineMap map = mapAttr.getValue(); 3052 if (map.getNumResults() != memrefType.getRank()) 3053 return op->emitOpError("affine map num results must equal memref rank"); 3054 if (map.getNumInputs() != numIndexOperands) 3055 return op->emitOpError("expects as many subscripts as affine map inputs"); 3056 3057 Region *scope = getAffineScope(op); 3058 for (auto idx : mapOperands) { 3059 if (!idx.getType().isIndex()) 3060 return op->emitOpError("index to load must have 'index' type"); 3061 if (!isValidAffineIndexOperand(idx, scope)) 3062 return op->emitOpError( 3063 "index must be a valid dimension or symbol identifier"); 3064 } 3065 3066 return success(); 3067 } 3068 3069 LogicalResult AffineLoadOp::verify() { 3070 auto memrefType = getMemRefType(); 3071 if (getType() != memrefType.getElementType()) 3072 return emitOpError("result type must match element type of memref"); 3073 3074 if (failed(verifyMemoryOpIndexing( 3075 getOperation(), 3076 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()), 3077 getMapOperands(), memrefType, 3078 /*numIndexOperands=*/getNumOperands() - 1))) 3079 return failure(); 3080 3081 return success(); 3082 } 3083 3084 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3085 MLIRContext *context) { 3086 results.add<SimplifyAffineOp<AffineLoadOp>>(context); 3087 } 3088 3089 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) { 3090 /// load(memrefcast) -> load 3091 if (succeeded(memref::foldMemRefCast(*this))) 3092 return getResult(); 3093 3094 // Fold load from a global constant memref. 3095 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>(); 3096 if (!getGlobalOp) 3097 return {}; 3098 // Get to the memref.global defining the symbol. 3099 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>(); 3100 if (!symbolTableOp) 3101 return {}; 3102 auto global = dyn_cast_or_null<memref::GlobalOp>( 3103 SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr())); 3104 if (!global) 3105 return {}; 3106 3107 // Check if the global memref is a constant. 3108 auto cstAttr = 3109 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue()); 3110 if (!cstAttr) 3111 return {}; 3112 // If it's a splat constant, we can fold irrespective of indices. 3113 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr)) 3114 return splatAttr.getSplatValue<Attribute>(); 3115 // Otherwise, we can fold only if we know the indices. 3116 if (!getAffineMap().isConstant()) 3117 return {}; 3118 auto indices = llvm::to_vector<4>( 3119 llvm::map_range(getAffineMap().getConstantResults(), 3120 [](int64_t v) -> uint64_t { return v; })); 3121 return cstAttr.getValues<Attribute>()[indices]; 3122 } 3123 3124 //===----------------------------------------------------------------------===// 3125 // AffineStoreOp 3126 //===----------------------------------------------------------------------===// 3127 3128 void AffineStoreOp::build(OpBuilder &builder, OperationState &result, 3129 Value valueToStore, Value memref, AffineMap map, 3130 ValueRange mapOperands) { 3131 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); 3132 result.addOperands(valueToStore); 3133 result.addOperands(memref); 3134 result.addOperands(mapOperands); 3135 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map); 3136 } 3137 3138 // Use identity map. 3139 void AffineStoreOp::build(OpBuilder &builder, OperationState &result, 3140 Value valueToStore, Value memref, 3141 ValueRange indices) { 3142 auto memrefType = llvm::cast<MemRefType>(memref.getType()); 3143 int64_t rank = memrefType.getRank(); 3144 // Create identity map for memrefs with at least one dimension or () -> () 3145 // for zero-dimensional memrefs. 3146 auto map = 3147 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); 3148 build(builder, result, valueToStore, memref, map, indices); 3149 } 3150 3151 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) { 3152 auto indexTy = parser.getBuilder().getIndexType(); 3153 3154 MemRefType type; 3155 OpAsmParser::UnresolvedOperand storeValueInfo; 3156 OpAsmParser::UnresolvedOperand memrefInfo; 3157 AffineMapAttr mapAttr; 3158 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; 3159 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() || 3160 parser.parseOperand(memrefInfo) || 3161 parser.parseAffineMapOfSSAIds( 3162 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(), 3163 result.attributes) || 3164 parser.parseOptionalAttrDict(result.attributes) || 3165 parser.parseColonType(type) || 3166 parser.resolveOperand(storeValueInfo, type.getElementType(), 3167 result.operands) || 3168 parser.resolveOperand(memrefInfo, type, result.operands) || 3169 parser.resolveOperands(mapOperands, indexTy, result.operands)); 3170 } 3171 3172 void AffineStoreOp::print(OpAsmPrinter &p) { 3173 p << " " << getValueToStore(); 3174 p << ", " << getMemRef() << '['; 3175 if (AffineMapAttr mapAttr = 3176 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName())) 3177 p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); 3178 p << ']'; 3179 p.printOptionalAttrDict((*this)->getAttrs(), 3180 /*elidedAttrs=*/{getMapAttrStrName()}); 3181 p << " : " << getMemRefType(); 3182 } 3183 3184 LogicalResult AffineStoreOp::verify() { 3185 // The value to store must have the same type as memref element type. 3186 auto memrefType = getMemRefType(); 3187 if (getValueToStore().getType() != memrefType.getElementType()) 3188 return emitOpError( 3189 "value to store must have the same type as memref element type"); 3190 3191 if (failed(verifyMemoryOpIndexing( 3192 getOperation(), 3193 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()), 3194 getMapOperands(), memrefType, 3195 /*numIndexOperands=*/getNumOperands() - 2))) 3196 return failure(); 3197 3198 return success(); 3199 } 3200 3201 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3202 MLIRContext *context) { 3203 results.add<SimplifyAffineOp<AffineStoreOp>>(context); 3204 } 3205 3206 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor, 3207 SmallVectorImpl<OpFoldResult> &results) { 3208 /// store(memrefcast) -> store 3209 return memref::foldMemRefCast(*this, getValueToStore()); 3210 } 3211 3212 //===----------------------------------------------------------------------===// 3213 // AffineMinMaxOpBase 3214 //===----------------------------------------------------------------------===// 3215 3216 template <typename T> 3217 static LogicalResult verifyAffineMinMaxOp(T op) { 3218 // Verify that operand count matches affine map dimension and symbol count. 3219 if (op.getNumOperands() != 3220 op.getMap().getNumDims() + op.getMap().getNumSymbols()) 3221 return op.emitOpError( 3222 "operand count and affine map dimension and symbol count must match"); 3223 3224 if (op.getMap().getNumResults() == 0) 3225 return op.emitOpError("affine map expect at least one result"); 3226 return success(); 3227 } 3228 3229 template <typename T> 3230 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { 3231 p << ' ' << op->getAttr(T::getMapAttrStrName()); 3232 auto operands = op.getOperands(); 3233 unsigned numDims = op.getMap().getNumDims(); 3234 p << '(' << operands.take_front(numDims) << ')'; 3235 3236 if (operands.size() != numDims) 3237 p << '[' << operands.drop_front(numDims) << ']'; 3238 p.printOptionalAttrDict(op->getAttrs(), 3239 /*elidedAttrs=*/{T::getMapAttrStrName()}); 3240 } 3241 3242 template <typename T> 3243 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, 3244 OperationState &result) { 3245 auto &builder = parser.getBuilder(); 3246 auto indexType = builder.getIndexType(); 3247 SmallVector<OpAsmParser::UnresolvedOperand, 8> dimInfos; 3248 SmallVector<OpAsmParser::UnresolvedOperand, 8> symInfos; 3249 AffineMapAttr mapAttr; 3250 return failure( 3251 parser.parseAttribute(mapAttr, T::getMapAttrStrName(), 3252 result.attributes) || 3253 parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) || 3254 parser.parseOperandList(symInfos, 3255 OpAsmParser::Delimiter::OptionalSquare) || 3256 parser.parseOptionalAttrDict(result.attributes) || 3257 parser.resolveOperands(dimInfos, indexType, result.operands) || 3258 parser.resolveOperands(symInfos, indexType, result.operands) || 3259 parser.addTypeToList(indexType, result.types)); 3260 } 3261 3262 /// Fold an affine min or max operation with the given operands. The operand 3263 /// list may contain nulls, which are interpreted as the operand not being a 3264 /// constant. 3265 template <typename T> 3266 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) { 3267 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value, 3268 "expected affine min or max op"); 3269 3270 // Fold the affine map. 3271 // TODO: Fold more cases: 3272 // min(some_affine, some_affine + constant, ...), etc. 3273 SmallVector<int64_t, 2> results; 3274 auto foldedMap = op.getMap().partialConstantFold(operands, &results); 3275 3276 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity()) 3277 return op.getOperand(0); 3278 3279 // If some of the map results are not constant, try changing the map in-place. 3280 if (results.empty()) { 3281 // If the map is the same, report that folding did not happen. 3282 if (foldedMap == op.getMap()) 3283 return {}; 3284 op->setAttr("map", AffineMapAttr::get(foldedMap)); 3285 return op.getResult(); 3286 } 3287 3288 // Otherwise, completely fold the op into a constant. 3289 auto resultIt = std::is_same<T, AffineMinOp>::value 3290 ? llvm::min_element(results) 3291 : llvm::max_element(results); 3292 if (resultIt == results.end()) 3293 return {}; 3294 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt); 3295 } 3296 3297 /// Remove duplicated expressions in affine min/max ops. 3298 template <typename T> 3299 struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> { 3300 using OpRewritePattern<T>::OpRewritePattern; 3301 3302 LogicalResult matchAndRewrite(T affineOp, 3303 PatternRewriter &rewriter) const override { 3304 AffineMap oldMap = affineOp.getAffineMap(); 3305 3306 SmallVector<AffineExpr, 4> newExprs; 3307 for (AffineExpr expr : oldMap.getResults()) { 3308 // This is a linear scan over newExprs, but it should be fine given that 3309 // we typically just have a few expressions per op. 3310 if (!llvm::is_contained(newExprs, expr)) 3311 newExprs.push_back(expr); 3312 } 3313 3314 if (newExprs.size() == oldMap.getNumResults()) 3315 return failure(); 3316 3317 auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), 3318 newExprs, rewriter.getContext()); 3319 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands()); 3320 3321 return success(); 3322 } 3323 }; 3324 3325 /// Merge an affine min/max op to its consumers if its consumer is also an 3326 /// affine min/max op. 3327 /// 3328 /// This pattern requires the producer affine min/max op is bound to a 3329 /// dimension/symbol that is used as a standalone expression in the consumer 3330 /// affine op's map. 3331 /// 3332 /// For example, a pattern like the following: 3333 /// 3334 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1] 3335 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2] 3336 /// 3337 /// Can be turned into: 3338 /// 3339 /// %1 = affine.min affine_map< 3340 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1] 3341 template <typename T> 3342 struct MergeAffineMinMaxOp : public OpRewritePattern<T> { 3343 using OpRewritePattern<T>::OpRewritePattern; 3344 3345 LogicalResult matchAndRewrite(T affineOp, 3346 PatternRewriter &rewriter) const override { 3347 AffineMap oldMap = affineOp.getAffineMap(); 3348 ValueRange dimOperands = 3349 affineOp.getMapOperands().take_front(oldMap.getNumDims()); 3350 ValueRange symOperands = 3351 affineOp.getMapOperands().take_back(oldMap.getNumSymbols()); 3352 3353 auto newDimOperands = llvm::to_vector<8>(dimOperands); 3354 auto newSymOperands = llvm::to_vector<8>(symOperands); 3355 SmallVector<AffineExpr, 4> newExprs; 3356 SmallVector<T, 4> producerOps; 3357 3358 // Go over each expression to see whether it's a single dimension/symbol 3359 // with the corresponding operand which is the result of another affine 3360 // min/max op. If So it can be merged into this affine op. 3361 for (AffineExpr expr : oldMap.getResults()) { 3362 if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) { 3363 Value symValue = symOperands[symExpr.getPosition()]; 3364 if (auto producerOp = symValue.getDefiningOp<T>()) { 3365 producerOps.push_back(producerOp); 3366 continue; 3367 } 3368 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 3369 Value dimValue = dimOperands[dimExpr.getPosition()]; 3370 if (auto producerOp = dimValue.getDefiningOp<T>()) { 3371 producerOps.push_back(producerOp); 3372 continue; 3373 } 3374 } 3375 // For the above cases we will remove the expression by merging the 3376 // producer affine min/max's affine expressions. Otherwise we need to 3377 // keep the existing expression. 3378 newExprs.push_back(expr); 3379 } 3380 3381 if (producerOps.empty()) 3382 return failure(); 3383 3384 unsigned numUsedDims = oldMap.getNumDims(); 3385 unsigned numUsedSyms = oldMap.getNumSymbols(); 3386 3387 // Now go over all producer affine ops and merge their expressions. 3388 for (T producerOp : producerOps) { 3389 AffineMap producerMap = producerOp.getAffineMap(); 3390 unsigned numProducerDims = producerMap.getNumDims(); 3391 unsigned numProducerSyms = producerMap.getNumSymbols(); 3392 3393 // Collect all dimension/symbol values. 3394 ValueRange dimValues = 3395 producerOp.getMapOperands().take_front(numProducerDims); 3396 ValueRange symValues = 3397 producerOp.getMapOperands().take_back(numProducerSyms); 3398 newDimOperands.append(dimValues.begin(), dimValues.end()); 3399 newSymOperands.append(symValues.begin(), symValues.end()); 3400 3401 // For expressions we need to shift to avoid overlap. 3402 for (AffineExpr expr : producerMap.getResults()) { 3403 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims) 3404 .shiftSymbols(numProducerSyms, numUsedSyms)); 3405 } 3406 3407 numUsedDims += numProducerDims; 3408 numUsedSyms += numProducerSyms; 3409 } 3410 3411 auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs, 3412 rewriter.getContext()); 3413 auto newOperands = 3414 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands)); 3415 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands); 3416 3417 return success(); 3418 } 3419 }; 3420 3421 /// Canonicalize the result expression order of an affine map and return success 3422 /// if the order changed. 3423 /// 3424 /// The function flattens the map's affine expressions to coefficient arrays and 3425 /// sorts them in lexicographic order. A coefficient array contains a multiplier 3426 /// for every dimension/symbol and a constant term. The canonicalization fails 3427 /// if a result expression is not pure or if the flattening requires local 3428 /// variables that, unlike dimensions and symbols, have no global order. 3429 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) { 3430 SmallVector<SmallVector<int64_t>> flattenedExprs; 3431 for (const AffineExpr &resultExpr : map.getResults()) { 3432 // Fail if the expression is not pure. 3433 if (!resultExpr.isPureAffine()) 3434 return failure(); 3435 3436 SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols()); 3437 auto flattenResult = flattener.walkPostOrder(resultExpr); 3438 if (failed(flattenResult)) 3439 return failure(); 3440 3441 // Fail if the flattened expression has local variables. 3442 if (flattener.operandExprStack.back().size() != 3443 map.getNumDims() + map.getNumSymbols() + 1) 3444 return failure(); 3445 3446 flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(), 3447 flattener.operandExprStack.back().end()); 3448 } 3449 3450 // Fail if sorting is not necessary. 3451 if (llvm::is_sorted(flattenedExprs)) 3452 return failure(); 3453 3454 // Reorder the result expressions according to their flattened form. 3455 SmallVector<unsigned> resultPermutation = 3456 llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults())); 3457 llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) { 3458 return flattenedExprs[lhs] < flattenedExprs[rhs]; 3459 }); 3460 SmallVector<AffineExpr> newExprs; 3461 for (unsigned idx : resultPermutation) 3462 newExprs.push_back(map.getResult(idx)); 3463 3464 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs, 3465 map.getContext()); 3466 return success(); 3467 } 3468 3469 /// Canonicalize the affine map result expression order of an affine min/max 3470 /// operation. 3471 /// 3472 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result 3473 /// expressions and replaces the operation if the order changed. 3474 /// 3475 /// For example, the following operation: 3476 /// 3477 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1) 3478 /// 3479 /// Turns into: 3480 /// 3481 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1) 3482 template <typename T> 3483 struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> { 3484 using OpRewritePattern<T>::OpRewritePattern; 3485 3486 LogicalResult matchAndRewrite(T affineOp, 3487 PatternRewriter &rewriter) const override { 3488 AffineMap map = affineOp.getAffineMap(); 3489 if (failed(canonicalizeMapExprAndTermOrder(map))) 3490 return failure(); 3491 rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands()); 3492 return success(); 3493 } 3494 }; 3495 3496 template <typename T> 3497 struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> { 3498 using OpRewritePattern<T>::OpRewritePattern; 3499 3500 LogicalResult matchAndRewrite(T affineOp, 3501 PatternRewriter &rewriter) const override { 3502 if (affineOp.getMap().getNumResults() != 1) 3503 return failure(); 3504 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(), 3505 affineOp.getOperands()); 3506 return success(); 3507 } 3508 }; 3509 3510 //===----------------------------------------------------------------------===// 3511 // AffineMinOp 3512 //===----------------------------------------------------------------------===// 3513 // 3514 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) 3515 // 3516 3517 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) { 3518 return foldMinMaxOp(*this, adaptor.getOperands()); 3519 } 3520 3521 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 3522 MLIRContext *context) { 3523 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>, 3524 DeduplicateAffineMinMaxExpressions<AffineMinOp>, 3525 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>, 3526 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>( 3527 context); 3528 } 3529 3530 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); } 3531 3532 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) { 3533 return parseAffineMinMaxOp<AffineMinOp>(parser, result); 3534 } 3535 3536 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); } 3537 3538 //===----------------------------------------------------------------------===// 3539 // AffineMaxOp 3540 //===----------------------------------------------------------------------===// 3541 // 3542 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) 3543 // 3544 3545 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) { 3546 return foldMinMaxOp(*this, adaptor.getOperands()); 3547 } 3548 3549 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 3550 MLIRContext *context) { 3551 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>, 3552 DeduplicateAffineMinMaxExpressions<AffineMaxOp>, 3553 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>, 3554 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>( 3555 context); 3556 } 3557 3558 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); } 3559 3560 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) { 3561 return parseAffineMinMaxOp<AffineMaxOp>(parser, result); 3562 } 3563 3564 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); } 3565 3566 //===----------------------------------------------------------------------===// 3567 // AffinePrefetchOp 3568 //===----------------------------------------------------------------------===// 3569 3570 // 3571 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> 3572 // 3573 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser, 3574 OperationState &result) { 3575 auto &builder = parser.getBuilder(); 3576 auto indexTy = builder.getIndexType(); 3577 3578 MemRefType type; 3579 OpAsmParser::UnresolvedOperand memrefInfo; 3580 IntegerAttr hintInfo; 3581 auto i32Type = parser.getBuilder().getIntegerType(32); 3582 StringRef readOrWrite, cacheType; 3583 3584 AffineMapAttr mapAttr; 3585 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; 3586 if (parser.parseOperand(memrefInfo) || 3587 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, 3588 AffinePrefetchOp::getMapAttrStrName(), 3589 result.attributes) || 3590 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 3591 parser.parseComma() || parser.parseKeyword("locality") || 3592 parser.parseLess() || 3593 parser.parseAttribute(hintInfo, i32Type, 3594 AffinePrefetchOp::getLocalityHintAttrStrName(), 3595 result.attributes) || 3596 parser.parseGreater() || parser.parseComma() || 3597 parser.parseKeyword(&cacheType) || 3598 parser.parseOptionalAttrDict(result.attributes) || 3599 parser.parseColonType(type) || 3600 parser.resolveOperand(memrefInfo, type, result.operands) || 3601 parser.resolveOperands(mapOperands, indexTy, result.operands)) 3602 return failure(); 3603 3604 if (readOrWrite != "read" && readOrWrite != "write") 3605 return parser.emitError(parser.getNameLoc(), 3606 "rw specifier has to be 'read' or 'write'"); 3607 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(), 3608 parser.getBuilder().getBoolAttr(readOrWrite == "write")); 3609 3610 if (cacheType != "data" && cacheType != "instr") 3611 return parser.emitError(parser.getNameLoc(), 3612 "cache type has to be 'data' or 'instr'"); 3613 3614 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(), 3615 parser.getBuilder().getBoolAttr(cacheType == "data")); 3616 3617 return success(); 3618 } 3619 3620 void AffinePrefetchOp::print(OpAsmPrinter &p) { 3621 p << " " << getMemref() << '['; 3622 AffineMapAttr mapAttr = 3623 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()); 3624 if (mapAttr) 3625 p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); 3626 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", " 3627 << "locality<" << getLocalityHint() << ">, " 3628 << (getIsDataCache() ? "data" : "instr"); 3629 p.printOptionalAttrDict( 3630 (*this)->getAttrs(), 3631 /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(), 3632 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()}); 3633 p << " : " << getMemRefType(); 3634 } 3635 3636 LogicalResult AffinePrefetchOp::verify() { 3637 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()); 3638 if (mapAttr) { 3639 AffineMap map = mapAttr.getValue(); 3640 if (map.getNumResults() != getMemRefType().getRank()) 3641 return emitOpError("affine.prefetch affine map num results must equal" 3642 " memref rank"); 3643 if (map.getNumInputs() + 1 != getNumOperands()) 3644 return emitOpError("too few operands"); 3645 } else { 3646 if (getNumOperands() != 1) 3647 return emitOpError("too few operands"); 3648 } 3649 3650 Region *scope = getAffineScope(*this); 3651 for (auto idx : getMapOperands()) { 3652 if (!isValidAffineIndexOperand(idx, scope)) 3653 return emitOpError( 3654 "index must be a valid dimension or symbol identifier"); 3655 } 3656 return success(); 3657 } 3658 3659 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results, 3660 MLIRContext *context) { 3661 // prefetch(memrefcast) -> prefetch 3662 results.add<SimplifyAffineOp<AffinePrefetchOp>>(context); 3663 } 3664 3665 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor, 3666 SmallVectorImpl<OpFoldResult> &results) { 3667 /// prefetch(memrefcast) -> prefetch 3668 return memref::foldMemRefCast(*this); 3669 } 3670 3671 //===----------------------------------------------------------------------===// 3672 // AffineParallelOp 3673 //===----------------------------------------------------------------------===// 3674 3675 void AffineParallelOp::build(OpBuilder &builder, OperationState &result, 3676 TypeRange resultTypes, 3677 ArrayRef<arith::AtomicRMWKind> reductions, 3678 ArrayRef<int64_t> ranges) { 3679 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0)); 3680 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) { 3681 return builder.getConstantAffineMap(value); 3682 })); 3683 SmallVector<int64_t> steps(ranges.size(), 1); 3684 build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs, 3685 /*ubArgs=*/{}, steps); 3686 } 3687 3688 void AffineParallelOp::build(OpBuilder &builder, OperationState &result, 3689 TypeRange resultTypes, 3690 ArrayRef<arith::AtomicRMWKind> reductions, 3691 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs, 3692 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs, 3693 ArrayRef<int64_t> steps) { 3694 assert(llvm::all_of(lbMaps, 3695 [lbMaps](AffineMap m) { 3696 return m.getNumDims() == lbMaps[0].getNumDims() && 3697 m.getNumSymbols() == lbMaps[0].getNumSymbols(); 3698 }) && 3699 "expected all lower bounds maps to have the same number of dimensions " 3700 "and symbols"); 3701 assert(llvm::all_of(ubMaps, 3702 [ubMaps](AffineMap m) { 3703 return m.getNumDims() == ubMaps[0].getNumDims() && 3704 m.getNumSymbols() == ubMaps[0].getNumSymbols(); 3705 }) && 3706 "expected all upper bounds maps to have the same number of dimensions " 3707 "and symbols"); 3708 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) && 3709 "expected lower bound maps to have as many inputs as lower bound " 3710 "operands"); 3711 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) && 3712 "expected upper bound maps to have as many inputs as upper bound " 3713 "operands"); 3714 3715 OpBuilder::InsertionGuard guard(builder); 3716 result.addTypes(resultTypes); 3717 3718 // Convert the reductions to integer attributes. 3719 SmallVector<Attribute, 4> reductionAttrs; 3720 for (arith::AtomicRMWKind reduction : reductions) 3721 reductionAttrs.push_back( 3722 builder.getI64IntegerAttr(static_cast<int64_t>(reduction))); 3723 result.addAttribute(getReductionsAttrStrName(), 3724 builder.getArrayAttr(reductionAttrs)); 3725 3726 // Concatenates maps defined in the same input space (same dimensions and 3727 // symbols), assumes there is at least one map. 3728 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps, 3729 SmallVectorImpl<int32_t> &groups) { 3730 if (maps.empty()) 3731 return AffineMap::get(builder.getContext()); 3732 SmallVector<AffineExpr> exprs; 3733 groups.reserve(groups.size() + maps.size()); 3734 exprs.reserve(maps.size()); 3735 for (AffineMap m : maps) { 3736 llvm::append_range(exprs, m.getResults()); 3737 groups.push_back(m.getNumResults()); 3738 } 3739 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs, 3740 maps[0].getContext()); 3741 }; 3742 3743 // Set up the bounds. 3744 SmallVector<int32_t> lbGroups, ubGroups; 3745 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups); 3746 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups); 3747 result.addAttribute(getLowerBoundsMapAttrStrName(), 3748 AffineMapAttr::get(lbMap)); 3749 result.addAttribute(getLowerBoundsGroupsAttrStrName(), 3750 builder.getI32TensorAttr(lbGroups)); 3751 result.addAttribute(getUpperBoundsMapAttrStrName(), 3752 AffineMapAttr::get(ubMap)); 3753 result.addAttribute(getUpperBoundsGroupsAttrStrName(), 3754 builder.getI32TensorAttr(ubGroups)); 3755 result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps)); 3756 result.addOperands(lbArgs); 3757 result.addOperands(ubArgs); 3758 3759 // Create a region and a block for the body. 3760 auto *bodyRegion = result.addRegion(); 3761 Block *body = builder.createBlock(bodyRegion); 3762 3763 // Add all the block arguments. 3764 for (unsigned i = 0, e = steps.size(); i < e; ++i) 3765 body->addArgument(IndexType::get(builder.getContext()), result.location); 3766 if (resultTypes.empty()) 3767 ensureTerminator(*bodyRegion, builder, result.location); 3768 } 3769 3770 SmallVector<Region *> AffineParallelOp::getLoopRegions() { 3771 return {&getRegion()}; 3772 } 3773 3774 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); } 3775 3776 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { 3777 return getOperands().take_front(getLowerBoundsMap().getNumInputs()); 3778 } 3779 3780 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() { 3781 return getOperands().drop_front(getLowerBoundsMap().getNumInputs()); 3782 } 3783 3784 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) { 3785 auto values = getLowerBoundsGroups().getValues<int32_t>(); 3786 unsigned start = 0; 3787 for (unsigned i = 0; i < pos; ++i) 3788 start += values[i]; 3789 return getLowerBoundsMap().getSliceMap(start, values[pos]); 3790 } 3791 3792 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) { 3793 auto values = getUpperBoundsGroups().getValues<int32_t>(); 3794 unsigned start = 0; 3795 for (unsigned i = 0; i < pos; ++i) 3796 start += values[i]; 3797 return getUpperBoundsMap().getSliceMap(start, values[pos]); 3798 } 3799 3800 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { 3801 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands()); 3802 } 3803 3804 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() { 3805 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands()); 3806 } 3807 3808 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() { 3809 if (hasMinMaxBounds()) 3810 return std::nullopt; 3811 3812 // Try to convert all the ranges to constant expressions. 3813 SmallVector<int64_t, 8> out; 3814 AffineValueMap rangesValueMap; 3815 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(), 3816 &rangesValueMap); 3817 out.reserve(rangesValueMap.getNumResults()); 3818 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) { 3819 auto expr = rangesValueMap.getResult(i); 3820 auto cst = dyn_cast<AffineConstantExpr>(expr); 3821 if (!cst) 3822 return std::nullopt; 3823 out.push_back(cst.getValue()); 3824 } 3825 return out; 3826 } 3827 3828 Block *AffineParallelOp::getBody() { return &getRegion().front(); } 3829 3830 OpBuilder AffineParallelOp::getBodyBuilder() { 3831 return OpBuilder(getBody(), std::prev(getBody()->end())); 3832 } 3833 3834 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) { 3835 assert(lbOperands.size() == map.getNumInputs() && 3836 "operands to map must match number of inputs"); 3837 3838 auto ubOperands = getUpperBoundsOperands(); 3839 3840 SmallVector<Value, 4> newOperands(lbOperands); 3841 newOperands.append(ubOperands.begin(), ubOperands.end()); 3842 (*this)->setOperands(newOperands); 3843 3844 setLowerBoundsMapAttr(AffineMapAttr::get(map)); 3845 } 3846 3847 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { 3848 assert(ubOperands.size() == map.getNumInputs() && 3849 "operands to map must match number of inputs"); 3850 3851 SmallVector<Value, 4> newOperands(getLowerBoundsOperands()); 3852 newOperands.append(ubOperands.begin(), ubOperands.end()); 3853 (*this)->setOperands(newOperands); 3854 3855 setUpperBoundsMapAttr(AffineMapAttr::get(map)); 3856 } 3857 3858 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) { 3859 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); 3860 } 3861 3862 // check whether resultType match op or not in affine.parallel 3863 static bool isResultTypeMatchAtomicRMWKind(Type resultType, 3864 arith::AtomicRMWKind op) { 3865 switch (op) { 3866 case arith::AtomicRMWKind::addf: 3867 return isa<FloatType>(resultType); 3868 case arith::AtomicRMWKind::addi: 3869 return isa<IntegerType>(resultType); 3870 case arith::AtomicRMWKind::assign: 3871 return true; 3872 case arith::AtomicRMWKind::mulf: 3873 return isa<FloatType>(resultType); 3874 case arith::AtomicRMWKind::muli: 3875 return isa<IntegerType>(resultType); 3876 case arith::AtomicRMWKind::maximumf: 3877 return isa<FloatType>(resultType); 3878 case arith::AtomicRMWKind::minimumf: 3879 return isa<FloatType>(resultType); 3880 case arith::AtomicRMWKind::maxs: { 3881 auto intType = llvm::dyn_cast<IntegerType>(resultType); 3882 return intType && intType.isSigned(); 3883 } 3884 case arith::AtomicRMWKind::mins: { 3885 auto intType = llvm::dyn_cast<IntegerType>(resultType); 3886 return intType && intType.isSigned(); 3887 } 3888 case arith::AtomicRMWKind::maxu: { 3889 auto intType = llvm::dyn_cast<IntegerType>(resultType); 3890 return intType && intType.isUnsigned(); 3891 } 3892 case arith::AtomicRMWKind::minu: { 3893 auto intType = llvm::dyn_cast<IntegerType>(resultType); 3894 return intType && intType.isUnsigned(); 3895 } 3896 case arith::AtomicRMWKind::ori: 3897 return isa<IntegerType>(resultType); 3898 case arith::AtomicRMWKind::andi: 3899 return isa<IntegerType>(resultType); 3900 default: 3901 return false; 3902 } 3903 } 3904 3905 LogicalResult AffineParallelOp::verify() { 3906 auto numDims = getNumDims(); 3907 if (getLowerBoundsGroups().getNumElements() != numDims || 3908 getUpperBoundsGroups().getNumElements() != numDims || 3909 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) { 3910 return emitOpError() << "the number of region arguments (" 3911 << getBody()->getNumArguments() 3912 << ") and the number of map groups for lower (" 3913 << getLowerBoundsGroups().getNumElements() 3914 << ") and upper bound (" 3915 << getUpperBoundsGroups().getNumElements() 3916 << "), and the number of steps (" << getSteps().size() 3917 << ") must all match"; 3918 } 3919 3920 unsigned expectedNumLBResults = 0; 3921 for (APInt v : getLowerBoundsGroups()) 3922 expectedNumLBResults += v.getZExtValue(); 3923 if (expectedNumLBResults != getLowerBoundsMap().getNumResults()) 3924 return emitOpError() << "expected lower bounds map to have " 3925 << expectedNumLBResults << " results"; 3926 unsigned expectedNumUBResults = 0; 3927 for (APInt v : getUpperBoundsGroups()) 3928 expectedNumUBResults += v.getZExtValue(); 3929 if (expectedNumUBResults != getUpperBoundsMap().getNumResults()) 3930 return emitOpError() << "expected upper bounds map to have " 3931 << expectedNumUBResults << " results"; 3932 3933 if (getReductions().size() != getNumResults()) 3934 return emitOpError("a reduction must be specified for each output"); 3935 3936 // Verify reduction ops are all valid and each result type matches reduction 3937 // ops 3938 for (auto it : llvm::enumerate((getReductions()))) { 3939 Attribute attr = it.value(); 3940 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 3941 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) 3942 return emitOpError("invalid reduction attribute"); 3943 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value(); 3944 if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind)) 3945 return emitOpError("result type cannot match reduction attribute"); 3946 } 3947 3948 // Verify that the bound operands are valid dimension/symbols. 3949 /// Lower bounds. 3950 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(), 3951 getLowerBoundsMap().getNumDims()))) 3952 return failure(); 3953 /// Upper bounds. 3954 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(), 3955 getUpperBoundsMap().getNumDims()))) 3956 return failure(); 3957 return success(); 3958 } 3959 3960 LogicalResult AffineValueMap::canonicalize() { 3961 SmallVector<Value, 4> newOperands{operands}; 3962 auto newMap = getAffineMap(); 3963 composeAffineMapAndOperands(&newMap, &newOperands); 3964 if (newMap == getAffineMap() && newOperands == operands) 3965 return failure(); 3966 reset(newMap, newOperands); 3967 return success(); 3968 } 3969 3970 /// Canonicalize the bounds of the given loop. 3971 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) { 3972 AffineValueMap lb = op.getLowerBoundsValueMap(); 3973 bool lbCanonicalized = succeeded(lb.canonicalize()); 3974 3975 AffineValueMap ub = op.getUpperBoundsValueMap(); 3976 bool ubCanonicalized = succeeded(ub.canonicalize()); 3977 3978 // Any canonicalization change always leads to updated map(s). 3979 if (!lbCanonicalized && !ubCanonicalized) 3980 return failure(); 3981 3982 if (lbCanonicalized) 3983 op.setLowerBounds(lb.getOperands(), lb.getAffineMap()); 3984 if (ubCanonicalized) 3985 op.setUpperBounds(ub.getOperands(), ub.getAffineMap()); 3986 3987 return success(); 3988 } 3989 3990 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor, 3991 SmallVectorImpl<OpFoldResult> &results) { 3992 return canonicalizeLoopBounds(*this); 3993 } 3994 3995 /// Prints a lower(upper) bound of an affine parallel loop with max(min) 3996 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group` 3997 /// identifies which of the those expressions form max/min groups. `operands` 3998 /// are the SSA values of dimensions and symbols and `keyword` is either "min" 3999 /// or "max". 4000 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, 4001 DenseIntElementsAttr group, ValueRange operands, 4002 StringRef keyword) { 4003 AffineMap map = mapAttr.getValue(); 4004 unsigned numDims = map.getNumDims(); 4005 ValueRange dimOperands = operands.take_front(numDims); 4006 ValueRange symOperands = operands.drop_front(numDims); 4007 unsigned start = 0; 4008 for (llvm::APInt groupSize : group) { 4009 if (start != 0) 4010 p << ", "; 4011 4012 unsigned size = groupSize.getZExtValue(); 4013 if (size == 1) { 4014 p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands); 4015 ++start; 4016 } else { 4017 p << keyword << '('; 4018 AffineMap submap = map.getSliceMap(start, size); 4019 p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands); 4020 p << ')'; 4021 start += size; 4022 } 4023 } 4024 } 4025 4026 void AffineParallelOp::print(OpAsmPrinter &p) { 4027 p << " (" << getBody()->getArguments() << ") = ("; 4028 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(), 4029 getLowerBoundsOperands(), "max"); 4030 p << ") to ("; 4031 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(), 4032 getUpperBoundsOperands(), "min"); 4033 p << ')'; 4034 SmallVector<int64_t, 8> steps = getSteps(); 4035 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); 4036 if (!elideSteps) { 4037 p << " step ("; 4038 llvm::interleaveComma(steps, p); 4039 p << ')'; 4040 } 4041 if (getNumResults()) { 4042 p << " reduce ("; 4043 llvm::interleaveComma(getReductions(), p, [&](auto &attr) { 4044 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind( 4045 llvm::cast<IntegerAttr>(attr).getInt()); 4046 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\""; 4047 }); 4048 p << ") -> (" << getResultTypes() << ")"; 4049 } 4050 4051 p << ' '; 4052 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 4053 /*printBlockTerminators=*/getNumResults()); 4054 p.printOptionalAttrDict( 4055 (*this)->getAttrs(), 4056 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(), 4057 AffineParallelOp::getLowerBoundsMapAttrStrName(), 4058 AffineParallelOp::getLowerBoundsGroupsAttrStrName(), 4059 AffineParallelOp::getUpperBoundsMapAttrStrName(), 4060 AffineParallelOp::getUpperBoundsGroupsAttrStrName(), 4061 AffineParallelOp::getStepsAttrStrName()}); 4062 } 4063 4064 /// Given a list of lists of parsed operands, populates `uniqueOperands` with 4065 /// unique operands. Also populates `replacements with affine expressions of 4066 /// `kind` that can be used to update affine maps previously accepting a 4067 /// `operands` to accept `uniqueOperands` instead. 4068 static ParseResult deduplicateAndResolveOperands( 4069 OpAsmParser &parser, 4070 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands, 4071 SmallVectorImpl<Value> &uniqueOperands, 4072 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) { 4073 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) && 4074 "expected operands to be dim or symbol expression"); 4075 4076 Type indexType = parser.getBuilder().getIndexType(); 4077 for (const auto &list : operands) { 4078 SmallVector<Value> valueOperands; 4079 if (parser.resolveOperands(list, indexType, valueOperands)) 4080 return failure(); 4081 for (Value operand : valueOperands) { 4082 unsigned pos = std::distance(uniqueOperands.begin(), 4083 llvm::find(uniqueOperands, operand)); 4084 if (pos == uniqueOperands.size()) 4085 uniqueOperands.push_back(operand); 4086 replacements.push_back( 4087 kind == AffineExprKind::DimId 4088 ? getAffineDimExpr(pos, parser.getContext()) 4089 : getAffineSymbolExpr(pos, parser.getContext())); 4090 } 4091 } 4092 return success(); 4093 } 4094 4095 namespace { 4096 enum class MinMaxKind { Min, Max }; 4097 } // namespace 4098 4099 /// Parses an affine map that can contain a min/max for groups of its results, 4100 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates 4101 /// `result` attributes with the map (flat list of expressions) and the grouping 4102 /// (list of integers that specify how many expressions to put into each 4103 /// min/max) attributes. Deduplicates repeated operands. 4104 /// 4105 /// parallel-bound ::= `(` parallel-group-list `)` 4106 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)? 4107 /// parallel-group ::= simple-group | min-max-group 4108 /// simple-group ::= expr-of-ssa-ids 4109 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)` 4110 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)? 4111 /// 4112 /// Examples: 4113 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6)) 4114 /// (%0, max(%1 - 2 * %2)) 4115 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, 4116 OperationState &result, 4117 MinMaxKind kind) { 4118 // Using `const` not `constexpr` below to workaround a MSVC optimizer bug, 4119 // see: https://reviews.llvm.org/D134227#3821753 4120 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map"; 4121 4122 StringRef mapName = kind == MinMaxKind::Min 4123 ? AffineParallelOp::getUpperBoundsMapAttrStrName() 4124 : AffineParallelOp::getLowerBoundsMapAttrStrName(); 4125 StringRef groupsName = 4126 kind == MinMaxKind::Min 4127 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName() 4128 : AffineParallelOp::getLowerBoundsGroupsAttrStrName(); 4129 4130 if (failed(parser.parseLParen())) 4131 return failure(); 4132 4133 if (succeeded(parser.parseOptionalRParen())) { 4134 result.addAttribute( 4135 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap())); 4136 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({})); 4137 return success(); 4138 } 4139 4140 SmallVector<AffineExpr> flatExprs; 4141 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands; 4142 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands; 4143 SmallVector<int32_t> numMapsPerGroup; 4144 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands; 4145 auto parseOperands = [&]() { 4146 if (succeeded(parser.parseOptionalKeyword( 4147 kind == MinMaxKind::Min ? "min" : "max"))) { 4148 mapOperands.clear(); 4149 AffineMapAttr map; 4150 if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName, 4151 result.attributes, 4152 OpAsmParser::Delimiter::Paren))) 4153 return failure(); 4154 result.attributes.erase(tmpAttrStrName); 4155 llvm::append_range(flatExprs, map.getValue().getResults()); 4156 auto operandsRef = llvm::ArrayRef(mapOperands); 4157 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims()); 4158 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef); 4159 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims()); 4160 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef); 4161 flatDimOperands.append(map.getValue().getNumResults(), dims); 4162 flatSymOperands.append(map.getValue().getNumResults(), syms); 4163 numMapsPerGroup.push_back(map.getValue().getNumResults()); 4164 } else { 4165 if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(), 4166 flatSymOperands.emplace_back(), 4167 flatExprs.emplace_back()))) 4168 return failure(); 4169 numMapsPerGroup.push_back(1); 4170 } 4171 return success(); 4172 }; 4173 if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen()) 4174 return failure(); 4175 4176 unsigned totalNumDims = 0; 4177 unsigned totalNumSyms = 0; 4178 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { 4179 unsigned numDims = flatDimOperands[i].size(); 4180 unsigned numSyms = flatSymOperands[i].size(); 4181 flatExprs[i] = flatExprs[i] 4182 .shiftDims(numDims, totalNumDims) 4183 .shiftSymbols(numSyms, totalNumSyms); 4184 totalNumDims += numDims; 4185 totalNumSyms += numSyms; 4186 } 4187 4188 // Deduplicate map operands. 4189 SmallVector<Value> dimOperands, symOperands; 4190 SmallVector<AffineExpr> dimRplacements, symRepacements; 4191 if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands, 4192 dimRplacements, AffineExprKind::DimId) || 4193 deduplicateAndResolveOperands(parser, flatSymOperands, symOperands, 4194 symRepacements, AffineExprKind::SymbolId)) 4195 return failure(); 4196 4197 result.operands.append(dimOperands.begin(), dimOperands.end()); 4198 result.operands.append(symOperands.begin(), symOperands.end()); 4199 4200 Builder &builder = parser.getBuilder(); 4201 auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs, 4202 parser.getContext()); 4203 flatMap = flatMap.replaceDimsAndSymbols( 4204 dimRplacements, symRepacements, dimOperands.size(), symOperands.size()); 4205 4206 result.addAttribute(mapName, AffineMapAttr::get(flatMap)); 4207 result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup)); 4208 return success(); 4209 } 4210 4211 // 4212 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound 4213 // `to` parallel-bound steps? region attr-dict? 4214 // steps ::= `steps` `(` integer-literals `)` 4215 // 4216 ParseResult AffineParallelOp::parse(OpAsmParser &parser, 4217 OperationState &result) { 4218 auto &builder = parser.getBuilder(); 4219 auto indexType = builder.getIndexType(); 4220 SmallVector<OpAsmParser::Argument, 4> ivs; 4221 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) || 4222 parser.parseEqual() || 4223 parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) || 4224 parser.parseKeyword("to") || 4225 parseAffineMapWithMinMax(parser, result, MinMaxKind::Min)) 4226 return failure(); 4227 4228 AffineMapAttr stepsMapAttr; 4229 NamedAttrList stepsAttrs; 4230 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands; 4231 if (failed(parser.parseOptionalKeyword("step"))) { 4232 SmallVector<int64_t, 4> steps(ivs.size(), 1); 4233 result.addAttribute(AffineParallelOp::getStepsAttrStrName(), 4234 builder.getI64ArrayAttr(steps)); 4235 } else { 4236 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr, 4237 AffineParallelOp::getStepsAttrStrName(), 4238 stepsAttrs, 4239 OpAsmParser::Delimiter::Paren)) 4240 return failure(); 4241 4242 // Convert steps from an AffineMap into an I64ArrayAttr. 4243 SmallVector<int64_t, 4> steps; 4244 auto stepsMap = stepsMapAttr.getValue(); 4245 for (const auto &result : stepsMap.getResults()) { 4246 auto constExpr = dyn_cast<AffineConstantExpr>(result); 4247 if (!constExpr) 4248 return parser.emitError(parser.getNameLoc(), 4249 "steps must be constant integers"); 4250 steps.push_back(constExpr.getValue()); 4251 } 4252 result.addAttribute(AffineParallelOp::getStepsAttrStrName(), 4253 builder.getI64ArrayAttr(steps)); 4254 } 4255 4256 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the 4257 // quoted strings are a member of the enum AtomicRMWKind. 4258 SmallVector<Attribute, 4> reductions; 4259 if (succeeded(parser.parseOptionalKeyword("reduce"))) { 4260 if (parser.parseLParen()) 4261 return failure(); 4262 auto parseAttributes = [&]() -> ParseResult { 4263 // Parse a single quoted string via the attribute parsing, and then 4264 // verify it is a member of the enum and convert to it's integer 4265 // representation. 4266 StringAttr attrVal; 4267 NamedAttrList attrStorage; 4268 auto loc = parser.getCurrentLocation(); 4269 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce", 4270 attrStorage)) 4271 return failure(); 4272 std::optional<arith::AtomicRMWKind> reduction = 4273 arith::symbolizeAtomicRMWKind(attrVal.getValue()); 4274 if (!reduction) 4275 return parser.emitError(loc, "invalid reduction value: ") << attrVal; 4276 reductions.push_back( 4277 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value()))); 4278 // While we keep getting commas, keep parsing. 4279 return success(); 4280 }; 4281 if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen()) 4282 return failure(); 4283 } 4284 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(), 4285 builder.getArrayAttr(reductions)); 4286 4287 // Parse return types of reductions (if any) 4288 if (parser.parseOptionalArrowTypeList(result.types)) 4289 return failure(); 4290 4291 // Now parse the body. 4292 Region *body = result.addRegion(); 4293 for (auto &iv : ivs) 4294 iv.type = indexType; 4295 if (parser.parseRegion(*body, ivs) || 4296 parser.parseOptionalAttrDict(result.attributes)) 4297 return failure(); 4298 4299 // Add a terminator if none was parsed. 4300 AffineParallelOp::ensureTerminator(*body, builder, result.location); 4301 return success(); 4302 } 4303 4304 //===----------------------------------------------------------------------===// 4305 // AffineYieldOp 4306 //===----------------------------------------------------------------------===// 4307 4308 LogicalResult AffineYieldOp::verify() { 4309 auto *parentOp = (*this)->getParentOp(); 4310 auto results = parentOp->getResults(); 4311 auto operands = getOperands(); 4312 4313 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp)) 4314 return emitOpError() << "only terminates affine.if/for/parallel regions"; 4315 if (parentOp->getNumResults() != getNumOperands()) 4316 return emitOpError() << "parent of yield must have same number of " 4317 "results as the yield operands"; 4318 for (auto it : llvm::zip(results, operands)) { 4319 if (std::get<0>(it).getType() != std::get<1>(it).getType()) 4320 return emitOpError() << "types mismatch between yield op and its parent"; 4321 } 4322 4323 return success(); 4324 } 4325 4326 //===----------------------------------------------------------------------===// 4327 // AffineVectorLoadOp 4328 //===----------------------------------------------------------------------===// 4329 4330 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, 4331 VectorType resultType, AffineMap map, 4332 ValueRange operands) { 4333 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); 4334 result.addOperands(operands); 4335 if (map) 4336 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); 4337 result.types.push_back(resultType); 4338 } 4339 4340 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, 4341 VectorType resultType, Value memref, 4342 AffineMap map, ValueRange mapOperands) { 4343 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); 4344 result.addOperands(memref); 4345 result.addOperands(mapOperands); 4346 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); 4347 result.types.push_back(resultType); 4348 } 4349 4350 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, 4351 VectorType resultType, Value memref, 4352 ValueRange indices) { 4353 auto memrefType = llvm::cast<MemRefType>(memref.getType()); 4354 int64_t rank = memrefType.getRank(); 4355 // Create identity map for memrefs with at least one dimension or () -> () 4356 // for zero-dimensional memrefs. 4357 auto map = 4358 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); 4359 build(builder, result, resultType, memref, map, indices); 4360 } 4361 4362 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 4363 MLIRContext *context) { 4364 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context); 4365 } 4366 4367 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser, 4368 OperationState &result) { 4369 auto &builder = parser.getBuilder(); 4370 auto indexTy = builder.getIndexType(); 4371 4372 MemRefType memrefType; 4373 VectorType resultType; 4374 OpAsmParser::UnresolvedOperand memrefInfo; 4375 AffineMapAttr mapAttr; 4376 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; 4377 return failure( 4378 parser.parseOperand(memrefInfo) || 4379 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, 4380 AffineVectorLoadOp::getMapAttrStrName(), 4381 result.attributes) || 4382 parser.parseOptionalAttrDict(result.attributes) || 4383 parser.parseColonType(memrefType) || parser.parseComma() || 4384 parser.parseType(resultType) || 4385 parser.resolveOperand(memrefInfo, memrefType, result.operands) || 4386 parser.resolveOperands(mapOperands, indexTy, result.operands) || 4387 parser.addTypeToList(resultType, result.types)); 4388 } 4389 4390 void AffineVectorLoadOp::print(OpAsmPrinter &p) { 4391 p << " " << getMemRef() << '['; 4392 if (AffineMapAttr mapAttr = 4393 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName())) 4394 p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); 4395 p << ']'; 4396 p.printOptionalAttrDict((*this)->getAttrs(), 4397 /*elidedAttrs=*/{getMapAttrStrName()}); 4398 p << " : " << getMemRefType() << ", " << getType(); 4399 } 4400 4401 /// Verify common invariants of affine.vector_load and affine.vector_store. 4402 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, 4403 VectorType vectorType) { 4404 // Check that memref and vector element types match. 4405 if (memrefType.getElementType() != vectorType.getElementType()) 4406 return op->emitOpError( 4407 "requires memref and vector types of the same elemental type"); 4408 return success(); 4409 } 4410 4411 LogicalResult AffineVectorLoadOp::verify() { 4412 MemRefType memrefType = getMemRefType(); 4413 if (failed(verifyMemoryOpIndexing( 4414 getOperation(), 4415 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()), 4416 getMapOperands(), memrefType, 4417 /*numIndexOperands=*/getNumOperands() - 1))) 4418 return failure(); 4419 4420 if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType()))) 4421 return failure(); 4422 4423 return success(); 4424 } 4425 4426 //===----------------------------------------------------------------------===// 4427 // AffineVectorStoreOp 4428 //===----------------------------------------------------------------------===// 4429 4430 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, 4431 Value valueToStore, Value memref, AffineMap map, 4432 ValueRange mapOperands) { 4433 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); 4434 result.addOperands(valueToStore); 4435 result.addOperands(memref); 4436 result.addOperands(mapOperands); 4437 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); 4438 } 4439 4440 // Use identity map. 4441 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, 4442 Value valueToStore, Value memref, 4443 ValueRange indices) { 4444 auto memrefType = llvm::cast<MemRefType>(memref.getType()); 4445 int64_t rank = memrefType.getRank(); 4446 // Create identity map for memrefs with at least one dimension or () -> () 4447 // for zero-dimensional memrefs. 4448 auto map = 4449 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); 4450 build(builder, result, valueToStore, memref, map, indices); 4451 } 4452 void AffineVectorStoreOp::getCanonicalizationPatterns( 4453 RewritePatternSet &results, MLIRContext *context) { 4454 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context); 4455 } 4456 4457 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser, 4458 OperationState &result) { 4459 auto indexTy = parser.getBuilder().getIndexType(); 4460 4461 MemRefType memrefType; 4462 VectorType resultType; 4463 OpAsmParser::UnresolvedOperand storeValueInfo; 4464 OpAsmParser::UnresolvedOperand memrefInfo; 4465 AffineMapAttr mapAttr; 4466 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; 4467 return failure( 4468 parser.parseOperand(storeValueInfo) || parser.parseComma() || 4469 parser.parseOperand(memrefInfo) || 4470 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, 4471 AffineVectorStoreOp::getMapAttrStrName(), 4472 result.attributes) || 4473 parser.parseOptionalAttrDict(result.attributes) || 4474 parser.parseColonType(memrefType) || parser.parseComma() || 4475 parser.parseType(resultType) || 4476 parser.resolveOperand(storeValueInfo, resultType, result.operands) || 4477 parser.resolveOperand(memrefInfo, memrefType, result.operands) || 4478 parser.resolveOperands(mapOperands, indexTy, result.operands)); 4479 } 4480 4481 void AffineVectorStoreOp::print(OpAsmPrinter &p) { 4482 p << " " << getValueToStore(); 4483 p << ", " << getMemRef() << '['; 4484 if (AffineMapAttr mapAttr = 4485 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName())) 4486 p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); 4487 p << ']'; 4488 p.printOptionalAttrDict((*this)->getAttrs(), 4489 /*elidedAttrs=*/{getMapAttrStrName()}); 4490 p << " : " << getMemRefType() << ", " << getValueToStore().getType(); 4491 } 4492 4493 LogicalResult AffineVectorStoreOp::verify() { 4494 MemRefType memrefType = getMemRefType(); 4495 if (failed(verifyMemoryOpIndexing( 4496 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()), 4497 getMapOperands(), memrefType, 4498 /*numIndexOperands=*/getNumOperands() - 2))) 4499 return failure(); 4500 4501 if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType()))) 4502 return failure(); 4503 4504 return success(); 4505 } 4506 4507 //===----------------------------------------------------------------------===// 4508 // DelinearizeIndexOp 4509 //===----------------------------------------------------------------------===// 4510 4511 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, 4512 OperationState &odsState, 4513 Value linearIndex, ValueRange dynamicBasis, 4514 ArrayRef<int64_t> staticBasis, 4515 bool hasOuterBound) { 4516 SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size() 4517 : staticBasis.size() + 1, 4518 linearIndex.getType()); 4519 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis, 4520 staticBasis); 4521 } 4522 4523 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, 4524 OperationState &odsState, 4525 Value linearIndex, ValueRange basis, 4526 bool hasOuterBound) { 4527 if (hasOuterBound && !basis.empty() && basis.front() == nullptr) { 4528 hasOuterBound = false; 4529 basis = basis.drop_front(); 4530 } 4531 SmallVector<Value> dynamicBasis; 4532 SmallVector<int64_t> staticBasis; 4533 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis, 4534 staticBasis); 4535 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis, 4536 hasOuterBound); 4537 } 4538 4539 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, 4540 OperationState &odsState, 4541 Value linearIndex, 4542 ArrayRef<OpFoldResult> basis, 4543 bool hasOuterBound) { 4544 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) { 4545 hasOuterBound = false; 4546 basis = basis.drop_front(); 4547 } 4548 SmallVector<Value> dynamicBasis; 4549 SmallVector<int64_t> staticBasis; 4550 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis); 4551 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis, 4552 hasOuterBound); 4553 } 4554 4555 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, 4556 OperationState &odsState, 4557 Value linearIndex, ArrayRef<int64_t> basis, 4558 bool hasOuterBound) { 4559 build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound); 4560 } 4561 4562 LogicalResult AffineDelinearizeIndexOp::verify() { 4563 ArrayRef<int64_t> staticBasis = getStaticBasis(); 4564 if (getNumResults() != staticBasis.size() && 4565 getNumResults() != staticBasis.size() + 1) 4566 return emitOpError("should return an index for each basis element and up " 4567 "to one extra index"); 4568 4569 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic); 4570 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size()) 4571 return emitOpError( 4572 "mismatch between dynamic and static basis (kDynamic marker but no " 4573 "corresponding dynamic basis entry) -- this can only happen due to an " 4574 "incorrect fold/rewrite"); 4575 4576 if (!llvm::all_of(staticBasis, [](int64_t v) { 4577 return v > 0 || ShapedType::isDynamic(v); 4578 })) 4579 return emitOpError("no basis element may be statically non-positive"); 4580 4581 return success(); 4582 } 4583 4584 /// Given mixed basis of affine.delinearize_index/linearize_index replace 4585 /// constant SSA values with the constant integer value and return the new 4586 /// static basis. In case no such candidate for replacement exists, this utility 4587 /// returns std::nullopt. 4588 static std::optional<SmallVector<int64_t>> 4589 foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis, 4590 MutableOperandRange mutableDynamicBasis, 4591 ArrayRef<Attribute> dynamicBasis) { 4592 uint64_t dynamicBasisIndex = 0; 4593 for (OpFoldResult basis : dynamicBasis) { 4594 if (basis) { 4595 mutableDynamicBasis.erase(dynamicBasisIndex); 4596 } else { 4597 ++dynamicBasisIndex; 4598 } 4599 } 4600 4601 // No constant SSA value exists. 4602 if (dynamicBasisIndex == dynamicBasis.size()) 4603 return std::nullopt; 4604 4605 SmallVector<int64_t> staticBasis; 4606 for (OpFoldResult basis : mixedBasis) { 4607 std::optional<int64_t> basisVal = getConstantIntValue(basis); 4608 if (!basisVal) 4609 staticBasis.push_back(ShapedType::kDynamic); 4610 else 4611 staticBasis.push_back(*basisVal); 4612 } 4613 4614 return staticBasis; 4615 } 4616 4617 LogicalResult 4618 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor, 4619 SmallVectorImpl<OpFoldResult> &result) { 4620 std::optional<SmallVector<int64_t>> maybeStaticBasis = 4621 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(), 4622 adaptor.getDynamicBasis()); 4623 if (maybeStaticBasis) { 4624 setStaticBasis(*maybeStaticBasis); 4625 return success(); 4626 } 4627 // If we won't be doing any division or modulo (no basis or the one basis 4628 // element is purely advisory), simply return the input value. 4629 if (getNumResults() == 1) { 4630 result.push_back(getLinearIndex()); 4631 return success(); 4632 } 4633 4634 if (adaptor.getLinearIndex() == nullptr) 4635 return failure(); 4636 4637 if (!adaptor.getDynamicBasis().empty()) 4638 return failure(); 4639 4640 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt(); 4641 Type attrType = getLinearIndex().getType(); 4642 4643 ArrayRef<int64_t> staticBasis = getStaticBasis(); 4644 if (hasOuterBound()) 4645 staticBasis = staticBasis.drop_front(); 4646 for (int64_t modulus : llvm::reverse(staticBasis)) { 4647 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus))); 4648 highPart = llvm::divideFloorSigned(highPart, modulus); 4649 } 4650 result.push_back(IntegerAttr::get(attrType, highPart)); 4651 std::reverse(result.begin(), result.end()); 4652 return success(); 4653 } 4654 4655 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() { 4656 OpBuilder builder(getContext()); 4657 if (hasOuterBound()) { 4658 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic) 4659 return getMixedValues(getStaticBasis().drop_front(), 4660 getDynamicBasis().drop_front(), builder); 4661 4662 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(), 4663 builder); 4664 } 4665 4666 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder); 4667 } 4668 4669 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() { 4670 SmallVector<OpFoldResult> ret = getMixedBasis(); 4671 if (!hasOuterBound()) 4672 ret.insert(ret.begin(), OpFoldResult()); 4673 return ret; 4674 } 4675 4676 namespace { 4677 4678 // Drops delinearization indices that correspond to unit-extent basis 4679 struct DropUnitExtentBasis 4680 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> { 4681 using OpRewritePattern::OpRewritePattern; 4682 4683 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp, 4684 PatternRewriter &rewriter) const override { 4685 SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr); 4686 std::optional<Value> zero = std::nullopt; 4687 Location loc = delinearizeOp->getLoc(); 4688 auto getZero = [&]() -> Value { 4689 if (!zero) 4690 zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 4691 return zero.value(); 4692 }; 4693 4694 // Replace all indices corresponding to unit-extent basis with 0. 4695 // Remaining basis can be used to get a new `affine.delinearize_index` op. 4696 SmallVector<OpFoldResult> newBasis; 4697 for (auto [index, basis] : 4698 llvm::enumerate(delinearizeOp.getPaddedBasis())) { 4699 std::optional<int64_t> basisVal = 4700 basis ? getConstantIntValue(basis) : std::nullopt; 4701 if (basisVal && *basisVal == 1) 4702 replacements[index] = getZero(); 4703 else 4704 newBasis.push_back(basis); 4705 } 4706 4707 if (newBasis.size() == delinearizeOp.getNumResults()) 4708 return rewriter.notifyMatchFailure(delinearizeOp, 4709 "no unit basis elements"); 4710 4711 if (!newBasis.empty()) { 4712 // Will drop the leading nullptr from `basis` if there was no outer bound. 4713 auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>( 4714 loc, delinearizeOp.getLinearIndex(), newBasis); 4715 int newIndex = 0; 4716 // Map back the new delinearized indices to the values they replace. 4717 for (auto &replacement : replacements) { 4718 if (replacement) 4719 continue; 4720 replacement = newDelinearizeOp->getResult(newIndex++); 4721 } 4722 } 4723 4724 rewriter.replaceOp(delinearizeOp, replacements); 4725 return success(); 4726 } 4727 }; 4728 4729 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index 4730 /// disjoint` and the two operations end with the same basis elements, 4731 /// cancel those parts of the operations out because they are inverses 4732 /// of each other. 4733 /// 4734 /// If the operations have the same basis, cancel them entirely. 4735 /// 4736 /// The `disjoint` flag is needed on the `affine.linearize_index` because 4737 /// otherwise, there is no guarantee that the inputs to the linearization are 4738 /// in-bounds the way the outputs of the delinearization would be. 4739 struct CancelDelinearizeOfLinearizeDisjointExactTail 4740 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> { 4741 using OpRewritePattern::OpRewritePattern; 4742 4743 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp, 4744 PatternRewriter &rewriter) const override { 4745 auto linearizeOp = delinearizeOp.getLinearIndex() 4746 .getDefiningOp<affine::AffineLinearizeIndexOp>(); 4747 if (!linearizeOp) 4748 return rewriter.notifyMatchFailure(delinearizeOp, 4749 "index doesn't come from linearize"); 4750 4751 if (!linearizeOp.getDisjoint()) 4752 return rewriter.notifyMatchFailure(linearizeOp, "not disjoint"); 4753 4754 ValueRange linearizeIns = linearizeOp.getMultiIndex(); 4755 // Note: we use the full basis so we don't lose outer bounds later. 4756 SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis(); 4757 SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis(); 4758 size_t numMatches = 0; 4759 for (auto [linSize, delinSize] : llvm::zip( 4760 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) { 4761 if (linSize != delinSize) 4762 break; 4763 ++numMatches; 4764 } 4765 4766 if (numMatches == 0) 4767 return rewriter.notifyMatchFailure( 4768 delinearizeOp, "final basis element doesn't match linearize"); 4769 4770 // The easy case: everything lines up and the basis match sup completely. 4771 if (numMatches == linearizeBasis.size() && 4772 numMatches == delinearizeBasis.size() && 4773 linearizeIns.size() == delinearizeOp.getNumResults()) { 4774 rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex()); 4775 return success(); 4776 } 4777 4778 Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>( 4779 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches), 4780 ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches), 4781 linearizeOp.getDisjoint()); 4782 auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>( 4783 delinearizeOp.getLoc(), newLinearize, 4784 ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches), 4785 delinearizeOp.hasOuterBound()); 4786 SmallVector<Value> mergedResults(newDelinearize.getResults()); 4787 mergedResults.append(linearizeIns.take_back(numMatches).begin(), 4788 linearizeIns.take_back(numMatches).end()); 4789 rewriter.replaceOp(delinearizeOp, mergedResults); 4790 return success(); 4791 } 4792 }; 4793 4794 /// If the input to a delinearization is a disjoint linearization, and the 4795 /// last k > 1 components of the delinearization basis multiply to the 4796 /// last component of the linearization basis, break the linearization and 4797 /// delinearization into two parts, peeling off the last input to linearization. 4798 /// 4799 /// For example: 4800 /// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index 4801 /// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ... 4802 /// becomes 4803 /// %0 = affine.linearize_index [%z, %y] by (3, 2) : index 4804 /// %1:2 = affine.delinearize_index %0 by (2, 3) : index 4805 /// %2:2 = affine.delinearize_index %x by (8, 4) : index 4806 /// where the original %1:4 is replaced by %1:2 ++ %2:2 4807 struct SplitDelinearizeSpanningLastLinearizeArg final 4808 : OpRewritePattern<affine::AffineDelinearizeIndexOp> { 4809 using OpRewritePattern::OpRewritePattern; 4810 4811 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp, 4812 PatternRewriter &rewriter) const override { 4813 auto linearizeOp = delinearizeOp.getLinearIndex() 4814 .getDefiningOp<affine::AffineLinearizeIndexOp>(); 4815 if (!linearizeOp) 4816 return rewriter.notifyMatchFailure(delinearizeOp, 4817 "index doesn't come from linearize"); 4818 4819 if (!linearizeOp.getDisjoint()) 4820 return rewriter.notifyMatchFailure(linearizeOp, 4821 "linearize isn't disjoint"); 4822 4823 int64_t target = linearizeOp.getStaticBasis().back(); 4824 if (ShapedType::isDynamic(target)) 4825 return rewriter.notifyMatchFailure( 4826 linearizeOp, "linearize ends with dynamic basis value"); 4827 4828 int64_t sizeToSplit = 1; 4829 size_t elemsToSplit = 0; 4830 ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis(); 4831 for (int64_t basisElem : llvm::reverse(basis)) { 4832 if (ShapedType::isDynamic(basisElem)) 4833 return rewriter.notifyMatchFailure( 4834 delinearizeOp, "dynamic basis element while scanning for split"); 4835 sizeToSplit *= basisElem; 4836 elemsToSplit += 1; 4837 4838 if (sizeToSplit > target) 4839 return rewriter.notifyMatchFailure(delinearizeOp, 4840 "overshot last argument size"); 4841 if (sizeToSplit == target) 4842 break; 4843 } 4844 4845 if (sizeToSplit < target) 4846 return rewriter.notifyMatchFailure( 4847 delinearizeOp, "product of known basis elements doesn't exceed last " 4848 "linearize argument"); 4849 4850 if (elemsToSplit < 2) 4851 return rewriter.notifyMatchFailure( 4852 delinearizeOp, 4853 "need at least two elements to form the basis product"); 4854 4855 Value linearizeWithoutBack = 4856 rewriter.create<affine::AffineLinearizeIndexOp>( 4857 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(), 4858 linearizeOp.getDynamicBasis(), 4859 linearizeOp.getStaticBasis().drop_back(), 4860 linearizeOp.getDisjoint()); 4861 auto delinearizeWithoutSplitPart = 4862 rewriter.create<affine::AffineDelinearizeIndexOp>( 4863 delinearizeOp.getLoc(), linearizeWithoutBack, 4864 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit), 4865 delinearizeOp.hasOuterBound()); 4866 auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>( 4867 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(), 4868 basis.take_back(elemsToSplit), /*hasOuterBound=*/true); 4869 SmallVector<Value> results = llvm::to_vector( 4870 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(), 4871 delinearizeBack.getResults())); 4872 rewriter.replaceOp(delinearizeOp, results); 4873 4874 return success(); 4875 } 4876 }; 4877 } // namespace 4878 4879 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns( 4880 RewritePatternSet &patterns, MLIRContext *context) { 4881 patterns 4882 .insert<CancelDelinearizeOfLinearizeDisjointExactTail, 4883 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>( 4884 context); 4885 } 4886 4887 //===----------------------------------------------------------------------===// 4888 // LinearizeIndexOp 4889 //===----------------------------------------------------------------------===// 4890 4891 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder, 4892 OperationState &odsState, 4893 ValueRange multiIndex, ValueRange basis, 4894 bool disjoint) { 4895 if (!basis.empty() && basis.front() == Value()) 4896 basis = basis.drop_front(); 4897 SmallVector<Value> dynamicBasis; 4898 SmallVector<int64_t> staticBasis; 4899 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis, 4900 staticBasis); 4901 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint); 4902 } 4903 4904 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder, 4905 OperationState &odsState, 4906 ValueRange multiIndex, 4907 ArrayRef<OpFoldResult> basis, 4908 bool disjoint) { 4909 if (!basis.empty() && basis.front() == OpFoldResult()) 4910 basis = basis.drop_front(); 4911 SmallVector<Value> dynamicBasis; 4912 SmallVector<int64_t> staticBasis; 4913 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis); 4914 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint); 4915 } 4916 4917 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder, 4918 OperationState &odsState, 4919 ValueRange multiIndex, 4920 ArrayRef<int64_t> basis, bool disjoint) { 4921 build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint); 4922 } 4923 4924 LogicalResult AffineLinearizeIndexOp::verify() { 4925 size_t numIndexes = getMultiIndex().size(); 4926 size_t numBasisElems = getStaticBasis().size(); 4927 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1) 4928 return emitOpError("should be passed a basis element for each index except " 4929 "possibly the first"); 4930 4931 auto dynamicMarkersCount = 4932 llvm::count_if(getStaticBasis(), ShapedType::isDynamic); 4933 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size()) 4934 return emitOpError( 4935 "mismatch between dynamic and static basis (kDynamic marker but no " 4936 "corresponding dynamic basis entry) -- this can only happen due to an " 4937 "incorrect fold/rewrite"); 4938 4939 return success(); 4940 } 4941 4942 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) { 4943 std::optional<SmallVector<int64_t>> maybeStaticBasis = 4944 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(), 4945 adaptor.getDynamicBasis()); 4946 if (maybeStaticBasis) { 4947 setStaticBasis(*maybeStaticBasis); 4948 return getResult(); 4949 } 4950 // No indices linearizes to zero. 4951 if (getMultiIndex().empty()) 4952 return IntegerAttr::get(getResult().getType(), 0); 4953 4954 // One single index linearizes to itself. 4955 if (getMultiIndex().size() == 1) 4956 return getMultiIndex().front(); 4957 4958 if (llvm::any_of(adaptor.getMultiIndex(), 4959 [](Attribute a) { return a == nullptr; })) 4960 return nullptr; 4961 4962 if (!adaptor.getDynamicBasis().empty()) 4963 return nullptr; 4964 4965 int64_t result = 0; 4966 int64_t stride = 1; 4967 for (auto [length, indexAttr] : 4968 llvm::zip_first(llvm::reverse(getStaticBasis()), 4969 llvm::reverse(adaptor.getMultiIndex()))) { 4970 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride; 4971 stride = stride * length; 4972 } 4973 // Handle the index element with no basis element. 4974 if (!hasOuterBound()) 4975 result = 4976 result + 4977 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride; 4978 4979 return IntegerAttr::get(getResult().getType(), result); 4980 } 4981 4982 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() { 4983 OpBuilder builder(getContext()); 4984 if (hasOuterBound()) { 4985 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic) 4986 return getMixedValues(getStaticBasis().drop_front(), 4987 getDynamicBasis().drop_front(), builder); 4988 4989 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(), 4990 builder); 4991 } 4992 4993 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder); 4994 } 4995 4996 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() { 4997 SmallVector<OpFoldResult> ret = getMixedBasis(); 4998 if (!hasOuterBound()) 4999 ret.insert(ret.begin(), OpFoldResult()); 5000 return ret; 5001 } 5002 5003 namespace { 5004 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1, 5005 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c, 5006 /// %...d)`. 5007 5008 /// Note that `disjoint` is required here, because, without it, we could have 5009 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)` 5010 /// is a valid operation where the `%c64` cannot be trivially dropped. 5011 /// 5012 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if 5013 /// the operation isn't asserted to be `disjoint`. 5014 struct DropLinearizeUnitComponentsIfDisjointOrZero final 5015 : OpRewritePattern<affine::AffineLinearizeIndexOp> { 5016 using OpRewritePattern::OpRewritePattern; 5017 5018 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op, 5019 PatternRewriter &rewriter) const override { 5020 ValueRange multiIndex = op.getMultiIndex(); 5021 size_t numIndices = multiIndex.size(); 5022 SmallVector<Value> newIndices; 5023 newIndices.reserve(numIndices); 5024 SmallVector<OpFoldResult> newBasis; 5025 newBasis.reserve(numIndices); 5026 5027 if (!op.hasOuterBound()) { 5028 newIndices.push_back(multiIndex.front()); 5029 multiIndex = multiIndex.drop_front(); 5030 } 5031 5032 SmallVector<OpFoldResult> basis = op.getMixedBasis(); 5033 for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) { 5034 std::optional<int64_t> basisEntry = getConstantIntValue(basisElem); 5035 if (!basisEntry || *basisEntry != 1) { 5036 newIndices.push_back(index); 5037 newBasis.push_back(basisElem); 5038 continue; 5039 } 5040 5041 std::optional<int64_t> indexValue = getConstantIntValue(index); 5042 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) { 5043 newIndices.push_back(index); 5044 newBasis.push_back(basisElem); 5045 continue; 5046 } 5047 } 5048 if (newIndices.size() == numIndices) 5049 return rewriter.notifyMatchFailure(op, 5050 "no unit basis entries to replace"); 5051 5052 if (newIndices.size() == 0) { 5053 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); 5054 return success(); 5055 } 5056 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>( 5057 op, newIndices, newBasis, op.getDisjoint()); 5058 return success(); 5059 } 5060 }; 5061 5062 /// Return the product of `terms`, creating an `affine.apply` if any of them are 5063 /// non-constant values. If any of `terms` is `nullptr`, return `nullptr`. 5064 static OpFoldResult computeProduct(Location loc, OpBuilder &builder, 5065 ArrayRef<OpFoldResult> terms) { 5066 int64_t nDynamic = 0; 5067 SmallVector<Value> dynamicPart; 5068 AffineExpr result = builder.getAffineConstantExpr(1); 5069 for (OpFoldResult term : terms) { 5070 if (!term) 5071 return term; 5072 std::optional<int64_t> maybeConst = getConstantIntValue(term); 5073 if (maybeConst) { 5074 result = result * builder.getAffineConstantExpr(*maybeConst); 5075 } else { 5076 dynamicPart.push_back(cast<Value>(term)); 5077 result = result * builder.getAffineSymbolExpr(nDynamic++); 5078 } 5079 } 5080 if (auto constant = dyn_cast<AffineConstantExpr>(result)) 5081 return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); 5082 return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult(); 5083 } 5084 5085 /// If conseceutive outputs of a delinearize_index are linearized with the same 5086 /// bounds, canonicalize away the redundant arithmetic. 5087 /// 5088 /// That is, if we have 5089 /// ``` 5090 /// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b) 5091 /// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d] 5092 /// by (...e, B1, B2, ..., BK, ...f) 5093 /// ``` 5094 /// 5095 /// We can rewrite this to 5096 /// ``` 5097 /// B = B1 * B2 ... BK 5098 /// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b) 5099 /// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f) 5100 /// ``` 5101 /// where we replace all results of %s unaffected by the change with results 5102 /// from %sMerged. 5103 /// 5104 /// As a special case, if all results of the delinearize are merged in this way 5105 /// we can replace those usages with %x, thus cancelling the delinearization 5106 /// entirely, as in 5107 /// ``` 5108 /// %s:3 = affine.delinearize_index %x into (2, 4, 8) 5109 /// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16) 5110 /// ``` 5111 /// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)` 5112 struct CancelLinearizeOfDelinearizePortion final 5113 : OpRewritePattern<affine::AffineLinearizeIndexOp> { 5114 using OpRewritePattern::OpRewritePattern; 5115 5116 private: 5117 // Struct representing a case where the cancellation pattern 5118 // applies. A `Match` means that `length` inputs to the linearize operation 5119 // starting at `linStart` can be cancelled with `length` outputs of 5120 // `delinearize`, starting from `delinStart`. 5121 struct Match { 5122 AffineDelinearizeIndexOp delinearize; 5123 unsigned linStart = 0; 5124 unsigned delinStart = 0; 5125 unsigned length = 0; 5126 }; 5127 5128 public: 5129 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp, 5130 PatternRewriter &rewriter) const override { 5131 SmallVector<Match> matches; 5132 5133 const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis(); 5134 ArrayRef<OpFoldResult> linBasisRef = linBasis; 5135 5136 ValueRange multiIndex = linearizeOp.getMultiIndex(); 5137 unsigned numLinArgs = multiIndex.size(); 5138 unsigned linArgIdx = 0; 5139 // We only want to replace one run from the same delinearize op per 5140 // pattern invocation lest we run into invalidation issues. 5141 llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize; 5142 while (linArgIdx < numLinArgs) { 5143 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]); 5144 if (!asResult) { 5145 linArgIdx++; 5146 continue; 5147 } 5148 5149 auto delinearizeOp = 5150 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner()); 5151 if (!delinearizeOp) { 5152 linArgIdx++; 5153 continue; 5154 } 5155 5156 /// Result 0 of the delinearize and argument 0 of the linearize can 5157 /// leave their maximum value unspecified. However, even if this happens 5158 /// we can still sometimes start the match process. Specifically, if 5159 /// - The argument we're matching is result 0 and argument 0 (so the 5160 /// bounds don't matter). For example, 5161 /// 5162 /// %0:2 = affine.delinearize_index %x into (8) : index, index 5163 /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...) 5164 /// allows cancellation 5165 /// - The delinearization doesn't specify a bound, but the linearization 5166 /// is `disjoint`, which asserts that the bound on the linearization is 5167 /// correct. 5168 unsigned delinArgIdx = asResult.getResultNumber(); 5169 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis(); 5170 OpFoldResult firstDelinBound = delinBasis[delinArgIdx]; 5171 OpFoldResult firstLinBound = linBasis[linArgIdx]; 5172 bool boundsMatch = firstDelinBound == firstLinBound; 5173 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0; 5174 bool knownByDisjoint = 5175 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound; 5176 if (!boundsMatch && !bothAtFront && !knownByDisjoint) { 5177 linArgIdx++; 5178 continue; 5179 } 5180 5181 unsigned j = 1; 5182 unsigned numDelinOuts = delinearizeOp.getNumResults(); 5183 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts; 5184 ++j) { 5185 if (multiIndex[linArgIdx + j] != 5186 delinearizeOp.getResult(delinArgIdx + j)) 5187 break; 5188 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j]) 5189 break; 5190 } 5191 // If there're multiple matches against the same delinearize_index, 5192 // only rewrite the first one we find to prevent invalidations. The next 5193 // ones will be taken care of by subsequent pattern invocations. 5194 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) { 5195 linArgIdx++; 5196 continue; 5197 } 5198 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j}); 5199 linArgIdx += j; 5200 } 5201 5202 if (matches.empty()) 5203 return rewriter.notifyMatchFailure( 5204 linearizeOp, "no run of delinearize outputs to deal with"); 5205 5206 // Record all the delinearize replacements so we can do them after creating 5207 // the new linearization operation, since the new operation might use 5208 // outputs of something we're replacing. 5209 SmallVector<SmallVector<Value>> delinearizeReplacements; 5210 5211 SmallVector<Value> newIndex; 5212 newIndex.reserve(numLinArgs); 5213 SmallVector<OpFoldResult> newBasis; 5214 newBasis.reserve(numLinArgs); 5215 unsigned prevMatchEnd = 0; 5216 for (Match m : matches) { 5217 unsigned gap = m.linStart - prevMatchEnd; 5218 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap)); 5219 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap)); 5220 // Update here so we don't forget this during early continues 5221 prevMatchEnd = m.linStart + m.length; 5222 5223 PatternRewriter::InsertionGuard g(rewriter); 5224 rewriter.setInsertionPoint(m.delinearize); 5225 5226 ArrayRef<OpFoldResult> basisToMerge = 5227 linBasisRef.slice(m.linStart, m.length); 5228 // We use the slice from the linearize's basis above because of the 5229 // "bounds inferred from `disjoint`" case above. 5230 OpFoldResult newSize = 5231 computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge); 5232 5233 // Trivial case where we can just skip past the delinearize all together 5234 if (m.length == m.delinearize.getNumResults()) { 5235 newIndex.push_back(m.delinearize.getLinearIndex()); 5236 newBasis.push_back(newSize); 5237 // Pad out set of replacements so we don't do anything with this one. 5238 delinearizeReplacements.push_back(SmallVector<Value>()); 5239 continue; 5240 } 5241 5242 SmallVector<Value> newDelinResults; 5243 SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis(); 5244 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart, 5245 newDelinBasis.begin() + m.delinStart + m.length); 5246 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize); 5247 auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>( 5248 m.delinearize.getLoc(), m.delinearize.getLinearIndex(), 5249 newDelinBasis); 5250 5251 // Since there may be other uses of the indices we just merged together, 5252 // create a residual affine.delinearize_index that delinearizes the 5253 // merged output into its component parts. 5254 Value combinedElem = newDelinearize.getResult(m.delinStart); 5255 auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>( 5256 m.delinearize.getLoc(), combinedElem, basisToMerge); 5257 5258 // Swap all the uses of the unaffected delinearize outputs to the new 5259 // delinearization so that the old code can be removed if this 5260 // linearize_index is the only user of the merged results. 5261 llvm::append_range(newDelinResults, 5262 newDelinearize.getResults().take_front(m.delinStart)); 5263 llvm::append_range(newDelinResults, residualDelinearize.getResults()); 5264 llvm::append_range( 5265 newDelinResults, 5266 newDelinearize.getResults().drop_front(m.delinStart + 1)); 5267 5268 delinearizeReplacements.push_back(newDelinResults); 5269 newIndex.push_back(combinedElem); 5270 newBasis.push_back(newSize); 5271 } 5272 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd)); 5273 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd)); 5274 rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>( 5275 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint()); 5276 5277 for (auto [m, newResults] : 5278 llvm::zip_equal(matches, delinearizeReplacements)) { 5279 if (newResults.empty()) 5280 continue; 5281 rewriter.replaceOp(m.delinearize, newResults); 5282 } 5283 5284 return success(); 5285 } 5286 }; 5287 5288 /// Strip leading zero from affine.linearize_index. 5289 /// 5290 /// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten 5291 /// to `affine.linearize_index [...a] by (...b)` in all cases. 5292 struct DropLinearizeLeadingZero final 5293 : OpRewritePattern<affine::AffineLinearizeIndexOp> { 5294 using OpRewritePattern::OpRewritePattern; 5295 5296 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op, 5297 PatternRewriter &rewriter) const override { 5298 Value leadingIdx = op.getMultiIndex().front(); 5299 if (!matchPattern(leadingIdx, m_Zero())) 5300 return failure(); 5301 5302 if (op.getMultiIndex().size() == 1) { 5303 rewriter.replaceOp(op, leadingIdx); 5304 return success(); 5305 } 5306 5307 SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis(); 5308 ArrayRef<OpFoldResult> newMixedBasis = mixedBasis; 5309 if (op.hasOuterBound()) 5310 newMixedBasis = newMixedBasis.drop_front(); 5311 5312 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>( 5313 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint()); 5314 return success(); 5315 } 5316 }; 5317 } // namespace 5318 5319 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns( 5320 RewritePatternSet &patterns, MLIRContext *context) { 5321 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero, 5322 DropLinearizeUnitComponentsIfDisjointOrZero>(context); 5323 } 5324 5325 //===----------------------------------------------------------------------===// 5326 // TableGen'd op method definitions 5327 //===----------------------------------------------------------------------===// 5328 5329 #define GET_OP_CLASSES 5330 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" 5331