1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/Arith/Utils/Utils.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/Utils/StaticValueUtils.h" 13 #include "mlir/IR/AffineMap.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/SmallBitVector.h" 25 26 using namespace mlir; 27 using namespace mlir::memref; 28 29 /// Materialize a single constant operation from a given attribute value with 30 /// the desired resultant type. 31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 32 Attribute value, Type type, 33 Location loc) { 34 return arith::ConstantOp::materialize(builder, value, type, loc); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // Common canonicalization pattern support logic 39 //===----------------------------------------------------------------------===// 40 41 /// This is a common class used for patterns of the form 42 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 43 /// into the root operation directly. 44 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 45 bool folded = false; 46 for (OpOperand &operand : op->getOpOperands()) { 47 auto cast = operand.get().getDefiningOp<CastOp>(); 48 if (cast && operand.get() != inner && 49 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) { 50 operand.set(cast.getOperand()); 51 folded = true; 52 } 53 } 54 return success(folded); 55 } 56 57 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 58 /// type. 59 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 60 if (auto memref = llvm::dyn_cast<MemRefType>(type)) 61 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 62 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type)) 63 return UnrankedTensorType::get(memref.getElementType()); 64 return NoneType::get(type.getContext()); 65 } 66 67 OpFoldResult memref::getMixedSize(OpBuilder &builder, Location loc, Value value, 68 int64_t dim) { 69 auto memrefType = llvm::cast<MemRefType>(value.getType()); 70 SmallVector<OpFoldResult> result; 71 if (memrefType.isDynamicDim(dim)) 72 return builder.createOrFold<memref::DimOp>(loc, value, dim); 73 74 return builder.getIndexAttr(memrefType.getDimSize(dim)); 75 } 76 77 SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder, 78 Location loc, Value value) { 79 auto memrefType = llvm::cast<MemRefType>(value.getType()); 80 SmallVector<OpFoldResult> result; 81 for (int64_t i = 0; i < memrefType.getRank(); ++i) 82 result.push_back(getMixedSize(builder, loc, value, i)); 83 return result; 84 } 85 86 //===----------------------------------------------------------------------===// 87 // Utility functions for propagating static information 88 //===----------------------------------------------------------------------===// 89 90 /// Helper function that infers the constant values from a list of \p values, 91 /// a \p memRefTy, and another helper function \p getAttributes. 92 /// The inferred constant values replace the related `OpFoldResult` in 93 /// \p values. 94 /// 95 /// \note This function shouldn't be used directly, instead, use the 96 /// `getConstifiedMixedXXX` methods from the related operations. 97 /// 98 /// \p getAttributes retuns a list of potentially constant values, as determined 99 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as 100 /// many elements as \p values or be empty. 101 /// 102 /// E.g., consider the following example: 103 /// ``` 104 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] : 105 /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>> 106 /// ``` 107 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`. 108 /// Now using this helper function with: 109 /// - `values == [2, %dyn_stride]`, 110 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>` 111 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around 112 /// `getStridesAndOffset`), and 113 /// - `isDynamic == ShapedType::isDynamic` 114 /// Will yield: `values == [2, 1]` 115 static void constifyIndexValues( 116 SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy, 117 MLIRContext *ctxt, 118 llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes, 119 llvm::function_ref<bool(int64_t)> isDynamic) { 120 SmallVector<int64_t> constValues = getAttributes(memRefTy); 121 Builder builder(ctxt); 122 for (const auto &it : llvm::enumerate(constValues)) { 123 int64_t constValue = it.value(); 124 if (!isDynamic(constValue)) 125 values[it.index()] = builder.getIndexAttr(constValue); 126 } 127 for (OpFoldResult &ofr : values) { 128 if (auto attr = dyn_cast<Attribute>(ofr)) { 129 // FIXME: We shouldn't need to do that, but right now, the static indices 130 // are created with the wrong type: `i64` instead of `index`. 131 // As a result, if we were to keep the attribute as is, we may fail to see 132 // that two attributes are equal because one would have the i64 type and 133 // the other the index type. 134 // The alternative would be to create constant indices with getI64Attr in 135 // this and the previous loop, but it doesn't logically make sense (we are 136 // dealing with indices here) and would only strenghten the inconsistency 137 // around how static indices are created (some places use getI64Attr, 138 // others use getIndexAttr). 139 // The workaround here is to stick to the IndexAttr type for all the 140 // values, hence we recreate the attribute even when it is already static 141 // to make sure the type is consistent. 142 ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt()); 143 continue; 144 } 145 std::optional<int64_t> maybeConstant = 146 getConstantIntValue(cast<Value>(ofr)); 147 if (maybeConstant) 148 ofr = builder.getIndexAttr(*maybeConstant); 149 } 150 } 151 152 /// Wrapper around `getShape` that conforms to the function signature 153 /// expected for `getAttributes` in `constifyIndexValues`. 154 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) { 155 ArrayRef<int64_t> sizes = memRefTy.getShape(); 156 return SmallVector<int64_t>(sizes); 157 } 158 159 /// Wrapper around `getStridesAndOffset` that returns only the offset and 160 /// conforms to the function signature expected for `getAttributes` in 161 /// `constifyIndexValues`. 162 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) { 163 SmallVector<int64_t> strides; 164 int64_t offset; 165 LogicalResult hasStaticInformation = 166 memrefType.getStridesAndOffset(strides, offset); 167 if (failed(hasStaticInformation)) 168 return SmallVector<int64_t>(); 169 return SmallVector<int64_t>(1, offset); 170 } 171 172 /// Wrapper around `getStridesAndOffset` that returns only the strides and 173 /// conforms to the function signature expected for `getAttributes` in 174 /// `constifyIndexValues`. 175 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) { 176 SmallVector<int64_t> strides; 177 int64_t offset; 178 LogicalResult hasStaticInformation = 179 memrefType.getStridesAndOffset(strides, offset); 180 if (failed(hasStaticInformation)) 181 return SmallVector<int64_t>(); 182 return strides; 183 } 184 185 //===----------------------------------------------------------------------===// 186 // AllocOp / AllocaOp 187 //===----------------------------------------------------------------------===// 188 189 void AllocOp::getAsmResultNames( 190 function_ref<void(Value, StringRef)> setNameFn) { 191 setNameFn(getResult(), "alloc"); 192 } 193 194 void AllocaOp::getAsmResultNames( 195 function_ref<void(Value, StringRef)> setNameFn) { 196 setNameFn(getResult(), "alloca"); 197 } 198 199 template <typename AllocLikeOp> 200 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 201 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 202 "applies to only alloc or alloca"); 203 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType()); 204 if (!memRefType) 205 return op.emitOpError("result must be a memref"); 206 207 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims()) 208 return op.emitOpError("dimension operand count does not equal memref " 209 "dynamic dimension count"); 210 211 unsigned numSymbols = 0; 212 if (!memRefType.getLayout().isIdentity()) 213 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 214 if (op.getSymbolOperands().size() != numSymbols) 215 return op.emitOpError("symbol operand count does not equal memref symbol " 216 "count: expected ") 217 << numSymbols << ", got " << op.getSymbolOperands().size(); 218 219 return success(); 220 } 221 222 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } 223 224 LogicalResult AllocaOp::verify() { 225 // An alloca op needs to have an ancestor with an allocation scope trait. 226 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 227 return emitOpError( 228 "requires an ancestor op with AutomaticAllocationScope trait"); 229 230 return verifyAllocLikeOp(*this); 231 } 232 233 namespace { 234 /// Fold constant dimensions into an alloc like operation. 235 template <typename AllocLikeOp> 236 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 237 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 238 239 LogicalResult matchAndRewrite(AllocLikeOp alloc, 240 PatternRewriter &rewriter) const override { 241 // Check to see if any dimensions operands are constants. If so, we can 242 // substitute and drop them. 243 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) { 244 APInt constSizeArg; 245 if (!matchPattern(operand, m_ConstantInt(&constSizeArg))) 246 return false; 247 return constSizeArg.isNonNegative(); 248 })) 249 return failure(); 250 251 auto memrefType = alloc.getType(); 252 253 // Ok, we have one or more constant operands. Collect the non-constant ones 254 // and keep track of the resultant memref type to build. 255 SmallVector<int64_t, 4> newShapeConstants; 256 newShapeConstants.reserve(memrefType.getRank()); 257 SmallVector<Value, 4> dynamicSizes; 258 259 unsigned dynamicDimPos = 0; 260 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 261 int64_t dimSize = memrefType.getDimSize(dim); 262 // If this is already static dimension, keep it. 263 if (!ShapedType::isDynamic(dimSize)) { 264 newShapeConstants.push_back(dimSize); 265 continue; 266 } 267 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos]; 268 APInt constSizeArg; 269 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) && 270 constSizeArg.isNonNegative()) { 271 // Dynamic shape dimension will be folded. 272 newShapeConstants.push_back(constSizeArg.getZExtValue()); 273 } else { 274 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 275 newShapeConstants.push_back(ShapedType::kDynamic); 276 dynamicSizes.push_back(dynamicSize); 277 } 278 dynamicDimPos++; 279 } 280 281 // Create new memref type (which will have fewer dynamic dimensions). 282 MemRefType newMemRefType = 283 MemRefType::Builder(memrefType).setShape(newShapeConstants); 284 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims()); 285 286 // Create and insert the alloc op for the new memref. 287 auto newAlloc = rewriter.create<AllocLikeOp>( 288 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(), 289 alloc.getAlignmentAttr()); 290 // Insert a cast so we have the same type as the old alloc. 291 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc); 292 return success(); 293 } 294 }; 295 296 /// Fold alloc operations with no users or only store and dealloc uses. 297 template <typename T> 298 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 299 using OpRewritePattern<T>::OpRewritePattern; 300 301 LogicalResult matchAndRewrite(T alloc, 302 PatternRewriter &rewriter) const override { 303 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 304 if (auto storeOp = dyn_cast<StoreOp>(op)) 305 return storeOp.getValue() == alloc; 306 return !isa<DeallocOp>(op); 307 })) 308 return failure(); 309 310 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 311 rewriter.eraseOp(user); 312 313 rewriter.eraseOp(alloc); 314 return success(); 315 } 316 }; 317 } // namespace 318 319 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 320 MLIRContext *context) { 321 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 322 } 323 324 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 325 MLIRContext *context) { 326 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 327 context); 328 } 329 330 //===----------------------------------------------------------------------===// 331 // ReallocOp 332 //===----------------------------------------------------------------------===// 333 334 LogicalResult ReallocOp::verify() { 335 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType()); 336 MemRefType resultType = getType(); 337 338 // The source memref should have identity layout (or none). 339 if (!sourceType.getLayout().isIdentity()) 340 return emitError("unsupported layout for source memref type ") 341 << sourceType; 342 343 // The result memref should have identity layout (or none). 344 if (!resultType.getLayout().isIdentity()) 345 return emitError("unsupported layout for result memref type ") 346 << resultType; 347 348 // The source memref and the result memref should be in the same memory space. 349 if (sourceType.getMemorySpace() != resultType.getMemorySpace()) 350 return emitError("different memory spaces specified for source memref " 351 "type ") 352 << sourceType << " and result memref type " << resultType; 353 354 // The source memref and the result memref should have the same element type. 355 if (sourceType.getElementType() != resultType.getElementType()) 356 return emitError("different element types specified for source memref " 357 "type ") 358 << sourceType << " and result memref type " << resultType; 359 360 // Verify that we have the dynamic dimension operand when it is needed. 361 if (resultType.getNumDynamicDims() && !getDynamicResultSize()) 362 return emitError("missing dimension operand for result type ") 363 << resultType; 364 if (!resultType.getNumDynamicDims() && getDynamicResultSize()) 365 return emitError("unnecessary dimension operand for result type ") 366 << resultType; 367 368 return success(); 369 } 370 371 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results, 372 MLIRContext *context) { 373 results.add<SimplifyDeadAlloc<ReallocOp>>(context); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // AllocaScopeOp 378 //===----------------------------------------------------------------------===// 379 380 void AllocaScopeOp::print(OpAsmPrinter &p) { 381 bool printBlockTerminators = false; 382 383 p << ' '; 384 if (!getResults().empty()) { 385 p << " -> (" << getResultTypes() << ")"; 386 printBlockTerminators = true; 387 } 388 p << ' '; 389 p.printRegion(getBodyRegion(), 390 /*printEntryBlockArgs=*/false, 391 /*printBlockTerminators=*/printBlockTerminators); 392 p.printOptionalAttrDict((*this)->getAttrs()); 393 } 394 395 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { 396 // Create a region for the body. 397 result.regions.reserve(1); 398 Region *bodyRegion = result.addRegion(); 399 400 // Parse optional results type list. 401 if (parser.parseOptionalArrowTypeList(result.types)) 402 return failure(); 403 404 // Parse the body region. 405 if (parser.parseRegion(*bodyRegion, /*arguments=*/{})) 406 return failure(); 407 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 408 result.location); 409 410 // Parse the optional attribute list. 411 if (parser.parseOptionalAttrDict(result.attributes)) 412 return failure(); 413 414 return success(); 415 } 416 417 void AllocaScopeOp::getSuccessorRegions( 418 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 419 if (!point.isParent()) { 420 regions.push_back(RegionSuccessor(getResults())); 421 return; 422 } 423 424 regions.push_back(RegionSuccessor(&getBodyRegion())); 425 } 426 427 /// Given an operation, return whether this op is guaranteed to 428 /// allocate an AutomaticAllocationScopeResource 429 static bool isGuaranteedAutomaticAllocation(Operation *op) { 430 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 431 if (!interface) 432 return false; 433 for (auto res : op->getResults()) { 434 if (auto effect = 435 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 436 if (isa<SideEffects::AutomaticAllocationScopeResource>( 437 effect->getResource())) 438 return true; 439 } 440 } 441 return false; 442 } 443 444 /// Given an operation, return whether this op itself could 445 /// allocate an AutomaticAllocationScopeResource. Note that 446 /// this will not check whether an operation contained within 447 /// the op can allocate. 448 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { 449 // This op itself doesn't create a stack allocation, 450 // the inner allocation should be handled separately. 451 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) 452 return false; 453 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 454 if (!interface) 455 return true; 456 for (auto res : op->getResults()) { 457 if (auto effect = 458 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 459 if (isa<SideEffects::AutomaticAllocationScopeResource>( 460 effect->getResource())) 461 return true; 462 } 463 } 464 return false; 465 } 466 467 /// Return whether this op is the last non terminating op 468 /// in a region. That is to say, it is in a one-block region 469 /// and is only followed by a terminator. This prevents 470 /// extending the lifetime of allocations. 471 static bool lastNonTerminatorInRegion(Operation *op) { 472 return op->getNextNode() == op->getBlock()->getTerminator() && 473 op->getParentRegion()->getBlocks().size() == 1; 474 } 475 476 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope 477 /// or it contains no allocation. 478 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { 479 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 480 481 LogicalResult matchAndRewrite(AllocaScopeOp op, 482 PatternRewriter &rewriter) const override { 483 bool hasPotentialAlloca = 484 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) { 485 if (alloc == op) 486 return WalkResult::advance(); 487 if (isOpItselfPotentialAutomaticAllocation(alloc)) 488 return WalkResult::interrupt(); 489 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>()) 490 return WalkResult::skip(); 491 return WalkResult::advance(); 492 }).wasInterrupted(); 493 494 // If this contains no potential allocation, it is always legal to 495 // inline. Otherwise, consider two conditions: 496 if (hasPotentialAlloca) { 497 // If the parent isn't an allocation scope, or we are not the last 498 // non-terminator op in the parent, we will extend the lifetime. 499 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) 500 return failure(); 501 if (!lastNonTerminatorInRegion(op)) 502 return failure(); 503 } 504 505 Block *block = &op.getRegion().front(); 506 Operation *terminator = block->getTerminator(); 507 ValueRange results = terminator->getOperands(); 508 rewriter.inlineBlockBefore(block, op); 509 rewriter.replaceOp(op, results); 510 rewriter.eraseOp(terminator); 511 return success(); 512 } 513 }; 514 515 /// Move allocations into an allocation scope, if it is legal to 516 /// move them (e.g. their operands are available at the location 517 /// the op would be moved to). 518 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { 519 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 520 521 LogicalResult matchAndRewrite(AllocaScopeOp op, 522 PatternRewriter &rewriter) const override { 523 524 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 525 return failure(); 526 527 Operation *lastParentWithoutScope = op->getParentOp(); 528 529 if (!lastParentWithoutScope || 530 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>()) 531 return failure(); 532 533 // Only apply to if this is this last non-terminator 534 // op in the block (lest lifetime be extended) of a one 535 // block region 536 if (!lastNonTerminatorInRegion(op) || 537 !lastNonTerminatorInRegion(lastParentWithoutScope)) 538 return failure(); 539 540 while (!lastParentWithoutScope->getParentOp() 541 ->hasTrait<OpTrait::AutomaticAllocationScope>()) { 542 lastParentWithoutScope = lastParentWithoutScope->getParentOp(); 543 if (!lastParentWithoutScope || 544 !lastNonTerminatorInRegion(lastParentWithoutScope)) 545 return failure(); 546 } 547 assert(lastParentWithoutScope->getParentOp() 548 ->hasTrait<OpTrait::AutomaticAllocationScope>()); 549 550 Region *containingRegion = nullptr; 551 for (auto &r : lastParentWithoutScope->getRegions()) { 552 if (r.isAncestor(op->getParentRegion())) { 553 assert(containingRegion == nullptr && 554 "only one region can contain the op"); 555 containingRegion = &r; 556 } 557 } 558 assert(containingRegion && "op must be contained in a region"); 559 560 SmallVector<Operation *> toHoist; 561 op->walk([&](Operation *alloc) { 562 if (!isGuaranteedAutomaticAllocation(alloc)) 563 return WalkResult::skip(); 564 565 // If any operand is not defined before the location of 566 // lastParentWithoutScope (i.e. where we would hoist to), skip. 567 if (llvm::any_of(alloc->getOperands(), [&](Value v) { 568 return containingRegion->isAncestor(v.getParentRegion()); 569 })) 570 return WalkResult::skip(); 571 toHoist.push_back(alloc); 572 return WalkResult::advance(); 573 }); 574 575 if (toHoist.empty()) 576 return failure(); 577 rewriter.setInsertionPoint(lastParentWithoutScope); 578 for (auto *op : toHoist) { 579 auto *cloned = rewriter.clone(*op); 580 rewriter.replaceOp(op, cloned->getResults()); 581 } 582 return success(); 583 } 584 }; 585 586 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, 587 MLIRContext *context) { 588 results.add<AllocaScopeInliner, AllocaScopeHoister>(context); 589 } 590 591 //===----------------------------------------------------------------------===// 592 // AssumeAlignmentOp 593 //===----------------------------------------------------------------------===// 594 595 LogicalResult AssumeAlignmentOp::verify() { 596 if (!llvm::isPowerOf2_32(getAlignment())) 597 return emitOpError("alignment must be power of 2"); 598 return success(); 599 } 600 601 //===----------------------------------------------------------------------===// 602 // CastOp 603 //===----------------------------------------------------------------------===// 604 605 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 606 setNameFn(getResult(), "cast"); 607 } 608 609 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 610 /// source memref. This is useful to fold a memref.cast into a consuming op 611 /// and implement canonicalization patterns for ops in different dialects that 612 /// may consume the results of memref.cast operations. Such foldable memref.cast 613 /// operations are typically inserted as `view` and `subview` ops are 614 /// canonicalized, to preserve the type compatibility of their uses. 615 /// 616 /// Returns true when all conditions are met: 617 /// 1. source and result are ranked memrefs with strided semantics and same 618 /// element type and rank. 619 /// 2. each of the source's size, offset or stride has more static information 620 /// than the corresponding result's size, offset or stride. 621 /// 622 /// Example 1: 623 /// ```mlir 624 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 625 /// %2 = consumer %1 ... : memref<?x?xf32> ... 626 /// ``` 627 /// 628 /// may fold into: 629 /// 630 /// ```mlir 631 /// %2 = consumer %0 ... : memref<8x16xf32> ... 632 /// ``` 633 /// 634 /// Example 2: 635 /// ``` 636 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 637 /// to memref<?x?xf32> 638 /// consumer %1 : memref<?x?xf32> ... 639 /// ``` 640 /// 641 /// may fold into: 642 /// 643 /// ``` 644 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 645 /// ``` 646 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 647 MemRefType sourceType = 648 llvm::dyn_cast<MemRefType>(castOp.getSource().getType()); 649 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType()); 650 651 // Requires ranked MemRefType. 652 if (!sourceType || !resultType) 653 return false; 654 655 // Requires same elemental type. 656 if (sourceType.getElementType() != resultType.getElementType()) 657 return false; 658 659 // Requires same rank. 660 if (sourceType.getRank() != resultType.getRank()) 661 return false; 662 663 // Only fold casts between strided memref forms. 664 int64_t sourceOffset, resultOffset; 665 SmallVector<int64_t, 4> sourceStrides, resultStrides; 666 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) || 667 failed(resultType.getStridesAndOffset(resultStrides, resultOffset))) 668 return false; 669 670 // If cast is towards more static sizes along any dimension, don't fold. 671 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 672 auto ss = std::get<0>(it), st = std::get<1>(it); 673 if (ss != st) 674 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 675 return false; 676 } 677 678 // If cast is towards more static offset along any dimension, don't fold. 679 if (sourceOffset != resultOffset) 680 if (ShapedType::isDynamic(sourceOffset) && 681 !ShapedType::isDynamic(resultOffset)) 682 return false; 683 684 // If cast is towards more static strides along any dimension, don't fold. 685 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 686 auto ss = std::get<0>(it), st = std::get<1>(it); 687 if (ss != st) 688 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 689 return false; 690 } 691 692 return true; 693 } 694 695 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 696 if (inputs.size() != 1 || outputs.size() != 1) 697 return false; 698 Type a = inputs.front(), b = outputs.front(); 699 auto aT = llvm::dyn_cast<MemRefType>(a); 700 auto bT = llvm::dyn_cast<MemRefType>(b); 701 702 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a); 703 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b); 704 705 if (aT && bT) { 706 if (aT.getElementType() != bT.getElementType()) 707 return false; 708 if (aT.getLayout() != bT.getLayout()) { 709 int64_t aOffset, bOffset; 710 SmallVector<int64_t, 4> aStrides, bStrides; 711 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) || 712 failed(bT.getStridesAndOffset(bStrides, bOffset)) || 713 aStrides.size() != bStrides.size()) 714 return false; 715 716 // Strides along a dimension/offset are compatible if the value in the 717 // source memref is static and the value in the target memref is the 718 // same. They are also compatible if either one is dynamic (see 719 // description of MemRefCastOp for details). 720 auto checkCompatible = [](int64_t a, int64_t b) { 721 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b); 722 }; 723 if (!checkCompatible(aOffset, bOffset)) 724 return false; 725 for (const auto &aStride : enumerate(aStrides)) 726 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 727 return false; 728 } 729 if (aT.getMemorySpace() != bT.getMemorySpace()) 730 return false; 731 732 // They must have the same rank, and any specified dimensions must match. 733 if (aT.getRank() != bT.getRank()) 734 return false; 735 736 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 737 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 738 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) && 739 aDim != bDim) 740 return false; 741 } 742 return true; 743 } else { 744 if (!aT && !uaT) 745 return false; 746 if (!bT && !ubT) 747 return false; 748 // Unranked to unranked casting is unsupported 749 if (uaT && ubT) 750 return false; 751 752 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 753 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 754 if (aEltType != bEltType) 755 return false; 756 757 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 758 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 759 return aMemSpace == bMemSpace; 760 } 761 762 return false; 763 } 764 765 OpFoldResult CastOp::fold(FoldAdaptor adaptor) { 766 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 767 } 768 769 //===----------------------------------------------------------------------===// 770 // CopyOp 771 //===----------------------------------------------------------------------===// 772 773 namespace { 774 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 775 /// and element type, the cast can be skipped. Such CastOps only cast the layout 776 /// of the type. 777 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 778 using OpRewritePattern<CopyOp>::OpRewritePattern; 779 780 LogicalResult matchAndRewrite(CopyOp copyOp, 781 PatternRewriter &rewriter) const override { 782 bool modified = false; 783 784 // Check source. 785 if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) { 786 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType()); 787 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType()); 788 789 if (fromType && toType) { 790 if (fromType.getShape() == toType.getShape() && 791 fromType.getElementType() == toType.getElementType()) { 792 rewriter.modifyOpInPlace(copyOp, [&] { 793 copyOp.getSourceMutable().assign(castOp.getSource()); 794 }); 795 modified = true; 796 } 797 } 798 } 799 800 // Check target. 801 if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) { 802 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType()); 803 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType()); 804 805 if (fromType && toType) { 806 if (fromType.getShape() == toType.getShape() && 807 fromType.getElementType() == toType.getElementType()) { 808 rewriter.modifyOpInPlace(copyOp, [&] { 809 copyOp.getTargetMutable().assign(castOp.getSource()); 810 }); 811 modified = true; 812 } 813 } 814 } 815 816 return success(modified); 817 } 818 }; 819 820 /// Fold memref.copy(%x, %x). 821 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 822 using OpRewritePattern<CopyOp>::OpRewritePattern; 823 824 LogicalResult matchAndRewrite(CopyOp copyOp, 825 PatternRewriter &rewriter) const override { 826 if (copyOp.getSource() != copyOp.getTarget()) 827 return failure(); 828 829 rewriter.eraseOp(copyOp); 830 return success(); 831 } 832 }; 833 834 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> { 835 using OpRewritePattern<CopyOp>::OpRewritePattern; 836 837 static bool isEmptyMemRef(BaseMemRefType type) { 838 return type.hasRank() && llvm::is_contained(type.getShape(), 0); 839 } 840 841 LogicalResult matchAndRewrite(CopyOp copyOp, 842 PatternRewriter &rewriter) const override { 843 if (isEmptyMemRef(copyOp.getSource().getType()) || 844 isEmptyMemRef(copyOp.getTarget().getType())) { 845 rewriter.eraseOp(copyOp); 846 return success(); 847 } 848 849 return failure(); 850 } 851 }; 852 } // namespace 853 854 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 855 MLIRContext *context) { 856 results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context); 857 } 858 859 LogicalResult CopyOp::fold(FoldAdaptor adaptor, 860 SmallVectorImpl<OpFoldResult> &results) { 861 /// copy(memrefcast) -> copy 862 bool folded = false; 863 Operation *op = *this; 864 for (OpOperand &operand : op->getOpOperands()) { 865 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 866 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 867 operand.set(castOp.getOperand()); 868 folded = true; 869 } 870 } 871 return success(folded); 872 } 873 874 //===----------------------------------------------------------------------===// 875 // DeallocOp 876 //===----------------------------------------------------------------------===// 877 878 LogicalResult DeallocOp::fold(FoldAdaptor adaptor, 879 SmallVectorImpl<OpFoldResult> &results) { 880 /// dealloc(memrefcast) -> dealloc 881 return foldMemRefCast(*this); 882 } 883 884 //===----------------------------------------------------------------------===// 885 // DimOp 886 //===----------------------------------------------------------------------===// 887 888 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 889 setNameFn(getResult(), "dim"); 890 } 891 892 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 893 int64_t index) { 894 auto loc = result.location; 895 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 896 build(builder, result, source, indexValue); 897 } 898 899 std::optional<int64_t> DimOp::getConstantIndex() { 900 return getConstantIntValue(getIndex()); 901 } 902 903 Speculation::Speculatability DimOp::getSpeculatability() { 904 auto constantIndex = getConstantIndex(); 905 if (!constantIndex) 906 return Speculation::NotSpeculatable; 907 908 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType()); 909 if (!rankedSourceType) 910 return Speculation::NotSpeculatable; 911 912 if (rankedSourceType.getRank() <= constantIndex) 913 return Speculation::NotSpeculatable; 914 915 return Speculation::Speculatable; 916 } 917 918 /// Return a map with key being elements in `vals` and data being number of 919 /// occurences of it. Use std::map, since the `vals` here are strides and the 920 /// dynamic stride value is the same as the tombstone value for 921 /// `DenseMap<int64_t>`. 922 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 923 std::map<int64_t, unsigned> numOccurences; 924 for (auto val : vals) 925 numOccurences[val]++; 926 return numOccurences; 927 } 928 929 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 930 /// to be a subset of `originalType` with some `1` entries erased, return the 931 /// set of indices that specifies which of the entries of `originalShape` are 932 /// dropped to obtain `reducedShape`. 933 /// This accounts for cases where there are multiple unit-dims, but only a 934 /// subset of those are dropped. For MemRefTypes these can be disambiguated 935 /// using the strides. If a dimension is dropped the stride must be dropped too. 936 static FailureOr<llvm::SmallBitVector> 937 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 938 ArrayRef<OpFoldResult> sizes) { 939 llvm::SmallBitVector unusedDims(originalType.getRank()); 940 if (originalType.getRank() == reducedType.getRank()) 941 return unusedDims; 942 943 for (const auto &dim : llvm::enumerate(sizes)) 944 if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value())) 945 if (llvm::cast<IntegerAttr>(attr).getInt() == 1) 946 unusedDims.set(dim.index()); 947 948 // Early exit for the case where the number of unused dims matches the number 949 // of ranks reduced. 950 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() == 951 originalType.getRank()) 952 return unusedDims; 953 954 SmallVector<int64_t> originalStrides, candidateStrides; 955 int64_t originalOffset, candidateOffset; 956 if (failed( 957 originalType.getStridesAndOffset(originalStrides, originalOffset)) || 958 failed( 959 reducedType.getStridesAndOffset(candidateStrides, candidateOffset))) 960 return failure(); 961 962 // For memrefs, a dimension is truly dropped if its corresponding stride is 963 // also dropped. This is particularly important when more than one of the dims 964 // is 1. Track the number of occurences of the strides in the original type 965 // and the candidate type. For each unused dim that stride should not be 966 // present in the candidate type. Note that there could be multiple dimensions 967 // that have the same size. We dont need to exactly figure out which dim 968 // corresponds to which stride, we just need to verify that the number of 969 // reptitions of a stride in the original + number of unused dims with that 970 // stride == number of repititions of a stride in the candidate. 971 std::map<int64_t, unsigned> currUnaccountedStrides = 972 getNumOccurences(originalStrides); 973 std::map<int64_t, unsigned> candidateStridesNumOccurences = 974 getNumOccurences(candidateStrides); 975 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) { 976 if (!unusedDims.test(dim)) 977 continue; 978 int64_t originalStride = originalStrides[dim]; 979 if (currUnaccountedStrides[originalStride] > 980 candidateStridesNumOccurences[originalStride]) { 981 // This dim can be treated as dropped. 982 currUnaccountedStrides[originalStride]--; 983 continue; 984 } 985 if (currUnaccountedStrides[originalStride] == 986 candidateStridesNumOccurences[originalStride]) { 987 // The stride for this is not dropped. Keep as is. 988 unusedDims.reset(dim); 989 continue; 990 } 991 if (currUnaccountedStrides[originalStride] < 992 candidateStridesNumOccurences[originalStride]) { 993 // This should never happen. Cant have a stride in the reduced rank type 994 // that wasnt in the original one. 995 return failure(); 996 } 997 } 998 999 if ((int64_t)unusedDims.count() + reducedType.getRank() != 1000 originalType.getRank()) 1001 return failure(); 1002 return unusedDims; 1003 } 1004 1005 llvm::SmallBitVector SubViewOp::getDroppedDims() { 1006 MemRefType sourceType = getSourceType(); 1007 MemRefType resultType = getType(); 1008 FailureOr<llvm::SmallBitVector> unusedDims = 1009 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 1010 assert(succeeded(unusedDims) && "unable to find unused dims of subview"); 1011 return *unusedDims; 1012 } 1013 1014 OpFoldResult DimOp::fold(FoldAdaptor adaptor) { 1015 // All forms of folding require a known index. 1016 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1017 if (!index) 1018 return {}; 1019 1020 // Folding for unranked types (UnrankedMemRefType) is not supported. 1021 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType()); 1022 if (!memrefType) 1023 return {}; 1024 1025 // Out of bound indices produce undefined behavior but are still valid IR. 1026 // Don't choke on them. 1027 int64_t indexVal = index.getInt(); 1028 if (indexVal < 0 || indexVal >= memrefType.getRank()) 1029 return {}; 1030 1031 // Fold if the shape extent along the given index is known. 1032 if (!memrefType.isDynamicDim(index.getInt())) { 1033 Builder builder(getContext()); 1034 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 1035 } 1036 1037 // The size at the given index is now known to be a dynamic size. 1038 unsigned unsignedIndex = index.getValue().getZExtValue(); 1039 1040 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 1041 Operation *definingOp = getSource().getDefiningOp(); 1042 1043 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 1044 return *(alloc.getDynamicSizes().begin() + 1045 memrefType.getDynamicDimIndex(unsignedIndex)); 1046 1047 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 1048 return *(alloca.getDynamicSizes().begin() + 1049 memrefType.getDynamicDimIndex(unsignedIndex)); 1050 1051 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 1052 return *(view.getDynamicSizes().begin() + 1053 memrefType.getDynamicDimIndex(unsignedIndex)); 1054 1055 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 1056 llvm::SmallBitVector unusedDims = subview.getDroppedDims(); 1057 unsigned resultIndex = 0; 1058 unsigned sourceRank = subview.getSourceType().getRank(); 1059 unsigned sourceIndex = 0; 1060 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 1061 if (unusedDims.test(i)) 1062 continue; 1063 if (resultIndex == unsignedIndex) { 1064 sourceIndex = i; 1065 break; 1066 } 1067 resultIndex++; 1068 } 1069 assert(subview.isDynamicSize(sourceIndex) && 1070 "expected dynamic subview size"); 1071 return subview.getDynamicSize(sourceIndex); 1072 } 1073 1074 if (auto sizeInterface = 1075 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 1076 assert(sizeInterface.isDynamicSize(unsignedIndex) && 1077 "Expected dynamic subview size"); 1078 return sizeInterface.getDynamicSize(unsignedIndex); 1079 } 1080 1081 // dim(memrefcast) -> dim 1082 if (succeeded(foldMemRefCast(*this))) 1083 return getResult(); 1084 1085 return {}; 1086 } 1087 1088 namespace { 1089 /// Fold dim of a memref reshape operation to a load into the reshape's shape 1090 /// operand. 1091 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 1092 using OpRewritePattern<DimOp>::OpRewritePattern; 1093 1094 LogicalResult matchAndRewrite(DimOp dim, 1095 PatternRewriter &rewriter) const override { 1096 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>(); 1097 1098 if (!reshape) 1099 return rewriter.notifyMatchFailure( 1100 dim, "Dim op is not defined by a reshape op."); 1101 1102 // dim of a memref reshape can be folded if dim.getIndex() dominates the 1103 // reshape. Instead of using `DominanceInfo` (which is usually costly) we 1104 // cheaply check that either of the following conditions hold: 1105 // 1. dim.getIndex() is defined in the same block as reshape but before 1106 // reshape. 1107 // 2. dim.getIndex() is defined in a parent block of 1108 // reshape. 1109 1110 // Check condition 1 1111 if (dim.getIndex().getParentBlock() == reshape->getBlock()) { 1112 if (auto *definingOp = dim.getIndex().getDefiningOp()) { 1113 if (reshape->isBeforeInBlock(definingOp)) { 1114 return rewriter.notifyMatchFailure( 1115 dim, 1116 "dim.getIndex is not defined before reshape in the same block."); 1117 } 1118 } // else dim.getIndex is a block argument to reshape->getBlock and 1119 // dominates reshape 1120 } // Check condition 2 1121 else if (dim->getBlock() != reshape->getBlock() && 1122 !dim.getIndex().getParentRegion()->isProperAncestor( 1123 reshape->getParentRegion())) { 1124 // If dim and reshape are in the same block but dim.getIndex() isn't, we 1125 // already know dim.getIndex() dominates reshape without calling 1126 // `isProperAncestor` 1127 return rewriter.notifyMatchFailure( 1128 dim, "dim.getIndex does not dominate reshape."); 1129 } 1130 1131 // Place the load directly after the reshape to ensure that the shape memref 1132 // was not mutated. 1133 rewriter.setInsertionPointAfter(reshape); 1134 Location loc = dim.getLoc(); 1135 Value load = 1136 rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex()); 1137 if (load.getType() != dim.getType()) 1138 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 1139 rewriter.replaceOp(dim, load); 1140 return success(); 1141 } 1142 }; 1143 1144 } // namespace 1145 1146 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 1147 MLIRContext *context) { 1148 results.add<DimOfMemRefReshape>(context); 1149 } 1150 1151 // --------------------------------------------------------------------------- 1152 // DmaStartOp 1153 // --------------------------------------------------------------------------- 1154 1155 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 1156 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 1157 ValueRange destIndices, Value numElements, 1158 Value tagMemRef, ValueRange tagIndices, Value stride, 1159 Value elementsPerStride) { 1160 result.addOperands(srcMemRef); 1161 result.addOperands(srcIndices); 1162 result.addOperands(destMemRef); 1163 result.addOperands(destIndices); 1164 result.addOperands({numElements, tagMemRef}); 1165 result.addOperands(tagIndices); 1166 if (stride) 1167 result.addOperands({stride, elementsPerStride}); 1168 } 1169 1170 void DmaStartOp::print(OpAsmPrinter &p) { 1171 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " 1172 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() 1173 << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; 1174 if (isStrided()) 1175 p << ", " << getStride() << ", " << getNumElementsPerStride(); 1176 1177 p.printOptionalAttrDict((*this)->getAttrs()); 1178 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 1179 << ", " << getTagMemRef().getType(); 1180 } 1181 1182 // Parse DmaStartOp. 1183 // Ex: 1184 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 1185 // %tag[%index], %stride, %num_elt_per_stride : 1186 // : memref<3076 x f32, 0>, 1187 // memref<1024 x f32, 2>, 1188 // memref<1 x i32> 1189 // 1190 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 1191 OpAsmParser::UnresolvedOperand srcMemRefInfo; 1192 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos; 1193 OpAsmParser::UnresolvedOperand dstMemRefInfo; 1194 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos; 1195 OpAsmParser::UnresolvedOperand numElementsInfo; 1196 OpAsmParser::UnresolvedOperand tagMemrefInfo; 1197 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos; 1198 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo; 1199 1200 SmallVector<Type, 3> types; 1201 auto indexType = parser.getBuilder().getIndexType(); 1202 1203 // Parse and resolve the following list of operands: 1204 // *) source memref followed by its indices (in square brackets). 1205 // *) destination memref followed by its indices (in square brackets). 1206 // *) dma size in KiB. 1207 if (parser.parseOperand(srcMemRefInfo) || 1208 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 1209 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 1210 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 1211 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1212 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 1213 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 1214 return failure(); 1215 1216 // Parse optional stride and elements per stride. 1217 if (parser.parseTrailingOperandList(strideInfo)) 1218 return failure(); 1219 1220 bool isStrided = strideInfo.size() == 2; 1221 if (!strideInfo.empty() && !isStrided) { 1222 return parser.emitError(parser.getNameLoc(), 1223 "expected two stride related operands"); 1224 } 1225 1226 if (parser.parseColonTypeList(types)) 1227 return failure(); 1228 if (types.size() != 3) 1229 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 1230 1231 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 1232 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 1233 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 1234 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 1235 // size should be an index. 1236 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 1237 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 1238 // tag indices should be index. 1239 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 1240 return failure(); 1241 1242 if (isStrided) { 1243 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1244 return failure(); 1245 } 1246 1247 return success(); 1248 } 1249 1250 LogicalResult DmaStartOp::verify() { 1251 unsigned numOperands = getNumOperands(); 1252 1253 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1254 // the number of elements. 1255 if (numOperands < 4) 1256 return emitOpError("expected at least 4 operands"); 1257 1258 // Check types of operands. The order of these calls is important: the later 1259 // calls rely on some type properties to compute the operand position. 1260 // 1. Source memref. 1261 if (!llvm::isa<MemRefType>(getSrcMemRef().getType())) 1262 return emitOpError("expected source to be of memref type"); 1263 if (numOperands < getSrcMemRefRank() + 4) 1264 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 1265 << " operands"; 1266 if (!getSrcIndices().empty() && 1267 !llvm::all_of(getSrcIndices().getTypes(), 1268 [](Type t) { return t.isIndex(); })) 1269 return emitOpError("expected source indices to be of index type"); 1270 1271 // 2. Destination memref. 1272 if (!llvm::isa<MemRefType>(getDstMemRef().getType())) 1273 return emitOpError("expected destination to be of memref type"); 1274 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 1275 if (numOperands < numExpectedOperands) 1276 return emitOpError() << "expected at least " << numExpectedOperands 1277 << " operands"; 1278 if (!getDstIndices().empty() && 1279 !llvm::all_of(getDstIndices().getTypes(), 1280 [](Type t) { return t.isIndex(); })) 1281 return emitOpError("expected destination indices to be of index type"); 1282 1283 // 3. Number of elements. 1284 if (!getNumElements().getType().isIndex()) 1285 return emitOpError("expected num elements to be of index type"); 1286 1287 // 4. Tag memref. 1288 if (!llvm::isa<MemRefType>(getTagMemRef().getType())) 1289 return emitOpError("expected tag to be of memref type"); 1290 numExpectedOperands += getTagMemRefRank(); 1291 if (numOperands < numExpectedOperands) 1292 return emitOpError() << "expected at least " << numExpectedOperands 1293 << " operands"; 1294 if (!getTagIndices().empty() && 1295 !llvm::all_of(getTagIndices().getTypes(), 1296 [](Type t) { return t.isIndex(); })) 1297 return emitOpError("expected tag indices to be of index type"); 1298 1299 // Optional stride-related operands must be either both present or both 1300 // absent. 1301 if (numOperands != numExpectedOperands && 1302 numOperands != numExpectedOperands + 2) 1303 return emitOpError("incorrect number of operands"); 1304 1305 // 5. Strides. 1306 if (isStrided()) { 1307 if (!getStride().getType().isIndex() || 1308 !getNumElementsPerStride().getType().isIndex()) 1309 return emitOpError( 1310 "expected stride and num elements per stride to be of type index"); 1311 } 1312 1313 return success(); 1314 } 1315 1316 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor, 1317 SmallVectorImpl<OpFoldResult> &results) { 1318 /// dma_start(memrefcast) -> dma_start 1319 return foldMemRefCast(*this); 1320 } 1321 1322 // --------------------------------------------------------------------------- 1323 // DmaWaitOp 1324 // --------------------------------------------------------------------------- 1325 1326 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor, 1327 SmallVectorImpl<OpFoldResult> &results) { 1328 /// dma_wait(memrefcast) -> dma_wait 1329 return foldMemRefCast(*this); 1330 } 1331 1332 LogicalResult DmaWaitOp::verify() { 1333 // Check that the number of tag indices matches the tagMemRef rank. 1334 unsigned numTagIndices = getTagIndices().size(); 1335 unsigned tagMemRefRank = getTagMemRefRank(); 1336 if (numTagIndices != tagMemRefRank) 1337 return emitOpError() << "expected tagIndices to have the same number of " 1338 "elements as the tagMemRef rank, expected " 1339 << tagMemRefRank << ", but got " << numTagIndices; 1340 return success(); 1341 } 1342 1343 //===----------------------------------------------------------------------===// 1344 // ExtractAlignedPointerAsIndexOp 1345 //===----------------------------------------------------------------------===// 1346 1347 void ExtractAlignedPointerAsIndexOp::getAsmResultNames( 1348 function_ref<void(Value, StringRef)> setNameFn) { 1349 setNameFn(getResult(), "intptr"); 1350 } 1351 1352 //===----------------------------------------------------------------------===// 1353 // ExtractStridedMetadataOp 1354 //===----------------------------------------------------------------------===// 1355 1356 /// The number and type of the results are inferred from the 1357 /// shape of the source. 1358 LogicalResult ExtractStridedMetadataOp::inferReturnTypes( 1359 MLIRContext *context, std::optional<Location> location, 1360 ExtractStridedMetadataOp::Adaptor adaptor, 1361 SmallVectorImpl<Type> &inferredReturnTypes) { 1362 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType()); 1363 if (!sourceType) 1364 return failure(); 1365 1366 unsigned sourceRank = sourceType.getRank(); 1367 IndexType indexType = IndexType::get(context); 1368 auto memrefType = 1369 MemRefType::get({}, sourceType.getElementType(), 1370 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace()); 1371 // Base. 1372 inferredReturnTypes.push_back(memrefType); 1373 // Offset. 1374 inferredReturnTypes.push_back(indexType); 1375 // Sizes and strides. 1376 for (unsigned i = 0; i < sourceRank * 2; ++i) 1377 inferredReturnTypes.push_back(indexType); 1378 return success(); 1379 } 1380 1381 void ExtractStridedMetadataOp::getAsmResultNames( 1382 function_ref<void(Value, StringRef)> setNameFn) { 1383 setNameFn(getBaseBuffer(), "base_buffer"); 1384 setNameFn(getOffset(), "offset"); 1385 // For multi-result to work properly with pretty names and packed syntax `x:3` 1386 // we can only give a pretty name to the first value in the pack. 1387 if (!getSizes().empty()) { 1388 setNameFn(getSizes().front(), "sizes"); 1389 setNameFn(getStrides().front(), "strides"); 1390 } 1391 } 1392 1393 /// Helper function to perform the replacement of all constant uses of `values` 1394 /// by a materialized constant extracted from `maybeConstants`. 1395 /// `values` and `maybeConstants` are expected to have the same size. 1396 template <typename Container> 1397 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, 1398 Container values, 1399 ArrayRef<OpFoldResult> maybeConstants) { 1400 assert(values.size() == maybeConstants.size() && 1401 " expected values and maybeConstants of the same size"); 1402 bool atLeastOneReplacement = false; 1403 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) { 1404 // Don't materialize a constant if there are no uses: this would indice 1405 // infinite loops in the driver. 1406 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result)) 1407 continue; 1408 assert(isa<Attribute>(maybeConstant) && 1409 "The constified value should be either unchanged (i.e., == result) " 1410 "or a constant"); 1411 Value constantVal = rewriter.create<arith::ConstantIndexOp>( 1412 loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt()); 1413 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { 1414 // modifyOpInPlace: lambda cannot capture structured bindings in C++17 1415 // yet. 1416 op->replaceUsesOfWith(result, constantVal); 1417 atLeastOneReplacement = true; 1418 } 1419 } 1420 return atLeastOneReplacement; 1421 } 1422 1423 LogicalResult 1424 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor, 1425 SmallVectorImpl<OpFoldResult> &results) { 1426 OpBuilder builder(*this); 1427 1428 bool atLeastOneReplacement = replaceConstantUsesOf( 1429 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()), 1430 getConstifiedMixedOffset()); 1431 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(), 1432 getConstifiedMixedSizes()); 1433 atLeastOneReplacement |= replaceConstantUsesOf( 1434 builder, getLoc(), getStrides(), getConstifiedMixedStrides()); 1435 1436 return success(atLeastOneReplacement); 1437 } 1438 1439 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() { 1440 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes()); 1441 constifyIndexValues(values, getSource().getType(), getContext(), 1442 getConstantSizes, ShapedType::isDynamic); 1443 return values; 1444 } 1445 1446 SmallVector<OpFoldResult> 1447 ExtractStridedMetadataOp::getConstifiedMixedStrides() { 1448 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides()); 1449 constifyIndexValues(values, getSource().getType(), getContext(), 1450 getConstantStrides, ShapedType::isDynamic); 1451 return values; 1452 } 1453 1454 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() { 1455 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset()); 1456 SmallVector<OpFoldResult> values(1, offsetOfr); 1457 constifyIndexValues(values, getSource().getType(), getContext(), 1458 getConstantOffset, ShapedType::isDynamic); 1459 return values[0]; 1460 } 1461 1462 //===----------------------------------------------------------------------===// 1463 // GenericAtomicRMWOp 1464 //===----------------------------------------------------------------------===// 1465 1466 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, 1467 Value memref, ValueRange ivs) { 1468 OpBuilder::InsertionGuard g(builder); 1469 result.addOperands(memref); 1470 result.addOperands(ivs); 1471 1472 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) { 1473 Type elementType = memrefType.getElementType(); 1474 result.addTypes(elementType); 1475 1476 Region *bodyRegion = result.addRegion(); 1477 builder.createBlock(bodyRegion); 1478 bodyRegion->addArgument(elementType, memref.getLoc()); 1479 } 1480 } 1481 1482 LogicalResult GenericAtomicRMWOp::verify() { 1483 auto &body = getRegion(); 1484 if (body.getNumArguments() != 1) 1485 return emitOpError("expected single number of entry block arguments"); 1486 1487 if (getResult().getType() != body.getArgument(0).getType()) 1488 return emitOpError("expected block argument of the same type result type"); 1489 1490 bool hasSideEffects = 1491 body.walk([&](Operation *nestedOp) { 1492 if (isMemoryEffectFree(nestedOp)) 1493 return WalkResult::advance(); 1494 nestedOp->emitError( 1495 "body of 'memref.generic_atomic_rmw' should contain " 1496 "only operations with no side effects"); 1497 return WalkResult::interrupt(); 1498 }) 1499 .wasInterrupted(); 1500 return hasSideEffects ? failure() : success(); 1501 } 1502 1503 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, 1504 OperationState &result) { 1505 OpAsmParser::UnresolvedOperand memref; 1506 Type memrefType; 1507 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs; 1508 1509 Type indexType = parser.getBuilder().getIndexType(); 1510 if (parser.parseOperand(memref) || 1511 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || 1512 parser.parseColonType(memrefType) || 1513 parser.resolveOperand(memref, memrefType, result.operands) || 1514 parser.resolveOperands(ivs, indexType, result.operands)) 1515 return failure(); 1516 1517 Region *body = result.addRegion(); 1518 if (parser.parseRegion(*body, {}) || 1519 parser.parseOptionalAttrDict(result.attributes)) 1520 return failure(); 1521 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType()); 1522 return success(); 1523 } 1524 1525 void GenericAtomicRMWOp::print(OpAsmPrinter &p) { 1526 p << ' ' << getMemref() << "[" << getIndices() 1527 << "] : " << getMemref().getType() << ' '; 1528 p.printRegion(getRegion()); 1529 p.printOptionalAttrDict((*this)->getAttrs()); 1530 } 1531 1532 //===----------------------------------------------------------------------===// 1533 // AtomicYieldOp 1534 //===----------------------------------------------------------------------===// 1535 1536 LogicalResult AtomicYieldOp::verify() { 1537 Type parentType = (*this)->getParentOp()->getResultTypes().front(); 1538 Type resultType = getResult().getType(); 1539 if (parentType != resultType) 1540 return emitOpError() << "types mismatch between yield op: " << resultType 1541 << " and its parent: " << parentType; 1542 return success(); 1543 } 1544 1545 //===----------------------------------------------------------------------===// 1546 // GlobalOp 1547 //===----------------------------------------------------------------------===// 1548 1549 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1550 TypeAttr type, 1551 Attribute initialValue) { 1552 p << type; 1553 if (!op.isExternal()) { 1554 p << " = "; 1555 if (op.isUninitialized()) 1556 p << "uninitialized"; 1557 else 1558 p.printAttributeWithoutType(initialValue); 1559 } 1560 } 1561 1562 static ParseResult 1563 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1564 Attribute &initialValue) { 1565 Type type; 1566 if (parser.parseType(type)) 1567 return failure(); 1568 1569 auto memrefType = llvm::dyn_cast<MemRefType>(type); 1570 if (!memrefType || !memrefType.hasStaticShape()) 1571 return parser.emitError(parser.getNameLoc()) 1572 << "type should be static shaped memref, but got " << type; 1573 typeAttr = TypeAttr::get(type); 1574 1575 if (parser.parseOptionalEqual()) 1576 return success(); 1577 1578 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1579 initialValue = UnitAttr::get(parser.getContext()); 1580 return success(); 1581 } 1582 1583 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1584 if (parser.parseAttribute(initialValue, tensorType)) 1585 return failure(); 1586 if (!llvm::isa<ElementsAttr>(initialValue)) 1587 return parser.emitError(parser.getNameLoc()) 1588 << "initial value should be a unit or elements attribute"; 1589 return success(); 1590 } 1591 1592 LogicalResult GlobalOp::verify() { 1593 auto memrefType = llvm::dyn_cast<MemRefType>(getType()); 1594 if (!memrefType || !memrefType.hasStaticShape()) 1595 return emitOpError("type should be static shaped memref, but got ") 1596 << getType(); 1597 1598 // Verify that the initial value, if present, is either a unit attribute or 1599 // an elements attribute. 1600 if (getInitialValue().has_value()) { 1601 Attribute initValue = getInitialValue().value(); 1602 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue)) 1603 return emitOpError("initial value should be a unit or elements " 1604 "attribute, but got ") 1605 << initValue; 1606 1607 // Check that the type of the initial value is compatible with the type of 1608 // the global variable. 1609 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) { 1610 Type initType = elementsAttr.getType(); 1611 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1612 if (initType != tensorType) 1613 return emitOpError("initial value expected to be of type ") 1614 << tensorType << ", but was of type " << initType; 1615 } 1616 } 1617 1618 if (std::optional<uint64_t> alignAttr = getAlignment()) { 1619 uint64_t alignment = *alignAttr; 1620 1621 if (!llvm::isPowerOf2_64(alignment)) 1622 return emitError() << "alignment attribute value " << alignment 1623 << " is not a power of 2"; 1624 } 1625 1626 // TODO: verify visibility for declarations. 1627 return success(); 1628 } 1629 1630 ElementsAttr GlobalOp::getConstantInitValue() { 1631 auto initVal = getInitialValue(); 1632 if (getConstant() && initVal.has_value()) 1633 return llvm::cast<ElementsAttr>(initVal.value()); 1634 return {}; 1635 } 1636 1637 //===----------------------------------------------------------------------===// 1638 // GetGlobalOp 1639 //===----------------------------------------------------------------------===// 1640 1641 LogicalResult 1642 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1643 // Verify that the result type is same as the type of the referenced 1644 // memref.global op. 1645 auto global = 1646 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr()); 1647 if (!global) 1648 return emitOpError("'") 1649 << getName() << "' does not reference a valid global memref"; 1650 1651 Type resultType = getResult().getType(); 1652 if (global.getType() != resultType) 1653 return emitOpError("result type ") 1654 << resultType << " does not match type " << global.getType() 1655 << " of the global memref @" << getName(); 1656 return success(); 1657 } 1658 1659 //===----------------------------------------------------------------------===// 1660 // LoadOp 1661 //===----------------------------------------------------------------------===// 1662 1663 LogicalResult LoadOp::verify() { 1664 if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) { 1665 return emitOpError("incorrect number of indices for load, expected ") 1666 << getMemRefType().getRank() << " but got " << getIndices().size(); 1667 } 1668 return success(); 1669 } 1670 1671 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) { 1672 /// load(memrefcast) -> load 1673 if (succeeded(foldMemRefCast(*this))) 1674 return getResult(); 1675 return OpFoldResult(); 1676 } 1677 1678 //===----------------------------------------------------------------------===// 1679 // MemorySpaceCastOp 1680 //===----------------------------------------------------------------------===// 1681 1682 void MemorySpaceCastOp::getAsmResultNames( 1683 function_ref<void(Value, StringRef)> setNameFn) { 1684 setNameFn(getResult(), "memspacecast"); 1685 } 1686 1687 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1688 if (inputs.size() != 1 || outputs.size() != 1) 1689 return false; 1690 Type a = inputs.front(), b = outputs.front(); 1691 auto aT = llvm::dyn_cast<MemRefType>(a); 1692 auto bT = llvm::dyn_cast<MemRefType>(b); 1693 1694 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a); 1695 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b); 1696 1697 if (aT && bT) { 1698 if (aT.getElementType() != bT.getElementType()) 1699 return false; 1700 if (aT.getLayout() != bT.getLayout()) 1701 return false; 1702 if (aT.getShape() != bT.getShape()) 1703 return false; 1704 return true; 1705 } 1706 if (uaT && ubT) { 1707 return uaT.getElementType() == ubT.getElementType(); 1708 } 1709 return false; 1710 } 1711 1712 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { 1713 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v, 1714 // t2) 1715 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) { 1716 getSourceMutable().assign(parentCast.getSource()); 1717 return getResult(); 1718 } 1719 return Value{}; 1720 } 1721 1722 //===----------------------------------------------------------------------===// 1723 // PrefetchOp 1724 //===----------------------------------------------------------------------===// 1725 1726 void PrefetchOp::print(OpAsmPrinter &p) { 1727 p << " " << getMemref() << '['; 1728 p.printOperands(getIndices()); 1729 p << ']' << ", " << (getIsWrite() ? "write" : "read"); 1730 p << ", locality<" << getLocalityHint(); 1731 p << ">, " << (getIsDataCache() ? "data" : "instr"); 1732 p.printOptionalAttrDict( 1733 (*this)->getAttrs(), 1734 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1735 p << " : " << getMemRefType(); 1736 } 1737 1738 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { 1739 OpAsmParser::UnresolvedOperand memrefInfo; 1740 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo; 1741 IntegerAttr localityHint; 1742 MemRefType type; 1743 StringRef readOrWrite, cacheType; 1744 1745 auto indexTy = parser.getBuilder().getIndexType(); 1746 auto i32Type = parser.getBuilder().getIntegerType(32); 1747 if (parser.parseOperand(memrefInfo) || 1748 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1749 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1750 parser.parseComma() || parser.parseKeyword("locality") || 1751 parser.parseLess() || 1752 parser.parseAttribute(localityHint, i32Type, "localityHint", 1753 result.attributes) || 1754 parser.parseGreater() || parser.parseComma() || 1755 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1756 parser.resolveOperand(memrefInfo, type, result.operands) || 1757 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1758 return failure(); 1759 1760 if (readOrWrite != "read" && readOrWrite != "write") 1761 return parser.emitError(parser.getNameLoc(), 1762 "rw specifier has to be 'read' or 'write'"); 1763 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(), 1764 parser.getBuilder().getBoolAttr(readOrWrite == "write")); 1765 1766 if (cacheType != "data" && cacheType != "instr") 1767 return parser.emitError(parser.getNameLoc(), 1768 "cache type has to be 'data' or 'instr'"); 1769 1770 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(), 1771 parser.getBuilder().getBoolAttr(cacheType == "data")); 1772 1773 return success(); 1774 } 1775 1776 LogicalResult PrefetchOp::verify() { 1777 if (getNumOperands() != 1 + getMemRefType().getRank()) 1778 return emitOpError("too few indices"); 1779 1780 return success(); 1781 } 1782 1783 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor, 1784 SmallVectorImpl<OpFoldResult> &results) { 1785 // prefetch(memrefcast) -> prefetch 1786 return foldMemRefCast(*this); 1787 } 1788 1789 //===----------------------------------------------------------------------===// 1790 // RankOp 1791 //===----------------------------------------------------------------------===// 1792 1793 OpFoldResult RankOp::fold(FoldAdaptor adaptor) { 1794 // Constant fold rank when the rank of the operand is known. 1795 auto type = getOperand().getType(); 1796 auto shapedType = llvm::dyn_cast<ShapedType>(type); 1797 if (shapedType && shapedType.hasRank()) 1798 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1799 return IntegerAttr(); 1800 } 1801 1802 //===----------------------------------------------------------------------===// 1803 // ReinterpretCastOp 1804 //===----------------------------------------------------------------------===// 1805 1806 void ReinterpretCastOp::getAsmResultNames( 1807 function_ref<void(Value, StringRef)> setNameFn) { 1808 setNameFn(getResult(), "reinterpret_cast"); 1809 } 1810 1811 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1812 /// `staticSizes` and `staticStrides` are automatically filled with 1813 /// source-memref-rank sentinel values that encode dynamic entries. 1814 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1815 MemRefType resultType, Value source, 1816 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1817 ArrayRef<OpFoldResult> strides, 1818 ArrayRef<NamedAttribute> attrs) { 1819 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1820 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1821 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets); 1822 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 1823 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 1824 result.addAttributes(attrs); 1825 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1826 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), 1827 b.getDenseI64ArrayAttr(staticSizes), 1828 b.getDenseI64ArrayAttr(staticStrides)); 1829 } 1830 1831 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1832 Value source, OpFoldResult offset, 1833 ArrayRef<OpFoldResult> sizes, 1834 ArrayRef<OpFoldResult> strides, 1835 ArrayRef<NamedAttribute> attrs) { 1836 auto sourceType = cast<BaseMemRefType>(source.getType()); 1837 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1838 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1839 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets); 1840 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 1841 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 1842 auto stridedLayout = StridedLayoutAttr::get( 1843 b.getContext(), staticOffsets.front(), staticStrides); 1844 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(), 1845 stridedLayout, sourceType.getMemorySpace()); 1846 build(b, result, resultType, source, offset, sizes, strides, attrs); 1847 } 1848 1849 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1850 MemRefType resultType, Value source, 1851 int64_t offset, ArrayRef<int64_t> sizes, 1852 ArrayRef<int64_t> strides, 1853 ArrayRef<NamedAttribute> attrs) { 1854 SmallVector<OpFoldResult> sizeValues = 1855 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1856 return b.getI64IntegerAttr(v); 1857 })); 1858 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1859 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1860 return b.getI64IntegerAttr(v); 1861 })); 1862 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1863 strideValues, attrs); 1864 } 1865 1866 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1867 MemRefType resultType, Value source, Value offset, 1868 ValueRange sizes, ValueRange strides, 1869 ArrayRef<NamedAttribute> attrs) { 1870 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1871 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1872 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1873 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1874 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1875 } 1876 1877 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1878 // completed automatically, like we have for subview and extract_slice. 1879 LogicalResult ReinterpretCastOp::verify() { 1880 // The source and result memrefs should be in the same memory space. 1881 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType()); 1882 auto resultType = llvm::cast<MemRefType>(getType()); 1883 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1884 return emitError("different memory spaces specified for source type ") 1885 << srcType << " and result memref type " << resultType; 1886 if (srcType.getElementType() != resultType.getElementType()) 1887 return emitError("different element types specified for source type ") 1888 << srcType << " and result memref type " << resultType; 1889 1890 // Match sizes in result memref type and in static_sizes attribute. 1891 for (auto [idx, resultSize, expectedSize] : 1892 llvm::enumerate(resultType.getShape(), getStaticSizes())) { 1893 if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize) 1894 return emitError("expected result type with size = ") 1895 << (ShapedType::isDynamic(expectedSize) 1896 ? std::string("dynamic") 1897 : std::to_string(expectedSize)) 1898 << " instead of " << resultSize << " in dim = " << idx; 1899 } 1900 1901 // Match offset and strides in static_offset and static_strides attributes. If 1902 // result memref type has no affine map specified, this will assume an 1903 // identity layout. 1904 int64_t resultOffset; 1905 SmallVector<int64_t, 4> resultStrides; 1906 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset))) 1907 return emitError("expected result type to have strided layout but found ") 1908 << resultType; 1909 1910 // Match offset in result memref type and in static_offsets attribute. 1911 int64_t expectedOffset = getStaticOffsets().front(); 1912 if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset) 1913 return emitError("expected result type with offset = ") 1914 << (ShapedType::isDynamic(expectedOffset) 1915 ? std::string("dynamic") 1916 : std::to_string(expectedOffset)) 1917 << " instead of " << resultOffset; 1918 1919 // Match strides in result memref type and in static_strides attribute. 1920 for (auto [idx, resultStride, expectedStride] : 1921 llvm::enumerate(resultStrides, getStaticStrides())) { 1922 if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride) 1923 return emitError("expected result type with stride = ") 1924 << (ShapedType::isDynamic(expectedStride) 1925 ? std::string("dynamic") 1926 : std::to_string(expectedStride)) 1927 << " instead of " << resultStride << " in dim = " << idx; 1928 } 1929 1930 return success(); 1931 } 1932 1933 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) { 1934 Value src = getSource(); 1935 auto getPrevSrc = [&]() -> Value { 1936 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x). 1937 if (auto prev = src.getDefiningOp<ReinterpretCastOp>()) 1938 return prev.getSource(); 1939 1940 // reinterpret_cast(cast(x)) -> reinterpret_cast(x). 1941 if (auto prev = src.getDefiningOp<CastOp>()) 1942 return prev.getSource(); 1943 1944 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets 1945 // are 0. 1946 if (auto prev = src.getDefiningOp<SubViewOp>()) 1947 if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { 1948 return isConstantIntValue(val, 0); 1949 })) 1950 return prev.getSource(); 1951 1952 return nullptr; 1953 }; 1954 1955 if (auto prevSrc = getPrevSrc()) { 1956 getSourceMutable().assign(prevSrc); 1957 return getResult(); 1958 } 1959 1960 // reinterpret_cast(x) w/o offset/shape/stride changes -> x 1961 if (!ShapedType::isDynamicShape(getType().getShape()) && 1962 src.getType() == getType() && getStaticOffsets().front() == 0) { 1963 return src; 1964 } 1965 1966 return nullptr; 1967 } 1968 1969 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() { 1970 SmallVector<OpFoldResult> values = getMixedSizes(); 1971 constifyIndexValues(values, getType(), getContext(), getConstantSizes, 1972 ShapedType::isDynamic); 1973 return values; 1974 } 1975 1976 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() { 1977 SmallVector<OpFoldResult> values = getMixedStrides(); 1978 constifyIndexValues(values, getType(), getContext(), getConstantStrides, 1979 ShapedType::isDynamic); 1980 return values; 1981 } 1982 1983 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() { 1984 SmallVector<OpFoldResult> values = getMixedOffsets(); 1985 assert(values.size() == 1 && 1986 "reinterpret_cast must have one and only one offset"); 1987 constifyIndexValues(values, getType(), getContext(), getConstantOffset, 1988 ShapedType::isDynamic); 1989 return values[0]; 1990 } 1991 1992 namespace { 1993 /// Replace the sequence: 1994 /// ``` 1995 /// base, offset, sizes, strides = extract_strided_metadata src 1996 /// dst = reinterpret_cast base to offset, sizes, strides 1997 /// ``` 1998 /// With 1999 /// 2000 /// ``` 2001 /// dst = memref.cast src 2002 /// ``` 2003 /// 2004 /// Note: The cast operation is only inserted when the type of dst and src 2005 /// are not the same. E.g., when going from <4xf32> to <?xf32>. 2006 /// 2007 /// This pattern also matches when the offset, sizes, and strides don't come 2008 /// directly from the `extract_strided_metadata`'s results but it can be 2009 /// statically proven that they would hold the same values. 2010 /// 2011 /// For instance, the following sequence would be replaced: 2012 /// ``` 2013 /// base, offset, sizes, strides = 2014 /// extract_strided_metadata memref : memref<3x4xty> 2015 /// dst = reinterpret_cast base to 0, [3, 4], strides 2016 /// ``` 2017 /// Because we know (thanks to the type of the input memref) that variable 2018 /// `offset` and `sizes` will respectively hold 0 and [3, 4]. 2019 /// 2020 /// Similarly, the following sequence would be replaced: 2021 /// ``` 2022 /// c0 = arith.constant 0 2023 /// c4 = arith.constant 4 2024 /// base, offset, sizes, strides = 2025 /// extract_strided_metadata memref : memref<3x4xty> 2026 /// dst = reinterpret_cast base to c0, [3, c4], strides 2027 /// ``` 2028 /// Because we know that `offset`and `c0` will hold 0 2029 /// and `c4` will hold 4. 2030 struct ReinterpretCastOpExtractStridedMetadataFolder 2031 : public OpRewritePattern<ReinterpretCastOp> { 2032 public: 2033 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern; 2034 2035 LogicalResult matchAndRewrite(ReinterpretCastOp op, 2036 PatternRewriter &rewriter) const override { 2037 auto extractStridedMetadata = 2038 op.getSource().getDefiningOp<ExtractStridedMetadataOp>(); 2039 if (!extractStridedMetadata) 2040 return failure(); 2041 // Check if the reinterpret cast reconstructs a memref with the exact same 2042 // properties as the extract strided metadata. 2043 2044 // First, check that the strides are the same. 2045 SmallVector<OpFoldResult> extractStridesOfr = 2046 extractStridedMetadata.getConstifiedMixedStrides(); 2047 SmallVector<OpFoldResult> reinterpretStridesOfr = 2048 op.getConstifiedMixedStrides(); 2049 if (extractStridesOfr.size() != reinterpretStridesOfr.size()) 2050 return failure(); 2051 2052 unsigned rank = op.getType().getRank(); 2053 for (unsigned i = 0; i < rank; ++i) { 2054 if (extractStridesOfr[i] != reinterpretStridesOfr[i]) 2055 return failure(); 2056 } 2057 2058 // Second, check the sizes. 2059 assert(extractStridedMetadata.getSizes().size() == 2060 op.getMixedSizes().size() && 2061 "Strides and sizes rank must match"); 2062 SmallVector<OpFoldResult> extractSizesOfr = 2063 extractStridedMetadata.getConstifiedMixedSizes(); 2064 SmallVector<OpFoldResult> reinterpretSizesOfr = 2065 op.getConstifiedMixedSizes(); 2066 for (unsigned i = 0; i < rank; ++i) { 2067 if (extractSizesOfr[i] != reinterpretSizesOfr[i]) 2068 return failure(); 2069 } 2070 // Finally, check the offset. 2071 assert(op.getMixedOffsets().size() == 1 && 2072 "reinterpret_cast with more than one offset should have been " 2073 "rejected by the verifier"); 2074 OpFoldResult extractOffsetOfr = 2075 extractStridedMetadata.getConstifiedMixedOffset(); 2076 OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset(); 2077 if (extractOffsetOfr != reinterpretOffsetOfr) 2078 return failure(); 2079 2080 // At this point, we know that the back and forth between extract strided 2081 // metadata and reinterpret cast is a noop. However, the final type of the 2082 // reinterpret cast may not be exactly the same as the original memref. 2083 // E.g., it could be changing a dimension from static to dynamic. Check that 2084 // here and add a cast if necessary. 2085 Type srcTy = extractStridedMetadata.getSource().getType(); 2086 if (srcTy == op.getResult().getType()) 2087 rewriter.replaceOp(op, extractStridedMetadata.getSource()); 2088 else 2089 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), 2090 extractStridedMetadata.getSource()); 2091 2092 return success(); 2093 } 2094 }; 2095 } // namespace 2096 2097 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 2098 MLIRContext *context) { 2099 results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context); 2100 } 2101 2102 //===----------------------------------------------------------------------===// 2103 // Reassociative reshape ops 2104 //===----------------------------------------------------------------------===// 2105 2106 void CollapseShapeOp::getAsmResultNames( 2107 function_ref<void(Value, StringRef)> setNameFn) { 2108 setNameFn(getResult(), "collapse_shape"); 2109 } 2110 2111 void ExpandShapeOp::getAsmResultNames( 2112 function_ref<void(Value, StringRef)> setNameFn) { 2113 setNameFn(getResult(), "expand_shape"); 2114 } 2115 2116 LogicalResult ExpandShapeOp::reifyResultShapes( 2117 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) { 2118 reifiedResultShapes = { 2119 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)}; 2120 return success(); 2121 } 2122 2123 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp 2124 /// result and operand. Layout maps are verified separately. 2125 /// 2126 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are 2127 /// allowed in a reassocation group. 2128 static LogicalResult 2129 verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape, 2130 ArrayRef<int64_t> expandedShape, 2131 ArrayRef<ReassociationIndices> reassociation, 2132 bool allowMultipleDynamicDimsPerGroup) { 2133 // There must be one reassociation group per collapsed dimension. 2134 if (collapsedShape.size() != reassociation.size()) 2135 return op->emitOpError("invalid number of reassociation groups: found ") 2136 << reassociation.size() << ", expected " << collapsedShape.size(); 2137 2138 // The next expected expanded dimension index (while iterating over 2139 // reassociation indices). 2140 int64_t nextDim = 0; 2141 for (const auto &it : llvm::enumerate(reassociation)) { 2142 ReassociationIndices group = it.value(); 2143 int64_t collapsedDim = it.index(); 2144 2145 bool foundDynamic = false; 2146 for (int64_t expandedDim : group) { 2147 if (expandedDim != nextDim++) 2148 return op->emitOpError("reassociation indices must be contiguous"); 2149 2150 if (expandedDim >= static_cast<int64_t>(expandedShape.size())) 2151 return op->emitOpError("reassociation index ") 2152 << expandedDim << " is out of bounds"; 2153 2154 // Check if there are multiple dynamic dims in a reassociation group. 2155 if (ShapedType::isDynamic(expandedShape[expandedDim])) { 2156 if (foundDynamic && !allowMultipleDynamicDimsPerGroup) 2157 return op->emitOpError( 2158 "at most one dimension in a reassociation group may be dynamic"); 2159 foundDynamic = true; 2160 } 2161 } 2162 2163 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity. 2164 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic) 2165 return op->emitOpError("collapsed dim (") 2166 << collapsedDim 2167 << ") must be dynamic if and only if reassociation group is " 2168 "dynamic"; 2169 2170 // If all dims in the reassociation group are static, the size of the 2171 // collapsed dim can be verified. 2172 if (!foundDynamic) { 2173 int64_t groupSize = 1; 2174 for (int64_t expandedDim : group) 2175 groupSize *= expandedShape[expandedDim]; 2176 if (groupSize != collapsedShape[collapsedDim]) 2177 return op->emitOpError("collapsed dim size (") 2178 << collapsedShape[collapsedDim] 2179 << ") must equal reassociation group size (" << groupSize << ")"; 2180 } 2181 } 2182 2183 if (collapsedShape.empty()) { 2184 // Rank 0: All expanded dimensions must be 1. 2185 for (int64_t d : expandedShape) 2186 if (d != 1) 2187 return op->emitOpError( 2188 "rank 0 memrefs can only be extended/collapsed with/from ones"); 2189 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) { 2190 // Rank >= 1: Number of dimensions among all reassociation groups must match 2191 // the result memref rank. 2192 return op->emitOpError("expanded rank (") 2193 << expandedShape.size() 2194 << ") inconsistent with number of reassociation indices (" << nextDim 2195 << ")"; 2196 } 2197 2198 return success(); 2199 } 2200 2201 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 2202 return getSymbolLessAffineMaps(getReassociationExprs()); 2203 } 2204 2205 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 2206 return convertReassociationIndicesToExprs(getContext(), 2207 getReassociationIndices()); 2208 } 2209 2210 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 2211 return getSymbolLessAffineMaps(getReassociationExprs()); 2212 } 2213 2214 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 2215 return convertReassociationIndicesToExprs(getContext(), 2216 getReassociationIndices()); 2217 } 2218 2219 /// Compute the layout map after expanding a given source MemRef type with the 2220 /// specified reassociation indices. 2221 static FailureOr<StridedLayoutAttr> 2222 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape, 2223 ArrayRef<ReassociationIndices> reassociation) { 2224 int64_t srcOffset; 2225 SmallVector<int64_t> srcStrides; 2226 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) 2227 return failure(); 2228 assert(srcStrides.size() == reassociation.size() && "invalid reassociation"); 2229 2230 // 1-1 mapping between srcStrides and reassociation packs. 2231 // Each srcStride starts with the given value and gets expanded according to 2232 // the proper entries in resultShape. 2233 // Example: 2234 // srcStrides = [10000, 1 , 100 ], 2235 // reassociations = [ [0], [1], [2, 3, 4]], 2236 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]] 2237 // -> For the purpose of stride calculation, the useful sizes are: 2238 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]]. 2239 // resultStrides = [10000, 1, 600, 200, 100] 2240 // Note that a stride does not get expanded along the first entry of each 2241 // shape pack. 2242 SmallVector<int64_t> reverseResultStrides; 2243 reverseResultStrides.reserve(resultShape.size()); 2244 unsigned shapeIndex = resultShape.size() - 1; 2245 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) { 2246 ReassociationIndices reassoc = std::get<0>(it); 2247 int64_t currentStrideToExpand = std::get<1>(it); 2248 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) { 2249 reverseResultStrides.push_back(currentStrideToExpand); 2250 currentStrideToExpand = 2251 (SaturatedInteger::wrap(currentStrideToExpand) * 2252 SaturatedInteger::wrap(resultShape[shapeIndex--])) 2253 .asInteger(); 2254 } 2255 } 2256 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides)); 2257 resultStrides.resize(resultShape.size(), 1); 2258 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides); 2259 } 2260 2261 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType( 2262 MemRefType srcType, ArrayRef<int64_t> resultShape, 2263 ArrayRef<ReassociationIndices> reassociation) { 2264 if (srcType.getLayout().isIdentity()) { 2265 // If the source is contiguous (i.e., no layout map specified), so is the 2266 // result. 2267 MemRefLayoutAttrInterface layout; 2268 return MemRefType::get(resultShape, srcType.getElementType(), layout, 2269 srcType.getMemorySpace()); 2270 } 2271 2272 // Source may not be contiguous. Compute the layout map. 2273 FailureOr<StridedLayoutAttr> computedLayout = 2274 computeExpandedLayoutMap(srcType, resultShape, reassociation); 2275 if (failed(computedLayout)) 2276 return failure(); 2277 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, 2278 srcType.getMemorySpace()); 2279 } 2280 2281 FailureOr<SmallVector<OpFoldResult>> 2282 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc, 2283 MemRefType expandedType, 2284 ArrayRef<ReassociationIndices> reassociation, 2285 ArrayRef<OpFoldResult> inputShape) { 2286 std::optional<SmallVector<OpFoldResult>> outputShape = 2287 inferExpandShapeOutputShape(b, loc, expandedType, reassociation, 2288 inputShape); 2289 if (!outputShape) 2290 return failure(); 2291 return *outputShape; 2292 } 2293 2294 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, 2295 Type resultType, Value src, 2296 ArrayRef<ReassociationIndices> reassociation, 2297 ArrayRef<OpFoldResult> outputShape) { 2298 auto [staticOutputShape, dynamicOutputShape] = 2299 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape)); 2300 build(builder, result, llvm::cast<MemRefType>(resultType), src, 2301 getReassociationIndicesAttribute(builder, reassociation), 2302 dynamicOutputShape, staticOutputShape); 2303 } 2304 2305 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, 2306 Type resultType, Value src, 2307 ArrayRef<ReassociationIndices> reassociation) { 2308 SmallVector<OpFoldResult> inputShape = 2309 getMixedSizes(builder, result.location, src); 2310 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType); 2311 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape( 2312 builder, result.location, memrefResultTy, reassociation, inputShape); 2313 // Failure of this assertion usually indicates presence of multiple 2314 // dynamic dimensions in the same reassociation group. 2315 assert(succeeded(outputShape) && "unable to infer output shape"); 2316 build(builder, result, memrefResultTy, src, reassociation, *outputShape); 2317 } 2318 2319 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, 2320 ArrayRef<int64_t> resultShape, Value src, 2321 ArrayRef<ReassociationIndices> reassociation) { 2322 // Only ranked memref source values are supported. 2323 auto srcType = llvm::cast<MemRefType>(src.getType()); 2324 FailureOr<MemRefType> resultType = 2325 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation); 2326 // Failure of this assertion usually indicates a problem with the source 2327 // type, e.g., could not get strides/offset. 2328 assert(succeeded(resultType) && "could not compute layout"); 2329 build(builder, result, *resultType, src, reassociation); 2330 } 2331 2332 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, 2333 ArrayRef<int64_t> resultShape, Value src, 2334 ArrayRef<ReassociationIndices> reassociation, 2335 ArrayRef<OpFoldResult> outputShape) { 2336 // Only ranked memref source values are supported. 2337 auto srcType = llvm::cast<MemRefType>(src.getType()); 2338 FailureOr<MemRefType> resultType = 2339 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation); 2340 // Failure of this assertion usually indicates a problem with the source 2341 // type, e.g., could not get strides/offset. 2342 assert(succeeded(resultType) && "could not compute layout"); 2343 build(builder, result, *resultType, src, reassociation, outputShape); 2344 } 2345 2346 LogicalResult ExpandShapeOp::verify() { 2347 MemRefType srcType = getSrcType(); 2348 MemRefType resultType = getResultType(); 2349 2350 if (srcType.getRank() > resultType.getRank()) { 2351 auto r0 = srcType.getRank(); 2352 auto r1 = resultType.getRank(); 2353 return emitOpError("has source rank ") 2354 << r0 << " and result rank " << r1 << ". This is not an expansion (" 2355 << r0 << " > " << r1 << ")."; 2356 } 2357 2358 // Verify result shape. 2359 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(), 2360 resultType.getShape(), 2361 getReassociationIndices(), 2362 /*allowMultipleDynamicDimsPerGroup=*/true))) 2363 return failure(); 2364 2365 // Compute expected result type (including layout map). 2366 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType( 2367 srcType, resultType.getShape(), getReassociationIndices()); 2368 if (failed(expectedResultType)) 2369 return emitOpError("invalid source layout map"); 2370 2371 // Check actual result type. 2372 if (*expectedResultType != resultType) 2373 return emitOpError("expected expanded type to be ") 2374 << *expectedResultType << " but found " << resultType; 2375 2376 if ((int64_t)getStaticOutputShape().size() != resultType.getRank()) 2377 return emitOpError("expected number of static shape bounds to be equal to " 2378 "the output rank (") 2379 << resultType.getRank() << ") but found " 2380 << getStaticOutputShape().size() << " inputs instead"; 2381 2382 if ((int64_t)getOutputShape().size() != 2383 llvm::count(getStaticOutputShape(), ShapedType::kDynamic)) 2384 return emitOpError("mismatch in dynamic dims in output_shape and " 2385 "static_output_shape: static_output_shape has ") 2386 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic) 2387 << " dynamic dims while output_shape has " << getOutputShape().size() 2388 << " values"; 2389 2390 // Verify if provided output shapes are in agreement with output type. 2391 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr(); 2392 ArrayRef<int64_t> resShape = getResult().getType().getShape(); 2393 for (auto [pos, shape] : llvm::enumerate(resShape)) { 2394 if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) { 2395 return emitOpError("invalid output shape provided at pos ") << pos; 2396 } 2397 } 2398 2399 return success(); 2400 } 2401 2402 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 2403 MLIRContext *context) { 2404 results.add< 2405 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>, 2406 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context); 2407 } 2408 2409 /// Compute the layout map after collapsing a given source MemRef type with the 2410 /// specified reassociation indices. 2411 /// 2412 /// Note: All collapsed dims in a reassociation group must be contiguous. It is 2413 /// not possible to check this by inspecting a MemRefType in the general case. 2414 /// If non-contiguity cannot be checked statically, the collapse is assumed to 2415 /// be valid (and thus accepted by this function) unless `strict = true`. 2416 static FailureOr<StridedLayoutAttr> 2417 computeCollapsedLayoutMap(MemRefType srcType, 2418 ArrayRef<ReassociationIndices> reassociation, 2419 bool strict = false) { 2420 int64_t srcOffset; 2421 SmallVector<int64_t> srcStrides; 2422 auto srcShape = srcType.getShape(); 2423 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) 2424 return failure(); 2425 2426 // The result stride of a reassociation group is the stride of the last entry 2427 // of the reassociation. (TODO: Should be the minimum stride in the 2428 // reassociation because strides are not necessarily sorted. E.g., when using 2429 // memref.transpose.) Dimensions of size 1 should be skipped, because their 2430 // strides are meaningless and could have any arbitrary value. 2431 SmallVector<int64_t> resultStrides; 2432 resultStrides.reserve(reassociation.size()); 2433 for (const ReassociationIndices &reassoc : reassociation) { 2434 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc); 2435 while (srcShape[ref.back()] == 1 && ref.size() > 1) 2436 ref = ref.drop_back(); 2437 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 2438 resultStrides.push_back(srcStrides[ref.back()]); 2439 } else { 2440 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so 2441 // the corresponding stride may have to be skipped. (See above comment.) 2442 // Therefore, the result stride cannot be statically determined and must 2443 // be dynamic. 2444 resultStrides.push_back(ShapedType::kDynamic); 2445 } 2446 } 2447 2448 // Validate that each reassociation group is contiguous. 2449 unsigned resultStrideIndex = resultStrides.size() - 1; 2450 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) { 2451 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front(); 2452 auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]); 2453 for (int64_t idx : llvm::reverse(trailingReassocs)) { 2454 stride = stride * SaturatedInteger::wrap(srcShape[idx]); 2455 2456 // Both source and result stride must have the same static value. In that 2457 // case, we can be sure, that the dimensions are collapsible (because they 2458 // are contiguous). 2459 // If `strict = false` (default during op verification), we accept cases 2460 // where one or both strides are dynamic. This is best effort: We reject 2461 // ops where obviously non-contiguous dims are collapsed, but accept ops 2462 // where we cannot be sure statically. Such ops may fail at runtime. See 2463 // the op documentation for details. 2464 auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]); 2465 if (strict && (stride.saturated || srcStride.saturated)) 2466 return failure(); 2467 2468 // Dimensions of size 1 should be skipped, because their strides are 2469 // meaningless and could have any arbitrary value. 2470 if (srcShape[idx - 1] == 1) 2471 continue; 2472 2473 if (!stride.saturated && !srcStride.saturated && stride != srcStride) 2474 return failure(); 2475 } 2476 } 2477 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides); 2478 } 2479 2480 bool CollapseShapeOp::isGuaranteedCollapsible( 2481 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) { 2482 // MemRefs with identity layout are always collapsible. 2483 if (srcType.getLayout().isIdentity()) 2484 return true; 2485 2486 return succeeded(computeCollapsedLayoutMap(srcType, reassociation, 2487 /*strict=*/true)); 2488 } 2489 2490 MemRefType CollapseShapeOp::computeCollapsedType( 2491 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) { 2492 SmallVector<int64_t> resultShape; 2493 resultShape.reserve(reassociation.size()); 2494 for (const ReassociationIndices &group : reassociation) { 2495 auto groupSize = SaturatedInteger::wrap(1); 2496 for (int64_t srcDim : group) 2497 groupSize = 2498 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim)); 2499 resultShape.push_back(groupSize.asInteger()); 2500 } 2501 2502 if (srcType.getLayout().isIdentity()) { 2503 // If the source is contiguous (i.e., no layout map specified), so is the 2504 // result. 2505 MemRefLayoutAttrInterface layout; 2506 return MemRefType::get(resultShape, srcType.getElementType(), layout, 2507 srcType.getMemorySpace()); 2508 } 2509 2510 // Source may not be fully contiguous. Compute the layout map. 2511 // Note: Dimensions that are collapsed into a single dim are assumed to be 2512 // contiguous. 2513 FailureOr<StridedLayoutAttr> computedLayout = 2514 computeCollapsedLayoutMap(srcType, reassociation); 2515 assert(succeeded(computedLayout) && 2516 "invalid source layout map or collapsing non-contiguous dims"); 2517 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, 2518 srcType.getMemorySpace()); 2519 } 2520 2521 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 2522 ArrayRef<ReassociationIndices> reassociation, 2523 ArrayRef<NamedAttribute> attrs) { 2524 auto srcType = llvm::cast<MemRefType>(src.getType()); 2525 MemRefType resultType = 2526 CollapseShapeOp::computeCollapsedType(srcType, reassociation); 2527 result.addAttribute(::mlir::getReassociationAttrName(), 2528 getReassociationIndicesAttribute(b, reassociation)); 2529 build(b, result, resultType, src, attrs); 2530 } 2531 2532 LogicalResult CollapseShapeOp::verify() { 2533 MemRefType srcType = getSrcType(); 2534 MemRefType resultType = getResultType(); 2535 2536 if (srcType.getRank() < resultType.getRank()) { 2537 auto r0 = srcType.getRank(); 2538 auto r1 = resultType.getRank(); 2539 return emitOpError("has source rank ") 2540 << r0 << " and result rank " << r1 << ". This is not a collapse (" 2541 << r0 << " < " << r1 << ")."; 2542 } 2543 2544 // Verify result shape. 2545 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(), 2546 srcType.getShape(), getReassociationIndices(), 2547 /*allowMultipleDynamicDimsPerGroup=*/true))) 2548 return failure(); 2549 2550 // Compute expected result type (including layout map). 2551 MemRefType expectedResultType; 2552 if (srcType.getLayout().isIdentity()) { 2553 // If the source is contiguous (i.e., no layout map specified), so is the 2554 // result. 2555 MemRefLayoutAttrInterface layout; 2556 expectedResultType = 2557 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout, 2558 srcType.getMemorySpace()); 2559 } else { 2560 // Source may not be fully contiguous. Compute the layout map. 2561 // Note: Dimensions that are collapsed into a single dim are assumed to be 2562 // contiguous. 2563 FailureOr<StridedLayoutAttr> computedLayout = 2564 computeCollapsedLayoutMap(srcType, getReassociationIndices()); 2565 if (failed(computedLayout)) 2566 return emitOpError( 2567 "invalid source layout map or collapsing non-contiguous dims"); 2568 expectedResultType = 2569 MemRefType::get(resultType.getShape(), srcType.getElementType(), 2570 *computedLayout, srcType.getMemorySpace()); 2571 } 2572 2573 if (expectedResultType != resultType) 2574 return emitOpError("expected collapsed type to be ") 2575 << expectedResultType << " but found " << resultType; 2576 2577 return success(); 2578 } 2579 2580 struct CollapseShapeOpMemRefCastFolder 2581 : public OpRewritePattern<CollapseShapeOp> { 2582 public: 2583 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 2584 2585 LogicalResult matchAndRewrite(CollapseShapeOp op, 2586 PatternRewriter &rewriter) const override { 2587 auto cast = op.getOperand().getDefiningOp<CastOp>(); 2588 if (!cast) 2589 return failure(); 2590 2591 if (!CastOp::canFoldIntoConsumerOp(cast)) 2592 return failure(); 2593 2594 Type newResultType = CollapseShapeOp::computeCollapsedType( 2595 llvm::cast<MemRefType>(cast.getOperand().getType()), 2596 op.getReassociationIndices()); 2597 2598 if (newResultType == op.getResultType()) { 2599 rewriter.modifyOpInPlace( 2600 op, [&]() { op.getSrcMutable().assign(cast.getSource()); }); 2601 } else { 2602 Value newOp = rewriter.create<CollapseShapeOp>( 2603 op->getLoc(), cast.getSource(), op.getReassociationIndices()); 2604 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 2605 } 2606 return success(); 2607 } 2608 }; 2609 2610 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 2611 MLIRContext *context) { 2612 results.add< 2613 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>, 2614 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp, 2615 memref::DimOp, MemRefType>, 2616 CollapseShapeOpMemRefCastFolder>(context); 2617 } 2618 2619 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { 2620 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, 2621 adaptor.getOperands()); 2622 } 2623 2624 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) { 2625 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, 2626 adaptor.getOperands()); 2627 } 2628 2629 //===----------------------------------------------------------------------===// 2630 // ReshapeOp 2631 //===----------------------------------------------------------------------===// 2632 2633 void ReshapeOp::getAsmResultNames( 2634 function_ref<void(Value, StringRef)> setNameFn) { 2635 setNameFn(getResult(), "reshape"); 2636 } 2637 2638 LogicalResult ReshapeOp::verify() { 2639 Type operandType = getSource().getType(); 2640 Type resultType = getResult().getType(); 2641 2642 Type operandElementType = 2643 llvm::cast<ShapedType>(operandType).getElementType(); 2644 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType(); 2645 if (operandElementType != resultElementType) 2646 return emitOpError("element types of source and destination memref " 2647 "types should be the same"); 2648 2649 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType)) 2650 if (!operandMemRefType.getLayout().isIdentity()) 2651 return emitOpError("source memref type should have identity affine map"); 2652 2653 int64_t shapeSize = 2654 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0); 2655 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType); 2656 if (resultMemRefType) { 2657 if (!resultMemRefType.getLayout().isIdentity()) 2658 return emitOpError("result memref type should have identity affine map"); 2659 if (shapeSize == ShapedType::kDynamic) 2660 return emitOpError("cannot use shape operand with dynamic length to " 2661 "reshape to statically-ranked memref type"); 2662 if (shapeSize != resultMemRefType.getRank()) 2663 return emitOpError( 2664 "length of shape operand differs from the result's memref rank"); 2665 } 2666 return success(); 2667 } 2668 2669 //===----------------------------------------------------------------------===// 2670 // StoreOp 2671 //===----------------------------------------------------------------------===// 2672 2673 LogicalResult StoreOp::verify() { 2674 if (getNumOperands() != 2 + getMemRefType().getRank()) 2675 return emitOpError("store index operand count not equal to memref rank"); 2676 2677 return success(); 2678 } 2679 2680 LogicalResult StoreOp::fold(FoldAdaptor adaptor, 2681 SmallVectorImpl<OpFoldResult> &results) { 2682 /// store(memrefcast) -> store 2683 return foldMemRefCast(*this, getValueToStore()); 2684 } 2685 2686 //===----------------------------------------------------------------------===// 2687 // SubViewOp 2688 //===----------------------------------------------------------------------===// 2689 2690 void SubViewOp::getAsmResultNames( 2691 function_ref<void(Value, StringRef)> setNameFn) { 2692 setNameFn(getResult(), "subview"); 2693 } 2694 2695 /// A subview result type can be fully inferred from the source type and the 2696 /// static representation of offsets, sizes and strides. Special sentinels 2697 /// encode the dynamic case. 2698 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 2699 ArrayRef<int64_t> staticOffsets, 2700 ArrayRef<int64_t> staticSizes, 2701 ArrayRef<int64_t> staticStrides) { 2702 unsigned rank = sourceMemRefType.getRank(); 2703 (void)rank; 2704 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 2705 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 2706 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 2707 2708 // Extract source offset and strides. 2709 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset(); 2710 2711 // Compute target offset whose value is: 2712 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 2713 int64_t targetOffset = sourceOffset; 2714 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 2715 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it); 2716 targetOffset = (SaturatedInteger::wrap(targetOffset) + 2717 SaturatedInteger::wrap(staticOffset) * 2718 SaturatedInteger::wrap(sourceStride)) 2719 .asInteger(); 2720 } 2721 2722 // Compute target stride whose value is: 2723 // `sourceStrides_i * staticStrides_i`. 2724 SmallVector<int64_t, 4> targetStrides; 2725 targetStrides.reserve(staticOffsets.size()); 2726 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 2727 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 2728 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) * 2729 SaturatedInteger::wrap(staticStride)) 2730 .asInteger()); 2731 } 2732 2733 // The type is now known. 2734 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(), 2735 StridedLayoutAttr::get(sourceMemRefType.getContext(), 2736 targetOffset, targetStrides), 2737 sourceMemRefType.getMemorySpace()); 2738 } 2739 2740 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 2741 ArrayRef<OpFoldResult> offsets, 2742 ArrayRef<OpFoldResult> sizes, 2743 ArrayRef<OpFoldResult> strides) { 2744 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2745 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2746 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 2747 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 2748 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 2749 if (!hasValidSizesOffsets(staticOffsets)) 2750 return {}; 2751 if (!hasValidSizesOffsets(staticSizes)) 2752 return {}; 2753 if (!hasValidStrides(staticStrides)) 2754 return {}; 2755 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 2756 staticSizes, staticStrides); 2757 } 2758 2759 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape, 2760 MemRefType sourceRankedTensorType, 2761 ArrayRef<int64_t> offsets, 2762 ArrayRef<int64_t> sizes, 2763 ArrayRef<int64_t> strides) { 2764 auto inferredType = llvm::cast<MemRefType>( 2765 inferResultType(sourceRankedTensorType, offsets, sizes, strides)); 2766 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) && 2767 "expected "); 2768 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size())) 2769 return inferredType; 2770 2771 // Compute which dimensions are dropped. 2772 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject = 2773 computeRankReductionMask(inferredType.getShape(), resultShape); 2774 assert(dimsToProject.has_value() && "invalid rank reduction"); 2775 2776 // Compute the layout and result type. 2777 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout()); 2778 SmallVector<int64_t> rankReducedStrides; 2779 rankReducedStrides.reserve(resultShape.size()); 2780 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) { 2781 if (!dimsToProject->contains(idx)) 2782 rankReducedStrides.push_back(value); 2783 } 2784 return MemRefType::get(resultShape, inferredType.getElementType(), 2785 StridedLayoutAttr::get(inferredLayout.getContext(), 2786 inferredLayout.getOffset(), 2787 rankReducedStrides), 2788 inferredType.getMemorySpace()); 2789 } 2790 2791 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape, 2792 MemRefType sourceRankedTensorType, 2793 ArrayRef<OpFoldResult> offsets, 2794 ArrayRef<OpFoldResult> sizes, 2795 ArrayRef<OpFoldResult> strides) { 2796 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2797 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2798 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 2799 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 2800 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 2801 return SubViewOp::inferRankReducedResultType( 2802 resultShape, sourceRankedTensorType, staticOffsets, staticSizes, 2803 staticStrides); 2804 } 2805 2806 // Build a SubViewOp with mixed static and dynamic entries and custom result 2807 // type. If the type passed is nullptr, it is inferred. 2808 void SubViewOp::build(OpBuilder &b, OperationState &result, 2809 MemRefType resultType, Value source, 2810 ArrayRef<OpFoldResult> offsets, 2811 ArrayRef<OpFoldResult> sizes, 2812 ArrayRef<OpFoldResult> strides, 2813 ArrayRef<NamedAttribute> attrs) { 2814 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2815 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2816 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 2817 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 2818 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 2819 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType()); 2820 // Structuring implementation this way avoids duplication between builders. 2821 if (!resultType) { 2822 resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType( 2823 sourceMemRefType, staticOffsets, staticSizes, staticStrides)); 2824 } 2825 result.addAttributes(attrs); 2826 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 2827 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), 2828 b.getDenseI64ArrayAttr(staticSizes), 2829 b.getDenseI64ArrayAttr(staticStrides)); 2830 } 2831 2832 // Build a SubViewOp with mixed static and dynamic entries and inferred result 2833 // type. 2834 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2835 ArrayRef<OpFoldResult> offsets, 2836 ArrayRef<OpFoldResult> sizes, 2837 ArrayRef<OpFoldResult> strides, 2838 ArrayRef<NamedAttribute> attrs) { 2839 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2840 } 2841 2842 // Build a SubViewOp with static entries and inferred result type. 2843 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2844 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2845 ArrayRef<int64_t> strides, 2846 ArrayRef<NamedAttribute> attrs) { 2847 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2848 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 2849 return b.getI64IntegerAttr(v); 2850 })); 2851 SmallVector<OpFoldResult> sizeValues = 2852 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2853 return b.getI64IntegerAttr(v); 2854 })); 2855 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2856 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2857 return b.getI64IntegerAttr(v); 2858 })); 2859 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 2860 } 2861 2862 // Build a SubViewOp with dynamic entries and custom result type. If the 2863 // type passed is nullptr, it is inferred. 2864 void SubViewOp::build(OpBuilder &b, OperationState &result, 2865 MemRefType resultType, Value source, 2866 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2867 ArrayRef<int64_t> strides, 2868 ArrayRef<NamedAttribute> attrs) { 2869 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2870 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 2871 return b.getI64IntegerAttr(v); 2872 })); 2873 SmallVector<OpFoldResult> sizeValues = 2874 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2875 return b.getI64IntegerAttr(v); 2876 })); 2877 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2878 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2879 return b.getI64IntegerAttr(v); 2880 })); 2881 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 2882 attrs); 2883 } 2884 2885 // Build a SubViewOp with dynamic entries and custom result type. If the type 2886 // passed is nullptr, it is inferred. 2887 void SubViewOp::build(OpBuilder &b, OperationState &result, 2888 MemRefType resultType, Value source, ValueRange offsets, 2889 ValueRange sizes, ValueRange strides, 2890 ArrayRef<NamedAttribute> attrs) { 2891 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2892 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 2893 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 2894 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 2895 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2896 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 2897 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 2898 } 2899 2900 // Build a SubViewOp with dynamic entries and inferred result type. 2901 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2902 ValueRange offsets, ValueRange sizes, ValueRange strides, 2903 ArrayRef<NamedAttribute> attrs) { 2904 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2905 } 2906 2907 /// For ViewLikeOpInterface. 2908 Value SubViewOp::getViewSource() { return getSource(); } 2909 2910 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same 2911 /// static value). 2912 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 2913 int64_t t1Offset, t2Offset; 2914 SmallVector<int64_t> t1Strides, t2Strides; 2915 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset); 2916 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset); 2917 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 2918 } 2919 2920 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same 2921 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be 2922 /// marked as dropped in `droppedDims`. 2923 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, 2924 const llvm::SmallBitVector &droppedDims) { 2925 assert(size_t(t1.getRank()) == droppedDims.size() && 2926 "incorrect number of bits"); 2927 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() && 2928 "incorrect number of dropped dims"); 2929 int64_t t1Offset, t2Offset; 2930 SmallVector<int64_t> t1Strides, t2Strides; 2931 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset); 2932 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset); 2933 if (failed(res1) || failed(res2)) 2934 return false; 2935 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) { 2936 if (droppedDims[i]) 2937 continue; 2938 if (t1Strides[i] != t2Strides[j]) 2939 return false; 2940 ++j; 2941 } 2942 return true; 2943 } 2944 2945 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 2946 Operation *op, Type expectedType) { 2947 auto memrefType = llvm::cast<ShapedType>(expectedType); 2948 switch (result) { 2949 case SliceVerificationResult::Success: 2950 return success(); 2951 case SliceVerificationResult::RankTooLarge: 2952 return op->emitError("expected result rank to be smaller or equal to ") 2953 << "the source rank. "; 2954 case SliceVerificationResult::SizeMismatch: 2955 return op->emitError("expected result type to be ") 2956 << expectedType 2957 << " or a rank-reduced version. (mismatch of result sizes) "; 2958 case SliceVerificationResult::ElemTypeMismatch: 2959 return op->emitError("expected result element type to be ") 2960 << memrefType.getElementType(); 2961 case SliceVerificationResult::MemSpaceMismatch: 2962 return op->emitError("expected result and source memory spaces to match."); 2963 case SliceVerificationResult::LayoutMismatch: 2964 return op->emitError("expected result type to be ") 2965 << expectedType 2966 << " or a rank-reduced version. (mismatch of result layout) "; 2967 } 2968 llvm_unreachable("unexpected subview verification result"); 2969 } 2970 2971 /// Verifier for SubViewOp. 2972 LogicalResult SubViewOp::verify() { 2973 MemRefType baseType = getSourceType(); 2974 MemRefType subViewType = getType(); 2975 2976 // The base memref and the view memref should be in the same memory space. 2977 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2978 return emitError("different memory spaces specified for base memref " 2979 "type ") 2980 << baseType << " and subview memref type " << subViewType; 2981 2982 // Verify that the base memref type has a strided layout map. 2983 if (!baseType.isStrided()) 2984 return emitError("base type ") << baseType << " is not strided"; 2985 2986 // Compute the expected result type, assuming that there are no rank 2987 // reductions. 2988 auto expectedType = cast<MemRefType>(SubViewOp::inferResultType( 2989 baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides())); 2990 2991 // Verify all properties of a shaped type: rank, element type and dimension 2992 // sizes. This takes into account potential rank reductions. 2993 auto shapedTypeVerification = isRankReducedType( 2994 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType); 2995 if (shapedTypeVerification != SliceVerificationResult::Success) 2996 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType); 2997 2998 // Make sure that the memory space did not change. 2999 if (expectedType.getMemorySpace() != subViewType.getMemorySpace()) 3000 return produceSubViewErrorMsg(SliceVerificationResult::MemSpaceMismatch, 3001 *this, expectedType); 3002 3003 // Verify the offset of the layout map. 3004 if (!haveCompatibleOffsets(expectedType, subViewType)) 3005 return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch, 3006 *this, expectedType); 3007 3008 // The only thing that's left to verify now are the strides. First, compute 3009 // the unused dimensions due to rank reductions. We have to look at sizes and 3010 // strides to decide which dimensions were dropped. This function also 3011 // partially verifies strides in case of rank reductions. 3012 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType, 3013 getMixedSizes()); 3014 if (failed(unusedDims)) 3015 return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch, 3016 *this, expectedType); 3017 3018 // Strides must match. 3019 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims)) 3020 return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch, 3021 *this, expectedType); 3022 3023 return success(); 3024 } 3025 3026 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 3027 return os << "range " << range.offset << ":" << range.size << ":" 3028 << range.stride; 3029 } 3030 3031 /// Return the list of Range (i.e. offset, size, stride). Each Range 3032 /// entry contains either the dynamic value or a ConstantIndexOp constructed 3033 /// with `b` at location `loc`. 3034 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 3035 OpBuilder &b, Location loc) { 3036 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 3037 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 3038 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 3039 SmallVector<Range, 8> res; 3040 unsigned rank = ranks[0]; 3041 res.reserve(rank); 3042 for (unsigned idx = 0; idx < rank; ++idx) { 3043 Value offset = 3044 op.isDynamicOffset(idx) 3045 ? op.getDynamicOffset(idx) 3046 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 3047 Value size = 3048 op.isDynamicSize(idx) 3049 ? op.getDynamicSize(idx) 3050 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 3051 Value stride = 3052 op.isDynamicStride(idx) 3053 ? op.getDynamicStride(idx) 3054 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 3055 res.emplace_back(Range{offset, size, stride}); 3056 } 3057 return res; 3058 } 3059 3060 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` 3061 /// to deduce the result type for the given `sourceType`. Additionally, reduce 3062 /// the rank of the inferred result type if `currentResultType` is lower rank 3063 /// than `currentSourceType`. Use this signature if `sourceType` is updated 3064 /// together with the result type. In this case, it is important to compute 3065 /// the dropped dimensions using `currentSourceType` whose strides align with 3066 /// `currentResultType`. 3067 static MemRefType getCanonicalSubViewResultType( 3068 MemRefType currentResultType, MemRefType currentSourceType, 3069 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 3070 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 3071 auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType( 3072 sourceType, mixedOffsets, mixedSizes, mixedStrides)); 3073 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask( 3074 currentSourceType, currentResultType, mixedSizes); 3075 if (failed(unusedDims)) 3076 return nullptr; 3077 3078 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout()); 3079 SmallVector<int64_t> shape, strides; 3080 unsigned numDimsAfterReduction = 3081 nonRankReducedType.getRank() - unusedDims->count(); 3082 shape.reserve(numDimsAfterReduction); 3083 strides.reserve(numDimsAfterReduction); 3084 for (const auto &[idx, size, stride] : 3085 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()), 3086 nonRankReducedType.getShape(), layout.getStrides())) { 3087 if (unusedDims->test(idx)) 3088 continue; 3089 shape.push_back(size); 3090 strides.push_back(stride); 3091 } 3092 3093 return MemRefType::get(shape, nonRankReducedType.getElementType(), 3094 StridedLayoutAttr::get(sourceType.getContext(), 3095 layout.getOffset(), strides), 3096 nonRankReducedType.getMemorySpace()); 3097 } 3098 3099 Value mlir::memref::createCanonicalRankReducingSubViewOp( 3100 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) { 3101 auto memrefType = llvm::cast<MemRefType>(memref.getType()); 3102 unsigned rank = memrefType.getRank(); 3103 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 3104 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref); 3105 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 3106 auto targetType = 3107 llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType( 3108 targetShape, memrefType, offsets, sizes, strides)); 3109 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets, 3110 sizes, strides); 3111 } 3112 3113 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc, 3114 Value value, 3115 ArrayRef<int64_t> desiredShape) { 3116 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType()); 3117 assert(sourceMemrefType && "not a ranked memref type"); 3118 auto sourceShape = sourceMemrefType.getShape(); 3119 if (sourceShape.equals(desiredShape)) 3120 return value; 3121 auto maybeRankReductionMask = 3122 mlir::computeRankReductionMask(sourceShape, desiredShape); 3123 if (!maybeRankReductionMask) 3124 return failure(); 3125 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape); 3126 } 3127 3128 /// Helper method to check if a `subview` operation is trivially a no-op. This 3129 /// is the case if the all offsets are zero, all strides are 1, and the source 3130 /// shape is same as the size of the subview. In such cases, the subview can 3131 /// be folded into its source. 3132 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 3133 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 3134 return false; 3135 3136 auto mixedOffsets = subViewOp.getMixedOffsets(); 3137 auto mixedSizes = subViewOp.getMixedSizes(); 3138 auto mixedStrides = subViewOp.getMixedStrides(); 3139 3140 // Check offsets are zero. 3141 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 3142 std::optional<int64_t> intValue = getConstantIntValue(ofr); 3143 return !intValue || intValue.value() != 0; 3144 })) 3145 return false; 3146 3147 // Check strides are one. 3148 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 3149 std::optional<int64_t> intValue = getConstantIntValue(ofr); 3150 return !intValue || intValue.value() != 1; 3151 })) 3152 return false; 3153 3154 // Check all size values are static and matches the (static) source shape. 3155 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 3156 for (const auto &size : llvm::enumerate(mixedSizes)) { 3157 std::optional<int64_t> intValue = getConstantIntValue(size.value()); 3158 if (!intValue || *intValue != sourceShape[size.index()]) 3159 return false; 3160 } 3161 // All conditions met. The `SubViewOp` is foldable as a no-op. 3162 return true; 3163 } 3164 3165 namespace { 3166 /// Pattern to rewrite a subview op with MemRefCast arguments. 3167 /// This essentially pushes memref.cast past its consuming subview when 3168 /// `canFoldIntoConsumerOp` is true. 3169 /// 3170 /// Example: 3171 /// ``` 3172 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 3173 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 3174 /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>> 3175 /// ``` 3176 /// is rewritten into: 3177 /// ``` 3178 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 3179 /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to 3180 /// memref<3x4xf32, strided<[?, 1], offset: ?>> 3181 /// ``` 3182 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 3183 public: 3184 using OpRewritePattern<SubViewOp>::OpRewritePattern; 3185 3186 LogicalResult matchAndRewrite(SubViewOp subViewOp, 3187 PatternRewriter &rewriter) const override { 3188 // Any constant operand, just return to let SubViewOpConstantFolder kick 3189 // in. 3190 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 3191 return matchPattern(operand, matchConstantIndex()); 3192 })) 3193 return failure(); 3194 3195 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>(); 3196 if (!castOp) 3197 return failure(); 3198 3199 if (!CastOp::canFoldIntoConsumerOp(castOp)) 3200 return failure(); 3201 3202 // Compute the SubViewOp result type after folding the MemRefCastOp. Use 3203 // the MemRefCastOp source operand type to infer the result type and the 3204 // current SubViewOp source operand type to compute the dropped dimensions 3205 // if the operation is rank-reducing. 3206 auto resultType = getCanonicalSubViewResultType( 3207 subViewOp.getType(), subViewOp.getSourceType(), 3208 llvm::cast<MemRefType>(castOp.getSource().getType()), 3209 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 3210 subViewOp.getMixedStrides()); 3211 if (!resultType) 3212 return failure(); 3213 3214 Value newSubView = rewriter.create<SubViewOp>( 3215 subViewOp.getLoc(), resultType, castOp.getSource(), 3216 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(), 3217 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(), 3218 subViewOp.getStaticStrides()); 3219 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 3220 newSubView); 3221 return success(); 3222 } 3223 }; 3224 3225 /// Canonicalize subview ops that are no-ops. When the source shape is not 3226 /// same as a result shape due to use of `affine_map`. 3227 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 3228 public: 3229 using OpRewritePattern<SubViewOp>::OpRewritePattern; 3230 3231 LogicalResult matchAndRewrite(SubViewOp subViewOp, 3232 PatternRewriter &rewriter) const override { 3233 if (!isTrivialSubViewOp(subViewOp)) 3234 return failure(); 3235 if (subViewOp.getSourceType() == subViewOp.getType()) { 3236 rewriter.replaceOp(subViewOp, subViewOp.getSource()); 3237 return success(); 3238 } 3239 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 3240 subViewOp.getSource()); 3241 return success(); 3242 } 3243 }; 3244 } // namespace 3245 3246 /// Return the canonical type of the result of a subview. 3247 struct SubViewReturnTypeCanonicalizer { 3248 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 3249 ArrayRef<OpFoldResult> mixedSizes, 3250 ArrayRef<OpFoldResult> mixedStrides) { 3251 // Infer a memref type without taking into account any rank reductions. 3252 auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets, 3253 mixedSizes, mixedStrides); 3254 if (!resTy) 3255 return {}; 3256 MemRefType nonReducedType = cast<MemRefType>(resTy); 3257 3258 // Directly return the non-rank reduced type if there are no dropped dims. 3259 llvm::SmallBitVector droppedDims = op.getDroppedDims(); 3260 if (droppedDims.none()) 3261 return nonReducedType; 3262 3263 // Take the strides and offset from the non-rank reduced type. 3264 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset(); 3265 3266 // Drop dims from shape and strides. 3267 SmallVector<int64_t> targetShape; 3268 SmallVector<int64_t> targetStrides; 3269 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) { 3270 if (droppedDims.test(i)) 3271 continue; 3272 targetStrides.push_back(nonReducedStrides[i]); 3273 targetShape.push_back(nonReducedType.getDimSize(i)); 3274 } 3275 3276 return MemRefType::get(targetShape, nonReducedType.getElementType(), 3277 StridedLayoutAttr::get(nonReducedType.getContext(), 3278 offset, targetStrides), 3279 nonReducedType.getMemorySpace()); 3280 } 3281 }; 3282 3283 /// A canonicalizer wrapper to replace SubViewOps. 3284 struct SubViewCanonicalizer { 3285 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 3286 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 3287 } 3288 }; 3289 3290 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 3291 MLIRContext *context) { 3292 results 3293 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 3294 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 3295 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 3296 } 3297 3298 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { 3299 MemRefType sourceMemrefType = getSource().getType(); 3300 MemRefType resultMemrefType = getResult().getType(); 3301 auto resultLayout = 3302 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout()); 3303 3304 if (resultMemrefType == sourceMemrefType && 3305 resultMemrefType.hasStaticShape() && 3306 (!resultLayout || resultLayout.hasStaticLayout())) { 3307 return getViewSource(); 3308 } 3309 3310 // Fold subview(subview(x)), where both subviews have the same size and the 3311 // second subview's offsets are all zero. (I.e., the second subview is a 3312 // no-op.) 3313 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) { 3314 auto srcSizes = srcSubview.getMixedSizes(); 3315 auto sizes = getMixedSizes(); 3316 auto offsets = getMixedOffsets(); 3317 bool allOffsetsZero = llvm::all_of( 3318 offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); 3319 auto strides = getMixedStrides(); 3320 bool allStridesOne = llvm::all_of( 3321 strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }); 3322 bool allSizesSame = llvm::equal(sizes, srcSizes); 3323 if (allOffsetsZero && allStridesOne && allSizesSame && 3324 resultMemrefType == sourceMemrefType) 3325 return getViewSource(); 3326 } 3327 3328 return {}; 3329 } 3330 3331 //===----------------------------------------------------------------------===// 3332 // TransposeOp 3333 //===----------------------------------------------------------------------===// 3334 3335 void TransposeOp::getAsmResultNames( 3336 function_ref<void(Value, StringRef)> setNameFn) { 3337 setNameFn(getResult(), "transpose"); 3338 } 3339 3340 /// Build a strided memref type by applying `permutationMap` to `memRefType`. 3341 static MemRefType inferTransposeResultType(MemRefType memRefType, 3342 AffineMap permutationMap) { 3343 auto originalSizes = memRefType.getShape(); 3344 auto [originalStrides, offset] = memRefType.getStridesAndOffset(); 3345 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank())); 3346 3347 // Compute permuted sizes and strides. 3348 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes); 3349 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides); 3350 3351 return MemRefType::Builder(memRefType) 3352 .setShape(sizes) 3353 .setLayout( 3354 StridedLayoutAttr::get(memRefType.getContext(), offset, strides)); 3355 } 3356 3357 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 3358 AffineMapAttr permutation, 3359 ArrayRef<NamedAttribute> attrs) { 3360 auto permutationMap = permutation.getValue(); 3361 assert(permutationMap); 3362 3363 auto memRefType = llvm::cast<MemRefType>(in.getType()); 3364 // Compute result type. 3365 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 3366 3367 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation); 3368 build(b, result, resultType, in, attrs); 3369 } 3370 3371 // transpose $in $permutation attr-dict : type($in) `to` type(results) 3372 void TransposeOp::print(OpAsmPrinter &p) { 3373 p << " " << getIn() << " " << getPermutation(); 3374 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()}); 3375 p << " : " << getIn().getType() << " to " << getType(); 3376 } 3377 3378 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { 3379 OpAsmParser::UnresolvedOperand in; 3380 AffineMap permutation; 3381 MemRefType srcType, dstType; 3382 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 3383 parser.parseOptionalAttrDict(result.attributes) || 3384 parser.parseColonType(srcType) || 3385 parser.resolveOperand(in, srcType, result.operands) || 3386 parser.parseKeywordType("to", dstType) || 3387 parser.addTypeToList(dstType, result.types)) 3388 return failure(); 3389 3390 result.addAttribute(TransposeOp::getPermutationAttrStrName(), 3391 AffineMapAttr::get(permutation)); 3392 return success(); 3393 } 3394 3395 LogicalResult TransposeOp::verify() { 3396 if (!getPermutation().isPermutation()) 3397 return emitOpError("expected a permutation map"); 3398 if (getPermutation().getNumDims() != getIn().getType().getRank()) 3399 return emitOpError("expected a permutation map of same rank as the input"); 3400 3401 auto srcType = llvm::cast<MemRefType>(getIn().getType()); 3402 auto resultType = llvm::cast<MemRefType>(getType()); 3403 auto canonicalResultType = inferTransposeResultType(srcType, getPermutation()) 3404 .canonicalizeStridedLayout(); 3405 3406 if (resultType.canonicalizeStridedLayout() != canonicalResultType) 3407 return emitOpError("result type ") 3408 << resultType 3409 << " is not equivalent to the canonical transposed input type " 3410 << canonicalResultType; 3411 return success(); 3412 } 3413 3414 OpFoldResult TransposeOp::fold(FoldAdaptor) { 3415 // First check for identity permutation, we can fold it away if input and 3416 // result types are identical already. 3417 if (getPermutation().isIdentity() && getType() == getIn().getType()) 3418 return getIn(); 3419 // Fold two consecutive memref.transpose Ops into one by composing their 3420 // permutation maps. 3421 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) { 3422 AffineMap composedPermutation = 3423 getPermutation().compose(otherTransposeOp.getPermutation()); 3424 getInMutable().assign(otherTransposeOp.getIn()); 3425 setPermutation(composedPermutation); 3426 return getResult(); 3427 } 3428 return {}; 3429 } 3430 3431 //===----------------------------------------------------------------------===// 3432 // ViewOp 3433 //===----------------------------------------------------------------------===// 3434 3435 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 3436 setNameFn(getResult(), "view"); 3437 } 3438 3439 LogicalResult ViewOp::verify() { 3440 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType()); 3441 auto viewType = getType(); 3442 3443 // The base memref should have identity layout map (or none). 3444 if (!baseType.getLayout().isIdentity()) 3445 return emitError("unsupported map for base memref type ") << baseType; 3446 3447 // The result memref should have identity layout map (or none). 3448 if (!viewType.getLayout().isIdentity()) 3449 return emitError("unsupported map for result memref type ") << viewType; 3450 3451 // The base memref and the view memref should be in the same memory space. 3452 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 3453 return emitError("different memory spaces specified for base memref " 3454 "type ") 3455 << baseType << " and view memref type " << viewType; 3456 3457 // Verify that we have the correct number of sizes for the result type. 3458 unsigned numDynamicDims = viewType.getNumDynamicDims(); 3459 if (getSizes().size() != numDynamicDims) 3460 return emitError("incorrect number of size operands for type ") << viewType; 3461 3462 return success(); 3463 } 3464 3465 Value ViewOp::getViewSource() { return getSource(); } 3466 3467 namespace { 3468 3469 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 3470 using OpRewritePattern<ViewOp>::OpRewritePattern; 3471 3472 LogicalResult matchAndRewrite(ViewOp viewOp, 3473 PatternRewriter &rewriter) const override { 3474 // Return if none of the operands are constants. 3475 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 3476 return matchPattern(operand, matchConstantIndex()); 3477 })) 3478 return failure(); 3479 3480 // Get result memref type. 3481 auto memrefType = viewOp.getType(); 3482 3483 // Get offset from old memref view type 'memRefType'. 3484 int64_t oldOffset; 3485 SmallVector<int64_t, 4> oldStrides; 3486 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset))) 3487 return failure(); 3488 assert(oldOffset == 0 && "Expected 0 offset"); 3489 3490 SmallVector<Value, 4> newOperands; 3491 3492 // Offset cannot be folded into result type. 3493 3494 // Fold any dynamic dim operands which are produced by a constant. 3495 SmallVector<int64_t, 4> newShapeConstants; 3496 newShapeConstants.reserve(memrefType.getRank()); 3497 3498 unsigned dynamicDimPos = 0; 3499 unsigned rank = memrefType.getRank(); 3500 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 3501 int64_t dimSize = memrefType.getDimSize(dim); 3502 // If this is already static dimension, keep it. 3503 if (!ShapedType::isDynamic(dimSize)) { 3504 newShapeConstants.push_back(dimSize); 3505 continue; 3506 } 3507 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp(); 3508 if (auto constantIndexOp = 3509 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 3510 // Dynamic shape dimension will be folded. 3511 newShapeConstants.push_back(constantIndexOp.value()); 3512 } else { 3513 // Dynamic shape dimension not folded; copy operand from old memref. 3514 newShapeConstants.push_back(dimSize); 3515 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]); 3516 } 3517 dynamicDimPos++; 3518 } 3519 3520 // Create new memref type with constant folded dims. 3521 MemRefType newMemRefType = 3522 MemRefType::Builder(memrefType).setShape(newShapeConstants); 3523 // Nothing new, don't fold. 3524 if (newMemRefType == memrefType) 3525 return failure(); 3526 3527 // Create new ViewOp. 3528 auto newViewOp = rewriter.create<ViewOp>( 3529 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), 3530 viewOp.getByteShift(), newOperands); 3531 // Insert a cast so we have the same type as the old memref type. 3532 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp); 3533 return success(); 3534 } 3535 }; 3536 3537 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 3538 using OpRewritePattern<ViewOp>::OpRewritePattern; 3539 3540 LogicalResult matchAndRewrite(ViewOp viewOp, 3541 PatternRewriter &rewriter) const override { 3542 Value memrefOperand = viewOp.getOperand(0); 3543 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 3544 if (!memrefCastOp) 3545 return failure(); 3546 Value allocOperand = memrefCastOp.getOperand(); 3547 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 3548 if (!allocOp) 3549 return failure(); 3550 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 3551 viewOp.getByteShift(), 3552 viewOp.getSizes()); 3553 return success(); 3554 } 3555 }; 3556 3557 } // namespace 3558 3559 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 3560 MLIRContext *context) { 3561 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 3562 } 3563 3564 //===----------------------------------------------------------------------===// 3565 // AtomicRMWOp 3566 //===----------------------------------------------------------------------===// 3567 3568 LogicalResult AtomicRMWOp::verify() { 3569 if (getMemRefType().getRank() != getNumOperands() - 2) 3570 return emitOpError( 3571 "expects the number of subscripts to be equal to memref rank"); 3572 switch (getKind()) { 3573 case arith::AtomicRMWKind::addf: 3574 case arith::AtomicRMWKind::maximumf: 3575 case arith::AtomicRMWKind::minimumf: 3576 case arith::AtomicRMWKind::mulf: 3577 if (!llvm::isa<FloatType>(getValue().getType())) 3578 return emitOpError() << "with kind '" 3579 << arith::stringifyAtomicRMWKind(getKind()) 3580 << "' expects a floating-point type"; 3581 break; 3582 case arith::AtomicRMWKind::addi: 3583 case arith::AtomicRMWKind::maxs: 3584 case arith::AtomicRMWKind::maxu: 3585 case arith::AtomicRMWKind::mins: 3586 case arith::AtomicRMWKind::minu: 3587 case arith::AtomicRMWKind::muli: 3588 case arith::AtomicRMWKind::ori: 3589 case arith::AtomicRMWKind::andi: 3590 if (!llvm::isa<IntegerType>(getValue().getType())) 3591 return emitOpError() << "with kind '" 3592 << arith::stringifyAtomicRMWKind(getKind()) 3593 << "' expects an integer type"; 3594 break; 3595 default: 3596 break; 3597 } 3598 return success(); 3599 } 3600 3601 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) { 3602 /// atomicrmw(memrefcast) -> atomicrmw 3603 if (succeeded(foldMemRefCast(*this, getValue()))) 3604 return getResult(); 3605 return OpFoldResult(); 3606 } 3607 3608 //===----------------------------------------------------------------------===// 3609 // TableGen'd op method definitions 3610 //===----------------------------------------------------------------------===// 3611 3612 #define GET_OP_CLASSES 3613 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 3614