1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 15 #include "mlir/Dialect/CommonFolders.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Traits.h" 18 #include "mlir/Dialect/UB/IR/UBOps.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/TypeUtilities.h" 25 #include "mlir/Interfaces/FunctionImplementation.h" 26 #include "mlir/Transforms/InliningUtils.h" 27 #include "llvm/ADT/SetOperations.h" 28 #include "llvm/ADT/SmallString.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 #include "llvm/Support/raw_ostream.h" 31 32 using namespace mlir; 33 using namespace mlir::shape; 34 35 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 36 37 namespace { 38 #include "ShapeCanonicalization.inc" 39 } // namespace 40 41 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 42 return RankedTensorType::get({rank}, IndexType::get(ctx)); 43 } 44 45 bool shape::isExtentTensorType(Type type) { 46 auto ranked = llvm::dyn_cast<RankedTensorType>(type); 47 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 48 } 49 50 LogicalResult shape::getShapeVec(Value input, 51 SmallVectorImpl<int64_t> &shapeValues) { 52 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 53 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType()); 54 if (!type.hasRank()) 55 return failure(); 56 llvm::append_range(shapeValues, type.getShape()); 57 return success(); 58 } 59 DenseIntElementsAttr attr; 60 if (matchPattern(input, m_Constant(&attr))) { 61 llvm::append_range(shapeValues, attr.getValues<int64_t>()); 62 return success(); 63 } 64 return failure(); 65 } 66 67 static bool isErrorPropagationPossible(TypeRange operandTypes) { 68 return llvm::any_of(operandTypes, 69 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>); 70 } 71 72 static LogicalResult verifySizeOrIndexOp(Operation *op) { 73 assert(op != nullptr && op->getNumResults() == 1); 74 Type resultTy = op->getResultTypes().front(); 75 if (isErrorPropagationPossible(op->getOperandTypes())) { 76 if (!llvm::isa<SizeType>(resultTy)) 77 return op->emitOpError() 78 << "if at least one of the operands can hold error values then " 79 "the result must be of type `size` to propagate them"; 80 } 81 return success(); 82 } 83 84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 85 assert(op != nullptr && op->getNumResults() == 1); 86 Type resultTy = op->getResultTypes().front(); 87 if (isErrorPropagationPossible(op->getOperandTypes())) { 88 if (!llvm::isa<ShapeType>(resultTy)) 89 return op->emitOpError() 90 << "if at least one of the operands can hold error values then " 91 "the result must be of type `shape` to propagate them"; 92 } 93 return success(); 94 } 95 96 template <typename... Ty> 97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 98 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front()); 99 } 100 101 template <typename... Ty, typename... ranges> 102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 103 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 104 } 105 106 //===----------------------------------------------------------------------===// 107 // InlinerInterface 108 //===----------------------------------------------------------------------===// 109 110 namespace { 111 /// This class defines the interface for inlining shape dialect ops. 112 struct ShapeInlinerInterface : public DialectInlinerInterface { 113 using DialectInlinerInterface::DialectInlinerInterface; 114 115 // Returns true if the given region 'src' can be inlined into the region 116 // 'dest' that is attached to an operation registered to the current dialect. 117 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 118 IRMapping &) const final { 119 return true; 120 } 121 122 // Returns true if the given operation 'op', that is registered to this 123 // dialect, can be inlined into the region 'dest' that is attached to an 124 // operation registered to the current dialect. 125 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 126 IRMapping &) const final { 127 return true; 128 } 129 }; 130 } // namespace 131 132 void ShapeDialect::initialize() { 133 addOperations< 134 #define GET_OP_LIST 135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 136 >(); 137 addTypes< 138 #define GET_TYPEDEF_LIST 139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc" 140 >(); 141 addInterfaces<ShapeInlinerInterface>(); 142 // Allow unknown operations during prototyping and testing. As the dialect is 143 // still evolving it makes it simple to start with an unregistered ops and 144 // try different variants before actually defining the op. 145 allowUnknownOperations(); 146 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp, 147 AssumingYieldOp>(); 148 } 149 150 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 151 Attribute value, Type type, 152 Location loc) { 153 if (auto poison = dyn_cast<ub::PoisonAttr>(value)) 154 return builder.create<ub::PoisonOp>(loc, type, poison); 155 156 if (llvm::isa<ShapeType>(type) || isExtentTensorType(type)) 157 return builder.create<ConstShapeOp>( 158 loc, type, llvm::cast<DenseIntElementsAttr>(value)); 159 if (llvm::isa<SizeType>(type)) 160 return builder.create<ConstSizeOp>(loc, type, 161 llvm::cast<IntegerAttr>(value)); 162 if (llvm::isa<WitnessType>(type)) 163 return builder.create<ConstWitnessOp>(loc, type, 164 llvm::cast<BoolAttr>(value)); 165 166 return arith::ConstantOp::materialize(builder, value, type, loc); 167 } 168 169 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 170 NamedAttribute attribute) { 171 // Verify shape.lib attribute. 172 if (attribute.getName() == "shape.lib") { 173 if (!op->hasTrait<OpTrait::SymbolTable>()) 174 return op->emitError( 175 "shape.lib attribute may only be on op implementing SymbolTable"); 176 177 if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) { 178 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 179 if (!symbol) 180 return op->emitError("shape function library ") 181 << symbolRef << " not found"; 182 return isa<shape::FunctionLibraryOp>(symbol) 183 ? success() 184 : op->emitError() 185 << symbolRef << " required to be shape function library"; 186 } 187 188 if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) { 189 // Verify all entries are function libraries and mappings in libraries 190 // refer to unique ops. 191 DenseSet<StringAttr> key; 192 for (auto it : arr) { 193 if (!llvm::isa<SymbolRefAttr>(it)) 194 return op->emitError( 195 "only SymbolRefAttr allowed in shape.lib attribute array"); 196 197 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 198 SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it))); 199 if (!shapeFnLib) 200 return op->emitError() 201 << it << " does not refer to FunctionLibraryOp"; 202 for (auto mapping : shapeFnLib.getMapping()) { 203 if (!key.insert(mapping.getName()).second) { 204 return op->emitError("only one op to shape mapping allowed, found " 205 "multiple for `") 206 << mapping.getName() << "`"; 207 } 208 } 209 } 210 return success(); 211 } 212 213 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 214 "allowed as shape.lib attribute"); 215 } 216 return success(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // AnyOp 221 //===----------------------------------------------------------------------===// 222 223 // TODO: Canonicalization should be implemented for shapes that can be 224 // determined through mixtures of the known dimensions of the inputs. 225 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) { 226 // Only the last operand is checked because AnyOp is commutative. 227 if (adaptor.getInputs().back()) 228 return adaptor.getInputs().back(); 229 230 return nullptr; 231 } 232 233 //===----------------------------------------------------------------------===// 234 // AssumingOp 235 //===----------------------------------------------------------------------===// 236 237 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { 238 result.regions.reserve(1); 239 Region *doRegion = result.addRegion(); 240 241 auto &builder = parser.getBuilder(); 242 OpAsmParser::UnresolvedOperand cond; 243 if (parser.parseOperand(cond) || 244 parser.resolveOperand(cond, builder.getType<WitnessType>(), 245 result.operands)) 246 return failure(); 247 248 // Parse optional results type list. 249 if (parser.parseOptionalArrowTypeList(result.types)) 250 return failure(); 251 252 // Parse the region and add a terminator if elided. 253 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 254 return failure(); 255 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 256 257 // Parse the optional attribute list. 258 if (parser.parseOptionalAttrDict(result.attributes)) 259 return failure(); 260 return success(); 261 } 262 263 void AssumingOp::print(OpAsmPrinter &p) { 264 bool yieldsResults = !getResults().empty(); 265 266 p << " " << getWitness(); 267 if (yieldsResults) 268 p << " -> (" << getResultTypes() << ")"; 269 p << ' '; 270 p.printRegion(getDoRegion(), 271 /*printEntryBlockArgs=*/false, 272 /*printBlockTerminators=*/yieldsResults); 273 p.printOptionalAttrDict((*this)->getAttrs()); 274 } 275 276 namespace { 277 // Removes AssumingOp with a passing witness and inlines the region. 278 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 279 using OpRewritePattern<AssumingOp>::OpRewritePattern; 280 281 LogicalResult matchAndRewrite(AssumingOp op, 282 PatternRewriter &rewriter) const override { 283 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 284 if (!witness || !witness.getPassingAttr()) 285 return failure(); 286 287 AssumingOp::inlineRegionIntoParent(op, rewriter); 288 return success(); 289 } 290 }; 291 292 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 293 using OpRewritePattern<AssumingOp>::OpRewritePattern; 294 295 LogicalResult matchAndRewrite(AssumingOp op, 296 PatternRewriter &rewriter) const override { 297 Block *body = op.getBody(); 298 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 299 300 // Find used values. 301 SmallVector<Value, 4> newYieldOperands; 302 for (auto [opResult, yieldOperand] : 303 llvm::zip(op.getResults(), yieldOp.getOperands())) { 304 if (!opResult.getUses().empty()) { 305 newYieldOperands.push_back(yieldOperand); 306 } 307 } 308 309 // Rewrite only if redundant results exist. 310 if (newYieldOperands.size() == yieldOp->getNumOperands()) 311 return failure(); 312 313 // Replace yield op in the old assuming op's body and move the entire region 314 // to the new assuming op. 315 rewriter.setInsertionPointToEnd(body); 316 auto newYieldOp = 317 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 318 rewriter.setInsertionPoint(op); 319 auto newOp = rewriter.create<AssumingOp>( 320 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 321 newOp.getDoRegion().takeBody(op.getDoRegion()); 322 323 // Use the new results to replace the previously used ones. 324 SmallVector<Value, 4> replacementValues; 325 auto src = newOp.getResults().begin(); 326 for (auto it : op.getResults()) { 327 if (it.getUses().empty()) 328 replacementValues.push_back(nullptr); 329 else 330 replacementValues.push_back(*src++); 331 } 332 rewriter.replaceOp(op, replacementValues); 333 return success(); 334 } 335 }; 336 } // namespace 337 338 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 339 MLIRContext *context) { 340 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 341 } 342 343 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 344 void AssumingOp::getSuccessorRegions( 345 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 346 // AssumingOp has unconditional control flow into the region and back to the 347 // parent, so return the correct RegionSuccessor purely based on the index 348 // being None or 0. 349 if (!point.isParent()) { 350 regions.push_back(RegionSuccessor(getResults())); 351 return; 352 } 353 354 regions.push_back(RegionSuccessor(&getDoRegion())); 355 } 356 357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 358 PatternRewriter &rewriter) { 359 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 360 auto *assumingBlock = op.getBody(); 361 auto initPosition = rewriter.getInsertionPoint(); 362 auto *blockAfterAssuming = 363 rewriter.splitBlock(blockBeforeAssuming, initPosition); 364 365 // Remove the AssumingOp and AssumingYieldOp. 366 auto &yieldOp = assumingBlock->back(); 367 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 368 rewriter.replaceOp(op, yieldOp.getOperands()); 369 rewriter.eraseOp(&yieldOp); 370 371 // Merge blocks together as there was no branching behavior from the 372 // AssumingOp. 373 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 374 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 375 } 376 377 void AssumingOp::build( 378 OpBuilder &builder, OperationState &result, Value witness, 379 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 380 OpBuilder::InsertionGuard g(builder); 381 382 result.addOperands(witness); 383 Region *bodyRegion = result.addRegion(); 384 builder.createBlock(bodyRegion); 385 386 // Build body. 387 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 388 builder.create<AssumingYieldOp>(result.location, yieldValues); 389 390 SmallVector<Type, 2> assumingTypes; 391 for (Value v : yieldValues) 392 assumingTypes.push_back(v.getType()); 393 result.addTypes(assumingTypes); 394 } 395 396 //===----------------------------------------------------------------------===// 397 // AddOp 398 //===----------------------------------------------------------------------===// 399 400 LogicalResult mlir::shape::AddOp::inferReturnTypes( 401 MLIRContext *context, std::optional<Location> location, 402 AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 403 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) || 404 llvm::isa<SizeType>(adaptor.getRhs().getType())) 405 inferredReturnTypes.assign({SizeType::get(context)}); 406 else 407 inferredReturnTypes.assign({IndexType::get(context)}); 408 return success(); 409 } 410 411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 412 // SizeType is compatible with IndexType. 413 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 414 } 415 416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) { 417 // add(x, 0) -> x 418 if (matchPattern(getRhs(), m_Zero())) 419 return getLhs(); 420 421 return constFoldBinaryOp<IntegerAttr>( 422 adaptor.getOperands(), 423 [](APInt a, const APInt &b) { return std::move(a) + b; }); 424 } 425 426 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } 427 428 //===----------------------------------------------------------------------===// 429 // AssumingAllOp 430 //===----------------------------------------------------------------------===// 431 432 namespace { 433 434 // Merge multiple `shape.assuming_all` operations together. 435 // 436 // %0 = shape.assuming_all %w0, %w1 437 // %1 = shape.assuming_all %w2, %0 438 // 439 // to: 440 // 441 // %0 = shape.assuming_all %w0, %w2, %w2 442 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { 443 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 444 445 LogicalResult matchAndRewrite(AssumingAllOp op, 446 PatternRewriter &rewriter) const override { 447 SmallVector<Value> operands; 448 449 for (Value operand : op.getInputs()) { 450 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>()) 451 operands.append(assumeAll.operand_begin(), assumeAll->operand_end()); 452 else 453 operands.push_back(operand); 454 } 455 456 // We didn't find any other `assuming_all` ops to merge with. 457 if (operands.size() == op.getNumOperands()) 458 return failure(); 459 460 // Replace with a new `assuming_all` operation with merged constraints. 461 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands); 462 return success(); 463 } 464 }; 465 466 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that 467 // are subsumed by others. 468 // 469 // %0 = shape.cstr_broadcastable %shape0, %shape1 470 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 471 // 472 // %2 = shape.cstr_broadcastable %shape3, %shape4 473 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 474 // 475 // %4 = shape.assuming_all %0, %1, %2, %3 476 // 477 // to: 478 // 479 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 480 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 481 // %2 = shape.assuming_all %0, %1 482 // 483 // In this example if shapes [0, 1, 2] are broadcastable, then it means that 484 // shapes [0, 1] are broadcastable too, and can be removed from the list of 485 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't 486 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). 487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { 488 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 489 490 LogicalResult matchAndRewrite(AssumingAllOp op, 491 PatternRewriter &rewriter) const override { 492 // Collect all `CstrBroadcastableOp` operands first. 493 SetVector<CstrBroadcastableOp> operands; 494 for (Value operand : op.getInputs()) { 495 // TODO: Apply this optimization if some of the witnesses are not 496 // produced by the `cstr_broadcastable`. 497 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>(); 498 if (!broadcastable) 499 return failure(); 500 501 operands.insert(broadcastable); 502 } 503 504 // Skip trivial `assuming_all` operations. 505 if (operands.size() <= 1) 506 return failure(); 507 508 // Collect shapes checked by `cstr_broadcastable` operands. 509 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes; 510 for (auto cstr : operands) { 511 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end()); 512 shapes.emplace_back(cstr, std::move(shapesSet)); 513 } 514 515 // Sort by the number of shape operands (larger to smaller). 516 llvm::sort(shapes, [](auto a, auto b) { 517 return a.first.getNumOperands() > b.first.getNumOperands(); 518 }); 519 520 // We start from the `cst_broadcastable` operations with largest number of 521 // shape operands, and remove redundant `cst_broadcastable` operations. We 522 // do this until we find a set of `cst_broadcastable` operations with 523 // non-overlapping constraints. 524 SmallVector<CstrBroadcastableOp> markedForErase; 525 526 for (unsigned i = 0; i < shapes.size(); ++i) { 527 auto isSubset = [&](auto pair) { 528 return llvm::set_is_subset(pair.second, shapes[i].second); 529 }; 530 531 // Keep redundant `cstr_broadcastable` operations to be erased. 532 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); 533 for (auto *it0 = it; it0 < shapes.end(); ++it0) 534 markedForErase.push_back(it0->first); 535 shapes.erase(it, shapes.end()); 536 } 537 538 // We didn't find any operands that could be removed. 539 if (markedForErase.empty()) 540 return failure(); 541 542 // Collect non-overlapping `cst_broadcastable` constraints. 543 SmallVector<Value> uniqueConstraints; 544 for (auto &shape : shapes) 545 uniqueConstraints.push_back(shape.first.getResult()); 546 547 // Replace with a new `assuming_all` operation ... 548 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints); 549 550 // ... and maybe erase `cstr_broadcastable` ops without uses. 551 for (auto &op : markedForErase) 552 if (op->use_empty()) 553 rewriter.eraseOp(op); 554 555 return success(); 556 } 557 }; 558 559 struct AssumingAllToCstrEqCanonicalization 560 : public OpRewritePattern<AssumingAllOp> { 561 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 562 563 LogicalResult matchAndRewrite(AssumingAllOp op, 564 PatternRewriter &rewriter) const override { 565 SmallVector<Value, 8> shapes; 566 for (Value w : op.getInputs()) { 567 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 568 if (!cstrEqOp) 569 return failure(); 570 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 571 return llvm::is_contained(shapes, s); 572 }); 573 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 574 return failure(); 575 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 576 } 577 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 578 return success(); 579 } 580 }; 581 582 template <typename OpTy> 583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 584 using OpRewritePattern<OpTy>::OpRewritePattern; 585 586 LogicalResult matchAndRewrite(OpTy op, 587 PatternRewriter &rewriter) const override { 588 // Find unique operands. 589 SetVector<Value> unique(op.operand_begin(), op.operand_end()); 590 591 // Reduce op to equivalent with unique operands. 592 if (unique.size() < op.getNumOperands()) { 593 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), 594 unique.takeVector(), op->getAttrs()); 595 return success(); 596 } 597 598 return failure(); 599 } 600 }; 601 } // namespace 602 603 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 604 MLIRContext *context) { 605 patterns 606 .add<MergeAssumingAllOps, AssumingAllOneOp, 607 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization, 608 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 609 } 610 611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) { 612 // Iterate in reverse to first handle all constant operands. They are 613 // guaranteed to be the tail of the inputs because this is commutative. 614 for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) { 615 Attribute a = adaptor.getInputs()[idx]; 616 // Cannot fold if any inputs are not constant; 617 if (!a) 618 return nullptr; 619 620 // We do not need to keep statically known values after handling them in 621 // this method. 622 getOperation()->eraseOperand(idx); 623 624 // Always false if any input is statically known false 625 if (!llvm::cast<BoolAttr>(a).getValue()) 626 return a; 627 } 628 // If this is reached, all inputs were statically known passing. 629 return BoolAttr::get(getContext(), true); 630 } 631 632 LogicalResult AssumingAllOp::verify() { 633 // Ensure that AssumingAllOp contains at least one operand 634 if (getNumOperands() == 0) 635 return emitOpError("no operands specified"); 636 637 return success(); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // BroadcastOp 642 //===----------------------------------------------------------------------===// 643 644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { 645 if (getShapes().size() == 1) { 646 // Otherwise, we need a cast which would be a canonicalization, not folding. 647 if (getShapes().front().getType() != getType()) 648 return nullptr; 649 return getShapes().front(); 650 } 651 652 // TODO: Support folding with more than 2 input shapes 653 if (getShapes().size() > 2) 654 return nullptr; 655 656 if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1]) 657 return nullptr; 658 auto lhsShape = llvm::to_vector<6>( 659 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0]) 660 .getValues<int64_t>()); 661 auto rhsShape = llvm::to_vector<6>( 662 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1]) 663 .getValues<int64_t>()); 664 SmallVector<int64_t, 6> resultShape; 665 666 // If the shapes are not compatible, we can't fold it. 667 // TODO: Fold to an "error". 668 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 669 return nullptr; 670 671 Builder builder(getContext()); 672 return builder.getIndexTensorAttr(resultShape); 673 } 674 675 LogicalResult BroadcastOp::verify() { 676 return verifyShapeOrExtentTensorOp(*this); 677 } 678 679 namespace { 680 template <typename OpTy> 681 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 682 using OpRewritePattern<OpTy>::OpRewritePattern; 683 684 LogicalResult matchAndRewrite(OpTy op, 685 PatternRewriter &rewriter) const override { 686 auto isPotentiallyNonEmptyShape = [](Value shape) { 687 if (auto extentTensorTy = 688 llvm::dyn_cast<RankedTensorType>(shape.getType())) { 689 if (extentTensorTy.getDimSize(0) == 0) 690 return false; 691 } 692 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 693 if (constShape.getShape().empty()) 694 return false; 695 } 696 return true; 697 }; 698 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(), 699 isPotentiallyNonEmptyShape); 700 701 // Replace the op with empty shape constant if all operants are reduced to 702 // be empty. 703 if (newOperands.empty()) { 704 rewriter.replaceOpWithNewOp<ConstShapeOp>( 705 op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({})); 706 return success(); 707 } 708 709 // Reduce op to equivalent without empty shape operands. 710 if (newOperands.size() < op.getNumOperands()) { 711 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 712 op->getAttrs()); 713 return success(); 714 } 715 716 return failure(); 717 } 718 }; 719 720 struct BroadcastForwardSingleOperandPattern 721 : public OpRewritePattern<BroadcastOp> { 722 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 723 724 LogicalResult matchAndRewrite(BroadcastOp op, 725 PatternRewriter &rewriter) const override { 726 if (op.getNumOperands() != 1) 727 return failure(); 728 Value replacement = op.getShapes().front(); 729 730 // Insert cast if needed. 731 if (replacement.getType() != op.getType()) { 732 auto loc = op.getLoc(); 733 if (llvm::isa<ShapeType>(op.getType())) { 734 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 735 } else { 736 assert(!llvm::isa<ShapeType>(op.getType()) && 737 !llvm::isa<ShapeType>(replacement.getType()) && 738 "expect extent tensor cast"); 739 replacement = 740 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 741 } 742 } 743 744 rewriter.replaceOp(op, replacement); 745 return success(); 746 } 747 }; 748 749 struct BroadcastFoldConstantOperandsPattern 750 : public OpRewritePattern<BroadcastOp> { 751 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 752 753 LogicalResult matchAndRewrite(BroadcastOp op, 754 PatternRewriter &rewriter) const override { 755 SmallVector<int64_t, 8> foldedConstantShape; 756 SmallVector<Value, 8> newShapeOperands; 757 for (Value shape : op.getShapes()) { 758 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 759 SmallVector<int64_t, 8> newFoldedConstantShape; 760 if (OpTrait::util::getBroadcastedShape( 761 foldedConstantShape, 762 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 763 newFoldedConstantShape)) { 764 foldedConstantShape = newFoldedConstantShape; 765 continue; 766 } 767 } 768 newShapeOperands.push_back(shape); 769 } 770 771 // Need at least two constant operands to fold anything. 772 if (op.getNumOperands() - newShapeOperands.size() < 2) 773 return failure(); 774 775 auto foldedConstantOperandsTy = RankedTensorType::get( 776 {static_cast<int64_t>(foldedConstantShape.size())}, 777 rewriter.getIndexType()); 778 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 779 op.getLoc(), foldedConstantOperandsTy, 780 rewriter.getIndexTensorAttr(foldedConstantShape))); 781 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 782 newShapeOperands); 783 return success(); 784 } 785 }; 786 787 template <typename OpTy> 788 struct CanonicalizeCastExtentTensorOperandsPattern 789 : public OpRewritePattern<OpTy> { 790 using OpRewritePattern<OpTy>::OpRewritePattern; 791 792 LogicalResult matchAndRewrite(OpTy op, 793 PatternRewriter &rewriter) const override { 794 // Canonicalize operands. 795 bool anyChange = false; 796 auto canonicalizeOperand = [&](Value operand) -> Value { 797 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 798 // Only eliminate the cast if it holds no shape information. 799 bool isInformationLoosingCast = 800 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0); 801 if (isInformationLoosingCast) { 802 anyChange = true; 803 return castOp.getSource(); 804 } 805 } 806 return operand; 807 }; 808 auto newOperands = llvm::to_vector<8>( 809 llvm::map_range(op.getOperands(), canonicalizeOperand)); 810 811 // Rewrite op if any change required. 812 if (!anyChange) 813 return failure(); 814 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 815 return success(); 816 } 817 }; 818 819 struct BroadcastConcretizeResultTypePattern 820 : public OpRewritePattern<BroadcastOp> { 821 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 822 823 LogicalResult matchAndRewrite(BroadcastOp op, 824 PatternRewriter &rewriter) const override { 825 // Only concretize dynamic extent tensor result types. 826 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType()); 827 if (!resultTy || !resultTy.isDynamicDim(0)) 828 return failure(); 829 830 // Infer resulting shape rank if possible. 831 int64_t maxRank = 0; 832 for (Value shape : op.getShapes()) { 833 if (auto extentTensorTy = 834 llvm::dyn_cast<RankedTensorType>(shape.getType())) { 835 // Cannot infer resulting shape rank if any operand is dynamically 836 // ranked. 837 if (extentTensorTy.isDynamicDim(0)) 838 return failure(); 839 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 840 } 841 } 842 843 auto newOp = rewriter.create<BroadcastOp>( 844 op.getLoc(), getExtentTensorType(getContext(), maxRank), 845 op.getShapes()); 846 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 847 return success(); 848 } 849 }; 850 } // namespace 851 852 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 853 MLIRContext *context) { 854 patterns.add<BroadcastConcretizeResultTypePattern, 855 BroadcastFoldConstantOperandsPattern, 856 BroadcastForwardSingleOperandPattern, 857 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 858 RemoveDuplicateOperandsPattern<BroadcastOp>, 859 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 860 } 861 862 //===----------------------------------------------------------------------===// 863 // ConcatOp 864 //===----------------------------------------------------------------------===// 865 866 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { 867 if (!adaptor.getLhs() || !adaptor.getRhs()) 868 return nullptr; 869 auto lhsShape = llvm::to_vector<6>( 870 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>()); 871 auto rhsShape = llvm::to_vector<6>( 872 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>()); 873 SmallVector<int64_t, 6> resultShape; 874 resultShape.append(lhsShape.begin(), lhsShape.end()); 875 resultShape.append(rhsShape.begin(), rhsShape.end()); 876 Builder builder(getContext()); 877 return builder.getIndexTensorAttr(resultShape); 878 } 879 880 //===----------------------------------------------------------------------===// 881 // ConstShapeOp 882 //===----------------------------------------------------------------------===// 883 884 void ConstShapeOp::print(OpAsmPrinter &p) { 885 p << " "; 886 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); 887 p << "["; 888 interleaveComma(getShape().getValues<int64_t>(), p); 889 p << "] : "; 890 p.printType(getType()); 891 } 892 893 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { 894 if (parser.parseOptionalAttrDict(result.attributes)) 895 return failure(); 896 // We piggy-back on ArrayAttr parsing, though we don't internally store the 897 // shape as an ArrayAttr. 898 // TODO: Implement custom parser and maybe make syntax a bit more concise. 899 Attribute extentsRaw; 900 NamedAttrList dummy; 901 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 902 return failure(); 903 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw); 904 if (!extentsArray) 905 return failure(); 906 SmallVector<int64_t, 6> ints; 907 for (Attribute extent : extentsArray) { 908 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent); 909 if (!attr) 910 return failure(); 911 ints.push_back(attr.getInt()); 912 } 913 Builder &builder = parser.getBuilder(); 914 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 915 Type resultTy; 916 if (parser.parseColonType(resultTy)) 917 return failure(); 918 result.types.push_back(resultTy); 919 return success(); 920 } 921 922 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); } 923 924 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 925 MLIRContext *context) { 926 patterns.add<TensorCastConstShape>(context); 927 } 928 929 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 930 MLIRContext *context, std::optional<Location> location, 931 ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 932 Builder b(context); 933 const Properties prop = adaptor.getProperties(); 934 inferredReturnTypes.assign({RankedTensorType::get( 935 {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())}); 936 return success(); 937 } 938 939 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 940 TypeRange r) { 941 if (l.size() != 1 || r.size() != 1) 942 return false; 943 944 Type lhs = l.front(); 945 Type rhs = r.front(); 946 947 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs)) 948 // Shape type is compatible with all other valid return types. 949 return true; 950 return lhs == rhs; 951 } 952 953 //===----------------------------------------------------------------------===// 954 // CstrBroadcastableOp 955 //===----------------------------------------------------------------------===// 956 957 void CstrBroadcastableOp::getCanonicalizationPatterns( 958 RewritePatternSet &patterns, MLIRContext *context) { 959 // Canonicalization patterns have overlap with the considerations during 960 // folding in case additional shape information is inferred at some point that 961 // does not result in folding. 962 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 963 CstrBroadcastableEqOps, 964 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 965 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 966 } 967 968 // Return true if there is exactly one attribute not representing a scalar 969 // broadcast. 970 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 971 bool nonScalarSeen = false; 972 for (Attribute a : attributes) { 973 if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) { 974 if (nonScalarSeen) 975 return false; 976 nonScalarSeen = true; 977 } 978 } 979 return true; 980 } 981 982 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) { 983 // No broadcasting is needed if all operands but one are scalar. 984 if (hasAtMostSingleNonScalar(adaptor.getShapes())) 985 return BoolAttr::get(getContext(), true); 986 987 if ([&] { 988 SmallVector<SmallVector<int64_t, 6>, 6> extents; 989 for (const auto &operand : adaptor.getShapes()) { 990 if (!operand) 991 return false; 992 extents.push_back(llvm::to_vector<6>( 993 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>())); 994 } 995 return OpTrait::util::staticallyKnownBroadcastable(extents); 996 }()) 997 return BoolAttr::get(getContext(), true); 998 999 // Lastly, see if folding can be completed based on what constraints are known 1000 // on the input shapes. 1001 if ([&] { 1002 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1003 for (auto shapeValue : getShapes()) { 1004 extents.emplace_back(); 1005 if (failed(getShapeVec(shapeValue, extents.back()))) 1006 return false; 1007 } 1008 return OpTrait::util::staticallyKnownBroadcastable(extents); 1009 }()) 1010 return BoolAttr::get(getContext(), true); 1011 1012 // Because a failing witness result here represents an eventual assertion 1013 // failure, we do not replace it with a constant witness. 1014 return nullptr; 1015 } 1016 1017 LogicalResult CstrBroadcastableOp::verify() { 1018 // Ensure that CstrBroadcastableOp contains at least two operands 1019 if (getNumOperands() < 2) 1020 return emitOpError("required at least 2 input shapes"); 1021 return success(); 1022 } 1023 1024 //===----------------------------------------------------------------------===// 1025 // CstrEqOp 1026 //===----------------------------------------------------------------------===// 1027 1028 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1029 MLIRContext *context) { 1030 // If inputs are equal, return passing witness 1031 patterns.add<CstrEqEqOps>(context); 1032 } 1033 1034 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) { 1035 if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) { 1036 return a && a == adaptor.getShapes().front(); 1037 })) 1038 return BoolAttr::get(getContext(), true); 1039 1040 // Because a failing witness result here represents an eventual assertion 1041 // failure, we do not try to replace it with a constant witness. Similarly, we 1042 // cannot if there are any non-const inputs. 1043 return nullptr; 1044 } 1045 1046 //===----------------------------------------------------------------------===// 1047 // ConstSizeOp 1048 //===----------------------------------------------------------------------===// 1049 1050 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 1051 int64_t value) { 1052 build(builder, result, builder.getIndexAttr(value)); 1053 } 1054 1055 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); } 1056 1057 void ConstSizeOp::getAsmResultNames( 1058 llvm::function_ref<void(Value, StringRef)> setNameFn) { 1059 SmallString<4> buffer; 1060 llvm::raw_svector_ostream os(buffer); 1061 os << "c" << getValue(); 1062 setNameFn(getResult(), os.str()); 1063 } 1064 1065 //===----------------------------------------------------------------------===// 1066 // ConstWitnessOp 1067 //===----------------------------------------------------------------------===// 1068 1069 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); } 1070 1071 //===----------------------------------------------------------------------===// 1072 // CstrRequireOp 1073 //===----------------------------------------------------------------------===// 1074 1075 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) { 1076 return adaptor.getPred(); 1077 } 1078 1079 //===----------------------------------------------------------------------===// 1080 // DimOp 1081 //===----------------------------------------------------------------------===// 1082 1083 std::optional<int64_t> DimOp::getConstantIndex() { 1084 if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>()) 1085 return constSizeOp.getValue().getLimitedValue(); 1086 if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>()) 1087 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt(); 1088 return std::nullopt; 1089 } 1090 1091 OpFoldResult DimOp::fold(FoldAdaptor adaptor) { 1092 Type valType = getValue().getType(); 1093 auto valShapedType = llvm::dyn_cast<ShapedType>(valType); 1094 if (!valShapedType || !valShapedType.hasRank()) 1095 return nullptr; 1096 std::optional<int64_t> index = getConstantIndex(); 1097 if (!index.has_value()) 1098 return nullptr; 1099 if (index.value() < 0 || index.value() >= valShapedType.getRank()) 1100 return nullptr; 1101 auto extent = valShapedType.getDimSize(*index); 1102 if (ShapedType::isDynamic(extent)) 1103 return nullptr; 1104 return IntegerAttr::get(IndexType::get(getContext()), extent); 1105 } 1106 1107 LogicalResult mlir::shape::DimOp::inferReturnTypes( 1108 MLIRContext *context, std::optional<Location> location, 1109 DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1110 inferredReturnTypes.assign({adaptor.getIndex().getType()}); 1111 return success(); 1112 } 1113 1114 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1115 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1116 } 1117 1118 //===----------------------------------------------------------------------===// 1119 // DivOp 1120 //===----------------------------------------------------------------------===// 1121 1122 OpFoldResult DivOp::fold(FoldAdaptor adaptor) { 1123 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs()); 1124 if (!lhs) 1125 return nullptr; 1126 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()); 1127 if (!rhs) 1128 return nullptr; 1129 1130 // Division in APInt does not follow floor(lhs, rhs) when the result is 1131 // negative. Rather, APInt rounds toward zero. 1132 APInt quotient, remainder; 1133 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 1134 if (quotient.isNegative() && !remainder.isZero()) { 1135 quotient -= 1; 1136 } 1137 1138 Type indexTy = IndexType::get(getContext()); 1139 return IntegerAttr::get(indexTy, quotient); 1140 } 1141 1142 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1143 MLIRContext *context, std::optional<Location> location, 1144 DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1145 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) || 1146 llvm::isa<SizeType>(adaptor.getRhs().getType())) 1147 inferredReturnTypes.assign({SizeType::get(context)}); 1148 else 1149 inferredReturnTypes.assign({IndexType::get(context)}); 1150 return success(); 1151 } 1152 1153 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1154 // SizeType is compatible with IndexType. 1155 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1156 } 1157 1158 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } 1159 1160 //===----------------------------------------------------------------------===// 1161 // ShapeEqOp 1162 //===----------------------------------------------------------------------===// 1163 1164 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) { 1165 bool allSame = true; 1166 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front()) 1167 return {}; 1168 for (Attribute operand : adaptor.getShapes().drop_front()) { 1169 if (!operand) 1170 return {}; 1171 allSame = allSame && operand == adaptor.getShapes().front(); 1172 } 1173 return BoolAttr::get(getContext(), allSame); 1174 } 1175 1176 //===----------------------------------------------------------------------===// 1177 // IndexToSizeOp 1178 //===----------------------------------------------------------------------===// 1179 1180 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) { 1181 // Constant values of both types, `shape.size` and `index`, are represented as 1182 // `IntegerAttr`s which makes constant folding simple. 1183 if (Attribute arg = adaptor.getArg()) 1184 return arg; 1185 return {}; 1186 } 1187 1188 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1189 MLIRContext *context) { 1190 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1191 } 1192 1193 //===----------------------------------------------------------------------===// 1194 // FromExtentsOp 1195 //===----------------------------------------------------------------------===// 1196 1197 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) { 1198 if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; })) 1199 return nullptr; 1200 SmallVector<int64_t, 6> extents; 1201 for (auto attr : adaptor.getExtents()) 1202 extents.push_back(llvm::cast<IntegerAttr>(attr).getInt()); 1203 Builder builder(getContext()); 1204 return builder.getIndexTensorAttr(extents); 1205 } 1206 1207 //===----------------------------------------------------------------------===// 1208 // FunctionLibraryOp 1209 //===----------------------------------------------------------------------===// 1210 1211 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1212 StringRef name) { 1213 result.attributes.push_back(builder.getNamedAttr( 1214 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1215 } 1216 1217 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1218 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>( 1219 getMapping().get(op->getName().getIdentifier())); 1220 if (!attr) 1221 return nullptr; 1222 return lookupSymbol<FuncOp>(attr); 1223 } 1224 1225 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, 1226 OperationState &result) { 1227 // Parse the op name. 1228 StringAttr nameAttr; 1229 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1230 result.attributes)) 1231 return failure(); 1232 1233 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1234 return failure(); 1235 1236 auto *bodyRegion = result.addRegion(); 1237 if (parser.parseRegion(*bodyRegion)) 1238 return failure(); 1239 1240 if (parser.parseKeyword("mapping")) 1241 return failure(); 1242 1243 DictionaryAttr mappingAttr; 1244 if (parser.parseAttribute(mappingAttr, 1245 parser.getBuilder().getType<NoneType>(), "mapping", 1246 result.attributes)) 1247 return failure(); 1248 return success(); 1249 } 1250 1251 void FunctionLibraryOp::print(OpAsmPrinter &p) { 1252 p << ' '; 1253 p.printSymbolName(getName()); 1254 p.printOptionalAttrDictWithKeyword( 1255 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); 1256 p << ' '; 1257 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1258 /*printBlockTerminators=*/false); 1259 p << " mapping "; 1260 p.printAttributeWithoutType(getMappingAttr()); 1261 } 1262 1263 //===----------------------------------------------------------------------===// 1264 // FuncOp 1265 //===----------------------------------------------------------------------===// 1266 1267 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 1268 ArrayRef<NamedAttribute> attrs) { 1269 OpBuilder builder(location->getContext()); 1270 OperationState state(location, getOperationName()); 1271 FuncOp::build(builder, state, name, type, attrs); 1272 return cast<FuncOp>(Operation::create(state)); 1273 } 1274 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 1275 Operation::dialect_attr_range attrs) { 1276 SmallVector<NamedAttribute, 8> attrRef(attrs); 1277 return create(location, name, type, llvm::ArrayRef(attrRef)); 1278 } 1279 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 1280 ArrayRef<NamedAttribute> attrs, 1281 ArrayRef<DictionaryAttr> argAttrs) { 1282 FuncOp func = create(location, name, type, attrs); 1283 func.setAllArgAttrs(argAttrs); 1284 return func; 1285 } 1286 1287 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 1288 FunctionType type, ArrayRef<NamedAttribute> attrs, 1289 ArrayRef<DictionaryAttr> argAttrs) { 1290 state.addAttribute(FuncOp::getSymNameAttrName(state.name), 1291 builder.getStringAttr(name)); 1292 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name), 1293 TypeAttr::get(type)); 1294 state.attributes.append(attrs.begin(), attrs.end()); 1295 state.addRegion(); 1296 1297 if (argAttrs.empty()) 1298 return; 1299 assert(type.getNumInputs() == argAttrs.size()); 1300 function_interface_impl::addArgAndResultAttrs( 1301 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, 1302 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); 1303 } 1304 1305 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 1306 auto buildFuncType = 1307 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 1308 function_interface_impl::VariadicFlag, 1309 std::string &) { return builder.getFunctionType(argTypes, results); }; 1310 1311 return function_interface_impl::parseFunctionOp( 1312 parser, result, /*allowVariadic=*/false, 1313 getFunctionTypeAttrName(result.name), buildFuncType, 1314 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 1315 } 1316 1317 void FuncOp::print(OpAsmPrinter &p) { 1318 function_interface_impl::printFunctionOp( 1319 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 1320 getArgAttrsAttrName(), getResAttrsAttrName()); 1321 } 1322 1323 //===----------------------------------------------------------------------===// 1324 // GetExtentOp 1325 //===----------------------------------------------------------------------===// 1326 1327 std::optional<int64_t> GetExtentOp::getConstantDim() { 1328 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1329 return constSizeOp.getValue().getLimitedValue(); 1330 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1331 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt(); 1332 return std::nullopt; 1333 } 1334 1335 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) { 1336 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape()); 1337 if (!elements) 1338 return nullptr; 1339 std::optional<int64_t> dim = getConstantDim(); 1340 if (!dim.has_value()) 1341 return nullptr; 1342 if (dim.value() >= elements.getNumElements()) 1343 return nullptr; 1344 return elements.getValues<Attribute>()[(uint64_t)dim.value()]; 1345 } 1346 1347 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1348 int64_t dim) { 1349 auto loc = result.location; 1350 auto dimAttr = builder.getIndexAttr(dim); 1351 if (llvm::isa<ShapeType>(shape.getType())) { 1352 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1353 build(builder, result, builder.getType<SizeType>(), shape, dim); 1354 } else { 1355 Value dim = 1356 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1357 build(builder, result, builder.getIndexType(), shape, dim); 1358 } 1359 } 1360 1361 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1362 MLIRContext *context, std::optional<Location> location, 1363 GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1364 inferredReturnTypes.assign({IndexType::get(context)}); 1365 return success(); 1366 } 1367 1368 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1369 TypeRange r) { 1370 // SizeType is compatible with IndexType. 1371 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1372 } 1373 1374 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } 1375 1376 //===----------------------------------------------------------------------===// 1377 // IsBroadcastableOp 1378 //===----------------------------------------------------------------------===// 1379 1380 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1381 MLIRContext *context) { 1382 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1383 } 1384 1385 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) { 1386 // Can always broadcast fewer than two shapes. 1387 if (adaptor.getShapes().size() < 2) { 1388 return BoolAttr::get(getContext(), true); 1389 } 1390 1391 return nullptr; 1392 } 1393 1394 //===----------------------------------------------------------------------===// 1395 // MeetOp 1396 //===----------------------------------------------------------------------===// 1397 1398 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1399 MLIRContext *context, std::optional<Location> location, 1400 MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1401 if (adaptor.getOperands().empty()) 1402 return failure(); 1403 1404 auto isShapeType = [](Type arg) { 1405 if (llvm::isa<ShapeType>(arg)) 1406 return true; 1407 return isExtentTensorType(arg); 1408 }; 1409 1410 ValueRange::type_range types = adaptor.getOperands().getTypes(); 1411 Type acc = types.front(); 1412 for (auto t : drop_begin(types)) { 1413 Type l = acc, r = t; 1414 if (!llvm::isa<ShapeType, SizeType>(l)) 1415 std::swap(l, r); 1416 1417 // Handle sizes, propagate error type if present. 1418 if (llvm::isa<SizeType>(l)) { 1419 if (llvm::isa<SizeType, IndexType>(r)) 1420 acc = l; 1421 else 1422 return emitOptionalError(location, "requires all sizes or shapes"); 1423 } else if (llvm::isa<IndexType>(l)) { 1424 if (llvm::isa<IndexType>(r)) 1425 acc = r; 1426 else 1427 return emitOptionalError(location, "requires all sizes or shapes"); 1428 } else if (llvm::isa<ShapeType>(l)) { 1429 // Handle shapes, propagate error type if present. 1430 if (isShapeType(r)) 1431 acc = l; 1432 else 1433 return emitOptionalError(location, "requires all sizes or shapes"); 1434 } else if (isExtentTensorType(l)) { 1435 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0]; 1436 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0]; 1437 if (ShapedType::isDynamic(rank1)) 1438 acc = l; 1439 else if (ShapedType::isDynamic(rank2)) 1440 acc = r; 1441 else if (rank1 != rank2) 1442 return emitOptionalError(location, "unequal shape cardinality"); 1443 else 1444 acc = l; 1445 } 1446 } 1447 inferredReturnTypes.assign({acc}); 1448 return success(); 1449 } 1450 1451 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1452 if (l.size() != 1 || r.size() != 1) 1453 return false; 1454 if (l == r) 1455 return true; 1456 1457 Type lhs = l.front(); 1458 Type rhs = r.front(); 1459 1460 if (!llvm::isa<ShapeType, SizeType>(lhs)) 1461 std::swap(lhs, rhs); 1462 1463 if (llvm::isa<SizeType>(lhs)) 1464 return llvm::isa<SizeType, IndexType>(rhs); 1465 if (llvm::isa<ShapeType>(lhs)) 1466 return llvm::isa<ShapeType, TensorType>(rhs); 1467 1468 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1469 return true; 1470 return false; 1471 } 1472 1473 //===----------------------------------------------------------------------===// 1474 // RankOp 1475 //===----------------------------------------------------------------------===// 1476 1477 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) { 1478 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape()); 1479 if (!shape) 1480 return {}; 1481 int64_t rank = shape.getNumElements(); 1482 Builder builder(getContext()); 1483 return builder.getIndexAttr(rank); 1484 } 1485 1486 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1487 /// Constant folding fails in cases where only the rank is constant, not the 1488 /// shape itself. 1489 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1490 /// 1491 /// Example: 1492 /// 1493 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1494 /// %rank = shape.rank %shape 1495 /// 1496 /// becomes 1497 /// 1498 /// %rank = shape.const_size 3 1499 1500 namespace { 1501 struct RankShapeOfCanonicalizationPattern 1502 : public OpRewritePattern<shape::RankOp> { 1503 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1504 1505 LogicalResult matchAndRewrite(shape::RankOp op, 1506 PatternRewriter &rewriter) const override { 1507 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1508 if (!shapeOfOp) 1509 return failure(); 1510 auto rankedTensorType = 1511 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType()); 1512 if (!rankedTensorType) 1513 return failure(); 1514 int64_t rank = rankedTensorType.getRank(); 1515 if (llvm::isa<IndexType>(op.getType())) { 1516 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1517 rank); 1518 } else if (llvm::isa<shape::SizeType>(op.getType())) { 1519 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1520 } else { 1521 return failure(); 1522 } 1523 return success(); 1524 } 1525 }; 1526 } // namespace 1527 1528 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1529 MLIRContext *context) { 1530 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1531 } 1532 1533 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1534 MLIRContext *context, std::optional<Location> location, 1535 RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1536 if (llvm::isa<ShapeType>(adaptor.getShape().getType())) 1537 inferredReturnTypes.assign({SizeType::get(context)}); 1538 else 1539 inferredReturnTypes.assign({IndexType::get(context)}); 1540 return success(); 1541 } 1542 1543 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1544 // SizeType is compatible with IndexType. 1545 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1546 } 1547 1548 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } 1549 1550 //===----------------------------------------------------------------------===// 1551 // NumElementsOp 1552 //===----------------------------------------------------------------------===// 1553 1554 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) { 1555 1556 // Fold only when argument constant. 1557 Attribute shape = adaptor.getShape(); 1558 if (!shape) 1559 return {}; 1560 1561 APInt product(64, 1); 1562 for (auto value : llvm::cast<DenseIntElementsAttr>(shape)) 1563 product *= value; 1564 Builder builder(getContext()); 1565 return builder.getIndexAttr(product.getLimitedValue()); 1566 } 1567 1568 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1569 MLIRContext *context, std::optional<Location> location, 1570 NumElementsOp::Adaptor adaptor, 1571 SmallVectorImpl<Type> &inferredReturnTypes) { 1572 if (llvm::isa<ShapeType>(adaptor.getShape().getType())) 1573 inferredReturnTypes.assign({SizeType::get(context)}); 1574 else 1575 inferredReturnTypes.assign({IndexType::get(context)}); 1576 return success(); 1577 } 1578 1579 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1580 TypeRange r) { 1581 // SizeType is compatible with IndexType. 1582 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1583 } 1584 1585 LogicalResult shape::NumElementsOp::verify() { 1586 return verifySizeOrIndexOp(*this); 1587 } 1588 1589 //===----------------------------------------------------------------------===// 1590 // MaxOp 1591 //===----------------------------------------------------------------------===// 1592 1593 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) { 1594 // If operands are equal, just propagate one. 1595 if (getLhs() == getRhs()) 1596 return getLhs(); 1597 return nullptr; 1598 } 1599 1600 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1601 MLIRContext *context, std::optional<Location> location, 1602 MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1603 if (adaptor.getLhs().getType() == adaptor.getRhs().getType()) 1604 inferredReturnTypes.assign({adaptor.getLhs().getType()}); 1605 else 1606 inferredReturnTypes.assign({SizeType::get(context)}); 1607 return success(); 1608 } 1609 1610 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1611 if (l.size() != 1 || r.size() != 1) 1612 return false; 1613 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front())) 1614 return true; 1615 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front())) 1616 return true; 1617 return false; 1618 } 1619 1620 //===----------------------------------------------------------------------===// 1621 // MinOp 1622 //===----------------------------------------------------------------------===// 1623 1624 OpFoldResult MinOp::fold(FoldAdaptor adaptor) { 1625 // If operands are equal, just propagate one. 1626 if (getLhs() == getRhs()) 1627 return getLhs(); 1628 return nullptr; 1629 } 1630 1631 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1632 MLIRContext *context, std::optional<Location> location, 1633 MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1634 if (adaptor.getLhs().getType() == adaptor.getRhs().getType()) 1635 inferredReturnTypes.assign({adaptor.getLhs().getType()}); 1636 else 1637 inferredReturnTypes.assign({SizeType::get(context)}); 1638 return success(); 1639 } 1640 1641 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1642 if (l.size() != 1 || r.size() != 1) 1643 return false; 1644 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front())) 1645 return true; 1646 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front())) 1647 return true; 1648 return false; 1649 } 1650 1651 //===----------------------------------------------------------------------===// 1652 // MulOp 1653 //===----------------------------------------------------------------------===// 1654 1655 OpFoldResult MulOp::fold(FoldAdaptor adaptor) { 1656 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs()); 1657 if (!lhs) 1658 return nullptr; 1659 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()); 1660 if (!rhs) 1661 return nullptr; 1662 APInt folded = lhs.getValue() * rhs.getValue(); 1663 Type indexTy = IndexType::get(getContext()); 1664 return IntegerAttr::get(indexTy, folded); 1665 } 1666 1667 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1668 MLIRContext *context, std::optional<Location> location, 1669 MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1670 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) || 1671 llvm::isa<SizeType>(adaptor.getRhs().getType())) 1672 inferredReturnTypes.assign({SizeType::get(context)}); 1673 else 1674 inferredReturnTypes.assign({IndexType::get(context)}); 1675 return success(); 1676 } 1677 1678 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1679 // SizeType is compatible with IndexType. 1680 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1681 } 1682 1683 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } 1684 1685 //===----------------------------------------------------------------------===// 1686 // ShapeOfOp 1687 //===----------------------------------------------------------------------===// 1688 1689 namespace { 1690 /// Replace shape_of(x) where x has a constant shape with a const_shape op. 1691 struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> { 1692 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1693 1694 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1695 PatternRewriter &rewriter) const override { 1696 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType()); 1697 if (!type || !type.hasStaticShape()) 1698 return failure(); 1699 Location loc = op.getLoc(); 1700 Value constShape = 1701 rewriter 1702 .create<ConstShapeOp>(loc, 1703 rewriter.getIndexTensorAttr(type.getShape())) 1704 .getResult(); 1705 if (constShape.getType() != op.getResult().getType()) 1706 constShape = rewriter.create<tensor::CastOp>( 1707 loc, op.getResult().getType(), constShape); 1708 rewriter.replaceOp(op, constShape); 1709 return success(); 1710 } 1711 }; 1712 1713 // Canonicalize 1714 // 1715 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> 1716 // %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex> 1717 // 1718 // to 1719 // 1720 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> 1721 // %1 = %shape 1722 // 1723 struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> { 1724 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1725 1726 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1727 PatternRewriter &rewriter) const override { 1728 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>(); 1729 if (!tensorReshapeOp) 1730 return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape"); 1731 if (!isa<TensorType>(op.getType())) 1732 return rewriter.notifyMatchFailure(op, "result is not a tensor"); 1733 1734 // Operand 'shape' of 'tensor.reshape' may now be used as the result of 1735 // 'shape.shape_of'. While its type is guaranteed to be compatible in well- 1736 // formed IR, it may not be identical (dynamically vs statically shaped), 1737 // in which case it needs to be cast first. 1738 Value shape = tensorReshapeOp.getShape(); 1739 if (op.getType() != shape.getType()) 1740 shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape); 1741 1742 rewriter.replaceOp(op, shape); 1743 return success(); 1744 } 1745 }; 1746 1747 // Canonicalize 1748 // ``` 1749 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1750 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1751 // ``` 1752 // to 1753 // ``` 1754 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1755 // ``` 1756 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1757 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1758 1759 LogicalResult matchAndRewrite(tensor::CastOp op, 1760 PatternRewriter &rewriter) const override { 1761 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType()); 1762 if (!ty || ty.getRank() != 1) 1763 return failure(); 1764 1765 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>(); 1766 if (!shapeOfOp) 1767 return failure(); 1768 1769 // Argument type must be ranked and must not conflict. 1770 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType()); 1771 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1772 return failure(); 1773 1774 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1775 return success(); 1776 } 1777 }; 1778 } // namespace 1779 1780 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1781 MLIRContext *context) { 1782 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape, 1783 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>( 1784 context); 1785 } 1786 1787 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1788 MLIRContext *context, std::optional<Location> location, 1789 ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { 1790 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType())) 1791 inferredReturnTypes.assign({ShapeType::get(context)}); 1792 else { 1793 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType()); 1794 int64_t rank = 1795 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic; 1796 Type indexTy = IndexType::get(context); 1797 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1798 inferredReturnTypes.assign({extentTensorTy}); 1799 } 1800 return success(); 1801 } 1802 1803 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1804 if (l.size() != 1 || r.size() != 1) 1805 return false; 1806 if (l == r) 1807 return true; 1808 1809 Type lhs = l.front(); 1810 Type rhs = r.front(); 1811 1812 if (!llvm::isa<ShapeType, ShapedType>(lhs) || 1813 !llvm::isa<ShapeType, ShapedType>(rhs)) 1814 return false; 1815 1816 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs)) 1817 // Shape type is compatible with all other valid return types. 1818 return true; 1819 1820 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1821 return true; 1822 return false; 1823 } 1824 1825 LogicalResult shape::ShapeOfOp::verify() { 1826 return verifyShapeOrExtentTensorOp(*this); 1827 } 1828 1829 //===----------------------------------------------------------------------===// 1830 // SizeToIndexOp 1831 //===----------------------------------------------------------------------===// 1832 1833 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) { 1834 // Constant values of both types, `shape.size` and `index`, are represented as 1835 // `IntegerAttr`s which makes constant folding simple. 1836 if (Attribute arg = adaptor.getArg()) 1837 return arg; 1838 return OpFoldResult(); 1839 } 1840 1841 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1842 MLIRContext *context) { 1843 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1844 } 1845 1846 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1847 if (inputs.size() != 1 || outputs.size() != 1) 1848 return false; 1849 return llvm::isa<IndexType, SizeType>(inputs[0]) && 1850 llvm::isa<IndexType>(outputs[0]); 1851 } 1852 1853 //===----------------------------------------------------------------------===// 1854 // YieldOp 1855 //===----------------------------------------------------------------------===// 1856 1857 LogicalResult shape::YieldOp::verify() { 1858 auto *parentOp = (*this)->getParentOp(); 1859 auto results = parentOp->getResults(); 1860 auto operands = getOperands(); 1861 1862 if (parentOp->getNumResults() != getNumOperands()) 1863 return emitOpError() << "number of operands does not match number of " 1864 "results of its parent"; 1865 for (auto e : llvm::zip(results, operands)) 1866 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1867 return emitOpError() << "types mismatch between yield op and its parent"; 1868 1869 return success(); 1870 } 1871 1872 //===----------------------------------------------------------------------===// 1873 // SplitAtOp 1874 //===----------------------------------------------------------------------===// 1875 1876 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor, 1877 SmallVectorImpl<OpFoldResult> &results) { 1878 if (!adaptor.getOperand() || !adaptor.getIndex()) 1879 return failure(); 1880 auto shapeVec = llvm::to_vector<6>( 1881 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>()); 1882 auto shape = llvm::ArrayRef(shapeVec); 1883 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt(); 1884 // Verify that the split point is in the correct range. 1885 // TODO: Constant fold to an "error". 1886 int64_t rank = shape.size(); 1887 if (-rank > splitPoint || splitPoint > rank) 1888 return failure(); 1889 if (splitPoint < 0) 1890 splitPoint += shape.size(); 1891 Builder builder(adaptor.getOperand().getContext()); 1892 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1893 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1894 return success(); 1895 } 1896 1897 //===----------------------------------------------------------------------===// 1898 // ToExtentTensorOp 1899 //===----------------------------------------------------------------------===// 1900 1901 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) { 1902 if (!adaptor.getInput()) 1903 return OpFoldResult(); 1904 Builder builder(getContext()); 1905 auto shape = llvm::to_vector<6>( 1906 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>()); 1907 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1908 builder.getIndexType()); 1909 return DenseIntElementsAttr::get(type, shape); 1910 } 1911 1912 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1913 if (inputs.size() != 1 || outputs.size() != 1) 1914 return false; 1915 if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) { 1916 if (!llvm::isa<IndexType>(inputTensor.getElementType()) || 1917 inputTensor.getRank() != 1) 1918 return false; 1919 } else if (!llvm::isa<ShapeType>(inputs[0])) { 1920 return false; 1921 } 1922 1923 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]); 1924 return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType()); 1925 } 1926 1927 //===----------------------------------------------------------------------===// 1928 // ReduceOp 1929 //===----------------------------------------------------------------------===// 1930 1931 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1932 ValueRange initVals) { 1933 OpBuilder::InsertionGuard g(builder); 1934 result.addOperands(shape); 1935 result.addOperands(initVals); 1936 1937 Region *bodyRegion = result.addRegion(); 1938 Block *bodyBlock = builder.createBlock( 1939 bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location); 1940 1941 Type elementType; 1942 if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType())) 1943 elementType = tensorType.getElementType(); 1944 else 1945 elementType = SizeType::get(builder.getContext()); 1946 bodyBlock->addArgument(elementType, shape.getLoc()); 1947 1948 for (Value initVal : initVals) { 1949 bodyBlock->addArgument(initVal.getType(), initVal.getLoc()); 1950 result.addTypes(initVal.getType()); 1951 } 1952 } 1953 1954 LogicalResult ReduceOp::verify() { 1955 // Verify block arg types. 1956 Block &block = getRegion().front(); 1957 1958 // The block takes index, extent, and aggregated values as arguments. 1959 auto blockArgsCount = getInitVals().size() + 2; 1960 if (block.getNumArguments() != blockArgsCount) 1961 return emitOpError() << "ReduceOp body is expected to have " 1962 << blockArgsCount << " arguments"; 1963 1964 // The first block argument is the index and must always be of type `index`. 1965 if (!llvm::isa<IndexType>(block.getArgument(0).getType())) 1966 return emitOpError( 1967 "argument 0 of ReduceOp body is expected to be of IndexType"); 1968 1969 // The second block argument is the extent and must be of type `size` or 1970 // `index`, depending on whether the reduce operation is applied to a shape or 1971 // to an extent tensor. 1972 Type extentTy = block.getArgument(1).getType(); 1973 if (llvm::isa<ShapeType>(getShape().getType())) { 1974 if (!llvm::isa<SizeType>(extentTy)) 1975 return emitOpError("argument 1 of ReduceOp body is expected to be of " 1976 "SizeType if the ReduceOp operates on a ShapeType"); 1977 } else { 1978 if (!llvm::isa<IndexType>(extentTy)) 1979 return emitOpError( 1980 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1981 "ReduceOp operates on an extent tensor"); 1982 } 1983 1984 for (const auto &type : llvm::enumerate(getInitVals())) 1985 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1986 return emitOpError() << "type mismatch between argument " 1987 << type.index() + 2 1988 << " of ReduceOp body and initial value " 1989 << type.index(); 1990 return success(); 1991 } 1992 1993 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { 1994 // Parse operands. 1995 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 1996 Type shapeOrExtentTensorType; 1997 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1998 OpAsmParser::Delimiter::Paren) || 1999 parser.parseColonType(shapeOrExtentTensorType) || 2000 parser.parseOptionalArrowTypeList(result.types)) 2001 return failure(); 2002 2003 // Resolve operands. 2004 auto initVals = llvm::ArrayRef(operands).drop_front(); 2005 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 2006 result.operands) || 2007 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 2008 result.operands)) 2009 return failure(); 2010 2011 // Parse the body. 2012 Region *body = result.addRegion(); 2013 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 2014 return failure(); 2015 2016 // Parse attributes. 2017 if (parser.parseOptionalAttrDict(result.attributes)) 2018 return failure(); 2019 2020 return success(); 2021 } 2022 2023 void ReduceOp::print(OpAsmPrinter &p) { 2024 p << '(' << getShape() << ", " << getInitVals() 2025 << ") : " << getShape().getType(); 2026 p.printOptionalArrowTypeList(getResultTypes()); 2027 p << ' '; 2028 p.printRegion(getRegion()); 2029 p.printOptionalAttrDict((*this)->getAttrs()); 2030 } 2031 2032 #define GET_OP_CLASSES 2033 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 2034 2035 #define GET_TYPEDEF_CLASSES 2036 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc" 2037