xref: /llvm-project/mlir/lib/Dialect/Shape/IR/Shape.cpp (revision 1e18815fdc13bb1f8b0b87acd8abf62b5cf70d53)
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15 #include "mlir/Dialect/CommonFolders.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Traits.h"
18 #include "mlir/Dialect/UB/IR/UBOps.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Interfaces/FunctionImplementation.h"
26 #include "mlir/Transforms/InliningUtils.h"
27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 using namespace mlir;
33 using namespace mlir::shape;
34 
35 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
36 
37 namespace {
38 #include "ShapeCanonicalization.inc"
39 } // namespace
40 
41 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
42   return RankedTensorType::get({rank}, IndexType::get(ctx));
43 }
44 
45 bool shape::isExtentTensorType(Type type) {
46   auto ranked = llvm::dyn_cast<RankedTensorType>(type);
47   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
48 }
49 
50 LogicalResult shape::getShapeVec(Value input,
51                                  SmallVectorImpl<int64_t> &shapeValues) {
52   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
53     auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
54     if (!type.hasRank())
55       return failure();
56     llvm::append_range(shapeValues, type.getShape());
57     return success();
58   }
59   DenseIntElementsAttr attr;
60   if (matchPattern(input, m_Constant(&attr))) {
61     llvm::append_range(shapeValues, attr.getValues<int64_t>());
62     return success();
63   }
64   return failure();
65 }
66 
67 static bool isErrorPropagationPossible(TypeRange operandTypes) {
68   return llvm::any_of(operandTypes,
69                       llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
70 }
71 
72 static LogicalResult verifySizeOrIndexOp(Operation *op) {
73   assert(op != nullptr && op->getNumResults() == 1);
74   Type resultTy = op->getResultTypes().front();
75   if (isErrorPropagationPossible(op->getOperandTypes())) {
76     if (!llvm::isa<SizeType>(resultTy))
77       return op->emitOpError()
78              << "if at least one of the operands can hold error values then "
79                 "the result must be of type `size` to propagate them";
80   }
81   return success();
82 }
83 
84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
85   assert(op != nullptr && op->getNumResults() == 1);
86   Type resultTy = op->getResultTypes().front();
87   if (isErrorPropagationPossible(op->getOperandTypes())) {
88     if (!llvm::isa<ShapeType>(resultTy))
89       return op->emitOpError()
90              << "if at least one of the operands can hold error values then "
91                 "the result must be of type `shape` to propagate them";
92   }
93   return success();
94 }
95 
96 template <typename... Ty>
97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98   return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
99 }
100 
101 template <typename... Ty, typename... ranges>
102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // InlinerInterface
108 //===----------------------------------------------------------------------===//
109 
110 namespace {
111 /// This class defines the interface for inlining shape dialect ops.
112 struct ShapeInlinerInterface : public DialectInlinerInterface {
113   using DialectInlinerInterface::DialectInlinerInterface;
114 
115   // Returns true if the given region 'src' can be inlined into the region
116   // 'dest' that is attached to an operation registered to the current dialect.
117   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118                        IRMapping &) const final {
119     return true;
120   }
121 
122   // Returns true if the given operation 'op', that is registered to this
123   // dialect, can be inlined into the region 'dest' that is attached to an
124   // operation registered to the current dialect.
125   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126                        IRMapping &) const final {
127     return true;
128   }
129 };
130 } // namespace
131 
132 void ShapeDialect::initialize() {
133   addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
136       >();
137   addTypes<
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140       >();
141   addInterfaces<ShapeInlinerInterface>();
142   // Allow unknown operations during prototyping and testing. As the dialect is
143   // still evolving it makes it simple to start with an unregistered ops and
144   // try different variants before actually defining the op.
145   allowUnknownOperations();
146   declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
147                             AssumingYieldOp>();
148 }
149 
150 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
151                                              Attribute value, Type type,
152                                              Location loc) {
153   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
154     return builder.create<ub::PoisonOp>(loc, type, poison);
155 
156   if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
157     return builder.create<ConstShapeOp>(
158         loc, type, llvm::cast<DenseIntElementsAttr>(value));
159   if (llvm::isa<SizeType>(type))
160     return builder.create<ConstSizeOp>(loc, type,
161                                        llvm::cast<IntegerAttr>(value));
162   if (llvm::isa<WitnessType>(type))
163     return builder.create<ConstWitnessOp>(loc, type,
164                                           llvm::cast<BoolAttr>(value));
165 
166   return arith::ConstantOp::materialize(builder, value, type, loc);
167 }
168 
169 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
170                                                      NamedAttribute attribute) {
171   // Verify shape.lib attribute.
172   if (attribute.getName() == "shape.lib") {
173     if (!op->hasTrait<OpTrait::SymbolTable>())
174       return op->emitError(
175           "shape.lib attribute may only be on op implementing SymbolTable");
176 
177     if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
178       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
179       if (!symbol)
180         return op->emitError("shape function library ")
181                << symbolRef << " not found";
182       return isa<shape::FunctionLibraryOp>(symbol)
183                  ? success()
184                  : op->emitError()
185                        << symbolRef << " required to be shape function library";
186     }
187 
188     if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
189       // Verify all entries are function libraries and mappings in libraries
190       // refer to unique ops.
191       DenseSet<StringAttr> key;
192       for (auto it : arr) {
193         if (!llvm::isa<SymbolRefAttr>(it))
194           return op->emitError(
195               "only SymbolRefAttr allowed in shape.lib attribute array");
196 
197         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
198             SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
199         if (!shapeFnLib)
200           return op->emitError()
201                  << it << " does not refer to FunctionLibraryOp";
202         for (auto mapping : shapeFnLib.getMapping()) {
203           if (!key.insert(mapping.getName()).second) {
204             return op->emitError("only one op to shape mapping allowed, found "
205                                  "multiple for `")
206                    << mapping.getName() << "`";
207           }
208         }
209       }
210       return success();
211     }
212 
213     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
214                          "allowed as shape.lib attribute");
215   }
216   return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // AnyOp
221 //===----------------------------------------------------------------------===//
222 
223 // TODO: Canonicalization should be implemented for shapes that can be
224 // determined through mixtures of the known dimensions of the inputs.
225 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
226   // Only the last operand is checked because AnyOp is commutative.
227   if (adaptor.getInputs().back())
228     return adaptor.getInputs().back();
229 
230   return nullptr;
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // AssumingOp
235 //===----------------------------------------------------------------------===//
236 
237 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
238   result.regions.reserve(1);
239   Region *doRegion = result.addRegion();
240 
241   auto &builder = parser.getBuilder();
242   OpAsmParser::UnresolvedOperand cond;
243   if (parser.parseOperand(cond) ||
244       parser.resolveOperand(cond, builder.getType<WitnessType>(),
245                             result.operands))
246     return failure();
247 
248   // Parse optional results type list.
249   if (parser.parseOptionalArrowTypeList(result.types))
250     return failure();
251 
252   // Parse the region and add a terminator if elided.
253   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
254     return failure();
255   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
256 
257   // Parse the optional attribute list.
258   if (parser.parseOptionalAttrDict(result.attributes))
259     return failure();
260   return success();
261 }
262 
263 void AssumingOp::print(OpAsmPrinter &p) {
264   bool yieldsResults = !getResults().empty();
265 
266   p << " " << getWitness();
267   if (yieldsResults)
268     p << " -> (" << getResultTypes() << ")";
269   p << ' ';
270   p.printRegion(getDoRegion(),
271                 /*printEntryBlockArgs=*/false,
272                 /*printBlockTerminators=*/yieldsResults);
273   p.printOptionalAttrDict((*this)->getAttrs());
274 }
275 
276 namespace {
277 // Removes AssumingOp with a passing witness and inlines the region.
278 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
279   using OpRewritePattern<AssumingOp>::OpRewritePattern;
280 
281   LogicalResult matchAndRewrite(AssumingOp op,
282                                 PatternRewriter &rewriter) const override {
283     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
284     if (!witness || !witness.getPassingAttr())
285       return failure();
286 
287     AssumingOp::inlineRegionIntoParent(op, rewriter);
288     return success();
289   }
290 };
291 
292 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
293   using OpRewritePattern<AssumingOp>::OpRewritePattern;
294 
295   LogicalResult matchAndRewrite(AssumingOp op,
296                                 PatternRewriter &rewriter) const override {
297     Block *body = op.getBody();
298     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
299 
300     // Find used values.
301     SmallVector<Value, 4> newYieldOperands;
302     for (auto [opResult, yieldOperand] :
303          llvm::zip(op.getResults(), yieldOp.getOperands())) {
304       if (!opResult.getUses().empty()) {
305         newYieldOperands.push_back(yieldOperand);
306       }
307     }
308 
309     // Rewrite only if redundant results exist.
310     if (newYieldOperands.size() == yieldOp->getNumOperands())
311       return failure();
312 
313     // Replace yield op in the old assuming op's body and move the entire region
314     // to the new assuming op.
315     rewriter.setInsertionPointToEnd(body);
316     auto newYieldOp =
317         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
318     rewriter.setInsertionPoint(op);
319     auto newOp = rewriter.create<AssumingOp>(
320         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
321     newOp.getDoRegion().takeBody(op.getDoRegion());
322 
323     // Use the new results to replace the previously used ones.
324     SmallVector<Value, 4> replacementValues;
325     auto src = newOp.getResults().begin();
326     for (auto it : op.getResults()) {
327       if (it.getUses().empty())
328         replacementValues.push_back(nullptr);
329       else
330         replacementValues.push_back(*src++);
331     }
332     rewriter.replaceOp(op, replacementValues);
333     return success();
334   }
335 };
336 } // namespace
337 
338 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
339                                              MLIRContext *context) {
340   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
341 }
342 
343 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
344 void AssumingOp::getSuccessorRegions(
345     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
346   // AssumingOp has unconditional control flow into the region and back to the
347   // parent, so return the correct RegionSuccessor purely based on the index
348   // being None or 0.
349   if (!point.isParent()) {
350     regions.push_back(RegionSuccessor(getResults()));
351     return;
352   }
353 
354   regions.push_back(RegionSuccessor(&getDoRegion()));
355 }
356 
357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
358                                         PatternRewriter &rewriter) {
359   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
360   auto *assumingBlock = op.getBody();
361   auto initPosition = rewriter.getInsertionPoint();
362   auto *blockAfterAssuming =
363       rewriter.splitBlock(blockBeforeAssuming, initPosition);
364 
365   // Remove the AssumingOp and AssumingYieldOp.
366   auto &yieldOp = assumingBlock->back();
367   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
368   rewriter.replaceOp(op, yieldOp.getOperands());
369   rewriter.eraseOp(&yieldOp);
370 
371   // Merge blocks together as there was no branching behavior from the
372   // AssumingOp.
373   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
374   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
375 }
376 
377 void AssumingOp::build(
378     OpBuilder &builder, OperationState &result, Value witness,
379     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
380   OpBuilder::InsertionGuard g(builder);
381 
382   result.addOperands(witness);
383   Region *bodyRegion = result.addRegion();
384   builder.createBlock(bodyRegion);
385 
386   // Build body.
387   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
388   builder.create<AssumingYieldOp>(result.location, yieldValues);
389 
390   SmallVector<Type, 2> assumingTypes;
391   for (Value v : yieldValues)
392     assumingTypes.push_back(v.getType());
393   result.addTypes(assumingTypes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // AddOp
398 //===----------------------------------------------------------------------===//
399 
400 LogicalResult mlir::shape::AddOp::inferReturnTypes(
401     MLIRContext *context, std::optional<Location> location,
402     AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
403   if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404       llvm::isa<SizeType>(adaptor.getRhs().getType()))
405     inferredReturnTypes.assign({SizeType::get(context)});
406   else
407     inferredReturnTypes.assign({IndexType::get(context)});
408   return success();
409 }
410 
411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
412   // SizeType is compatible with IndexType.
413   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
414 }
415 
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
417   // add(x, 0) -> x
418   if (matchPattern(getRhs(), m_Zero()))
419     return getLhs();
420 
421   return constFoldBinaryOp<IntegerAttr>(
422       adaptor.getOperands(),
423       [](APInt a, const APInt &b) { return std::move(a) + b; });
424 }
425 
426 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
427 
428 //===----------------------------------------------------------------------===//
429 // AssumingAllOp
430 //===----------------------------------------------------------------------===//
431 
432 namespace {
433 
434 // Merge multiple `shape.assuming_all` operations together.
435 //
436 //   %0 = shape.assuming_all %w0, %w1
437 //   %1 = shape.assuming_all %w2, %0
438 //
439 // to:
440 //
441 //   %0 = shape.assuming_all %w0, %w2, %w2
442 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
443   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
444 
445   LogicalResult matchAndRewrite(AssumingAllOp op,
446                                 PatternRewriter &rewriter) const override {
447     SmallVector<Value> operands;
448 
449     for (Value operand : op.getInputs()) {
450       if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451         operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452       else
453         operands.push_back(operand);
454     }
455 
456     // We didn't find any other `assuming_all` ops to merge with.
457     if (operands.size() == op.getNumOperands())
458       return failure();
459 
460     // Replace with a new `assuming_all` operation with merged constraints.
461     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
462     return success();
463   }
464 };
465 
466 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
467 // are subsumed by others.
468 //
469 //   %0 = shape.cstr_broadcastable %shape0, %shape1
470 //   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
471 //
472 //   %2 = shape.cstr_broadcastable %shape3, %shape4
473 //   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
474 //
475 //   %4 = shape.assuming_all %0, %1, %2, %3
476 //
477 // to:
478 //
479 //   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
480 //   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
481 //   %2 = shape.assuming_all %0, %1
482 //
483 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
484 // shapes [0, 1] are broadcastable too, and can be removed from the list of
485 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
486 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
488   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
489 
490   LogicalResult matchAndRewrite(AssumingAllOp op,
491                                 PatternRewriter &rewriter) const override {
492     // Collect all `CstrBroadcastableOp` operands first.
493     SetVector<CstrBroadcastableOp> operands;
494     for (Value operand : op.getInputs()) {
495       // TODO: Apply this optimization if some of the witnesses are not
496       // produced by the `cstr_broadcastable`.
497       auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498       if (!broadcastable)
499         return failure();
500 
501       operands.insert(broadcastable);
502     }
503 
504     // Skip trivial `assuming_all` operations.
505     if (operands.size() <= 1)
506       return failure();
507 
508     // Collect shapes checked by `cstr_broadcastable` operands.
509     SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
510     for (auto cstr : operands) {
511       DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
512       shapes.emplace_back(cstr, std::move(shapesSet));
513     }
514 
515     // Sort by the number of shape operands (larger to smaller).
516     llvm::sort(shapes, [](auto a, auto b) {
517       return a.first.getNumOperands() > b.first.getNumOperands();
518     });
519 
520     // We start from the `cst_broadcastable` operations with largest number of
521     // shape operands, and remove redundant `cst_broadcastable` operations. We
522     // do this until we find a set of `cst_broadcastable` operations with
523     // non-overlapping constraints.
524     SmallVector<CstrBroadcastableOp> markedForErase;
525 
526     for (unsigned i = 0; i < shapes.size(); ++i) {
527       auto isSubset = [&](auto pair) {
528         return llvm::set_is_subset(pair.second, shapes[i].second);
529       };
530 
531       // Keep redundant `cstr_broadcastable` operations to be erased.
532       auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533       for (auto *it0 = it; it0 < shapes.end(); ++it0)
534         markedForErase.push_back(it0->first);
535       shapes.erase(it, shapes.end());
536     }
537 
538     // We didn't find any operands that could be removed.
539     if (markedForErase.empty())
540       return failure();
541 
542     // Collect non-overlapping `cst_broadcastable` constraints.
543     SmallVector<Value> uniqueConstraints;
544     for (auto &shape : shapes)
545       uniqueConstraints.push_back(shape.first.getResult());
546 
547     // Replace with a new `assuming_all` operation ...
548     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
549 
550     // ... and maybe erase `cstr_broadcastable` ops without uses.
551     for (auto &op : markedForErase)
552       if (op->use_empty())
553         rewriter.eraseOp(op);
554 
555     return success();
556   }
557 };
558 
559 struct AssumingAllToCstrEqCanonicalization
560     : public OpRewritePattern<AssumingAllOp> {
561   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
562 
563   LogicalResult matchAndRewrite(AssumingAllOp op,
564                                 PatternRewriter &rewriter) const override {
565     SmallVector<Value, 8> shapes;
566     for (Value w : op.getInputs()) {
567       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
568       if (!cstrEqOp)
569         return failure();
570       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
571         return llvm::is_contained(shapes, s);
572       });
573       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574         return failure();
575       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
576     }
577     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
578     return success();
579   }
580 };
581 
582 template <typename OpTy>
583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
584   using OpRewritePattern<OpTy>::OpRewritePattern;
585 
586   LogicalResult matchAndRewrite(OpTy op,
587                                 PatternRewriter &rewriter) const override {
588     // Find unique operands.
589     SetVector<Value> unique(op.operand_begin(), op.operand_end());
590 
591     // Reduce op to equivalent with unique operands.
592     if (unique.size() < op.getNumOperands()) {
593       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
594                                         unique.takeVector(), op->getAttrs());
595       return success();
596     }
597 
598     return failure();
599   }
600 };
601 } // namespace
602 
603 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604                                                 MLIRContext *context) {
605   patterns
606       .add<MergeAssumingAllOps, AssumingAllOneOp,
607            AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608            RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
609 }
610 
611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
612   // Iterate in reverse to first handle all constant operands. They are
613   // guaranteed to be the tail of the inputs because this is commutative.
614   for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
615     Attribute a = adaptor.getInputs()[idx];
616     // Cannot fold if any inputs are not constant;
617     if (!a)
618       return nullptr;
619 
620     // We do not need to keep statically known values after handling them in
621     // this method.
622     getOperation()->eraseOperand(idx);
623 
624     // Always false if any input is statically known false
625     if (!llvm::cast<BoolAttr>(a).getValue())
626       return a;
627   }
628   // If this is reached, all inputs were statically known passing.
629   return BoolAttr::get(getContext(), true);
630 }
631 
632 LogicalResult AssumingAllOp::verify() {
633   // Ensure that AssumingAllOp contains at least one operand
634   if (getNumOperands() == 0)
635     return emitOpError("no operands specified");
636 
637   return success();
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // BroadcastOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
645   if (getShapes().size() == 1) {
646     // Otherwise, we need a cast which would be a canonicalization, not folding.
647     if (getShapes().front().getType() != getType())
648       return nullptr;
649     return getShapes().front();
650   }
651 
652   // TODO: Support folding with more than 2 input shapes
653   if (getShapes().size() > 2)
654     return nullptr;
655 
656   if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
657     return nullptr;
658   auto lhsShape = llvm::to_vector<6>(
659       llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
660           .getValues<int64_t>());
661   auto rhsShape = llvm::to_vector<6>(
662       llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
663           .getValues<int64_t>());
664   SmallVector<int64_t, 6> resultShape;
665 
666   // If the shapes are not compatible, we can't fold it.
667   // TODO: Fold to an "error".
668   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
669     return nullptr;
670 
671   Builder builder(getContext());
672   return builder.getIndexTensorAttr(resultShape);
673 }
674 
675 LogicalResult BroadcastOp::verify() {
676   return verifyShapeOrExtentTensorOp(*this);
677 }
678 
679 namespace {
680 template <typename OpTy>
681 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
682   using OpRewritePattern<OpTy>::OpRewritePattern;
683 
684   LogicalResult matchAndRewrite(OpTy op,
685                                 PatternRewriter &rewriter) const override {
686     auto isPotentiallyNonEmptyShape = [](Value shape) {
687       if (auto extentTensorTy =
688               llvm::dyn_cast<RankedTensorType>(shape.getType())) {
689         if (extentTensorTy.getDimSize(0) == 0)
690           return false;
691       }
692       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
693         if (constShape.getShape().empty())
694           return false;
695       }
696       return true;
697     };
698     auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
699                                                  isPotentiallyNonEmptyShape);
700 
701     // Replace the op with empty shape constant if all operants are reduced to
702     // be empty.
703     if (newOperands.empty()) {
704       rewriter.replaceOpWithNewOp<ConstShapeOp>(
705           op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
706       return success();
707     }
708 
709     // Reduce op to equivalent without empty shape operands.
710     if (newOperands.size() < op.getNumOperands()) {
711       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
712                                         op->getAttrs());
713       return success();
714     }
715 
716     return failure();
717   }
718 };
719 
720 struct BroadcastForwardSingleOperandPattern
721     : public OpRewritePattern<BroadcastOp> {
722   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
723 
724   LogicalResult matchAndRewrite(BroadcastOp op,
725                                 PatternRewriter &rewriter) const override {
726     if (op.getNumOperands() != 1)
727       return failure();
728     Value replacement = op.getShapes().front();
729 
730     // Insert cast if needed.
731     if (replacement.getType() != op.getType()) {
732       auto loc = op.getLoc();
733       if (llvm::isa<ShapeType>(op.getType())) {
734         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
735       } else {
736         assert(!llvm::isa<ShapeType>(op.getType()) &&
737                !llvm::isa<ShapeType>(replacement.getType()) &&
738                "expect extent tensor cast");
739         replacement =
740             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
741       }
742     }
743 
744     rewriter.replaceOp(op, replacement);
745     return success();
746   }
747 };
748 
749 struct BroadcastFoldConstantOperandsPattern
750     : public OpRewritePattern<BroadcastOp> {
751   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
752 
753   LogicalResult matchAndRewrite(BroadcastOp op,
754                                 PatternRewriter &rewriter) const override {
755     SmallVector<int64_t, 8> foldedConstantShape;
756     SmallVector<Value, 8> newShapeOperands;
757     for (Value shape : op.getShapes()) {
758       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
759         SmallVector<int64_t, 8> newFoldedConstantShape;
760         if (OpTrait::util::getBroadcastedShape(
761                 foldedConstantShape,
762                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
763                 newFoldedConstantShape)) {
764           foldedConstantShape = newFoldedConstantShape;
765           continue;
766         }
767       }
768       newShapeOperands.push_back(shape);
769     }
770 
771     // Need at least two constant operands to fold anything.
772     if (op.getNumOperands() - newShapeOperands.size() < 2)
773       return failure();
774 
775     auto foldedConstantOperandsTy = RankedTensorType::get(
776         {static_cast<int64_t>(foldedConstantShape.size())},
777         rewriter.getIndexType());
778     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
779         op.getLoc(), foldedConstantOperandsTy,
780         rewriter.getIndexTensorAttr(foldedConstantShape)));
781     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
782                                              newShapeOperands);
783     return success();
784   }
785 };
786 
787 template <typename OpTy>
788 struct CanonicalizeCastExtentTensorOperandsPattern
789     : public OpRewritePattern<OpTy> {
790   using OpRewritePattern<OpTy>::OpRewritePattern;
791 
792   LogicalResult matchAndRewrite(OpTy op,
793                                 PatternRewriter &rewriter) const override {
794     // Canonicalize operands.
795     bool anyChange = false;
796     auto canonicalizeOperand = [&](Value operand) -> Value {
797       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
798         // Only eliminate the cast if it holds no shape information.
799         bool isInformationLoosingCast =
800             llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
801         if (isInformationLoosingCast) {
802           anyChange = true;
803           return castOp.getSource();
804         }
805       }
806       return operand;
807     };
808     auto newOperands = llvm::to_vector<8>(
809         llvm::map_range(op.getOperands(), canonicalizeOperand));
810 
811     // Rewrite op if any change required.
812     if (!anyChange)
813       return failure();
814     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
815     return success();
816   }
817 };
818 
819 struct BroadcastConcretizeResultTypePattern
820     : public OpRewritePattern<BroadcastOp> {
821   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
822 
823   LogicalResult matchAndRewrite(BroadcastOp op,
824                                 PatternRewriter &rewriter) const override {
825     // Only concretize dynamic extent tensor result types.
826     auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
827     if (!resultTy || !resultTy.isDynamicDim(0))
828       return failure();
829 
830     // Infer resulting shape rank if possible.
831     int64_t maxRank = 0;
832     for (Value shape : op.getShapes()) {
833       if (auto extentTensorTy =
834               llvm::dyn_cast<RankedTensorType>(shape.getType())) {
835         // Cannot infer resulting shape rank if any operand is dynamically
836         // ranked.
837         if (extentTensorTy.isDynamicDim(0))
838           return failure();
839         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
840       }
841     }
842 
843     auto newOp = rewriter.create<BroadcastOp>(
844         op.getLoc(), getExtentTensorType(getContext(), maxRank),
845         op.getShapes());
846     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
847     return success();
848   }
849 };
850 } // namespace
851 
852 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
853                                               MLIRContext *context) {
854   patterns.add<BroadcastConcretizeResultTypePattern,
855                BroadcastFoldConstantOperandsPattern,
856                BroadcastForwardSingleOperandPattern,
857                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
858                RemoveDuplicateOperandsPattern<BroadcastOp>,
859                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
860 }
861 
862 //===----------------------------------------------------------------------===//
863 // ConcatOp
864 //===----------------------------------------------------------------------===//
865 
866 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
867   if (!adaptor.getLhs() || !adaptor.getRhs())
868     return nullptr;
869   auto lhsShape = llvm::to_vector<6>(
870       llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
871   auto rhsShape = llvm::to_vector<6>(
872       llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
873   SmallVector<int64_t, 6> resultShape;
874   resultShape.append(lhsShape.begin(), lhsShape.end());
875   resultShape.append(rhsShape.begin(), rhsShape.end());
876   Builder builder(getContext());
877   return builder.getIndexTensorAttr(resultShape);
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // ConstShapeOp
882 //===----------------------------------------------------------------------===//
883 
884 void ConstShapeOp::print(OpAsmPrinter &p) {
885   p << " ";
886   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
887   p << "[";
888   interleaveComma(getShape().getValues<int64_t>(), p);
889   p << "] : ";
890   p.printType(getType());
891 }
892 
893 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
894   if (parser.parseOptionalAttrDict(result.attributes))
895     return failure();
896   // We piggy-back on ArrayAttr parsing, though we don't internally store the
897   // shape as an ArrayAttr.
898   // TODO: Implement custom parser and maybe make syntax a bit more concise.
899   Attribute extentsRaw;
900   NamedAttrList dummy;
901   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
902     return failure();
903   auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
904   if (!extentsArray)
905     return failure();
906   SmallVector<int64_t, 6> ints;
907   for (Attribute extent : extentsArray) {
908     IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
909     if (!attr)
910       return failure();
911     ints.push_back(attr.getInt());
912   }
913   Builder &builder = parser.getBuilder();
914   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
915   Type resultTy;
916   if (parser.parseColonType(resultTy))
917     return failure();
918   result.types.push_back(resultTy);
919   return success();
920 }
921 
922 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
923 
924 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
925                                                MLIRContext *context) {
926   patterns.add<TensorCastConstShape>(context);
927 }
928 
929 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
930     MLIRContext *context, std::optional<Location> location,
931     ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
932   Builder b(context);
933   const Properties prop = adaptor.getProperties();
934   inferredReturnTypes.assign({RankedTensorType::get(
935       {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
936   return success();
937 }
938 
939 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
940                                                         TypeRange r) {
941   if (l.size() != 1 || r.size() != 1)
942     return false;
943 
944   Type lhs = l.front();
945   Type rhs = r.front();
946 
947   if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
948     // Shape type is compatible with all other valid return types.
949     return true;
950   return lhs == rhs;
951 }
952 
953 //===----------------------------------------------------------------------===//
954 // CstrBroadcastableOp
955 //===----------------------------------------------------------------------===//
956 
957 void CstrBroadcastableOp::getCanonicalizationPatterns(
958     RewritePatternSet &patterns, MLIRContext *context) {
959   // Canonicalization patterns have overlap with the considerations during
960   // folding in case additional shape information is inferred at some point that
961   // does not result in folding.
962   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
963                CstrBroadcastableEqOps,
964                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
965                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
966 }
967 
968 // Return true if there is exactly one attribute not representing a scalar
969 // broadcast.
970 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
971   bool nonScalarSeen = false;
972   for (Attribute a : attributes) {
973     if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
974       if (nonScalarSeen)
975         return false;
976       nonScalarSeen = true;
977     }
978   }
979   return true;
980 }
981 
982 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
983   // No broadcasting is needed if all operands but one are scalar.
984   if (hasAtMostSingleNonScalar(adaptor.getShapes()))
985     return BoolAttr::get(getContext(), true);
986 
987   if ([&] {
988         SmallVector<SmallVector<int64_t, 6>, 6> extents;
989         for (const auto &operand : adaptor.getShapes()) {
990           if (!operand)
991             return false;
992           extents.push_back(llvm::to_vector<6>(
993               llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
994         }
995         return OpTrait::util::staticallyKnownBroadcastable(extents);
996       }())
997     return BoolAttr::get(getContext(), true);
998 
999   // Lastly, see if folding can be completed based on what constraints are known
1000   // on the input shapes.
1001   if ([&] {
1002         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1003         for (auto shapeValue : getShapes()) {
1004           extents.emplace_back();
1005           if (failed(getShapeVec(shapeValue, extents.back())))
1006             return false;
1007         }
1008         return OpTrait::util::staticallyKnownBroadcastable(extents);
1009       }())
1010     return BoolAttr::get(getContext(), true);
1011 
1012   // Because a failing witness result here represents an eventual assertion
1013   // failure, we do not replace it with a constant witness.
1014   return nullptr;
1015 }
1016 
1017 LogicalResult CstrBroadcastableOp::verify() {
1018   // Ensure that CstrBroadcastableOp contains at least two operands
1019   if (getNumOperands() < 2)
1020     return emitOpError("required at least 2 input shapes");
1021   return success();
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // CstrEqOp
1026 //===----------------------------------------------------------------------===//
1027 
1028 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1029                                            MLIRContext *context) {
1030   // If inputs are equal, return passing witness
1031   patterns.add<CstrEqEqOps>(context);
1032 }
1033 
1034 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1035   if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1036         return a && a == adaptor.getShapes().front();
1037       }))
1038     return BoolAttr::get(getContext(), true);
1039 
1040   // Because a failing witness result here represents an eventual assertion
1041   // failure, we do not try to replace it with a constant witness. Similarly, we
1042   // cannot if there are any non-const inputs.
1043   return nullptr;
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // ConstSizeOp
1048 //===----------------------------------------------------------------------===//
1049 
1050 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1051                         int64_t value) {
1052   build(builder, result, builder.getIndexAttr(value));
1053 }
1054 
1055 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1056 
1057 void ConstSizeOp::getAsmResultNames(
1058     llvm::function_ref<void(Value, StringRef)> setNameFn) {
1059   SmallString<4> buffer;
1060   llvm::raw_svector_ostream os(buffer);
1061   os << "c" << getValue();
1062   setNameFn(getResult(), os.str());
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // ConstWitnessOp
1067 //===----------------------------------------------------------------------===//
1068 
1069 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // CstrRequireOp
1073 //===----------------------------------------------------------------------===//
1074 
1075 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1076   return adaptor.getPred();
1077 }
1078 
1079 //===----------------------------------------------------------------------===//
1080 // DimOp
1081 //===----------------------------------------------------------------------===//
1082 
1083 std::optional<int64_t> DimOp::getConstantIndex() {
1084   if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1085     return constSizeOp.getValue().getLimitedValue();
1086   if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1087     return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1088   return std::nullopt;
1089 }
1090 
1091 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1092   Type valType = getValue().getType();
1093   auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1094   if (!valShapedType || !valShapedType.hasRank())
1095     return nullptr;
1096   std::optional<int64_t> index = getConstantIndex();
1097   if (!index.has_value())
1098     return nullptr;
1099   if (index.value() < 0 || index.value() >= valShapedType.getRank())
1100     return nullptr;
1101   auto extent = valShapedType.getDimSize(*index);
1102   if (ShapedType::isDynamic(extent))
1103     return nullptr;
1104   return IntegerAttr::get(IndexType::get(getContext()), extent);
1105 }
1106 
1107 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1108     MLIRContext *context, std::optional<Location> location,
1109     DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1110   inferredReturnTypes.assign({adaptor.getIndex().getType()});
1111   return success();
1112 }
1113 
1114 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1115   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1116 }
1117 
1118 //===----------------------------------------------------------------------===//
1119 // DivOp
1120 //===----------------------------------------------------------------------===//
1121 
1122 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1123   auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1124   if (!lhs)
1125     return nullptr;
1126   auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1127   if (!rhs)
1128     return nullptr;
1129 
1130   // Division in APInt does not follow floor(lhs, rhs) when the result is
1131   // negative. Rather, APInt rounds toward zero.
1132   APInt quotient, remainder;
1133   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1134   if (quotient.isNegative() && !remainder.isZero()) {
1135     quotient -= 1;
1136   }
1137 
1138   Type indexTy = IndexType::get(getContext());
1139   return IntegerAttr::get(indexTy, quotient);
1140 }
1141 
1142 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1143     MLIRContext *context, std::optional<Location> location,
1144     DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1145   if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1146       llvm::isa<SizeType>(adaptor.getRhs().getType()))
1147     inferredReturnTypes.assign({SizeType::get(context)});
1148   else
1149     inferredReturnTypes.assign({IndexType::get(context)});
1150   return success();
1151 }
1152 
1153 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1154   // SizeType is compatible with IndexType.
1155   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1156 }
1157 
1158 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1159 
1160 //===----------------------------------------------------------------------===//
1161 // ShapeEqOp
1162 //===----------------------------------------------------------------------===//
1163 
1164 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1165   bool allSame = true;
1166   if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1167     return {};
1168   for (Attribute operand : adaptor.getShapes().drop_front()) {
1169     if (!operand)
1170       return {};
1171     allSame = allSame && operand == adaptor.getShapes().front();
1172   }
1173   return BoolAttr::get(getContext(), allSame);
1174 }
1175 
1176 //===----------------------------------------------------------------------===//
1177 // IndexToSizeOp
1178 //===----------------------------------------------------------------------===//
1179 
1180 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1181   // Constant values of both types, `shape.size` and `index`, are represented as
1182   // `IntegerAttr`s which makes constant folding simple.
1183   if (Attribute arg = adaptor.getArg())
1184     return arg;
1185   return {};
1186 }
1187 
1188 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1189                                                 MLIRContext *context) {
1190   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // FromExtentsOp
1195 //===----------------------------------------------------------------------===//
1196 
1197 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1198   if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1199     return nullptr;
1200   SmallVector<int64_t, 6> extents;
1201   for (auto attr : adaptor.getExtents())
1202     extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1203   Builder builder(getContext());
1204   return builder.getIndexTensorAttr(extents);
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // FunctionLibraryOp
1209 //===----------------------------------------------------------------------===//
1210 
1211 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1212                               StringRef name) {
1213   result.attributes.push_back(builder.getNamedAttr(
1214       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1215 }
1216 
1217 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1218   auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1219       getMapping().get(op->getName().getIdentifier()));
1220   if (!attr)
1221     return nullptr;
1222   return lookupSymbol<FuncOp>(attr);
1223 }
1224 
1225 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1226                                      OperationState &result) {
1227   // Parse the op name.
1228   StringAttr nameAttr;
1229   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1230                              result.attributes))
1231     return failure();
1232 
1233   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1234     return failure();
1235 
1236   auto *bodyRegion = result.addRegion();
1237   if (parser.parseRegion(*bodyRegion))
1238     return failure();
1239 
1240   if (parser.parseKeyword("mapping"))
1241     return failure();
1242 
1243   DictionaryAttr mappingAttr;
1244   if (parser.parseAttribute(mappingAttr,
1245                             parser.getBuilder().getType<NoneType>(), "mapping",
1246                             result.attributes))
1247     return failure();
1248   return success();
1249 }
1250 
1251 void FunctionLibraryOp::print(OpAsmPrinter &p) {
1252   p << ' ';
1253   p.printSymbolName(getName());
1254   p.printOptionalAttrDictWithKeyword(
1255       (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1256   p << ' ';
1257   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1258                 /*printBlockTerminators=*/false);
1259   p << " mapping ";
1260   p.printAttributeWithoutType(getMappingAttr());
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // FuncOp
1265 //===----------------------------------------------------------------------===//
1266 
1267 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1268                       ArrayRef<NamedAttribute> attrs) {
1269   OpBuilder builder(location->getContext());
1270   OperationState state(location, getOperationName());
1271   FuncOp::build(builder, state, name, type, attrs);
1272   return cast<FuncOp>(Operation::create(state));
1273 }
1274 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1275                       Operation::dialect_attr_range attrs) {
1276   SmallVector<NamedAttribute, 8> attrRef(attrs);
1277   return create(location, name, type, llvm::ArrayRef(attrRef));
1278 }
1279 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1280                       ArrayRef<NamedAttribute> attrs,
1281                       ArrayRef<DictionaryAttr> argAttrs) {
1282   FuncOp func = create(location, name, type, attrs);
1283   func.setAllArgAttrs(argAttrs);
1284   return func;
1285 }
1286 
1287 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1288                    FunctionType type, ArrayRef<NamedAttribute> attrs,
1289                    ArrayRef<DictionaryAttr> argAttrs) {
1290   state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1291                      builder.getStringAttr(name));
1292   state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1293                      TypeAttr::get(type));
1294   state.attributes.append(attrs.begin(), attrs.end());
1295   state.addRegion();
1296 
1297   if (argAttrs.empty())
1298     return;
1299   assert(type.getNumInputs() == argAttrs.size());
1300   function_interface_impl::addArgAndResultAttrs(
1301       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
1302       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1303 }
1304 
1305 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1306   auto buildFuncType =
1307       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1308          function_interface_impl::VariadicFlag,
1309          std::string &) { return builder.getFunctionType(argTypes, results); };
1310 
1311   return function_interface_impl::parseFunctionOp(
1312       parser, result, /*allowVariadic=*/false,
1313       getFunctionTypeAttrName(result.name), buildFuncType,
1314       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1315 }
1316 
1317 void FuncOp::print(OpAsmPrinter &p) {
1318   function_interface_impl::printFunctionOp(
1319       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1320       getArgAttrsAttrName(), getResAttrsAttrName());
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // GetExtentOp
1325 //===----------------------------------------------------------------------===//
1326 
1327 std::optional<int64_t> GetExtentOp::getConstantDim() {
1328   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1329     return constSizeOp.getValue().getLimitedValue();
1330   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1331     return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1332   return std::nullopt;
1333 }
1334 
1335 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1336   auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1337   if (!elements)
1338     return nullptr;
1339   std::optional<int64_t> dim = getConstantDim();
1340   if (!dim.has_value())
1341     return nullptr;
1342   if (dim.value() >= elements.getNumElements())
1343     return nullptr;
1344   return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1345 }
1346 
1347 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1348                         int64_t dim) {
1349   auto loc = result.location;
1350   auto dimAttr = builder.getIndexAttr(dim);
1351   if (llvm::isa<ShapeType>(shape.getType())) {
1352     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1353     build(builder, result, builder.getType<SizeType>(), shape, dim);
1354   } else {
1355     Value dim =
1356         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1357     build(builder, result, builder.getIndexType(), shape, dim);
1358   }
1359 }
1360 
1361 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1362     MLIRContext *context, std::optional<Location> location,
1363     GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1364   inferredReturnTypes.assign({IndexType::get(context)});
1365   return success();
1366 }
1367 
1368 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1369                                                        TypeRange r) {
1370   // SizeType is compatible with IndexType.
1371   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1372 }
1373 
1374 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1375 
1376 //===----------------------------------------------------------------------===//
1377 // IsBroadcastableOp
1378 //===----------------------------------------------------------------------===//
1379 
1380 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1381                                                     MLIRContext *context) {
1382   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1383 }
1384 
1385 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1386   // Can always broadcast fewer than two shapes.
1387   if (adaptor.getShapes().size() < 2) {
1388     return BoolAttr::get(getContext(), true);
1389   }
1390 
1391   return nullptr;
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // MeetOp
1396 //===----------------------------------------------------------------------===//
1397 
1398 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1399     MLIRContext *context, std::optional<Location> location,
1400     MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1401   if (adaptor.getOperands().empty())
1402     return failure();
1403 
1404   auto isShapeType = [](Type arg) {
1405     if (llvm::isa<ShapeType>(arg))
1406       return true;
1407     return isExtentTensorType(arg);
1408   };
1409 
1410   ValueRange::type_range types = adaptor.getOperands().getTypes();
1411   Type acc = types.front();
1412   for (auto t : drop_begin(types)) {
1413     Type l = acc, r = t;
1414     if (!llvm::isa<ShapeType, SizeType>(l))
1415       std::swap(l, r);
1416 
1417     // Handle sizes, propagate error type if present.
1418     if (llvm::isa<SizeType>(l)) {
1419       if (llvm::isa<SizeType, IndexType>(r))
1420         acc = l;
1421       else
1422         return emitOptionalError(location, "requires all sizes or shapes");
1423     } else if (llvm::isa<IndexType>(l)) {
1424       if (llvm::isa<IndexType>(r))
1425         acc = r;
1426       else
1427         return emitOptionalError(location, "requires all sizes or shapes");
1428     } else if (llvm::isa<ShapeType>(l)) {
1429       // Handle shapes, propagate error type if present.
1430       if (isShapeType(r))
1431         acc = l;
1432       else
1433         return emitOptionalError(location, "requires all sizes or shapes");
1434     } else if (isExtentTensorType(l)) {
1435       auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1436       auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1437       if (ShapedType::isDynamic(rank1))
1438         acc = l;
1439       else if (ShapedType::isDynamic(rank2))
1440         acc = r;
1441       else if (rank1 != rank2)
1442         return emitOptionalError(location, "unequal shape cardinality");
1443       else
1444         acc = l;
1445     }
1446   }
1447   inferredReturnTypes.assign({acc});
1448   return success();
1449 }
1450 
1451 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1452   if (l.size() != 1 || r.size() != 1)
1453     return false;
1454   if (l == r)
1455     return true;
1456 
1457   Type lhs = l.front();
1458   Type rhs = r.front();
1459 
1460   if (!llvm::isa<ShapeType, SizeType>(lhs))
1461     std::swap(lhs, rhs);
1462 
1463   if (llvm::isa<SizeType>(lhs))
1464     return llvm::isa<SizeType, IndexType>(rhs);
1465   if (llvm::isa<ShapeType>(lhs))
1466     return llvm::isa<ShapeType, TensorType>(rhs);
1467 
1468   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1469     return true;
1470   return false;
1471 }
1472 
1473 //===----------------------------------------------------------------------===//
1474 // RankOp
1475 //===----------------------------------------------------------------------===//
1476 
1477 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1478   auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1479   if (!shape)
1480     return {};
1481   int64_t rank = shape.getNumElements();
1482   Builder builder(getContext());
1483   return builder.getIndexAttr(rank);
1484 }
1485 
1486 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1487 /// Constant folding fails in cases where only the rank is constant, not the
1488 /// shape itself.
1489 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1490 ///
1491 /// Example:
1492 ///
1493 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1494 /// %rank = shape.rank %shape
1495 ///
1496 /// becomes
1497 ///
1498 /// %rank = shape.const_size 3
1499 
1500 namespace {
1501 struct RankShapeOfCanonicalizationPattern
1502     : public OpRewritePattern<shape::RankOp> {
1503   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1504 
1505   LogicalResult matchAndRewrite(shape::RankOp op,
1506                                 PatternRewriter &rewriter) const override {
1507     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1508     if (!shapeOfOp)
1509       return failure();
1510     auto rankedTensorType =
1511         llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1512     if (!rankedTensorType)
1513       return failure();
1514     int64_t rank = rankedTensorType.getRank();
1515     if (llvm::isa<IndexType>(op.getType())) {
1516       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1517                                                           rank);
1518     } else if (llvm::isa<shape::SizeType>(op.getType())) {
1519       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1520     } else {
1521       return failure();
1522     }
1523     return success();
1524   }
1525 };
1526 } // namespace
1527 
1528 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1529                                                 MLIRContext *context) {
1530   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1531 }
1532 
1533 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1534     MLIRContext *context, std::optional<Location> location,
1535     RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1536   if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1537     inferredReturnTypes.assign({SizeType::get(context)});
1538   else
1539     inferredReturnTypes.assign({IndexType::get(context)});
1540   return success();
1541 }
1542 
1543 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1544   // SizeType is compatible with IndexType.
1545   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1546 }
1547 
1548 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1549 
1550 //===----------------------------------------------------------------------===//
1551 // NumElementsOp
1552 //===----------------------------------------------------------------------===//
1553 
1554 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1555 
1556   // Fold only when argument constant.
1557   Attribute shape = adaptor.getShape();
1558   if (!shape)
1559     return {};
1560 
1561   APInt product(64, 1);
1562   for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1563     product *= value;
1564   Builder builder(getContext());
1565   return builder.getIndexAttr(product.getLimitedValue());
1566 }
1567 
1568 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1569     MLIRContext *context, std::optional<Location> location,
1570     NumElementsOp::Adaptor adaptor,
1571     SmallVectorImpl<Type> &inferredReturnTypes) {
1572   if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1573     inferredReturnTypes.assign({SizeType::get(context)});
1574   else
1575     inferredReturnTypes.assign({IndexType::get(context)});
1576   return success();
1577 }
1578 
1579 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1580                                                          TypeRange r) {
1581   // SizeType is compatible with IndexType.
1582   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1583 }
1584 
1585 LogicalResult shape::NumElementsOp::verify() {
1586   return verifySizeOrIndexOp(*this);
1587 }
1588 
1589 //===----------------------------------------------------------------------===//
1590 // MaxOp
1591 //===----------------------------------------------------------------------===//
1592 
1593 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1594   // If operands are equal, just propagate one.
1595   if (getLhs() == getRhs())
1596     return getLhs();
1597   return nullptr;
1598 }
1599 
1600 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1601     MLIRContext *context, std::optional<Location> location,
1602     MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1603   if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1604     inferredReturnTypes.assign({adaptor.getLhs().getType()});
1605   else
1606     inferredReturnTypes.assign({SizeType::get(context)});
1607   return success();
1608 }
1609 
1610 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1611   if (l.size() != 1 || r.size() != 1)
1612     return false;
1613   if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1614     return true;
1615   if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1616     return true;
1617   return false;
1618 }
1619 
1620 //===----------------------------------------------------------------------===//
1621 // MinOp
1622 //===----------------------------------------------------------------------===//
1623 
1624 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1625   // If operands are equal, just propagate one.
1626   if (getLhs() == getRhs())
1627     return getLhs();
1628   return nullptr;
1629 }
1630 
1631 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1632     MLIRContext *context, std::optional<Location> location,
1633     MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1634   if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1635     inferredReturnTypes.assign({adaptor.getLhs().getType()});
1636   else
1637     inferredReturnTypes.assign({SizeType::get(context)});
1638   return success();
1639 }
1640 
1641 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1642   if (l.size() != 1 || r.size() != 1)
1643     return false;
1644   if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1645     return true;
1646   if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1647     return true;
1648   return false;
1649 }
1650 
1651 //===----------------------------------------------------------------------===//
1652 // MulOp
1653 //===----------------------------------------------------------------------===//
1654 
1655 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1656   auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1657   if (!lhs)
1658     return nullptr;
1659   auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1660   if (!rhs)
1661     return nullptr;
1662   APInt folded = lhs.getValue() * rhs.getValue();
1663   Type indexTy = IndexType::get(getContext());
1664   return IntegerAttr::get(indexTy, folded);
1665 }
1666 
1667 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1668     MLIRContext *context, std::optional<Location> location,
1669     MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1670   if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1671       llvm::isa<SizeType>(adaptor.getRhs().getType()))
1672     inferredReturnTypes.assign({SizeType::get(context)});
1673   else
1674     inferredReturnTypes.assign({IndexType::get(context)});
1675   return success();
1676 }
1677 
1678 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1679   // SizeType is compatible with IndexType.
1680   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1681 }
1682 
1683 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1684 
1685 //===----------------------------------------------------------------------===//
1686 // ShapeOfOp
1687 //===----------------------------------------------------------------------===//
1688 
1689 namespace {
1690 /// Replace shape_of(x) where x has a constant shape with a const_shape op.
1691 struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1692   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1693 
1694   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1695                                 PatternRewriter &rewriter) const override {
1696     auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1697     if (!type || !type.hasStaticShape())
1698       return failure();
1699     Location loc = op.getLoc();
1700     Value constShape =
1701         rewriter
1702             .create<ConstShapeOp>(loc,
1703                                   rewriter.getIndexTensorAttr(type.getShape()))
1704             .getResult();
1705     if (constShape.getType() != op.getResult().getType())
1706       constShape = rewriter.create<tensor::CastOp>(
1707           loc, op.getResult().getType(), constShape);
1708     rewriter.replaceOp(op, constShape);
1709     return success();
1710   }
1711 };
1712 
1713 // Canonicalize
1714 //
1715 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1716 // %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1717 //
1718 // to
1719 //
1720 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1721 // %1 = %shape
1722 //
1723 struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1724   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1725 
1726   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1727                                 PatternRewriter &rewriter) const override {
1728     auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1729     if (!tensorReshapeOp)
1730       return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1731     if (!isa<TensorType>(op.getType()))
1732       return rewriter.notifyMatchFailure(op, "result is not a tensor");
1733 
1734     // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1735     // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1736     // formed IR, it may not be identical (dynamically vs statically shaped),
1737     // in which case it needs to be cast first.
1738     Value shape = tensorReshapeOp.getShape();
1739     if (op.getType() != shape.getType())
1740       shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
1741 
1742     rewriter.replaceOp(op, shape);
1743     return success();
1744   }
1745 };
1746 
1747 // Canonicalize
1748 // ```
1749 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1750 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1751 // ```
1752 // to
1753 // ```
1754 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1755 // ```
1756 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1757   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1758 
1759   LogicalResult matchAndRewrite(tensor::CastOp op,
1760                                 PatternRewriter &rewriter) const override {
1761     auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1762     if (!ty || ty.getRank() != 1)
1763       return failure();
1764 
1765     auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1766     if (!shapeOfOp)
1767       return failure();
1768 
1769     // Argument type must be ranked and must not conflict.
1770     auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1771     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1772       return failure();
1773 
1774     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1775     return success();
1776   }
1777 };
1778 } // namespace
1779 
1780 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1781                                             MLIRContext *context) {
1782   patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1783                ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1784       context);
1785 }
1786 
1787 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1788     MLIRContext *context, std::optional<Location> location,
1789     ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1790   if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1791     inferredReturnTypes.assign({ShapeType::get(context)});
1792   else {
1793     auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1794     int64_t rank =
1795         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1796     Type indexTy = IndexType::get(context);
1797     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1798     inferredReturnTypes.assign({extentTensorTy});
1799   }
1800   return success();
1801 }
1802 
1803 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1804   if (l.size() != 1 || r.size() != 1)
1805     return false;
1806   if (l == r)
1807     return true;
1808 
1809   Type lhs = l.front();
1810   Type rhs = r.front();
1811 
1812   if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1813       !llvm::isa<ShapeType, ShapedType>(rhs))
1814     return false;
1815 
1816   if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1817     // Shape type is compatible with all other valid return types.
1818     return true;
1819 
1820   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1821     return true;
1822   return false;
1823 }
1824 
1825 LogicalResult shape::ShapeOfOp::verify() {
1826   return verifyShapeOrExtentTensorOp(*this);
1827 }
1828 
1829 //===----------------------------------------------------------------------===//
1830 // SizeToIndexOp
1831 //===----------------------------------------------------------------------===//
1832 
1833 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1834   // Constant values of both types, `shape.size` and `index`, are represented as
1835   // `IntegerAttr`s which makes constant folding simple.
1836   if (Attribute arg = adaptor.getArg())
1837     return arg;
1838   return OpFoldResult();
1839 }
1840 
1841 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1842                                                 MLIRContext *context) {
1843   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1844 }
1845 
1846 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1847   if (inputs.size() != 1 || outputs.size() != 1)
1848     return false;
1849   return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1850          llvm::isa<IndexType>(outputs[0]);
1851 }
1852 
1853 //===----------------------------------------------------------------------===//
1854 // YieldOp
1855 //===----------------------------------------------------------------------===//
1856 
1857 LogicalResult shape::YieldOp::verify() {
1858   auto *parentOp = (*this)->getParentOp();
1859   auto results = parentOp->getResults();
1860   auto operands = getOperands();
1861 
1862   if (parentOp->getNumResults() != getNumOperands())
1863     return emitOpError() << "number of operands does not match number of "
1864                             "results of its parent";
1865   for (auto e : llvm::zip(results, operands))
1866     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1867       return emitOpError() << "types mismatch between yield op and its parent";
1868 
1869   return success();
1870 }
1871 
1872 //===----------------------------------------------------------------------===//
1873 // SplitAtOp
1874 //===----------------------------------------------------------------------===//
1875 
1876 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1877                               SmallVectorImpl<OpFoldResult> &results) {
1878   if (!adaptor.getOperand() || !adaptor.getIndex())
1879     return failure();
1880   auto shapeVec = llvm::to_vector<6>(
1881       llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1882   auto shape = llvm::ArrayRef(shapeVec);
1883   auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1884   // Verify that the split point is in the correct range.
1885   // TODO: Constant fold to an "error".
1886   int64_t rank = shape.size();
1887   if (-rank > splitPoint || splitPoint > rank)
1888     return failure();
1889   if (splitPoint < 0)
1890     splitPoint += shape.size();
1891   Builder builder(adaptor.getOperand().getContext());
1892   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1893   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1894   return success();
1895 }
1896 
1897 //===----------------------------------------------------------------------===//
1898 // ToExtentTensorOp
1899 //===----------------------------------------------------------------------===//
1900 
1901 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1902   if (!adaptor.getInput())
1903     return OpFoldResult();
1904   Builder builder(getContext());
1905   auto shape = llvm::to_vector<6>(
1906       llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1907   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1908                                     builder.getIndexType());
1909   return DenseIntElementsAttr::get(type, shape);
1910 }
1911 
1912 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1913   if (inputs.size() != 1 || outputs.size() != 1)
1914     return false;
1915   if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1916     if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1917         inputTensor.getRank() != 1)
1918       return false;
1919   } else if (!llvm::isa<ShapeType>(inputs[0])) {
1920     return false;
1921   }
1922 
1923   TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1924   return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1925 }
1926 
1927 //===----------------------------------------------------------------------===//
1928 // ReduceOp
1929 //===----------------------------------------------------------------------===//
1930 
1931 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1932                      ValueRange initVals) {
1933   OpBuilder::InsertionGuard g(builder);
1934   result.addOperands(shape);
1935   result.addOperands(initVals);
1936 
1937   Region *bodyRegion = result.addRegion();
1938   Block *bodyBlock = builder.createBlock(
1939       bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1940 
1941   Type elementType;
1942   if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1943     elementType = tensorType.getElementType();
1944   else
1945     elementType = SizeType::get(builder.getContext());
1946   bodyBlock->addArgument(elementType, shape.getLoc());
1947 
1948   for (Value initVal : initVals) {
1949     bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1950     result.addTypes(initVal.getType());
1951   }
1952 }
1953 
1954 LogicalResult ReduceOp::verify() {
1955   // Verify block arg types.
1956   Block &block = getRegion().front();
1957 
1958   // The block takes index, extent, and aggregated values as arguments.
1959   auto blockArgsCount = getInitVals().size() + 2;
1960   if (block.getNumArguments() != blockArgsCount)
1961     return emitOpError() << "ReduceOp body is expected to have "
1962                          << blockArgsCount << " arguments";
1963 
1964   // The first block argument is the index and must always be of type `index`.
1965   if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1966     return emitOpError(
1967         "argument 0 of ReduceOp body is expected to be of IndexType");
1968 
1969   // The second block argument is the extent and must be of type `size` or
1970   // `index`, depending on whether the reduce operation is applied to a shape or
1971   // to an extent tensor.
1972   Type extentTy = block.getArgument(1).getType();
1973   if (llvm::isa<ShapeType>(getShape().getType())) {
1974     if (!llvm::isa<SizeType>(extentTy))
1975       return emitOpError("argument 1 of ReduceOp body is expected to be of "
1976                          "SizeType if the ReduceOp operates on a ShapeType");
1977   } else {
1978     if (!llvm::isa<IndexType>(extentTy))
1979       return emitOpError(
1980           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1981           "ReduceOp operates on an extent tensor");
1982   }
1983 
1984   for (const auto &type : llvm::enumerate(getInitVals()))
1985     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1986       return emitOpError() << "type mismatch between argument "
1987                            << type.index() + 2
1988                            << " of ReduceOp body and initial value "
1989                            << type.index();
1990   return success();
1991 }
1992 
1993 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1994   // Parse operands.
1995   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1996   Type shapeOrExtentTensorType;
1997   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1998                               OpAsmParser::Delimiter::Paren) ||
1999       parser.parseColonType(shapeOrExtentTensorType) ||
2000       parser.parseOptionalArrowTypeList(result.types))
2001     return failure();
2002 
2003   // Resolve operands.
2004   auto initVals = llvm::ArrayRef(operands).drop_front();
2005   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
2006                             result.operands) ||
2007       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2008                              result.operands))
2009     return failure();
2010 
2011   // Parse the body.
2012   Region *body = result.addRegion();
2013   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2014     return failure();
2015 
2016   // Parse attributes.
2017   if (parser.parseOptionalAttrDict(result.attributes))
2018     return failure();
2019 
2020   return success();
2021 }
2022 
2023 void ReduceOp::print(OpAsmPrinter &p) {
2024   p << '(' << getShape() << ", " << getInitVals()
2025     << ") : " << getShape().getType();
2026   p.printOptionalArrowTypeList(getResultTypes());
2027   p << ' ';
2028   p.printRegion(getRegion());
2029   p.printOptionalAttrDict((*this)->getAttrs());
2030 }
2031 
2032 #define GET_OP_CLASSES
2033 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2034 
2035 #define GET_TYPEDEF_CLASSES
2036 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
2037