xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (revision 092372da15e5165be14cdbb7cac3cf4976fd82d0)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Arith/Utils/Utils.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinAttributeInterfaces.h"
19 #include "mlir/IR/BuiltinTypeInterfaces.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/MathExtras.h"
33 #include <algorithm>
34 #include <optional>
35 
36 using namespace mlir;
37 using namespace mlir::tensor;
38 
39 using llvm::divideCeilSigned;
40 using llvm::divideFloorSigned;
41 using llvm::mod;
42 
43 /// Materialize a single constant operation from a given attribute value with
44 /// the desired resultant type.
45 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
46                                               Attribute value, Type type,
47                                               Location loc) {
48   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
49     return op;
50   if (complex::ConstantOp::isBuildableWith(value, type))
51     return builder.create<complex::ConstantOp>(loc, type,
52                                                llvm::cast<ArrayAttr>(value));
53   return nullptr;
54 }
55 
56 OpFoldResult tensor::getMixedSize(OpBuilder &builder, Location loc, Value value,
57                                   int64_t dim) {
58   auto tensorType = llvm::cast<RankedTensorType>(value.getType());
59   SmallVector<OpFoldResult> result;
60   if (tensorType.isDynamicDim(dim))
61     return builder.createOrFold<tensor::DimOp>(loc, value, dim);
62 
63   return builder.getIndexAttr(tensorType.getDimSize(dim));
64 }
65 
66 SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
67                                                 Location loc, Value value) {
68   auto tensorType = llvm::cast<RankedTensorType>(value.getType());
69   SmallVector<OpFoldResult> result;
70   for (int64_t i = 0; i < tensorType.getRank(); ++i)
71     result.push_back(getMixedSize(builder, loc, value, i));
72   return result;
73 }
74 
75 FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
76                                                 OpResult opResult) {
77   auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
78   assert(tensorType && "expected tensor type");
79 
80   // If the op has a destination, it implements DestinationStyleOpInterface and
81   // we can query the destination operand from that interface.
82   auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
83   if (destOp)
84     return destOp.getTiedOpOperand(opResult)->get();
85 
86   // Otherwise, create a new destination tensor with the same shape.
87   OpBuilder::InsertionGuard g(b);
88   b.setInsertionPoint(opResult.getDefiningOp());
89 
90   // Compute sizes.
91   SmallVector<OpFoldResult> mixedSizes;
92   if (!tensorType.hasStaticShape()) {
93     // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
94     ReifiedRankedShapedTypeDims reifiedShapes;
95     if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
96       return failure();
97     mixedSizes = reifiedShapes[opResult.getResultNumber()];
98   } else {
99     // Static shape: Take static sizes directly.
100     for (int64_t sz : tensorType.getShape())
101       mixedSizes.push_back(b.getIndexAttr(sz));
102   }
103 
104   // Create empty tensor.
105   Value emptyTensor =
106       b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
107   return emptyTensor;
108 }
109 
110 LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
111                                               Operation *op,
112                                               SmallVector<Value> &result) {
113   for (OpResult opResult : op->getResults()) {
114     if (llvm::isa<TensorType>(opResult.getType())) {
115       FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
116       if (failed(destination))
117         return failure();
118       result.push_back(*destination);
119     }
120   }
121   return success();
122 }
123 
124 bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
125   if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126     if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127       return rtp1.getShape() == rtp2.getShape() &&
128              rtp1.getElementType() == rtp2.getElementType();
129     return false;
130   }
131   return tp1 == tp2; // default implementation
132 }
133 
134 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135 /// rank-extending tensor.insert_slice op.
136 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
137                                            ArrayRef<OpFoldResult> mixedSizes) {
138   llvm::SmallBitVector droppedDims(mixedSizes.size());
139   int64_t shapePos = reducedShape.size() - 1;
140 
141   for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142     size_t idx = mixedSizes.size() - size.index() - 1;
143     // Rank-reduced dims must have a static unit dimension.
144     bool isStaticUnitSize =
145         isa<Attribute>(size.value()) &&
146         llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
147 
148     if (shapePos < 0) {
149       // There are no more dims in the reduced shape. All remaining sizes must
150       // be rank-reduced dims.
151       assert(isStaticUnitSize && "expected unit dim");
152       droppedDims.set(idx);
153       continue;
154     }
155 
156     // Dim is preserved if the size is not a static 1.
157     if (!isStaticUnitSize) {
158       --shapePos;
159       continue;
160     }
161 
162     // Dim is preserved if the reduced shape dim is also 1.
163     if (reducedShape[shapePos] == 1) {
164       --shapePos;
165       continue;
166     }
167 
168     // Otherwise: Dim is dropped.
169     droppedDims.set(idx);
170   }
171 
172   assert(shapePos < 0 && "dimension mismatch");
173   return droppedDims;
174 }
175 
176 /// Given a ranked tensor type and a range of values that defines its dynamic
177 /// dimension sizes, turn all dynamic sizes that have a constant value into
178 /// static dimension sizes.
179 static RankedTensorType
180 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
181                             SmallVector<Value> &foldedDynamicSizes) {
182   SmallVector<int64_t> staticShape(type.getShape());
183   assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184          "incorrect number of dynamic sizes");
185 
186   // Compute new static and dynamic sizes.
187   unsigned ctr = 0;
188   for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189     if (type.isDynamicDim(i)) {
190       Value dynamicSize = dynamicSizes[ctr++];
191       std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
192       if (cst.has_value()) {
193         // Dynamic size must be non-negative.
194         if (cst.value() < 0) {
195           foldedDynamicSizes.push_back(dynamicSize);
196           continue;
197         }
198         staticShape[i] = *cst;
199       } else {
200         foldedDynamicSizes.push_back(dynamicSize);
201       }
202     }
203   }
204 
205   return RankedTensorType::get(staticShape, type.getElementType(),
206                                type.getEncoding());
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // BitcastOp
211 //===----------------------------------------------------------------------===//
212 
213 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214   if (inputs.size() != 1 || outputs.size() != 1)
215     return false;
216   Type a = inputs.front(), b = outputs.front();
217   auto aT = dyn_cast<TensorType>(a);
218   auto bT = dyn_cast<TensorType>(b);
219   if (!aT || !bT)
220     return false;
221 
222   if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223     return false;
224 
225   return succeeded(verifyCompatibleShape(aT, bT));
226 }
227 
228 namespace {
229 
230 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
231 /// operation.
232 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
233   using OpRewritePattern<BitcastOp>::OpRewritePattern;
234 
235   LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236                                 PatternRewriter &rewriter) const final {
237     auto tensorBitcastOperand =
238         tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239     if (!tensorBitcastOperand)
240       return failure();
241 
242     auto resultType = cast<TensorType>(tensorBitcast.getType());
243     rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244                                            tensorBitcastOperand.getOperand());
245     return success();
246   }
247 };
248 
249 } // namespace
250 
251 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
252                                             MLIRContext *context) {
253   results.add<ChainedTensorBitcast>(context);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // CastOp
258 //===----------------------------------------------------------------------===//
259 
260 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261   setNameFn(getResult(), "cast");
262 }
263 
264 /// Returns true if `target` is a ranked tensor type that preserves static
265 /// information available in the `source` ranked tensor type.
266 bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
267   auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268   auto targetType = llvm::dyn_cast<RankedTensorType>(target);
269 
270   // Requires RankedTensorType.
271   if (!sourceType || !targetType)
272     return false;
273 
274   // Requires same elemental type.
275   if (sourceType.getElementType() != targetType.getElementType())
276     return false;
277 
278   // Requires same rank.
279   if (sourceType.getRank() != targetType.getRank())
280     return false;
281 
282   // Requires same encoding.
283   if (sourceType.getEncoding() != targetType.getEncoding())
284     return false;
285 
286   // If cast is towards more static sizes along any dimension, don't fold.
287   for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288     if (!ShapedType::isDynamic(std::get<0>(t)) &&
289         ShapedType::isDynamic(std::get<1>(t)))
290       return false;
291   }
292 
293   return true;
294 }
295 
296 /// Determines whether tensor::CastOp casts to a more dynamic version of the
297 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
298 /// implement canonicalization patterns for ops in different dialects that may
299 /// consume the results of tensor.cast operations. Such foldable tensor.cast
300 /// operations are typically inserted as `slice` ops and are canonicalized,
301 /// to preserve the type compatibility of their uses.
302 ///
303 /// Returns true when all conditions are met:
304 /// 1. source and result are ranked tensors with same element type and rank.
305 /// 2. the tensor type has more static information than the result
306 ///
307 /// Example:
308 /// ```mlir
309 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
311 /// ```
312 ///
313 /// folds into:
314 ///
315 /// ```mlir
316 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
317 /// ```
318 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
319   if (!castOp)
320     return false;
321 
322   // Can fold if the source of cast has at least as much static information as
323   // its results.
324   return preservesStaticInformation(castOp.getType(),
325                                     castOp.getSource().getType());
326 }
327 
328 /// Determines whether the tensor::CastOp casts to a more static version of the
329 /// source tensor. This is useful to fold into a producing op and implement
330 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
331 /// being from different dialects. Returns true when all conditions are met:
332 /// 1. source and result and ranked tensors with same element type and rank.
333 /// 2. the result type has more static information than the source.
334 ///
335 /// Example:
336 /// ```mlir
337 ///   %1 = producer ... : tensor<?x?xf32>
338 ///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
339 /// ```
340 ///
341 /// can be canonicalized to :
342 ///
343 /// ```mlir
344 ///   %2 = producer ... : tensor<8x16xf32>
345 /// ```
346 /// Not all ops might be canonicalizable this way, but for those that can be,
347 /// this method provides a check that it is worth doing the canonicalization.
348 bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
349   if (!castOp)
350     return false;
351   return preservesStaticInformation(castOp.getSource().getType(),
352                                     castOp.getType());
353 }
354 
355 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
356 /// that can be folded.
357 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
358   bool folded = false;
359   for (OpOperand &operand : op->getOpOperands()) {
360     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
361     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
362       operand.set(castOp.getOperand());
363       folded = true;
364     }
365   }
366   return success(folded);
367 }
368 
369 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
370   if (inputs.size() != 1 || outputs.size() != 1)
371     return false;
372   Type a = inputs.front(), b = outputs.front();
373   auto aT = llvm::dyn_cast<TensorType>(a);
374   auto bT = llvm::dyn_cast<TensorType>(b);
375   if (!aT || !bT)
376     return false;
377 
378   if (aT.getElementType() != bT.getElementType())
379     return false;
380 
381   return succeeded(verifyCompatibleShape(aT, bT));
382 }
383 
384 /// Compute a TensorType that has the joined shape knowledge of the two
385 /// given TensorTypes. The element types need to match.
386 static TensorType joinShapes(TensorType one, TensorType two) {
387   assert(one.getElementType() == two.getElementType());
388 
389   if (!one.hasRank())
390     return two;
391   if (!two.hasRank())
392     return one;
393 
394   int64_t rank = one.getRank();
395   if (rank != two.getRank())
396     return {};
397 
398   SmallVector<int64_t, 4> join;
399   join.reserve(rank);
400   for (int64_t i = 0; i < rank; ++i) {
401     if (one.isDynamicDim(i)) {
402       join.push_back(two.getDimSize(i));
403       continue;
404     }
405     if (two.isDynamicDim(i)) {
406       join.push_back(one.getDimSize(i));
407       continue;
408     }
409     if (one.getDimSize(i) != two.getDimSize(i))
410       return {};
411     join.push_back(one.getDimSize(i));
412   }
413   return RankedTensorType::get(join, one.getElementType());
414 }
415 
416 namespace {
417 
418 /// Replaces chains of two tensor.cast operations by a single tensor.cast
419 /// operation if doing so does not remove runtime constraints.
420 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
421   using OpRewritePattern<CastOp>::OpRewritePattern;
422 
423   LogicalResult matchAndRewrite(CastOp tensorCast,
424                                 PatternRewriter &rewriter) const final {
425     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
426 
427     if (!tensorCastOperand)
428       return failure();
429 
430     auto sourceType =
431         llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
432     auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
433     auto resultType = llvm::cast<TensorType>(tensorCast.getType());
434 
435     // We can remove the intermediate cast if joining all three produces the
436     // same result as just joining the source and result shapes.
437     auto firstJoin =
438         joinShapes(joinShapes(sourceType, intermediateType), resultType);
439 
440     // The join might not exist if the cast sequence would fail at runtime.
441     if (!firstJoin)
442       return failure();
443 
444     // The newJoin always exists if the above join exists, it might just contain
445     // less information. If so, we cannot drop the intermediate cast, as doing
446     // so would remove runtime checks.
447     auto newJoin = joinShapes(sourceType, resultType);
448     if (firstJoin != newJoin)
449       return failure();
450 
451     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
452                                         tensorCastOperand.getOperand());
453     return success();
454   }
455 };
456 
457 /// Fold tensor.cast into tesor.extract_slice producer.
458 /// Example:
459 /// ```
460 ///  %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
461 ///    tensor<128x512xf32> to tensor<?x512xf32>
462 ///  %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
463 /// ```
464 /// ->
465 /// ```
466 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
467 ///   tensor<128x512xf32> to tensor<16x512xf32>
468 /// ```
469 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
470   using OpRewritePattern<CastOp>::OpRewritePattern;
471 
472   LogicalResult matchAndRewrite(CastOp tensorCast,
473                                 PatternRewriter &rewriter) const final {
474     auto extractOperand =
475         tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
476 
477     // Cannot fold cast to unranked tensor.
478     auto rankedResultType =
479         llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
480     if (!rankedResultType)
481       return failure();
482 
483     if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
484         rankedResultType.getShape() ==
485             llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
486                 .getShape())
487       return failure();
488 
489     SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
490     auto dimMask = computeRankReductionMask(
491         extractOperand.getStaticSizes(), extractOperand.getType().getShape());
492     size_t dimIndex = 0;
493     for (size_t i = 0, e = sizes.size(); i < e; i++) {
494       if (dimMask && dimMask->count(i))
495         continue;
496       int64_t dim = rankedResultType.getShape()[dimIndex++];
497       if (ShapedType::isDynamic(dim))
498         continue;
499       sizes[i] = rewriter.getIndexAttr(dim);
500     }
501 
502     rewriter.replaceOpWithNewOp<ExtractSliceOp>(
503         tensorCast, rankedResultType, extractOperand.getSource(),
504         extractOperand.getMixedOffsets(), sizes,
505         extractOperand.getMixedStrides());
506     return success();
507   }
508 };
509 
510 } // namespace
511 
512 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
513                                          MLIRContext *context) {
514   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // ConcatOp
519 //===----------------------------------------------------------------------===//
520 
521 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
522   assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
523   auto tensorTypes =
524       llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
525         return llvm::cast<RankedTensorType>(type);
526       }));
527   int64_t concatRank = tensorTypes[0].getRank();
528 
529   // The concatenation dim must be in the range [0, rank).
530   assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
531 
532   SmallVector<int64_t> sizes(concatRank);
533   for (int64_t i = 0, e = concatRank; i < e; ++i) {
534     if (i == dim)
535       continue;
536     SaturatedInteger size;
537     for (auto tensorType : tensorTypes)
538       size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
539     sizes[i] = size.asInteger();
540   }
541   auto concatSize = SaturatedInteger::wrap(0);
542   for (auto tensorType : tensorTypes)
543     concatSize =
544         concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
545   sizes[dim] = concatSize.asInteger();
546   return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
547 }
548 
549 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
550                      ValueRange inputs) {
551   FailureOr<RankedTensorType> resultType =
552       inferResultType(dim, inputs.getTypes());
553   assert(succeeded(resultType) && "failed to infer concatenation result type");
554   build(builder, result, *resultType, dim, inputs);
555 }
556 
557 LogicalResult ConcatOp::verify() {
558   if (getInputs().size() < 1)
559     return emitOpError("requires at least one input");
560 
561   SmallVector<RankedTensorType> inputTypes;
562   for (auto input : getInputs())
563     inputTypes.push_back(cast<RankedTensorType>(input.getType()));
564 
565   RankedTensorType resultType = getResultType();
566   int64_t resultRank = getRank();
567   if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
568         return type.getRank() != resultRank;
569       }))
570     return emitOpError("rank of concatenated inputs must match result rank");
571 
572   Type resultElementType = resultType.getElementType();
573   if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
574         return type.getElementType() != resultElementType;
575       }))
576     return emitOpError("inputs and result element type must match");
577 
578   int64_t dim = getDim();
579   if (dim >= resultRank)
580     return emitOpError("concatenation dim must be less than the tensor rank");
581 
582   SmallVector<int64_t> sizes(resultRank);
583   for (int64_t i = 0, e = resultRank; i < e; ++i) {
584     if (i == dim)
585       continue;
586     SaturatedInteger size;
587     for (auto tensorType : inputTypes) {
588       FailureOr<SaturatedInteger> maybeSize =
589           size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
590       if (failed(maybeSize))
591         return emitOpError("static concatenation size mismatch along ")
592                << "non-concatenated dimension " << i;
593       size = *maybeSize;
594     }
595     sizes[i] = size.asInteger();
596   }
597   auto concatSize = SaturatedInteger::wrap(0);
598   for (auto tensorType : inputTypes)
599     concatSize =
600         concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
601   sizes[dim] = concatSize.asInteger();
602   auto inferredResultType =
603       RankedTensorType::get(sizes, inputTypes[0].getElementType());
604 
605   for (auto [inferredSize, actualSize] :
606        llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
607     bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
608                       ShapedType::isDynamic(actualSize);
609     if (!hasDynamic && inferredSize != actualSize)
610       return emitOpError("result type ")
611              << resultType << "does not match inferred shape "
612              << inferredResultType << " static sizes";
613   }
614 
615   return success();
616 }
617 
618 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
619   size_t numInputs = getInputs().size();
620   uint64_t concatDim = getDim();
621 
622   SmallVector<SmallVector<OpFoldResult>> inputShapes;
623   inputShapes.reserve(numInputs);
624   SmallVector<OpFoldResult> concatOffsets;
625   concatOffsets.reserve(numInputs);
626   SmallVector<OpFoldResult> outputShape;
627 
628   AffineExpr addExpr =
629       builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
630   OpFoldResult zero = builder.getIndexAttr(0);
631   Location loc = getLoc();
632   for (auto [index, input] : llvm::enumerate(getInputs())) {
633     SmallVector<OpFoldResult> inputShape =
634         tensor::getMixedSizes(builder, input.getLoc(), input);
635     if (index == 0) {
636       outputShape = inputShape;
637       concatOffsets.push_back(zero);
638     } else {
639       concatOffsets.push_back(outputShape[concatDim]);
640       outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
641           builder, loc, addExpr,
642           {outputShape[concatDim], inputShape[concatDim]});
643     }
644     inputShapes.emplace_back(std::move(inputShape));
645   }
646 
647   Value replacement = builder.create<tensor::EmptyOp>(
648       loc, outputShape, getType().getElementType());
649 
650   int64_t rank = getType().getRank();
651   OpFoldResult one = builder.getIndexAttr(1);
652   SmallVector<OpFoldResult> strides(rank, one);
653   SmallVector<OpFoldResult> offsets(rank, zero);
654   for (auto [index, input] : llvm::enumerate(getInputs())) {
655     offsets[concatDim] = concatOffsets[index];
656     auto insertSlice = builder.create<tensor::InsertSliceOp>(
657         loc, input, replacement, offsets, inputShapes[index], strides);
658     replacement = insertSlice.getResult();
659   }
660   if (replacement.getType() != getType()) {
661     replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
662   }
663   return SmallVector<Value>{replacement};
664 }
665 
666 LogicalResult
667 ConcatOp::reifyResultShapes(OpBuilder &builder,
668                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
669   ValueRange inputs = getInputs();
670   int64_t dim = getDim();
671   RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
672 
673   Value init = inputs[0];
674   int64_t rank = getType().getRank();
675 
676   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
677 
678   // Pre-populate the result sizes with as much static information as possible
679   // from the given result type, as well as the inferred result type, otherwise
680   // use the dim sizes from the first input.
681   for (int64_t i = 0; i < rank; ++i) {
682     if (i == dim)
683       continue;
684     if (!getType().isDynamicDim(i)) {
685       reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
686     } else if (!inferredResultType.isDynamicDim(i)) {
687       reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
688           builder, getLoc(),
689           builder.getIndexAttr(inferredResultType.getDimSize(i)));
690     } else {
691       reifiedReturnShapes[0][i] =
692           builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
693     }
694   }
695 
696   if (getType().isDynamicDim(dim)) {
697     // Take the sum of the input sizes along the concatenated dim.
698     AffineExpr sum = builder.getAffineDimExpr(0);
699     SmallVector<OpFoldResult> sizes = {
700         builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
701     for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
702       sum = sum + builder.getAffineDimExpr(idx + 1);
703       sizes.push_back(
704           builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
705     }
706     reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
707         builder, getLoc(),
708         affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
709   } else {
710     // If the result shape is static along the concatenated dim, use the static
711     // shape.
712     reifiedReturnShapes[0][dim] =
713         builder.getIndexAttr(getType().getDimSize(dim));
714   }
715   return success();
716 }
717 
718 void ConcatOp::getAsmResultNames(
719     function_ref<void(Value, StringRef)> setNameFn) {
720   setNameFn(getResult(), "concat");
721 }
722 
723 OpFoldResult ConcatOp::fold(FoldAdaptor) {
724   ValueRange inputs = getInputs();
725   if (inputs.size() == 1 && inputs[0].getType() == getResultType())
726     return inputs[0];
727   return {};
728 }
729 
730 namespace {
731 /// Fold a concat op with a single input to a cast.
732 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
733   using OpRewritePattern<ConcatOp>::OpRewritePattern;
734 
735   LogicalResult matchAndRewrite(ConcatOp concatOp,
736                                 PatternRewriter &rewriter) const override {
737     if (concatOp.getInputs().size() != 1)
738       return failure();
739     rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
740                                         concatOp.getInputs()[0]);
741     return success();
742   }
743 };
744 } // namespace
745 
746 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
747                                            MLIRContext *context) {
748   results.add<SingleInputConcatOp>(context);
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // DimOp
753 //===----------------------------------------------------------------------===//
754 
755 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
756   setNameFn(getResult(), "dim");
757 }
758 
759 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
760                   int64_t index) {
761   auto loc = result.location;
762   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
763   build(builder, result, source, indexValue);
764 }
765 
766 std::optional<int64_t> DimOp::getConstantIndex() {
767   return getConstantIntValue(getIndex());
768 }
769 
770 Speculation::Speculatability DimOp::getSpeculatability() {
771   auto constantIndex = getConstantIndex();
772   if (!constantIndex)
773     return Speculation::NotSpeculatable;
774 
775   auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
776   if (!rankedSourceType)
777     return Speculation::NotSpeculatable;
778 
779   if (rankedSourceType.getRank() <= constantIndex)
780     return Speculation::NotSpeculatable;
781 
782   return Speculation::Speculatable;
783 }
784 
785 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
786   // All forms of folding require a known index.
787   auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
788   if (!index)
789     return {};
790 
791   // Folding for unranked types (UnrankedTensorType) is not supported.
792   auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
793   if (!tensorType)
794     return {};
795 
796   // Out of bound indices produce undefined behavior but are still valid IR.
797   // Don't choke on them.
798   int64_t indexVal = index.getInt();
799   if (indexVal < 0 || indexVal >= tensorType.getRank())
800     return {};
801 
802   // Fold if the shape extent along the given index is known.
803   if (!tensorType.isDynamicDim(index.getInt())) {
804     Builder builder(getContext());
805     return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
806   }
807 
808   Operation *definingOp = getSource().getDefiningOp();
809 
810   // Fold dim to the operand of tensor.generate.
811   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
812     auto resultType =
813         llvm::cast<RankedTensorType>(fromElements.getResult().getType());
814     // The case where the type encodes the size of the dimension is handled
815     // above.
816     assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
817 
818     // Find the operand of the fromElements that corresponds to this index.
819     auto dynExtents = fromElements.getDynamicExtents().begin();
820     for (auto dim : resultType.getShape().take_front(index.getInt()))
821       if (ShapedType::isDynamic(dim))
822         dynExtents++;
823 
824     return Value{*dynExtents};
825   }
826 
827   // The size at the given index is now known to be a dynamic size.
828   unsigned unsignedIndex = index.getValue().getZExtValue();
829 
830   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
831     // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
832     // `resolve-shaped-type-result-dims` pass.
833     if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
834         sliceOp.isDynamicSize(unsignedIndex)) {
835       return {sliceOp.getDynamicSize(unsignedIndex)};
836     }
837   }
838 
839   // dim(cast) -> dim
840   if (succeeded(foldTensorCast(*this)))
841     return getResult();
842 
843   return {};
844 }
845 
846 namespace {
847 /// Fold dim of a cast into the dim of the source of the tensor cast.
848 struct DimOfCastOp : public OpRewritePattern<DimOp> {
849   using OpRewritePattern<DimOp>::OpRewritePattern;
850 
851   LogicalResult matchAndRewrite(DimOp dimOp,
852                                 PatternRewriter &rewriter) const override {
853     auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
854     if (!castOp)
855       return failure();
856     Value newSource = castOp.getOperand();
857     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
858     return success();
859   }
860 };
861 
862 /// Fold dim of a destination passing style op into the dim of the corresponding
863 /// init.
864 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
865   using OpRewritePattern<DimOp>::OpRewritePattern;
866 
867   LogicalResult matchAndRewrite(DimOp dimOp,
868                                 PatternRewriter &rewriter) const override {
869     auto source = dimOp.getSource();
870     auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
871     if (!destOp)
872       return failure();
873 
874     auto resultIndex = cast<OpResult>(source).getResultNumber();
875     auto *initOperand = destOp.getDpsInitOperand(resultIndex);
876 
877     rewriter.modifyOpInPlace(
878         dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
879     return success();
880   }
881 };
882 
883 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
884 /// operand.
885 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
886   using OpRewritePattern<DimOp>::OpRewritePattern;
887 
888   LogicalResult matchAndRewrite(DimOp dim,
889                                 PatternRewriter &rewriter) const override {
890     auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
891 
892     if (!reshape)
893       return failure();
894 
895     // Since tensors are immutable we don't need to worry about where to place
896     // the extract call
897     rewriter.setInsertionPointAfter(dim);
898     Location loc = dim.getLoc();
899     Value extract =
900         rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
901     if (extract.getType() != dim.getType())
902       extract =
903           rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
904     rewriter.replaceOp(dim, extract);
905     return success();
906   }
907 };
908 } // namespace
909 
910 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
911                                         MLIRContext *context) {
912   results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // EmptyOp
917 //===----------------------------------------------------------------------===//
918 
919 void EmptyOp::build(OpBuilder &builder, OperationState &result,
920                     ArrayRef<int64_t> staticShape, Type elementType,
921                     Attribute encoding) {
922   assert(all_of(staticShape,
923                 [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
924          "expected only static sizes");
925   build(builder, result, staticShape, elementType, ValueRange{}, encoding);
926 }
927 
928 void EmptyOp::build(OpBuilder &builder, OperationState &result,
929                     ArrayRef<int64_t> staticShape, Type elementType,
930                     ValueRange dynamicSizes, Attribute encoding) {
931   auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
932   build(builder, result, tensorType, dynamicSizes);
933 }
934 
935 void EmptyOp::build(OpBuilder &builder, OperationState &result,
936                     ArrayRef<OpFoldResult> sizes, Type elementType,
937                     Attribute encoding) {
938   SmallVector<int64_t> staticShape;
939   SmallVector<Value> dynamicSizes;
940   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
941   build(builder, result, staticShape, elementType, dynamicSizes, encoding);
942 }
943 
944 LogicalResult EmptyOp::verify() {
945   if (getType().getNumDynamicDims() != getDynamicSizes().size())
946     return emitOpError("incorrect number of dynamic sizes, has ")
947            << getDynamicSizes().size() << ", expected "
948            << getType().getNumDynamicDims();
949   return success();
950 }
951 
952 LogicalResult
953 EmptyOp::reifyResultShapes(OpBuilder &builder,
954                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
955   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
956   unsigned ctr = 0;
957   for (int64_t i = 0; i < getType().getRank(); ++i) {
958     if (getType().isDynamicDim(i)) {
959       reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
960     } else {
961       reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
962     }
963   }
964   return success();
965 }
966 
967 Value EmptyOp::getDynamicSize(unsigned idx) {
968   assert(getType().isDynamicDim(idx) && "expected dynamic dim");
969   unsigned ctr = 0;
970   for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
971     if (getType().isDynamicDim(i))
972       ++ctr;
973   return getDynamicSizes()[ctr];
974 }
975 
976 SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
977   SmallVector<OpFoldResult> result;
978   unsigned ctr = 0;
979   OpBuilder b(getContext());
980   for (int64_t i = 0; i < getType().getRank(); ++i) {
981     if (getType().isDynamicDim(i)) {
982       result.push_back(getDynamicSizes()[ctr++]);
983     } else {
984       result.push_back(b.getIndexAttr(getType().getShape()[i]));
985     }
986   }
987   return result;
988 }
989 
990 namespace {
991 /// Change the type of the result of a `tensor.empty` by making the result
992 /// type statically sized along dimensions that in the original operation were
993 /// defined as dynamic, but the size was defined using a `constant` op. For
994 /// example
995 ///
996 ///  %c5 = arith.constant 5: index
997 ///  %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
998 ///
999 ///  to
1000 ///
1001 ///  %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1002 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1003   using OpRewritePattern<EmptyOp>::OpRewritePattern;
1004 
1005   LogicalResult matchAndRewrite(EmptyOp op,
1006                                 PatternRewriter &rewriter) const override {
1007     SmallVector<Value> foldedDynamicSizes;
1008     RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1009         op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1010 
1011     // Stop here if no dynamic size was promoted to static.
1012     if (foldedTensorType == op.getType())
1013       return failure();
1014 
1015     auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
1016                                           foldedDynamicSizes);
1017     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1018     return success();
1019   }
1020 };
1021 
1022 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1023   using OpRewritePattern<DimOp>::OpRewritePattern;
1024 
1025   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1026                                 PatternRewriter &rewriter) const override {
1027     std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1028     auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1029     if (!emptyTensorOp || !maybeConstantIndex)
1030       return failure();
1031     auto emptyTensorType = emptyTensorOp.getType();
1032     if (*maybeConstantIndex < 0 ||
1033         *maybeConstantIndex >= emptyTensorType.getRank() ||
1034         !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1035       return failure();
1036     rewriter.replaceOp(dimOp,
1037                        emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1038     return success();
1039   }
1040 };
1041 
1042 /// Canonicalize
1043 ///
1044 /// ```mlir
1045 ///   %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1046 ///   %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1047 /// ```
1048 ///
1049 /// into
1050 ///
1051 /// ```mlir
1052 ///   %0 = tensor.empty(%d1) : tensor<4x?xf32>
1053 /// ```
1054 ///
1055 /// This assumes the input program is correct in terms of its shape. So it is
1056 /// safe to assume that `%d0` is in fact 4.
1057 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1058   using OpRewritePattern<CastOp>::OpRewritePattern;
1059 
1060   LogicalResult matchAndRewrite(CastOp castOp,
1061                                 PatternRewriter &rewriter) const override {
1062     if (!canFoldIntoProducerOp(castOp))
1063       return failure();
1064     auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1065     if (!producer)
1066       return failure();
1067 
1068     auto resultType =
1069         llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1070     ArrayRef<int64_t> resultShape = resultType.getShape();
1071     SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1072     SmallVector<OpFoldResult> newMixedSizes;
1073     newMixedSizes.reserve(currMixedSizes.size());
1074     assert(resultShape.size() == currMixedSizes.size() &&
1075            "mismatch in result shape and sizes of empty op");
1076     for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1077       int64_t newDim = std::get<0>(it);
1078       OpFoldResult currDim = std::get<1>(it);
1079       // Case 1: The empty tensor dim is static. Check that the tensor cast
1080       // result dim matches.
1081       if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1082         if (ShapedType::isDynamic(newDim) ||
1083             newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1084           // Something is off, the cast result shape cannot be more dynamic
1085           // than the empty tensor result shape (enforced by
1086           // `canFoldIntoProducer`). Abort for now.
1087           return rewriter.notifyMatchFailure(
1088               producer, "mismatch in static value of shape of empty tensor "
1089                         "result and cast result");
1090         }
1091         newMixedSizes.push_back(attr);
1092         continue;
1093       }
1094 
1095       // Case 2 : The tensor cast shape is static, but empty tensor result
1096       // shape is dynamic.
1097       if (!ShapedType::isDynamic(newDim)) {
1098         newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1099         continue;
1100       }
1101 
1102       // Case 3 : The tensor cast shape is dynamic and empty tensor result
1103       // shape is dynamic. Use the dynamic value from the empty tensor op.
1104       newMixedSizes.push_back(currDim);
1105     }
1106 
1107     // TODO: Do not drop tensor encoding.
1108     rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1109                                          resultType.getElementType());
1110     return success();
1111   }
1112 };
1113 
1114 } // namespace
1115 
1116 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1117                                           MLIRContext *context) {
1118   results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1119               ReplaceEmptyTensorStaticShapeDims>(context);
1120 }
1121 
1122 /// Try to remove a tensor operation if it would only reshape a constant.
1123 /// Removes the op and replaces the constant with a new constant of the result
1124 /// shape. When an optional cst attribute is passed, it is reshaped only if the
1125 /// splat value matches the value in the attribute.
1126 static OpFoldResult
1127 reshapeConstantSource(DenseElementsAttr source, TensorType result,
1128                       std::optional<Attribute> cst = std::nullopt) {
1129   if (source && source.isSplat() && result.hasStaticShape() &&
1130       (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
1131     return source.resizeSplat(result);
1132 
1133   return {};
1134 }
1135 
1136 //===----------------------------------------------------------------------===//
1137 // ExtractOp
1138 //===----------------------------------------------------------------------===//
1139 
1140 namespace {
1141 
1142 /// Canonicalizes the pattern of the form
1143 ///
1144 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1145 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1146 ///
1147 /// to
1148 ///
1149 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1150 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1151   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1152 
1153   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1154                                 PatternRewriter &rewriter) const final {
1155     auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1156     if (!tensorCast)
1157       return failure();
1158     if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1159       return failure();
1160     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1161         extract, tensorCast.getSource(), extract.getIndices());
1162     return success();
1163   }
1164 };
1165 
1166 } // namespace
1167 
1168 void ExtractOp::getAsmResultNames(
1169     function_ref<void(Value, StringRef)> setNameFn) {
1170   setNameFn(getResult(), "extracted");
1171 }
1172 
1173 LogicalResult ExtractOp::verify() {
1174   // Verify the # indices match if we have a ranked type.
1175   auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1176   if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1177     return emitOpError("incorrect number of indices for extract_element");
1178   return success();
1179 }
1180 
1181 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1182   if (Attribute tensor = adaptor.getTensor()) {
1183     // If this is a splat elements attribute, simply return the value.
1184     // All of the elements of a splat attribute are the same.
1185     if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1186       return splatTensor.getSplatValue<Attribute>();
1187 
1188     // If this is a dense resource elements attribute, return.
1189     if (isa<DenseResourceElementsAttr>(tensor))
1190       return {};
1191   }
1192 
1193   // Collect the constant indices into the tensor.
1194   SmallVector<uint64_t, 8> indices;
1195   for (Attribute indice : adaptor.getIndices()) {
1196     if (!indice || !llvm::isa<IntegerAttr>(indice))
1197       return {};
1198     indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1199   }
1200 
1201   // Fold extract(from_elements(...)).
1202   if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1203     auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1204     auto rank = tensorType.getRank();
1205     assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1206            "rank mismatch");
1207     int flatIndex = 0;
1208     int stride = 1;
1209     for (int i = rank - 1; i >= 0; --i) {
1210       flatIndex += indices[i] * stride;
1211       stride *= tensorType.getDimSize(i);
1212     }
1213     // Prevent out of bounds accesses. This can happen in invalid code that
1214     // will never execute.
1215     if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1216         flatIndex < 0)
1217       return {};
1218     return fromElementsOp.getElements()[flatIndex];
1219   }
1220 
1221   // If this is an elements attribute, query the value at the given indices.
1222   if (Attribute tensor = adaptor.getTensor()) {
1223     auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1224     if (elementsAttr && elementsAttr.isValidIndex(indices))
1225       return elementsAttr.getValues<Attribute>()[indices];
1226   }
1227 
1228   return {};
1229 }
1230 
1231 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1232                                             MLIRContext *context) {
1233   results.add<ExtractFromTensorCast>(context);
1234 }
1235 
1236 //===----------------------------------------------------------------------===//
1237 // FromElementsOp
1238 //===----------------------------------------------------------------------===//
1239 
1240 void FromElementsOp::getAsmResultNames(
1241     function_ref<void(Value, StringRef)> setNameFn) {
1242   setNameFn(getResult(), "from_elements");
1243 }
1244 
1245 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1246                            ValueRange elements) {
1247   assert(!elements.empty() && "expected at least one element");
1248   Type resultType = RankedTensorType::get(
1249       {static_cast<int64_t>(elements.size())}, elements.front().getType());
1250   build(builder, result, resultType, elements);
1251 }
1252 
1253 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1254   if (!llvm::is_contained(adaptor.getElements(), nullptr))
1255     return DenseElementsAttr::get(getType(), adaptor.getElements());
1256   return {};
1257 }
1258 
1259 namespace {
1260 
1261 // Pushes the index_casts that occur before extractions to after the extract.
1262 // This minimizes type conversion in some cases and enables the extract
1263 // canonicalizer. This changes:
1264 //
1265 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1266 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1267 //
1268 // to the following:
1269 //
1270 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1271 // %cast = arith.index_cast %extract : i32 to index
1272 //
1273 // to just %element.
1274 //
1275 // Consider expanding this to a template and handle all tensor cast
1276 // operations.
1277 struct ExtractElementFromIndexCast
1278     : public OpRewritePattern<tensor::ExtractOp> {
1279   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1280 
1281   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1282                                 PatternRewriter &rewriter) const final {
1283     Location loc = extract.getLoc();
1284     auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1285     if (!indexCast)
1286       return failure();
1287 
1288     Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1289 
1290     auto newExtract = rewriter.create<tensor::ExtractOp>(
1291         loc, elementTy, indexCast.getIn(), extract.getIndices());
1292 
1293     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1294                                                     newExtract);
1295 
1296     return success();
1297   }
1298 };
1299 
1300 } // namespace
1301 
1302 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1303                                                  MLIRContext *context) {
1304   results.add<ExtractElementFromIndexCast>(context);
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // GatherOp
1309 //===----------------------------------------------------------------------===//
1310 
1311 void GatherOp::getAsmResultNames(
1312     function_ref<void(Value, StringRef)> setNameFn) {
1313   setNameFn(getResult(), "gather");
1314 }
1315 
1316 /// Return the inferred result type for a gatherOp where:
1317 ///   - sourceType is the type of the source tensor gathered from
1318 ///   - indicesType is the type of the indices used to gather
1319 ///   - gatherDims are the dims along which the gather occurs.
1320 /// Return a full rank or ranked-reduced variant of the type depending on
1321 /// the value of rankReduced.
1322 ///
1323 /// The leading dimensions of the index tensor give the result tensor its
1324 /// leading dimensions.
1325 /// The trailing dimensions of the result tensor are obtained from the source
1326 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1327 /// rankedReduced is false), or skipping them (otherwise).
1328 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1329                                            RankedTensorType indicesType,
1330                                            ArrayRef<int64_t> gatherDims,
1331                                            bool rankReduced) {
1332   SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1333   resultShape.reserve(resultShape.size() + sourceType.getRank());
1334   for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1335     if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1336       if (!rankReduced)
1337         resultShape.push_back(1);
1338       continue;
1339     }
1340     resultShape.push_back(sourceType.getDimSize(idx));
1341   }
1342   return RankedTensorType::Builder(sourceType).setShape(resultShape);
1343 }
1344 
1345 static LogicalResult
1346 verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims,
1347                           ArrayRef<int64_t> indices, int64_t rank,
1348                           StringRef gatherOrScatter, StringRef sourceOrDest) {
1349   if (dims.empty())
1350     return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1351 
1352   int64_t numGatherDims = dims.size();
1353   if (numGatherDims > rank)
1354     return op->emitOpError(gatherOrScatter)
1355            << "_dims overflow " << sourceOrDest << " rank";
1356   if (indices.empty() || indices.back() != numGatherDims)
1357     return op->emitOpError(gatherOrScatter)
1358            << "_dims length must match the size of last dimension of indices";
1359   for (int64_t val : dims) {
1360     if (val < 0)
1361       return op->emitOpError(gatherOrScatter)
1362              << "_dims value must be non-negative";
1363     if (val >= rank)
1364       return op->emitOpError(gatherOrScatter)
1365              << "_dims value must be smaller than " << sourceOrDest << " rank";
1366   }
1367   for (int64_t i = 1; i < numGatherDims; ++i) {
1368     if (dims[i - 1] >= dims[i])
1369       return op->emitOpError(gatherOrScatter)
1370              << "_dims values must be strictly increasing";
1371   }
1372   return success();
1373 }
1374 
1375 LogicalResult GatherOp::verify() {
1376   int64_t sourceRank = getSourceType().getRank();
1377   ArrayRef<int64_t> gatherDims = getGatherDims();
1378   if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1379                                        getIndicesType().getShape(), sourceRank,
1380                                        "gather", "source")))
1381     return failure();
1382 
1383   RankedTensorType expectedResultType = GatherOp::inferResultType(
1384       getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1385   RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1386       getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1387   if (getResultType() != expectedResultType &&
1388       getResultType() != expectedRankReducedResultType) {
1389     return emitOpError("result type "
1390                        "mismatch: "
1391                        "expected ")
1392            << expectedResultType << " or its rank-reduced variant "
1393            << expectedRankReducedResultType << " (got: " << getResultType()
1394            << ")";
1395   }
1396 
1397   return success();
1398 }
1399 
1400 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1401   if (OpFoldResult reshapedSource = reshapeConstantSource(
1402           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1403           getResult().getType()))
1404     return reshapedSource;
1405   return {};
1406 }
1407 
1408 //===----------------------------------------------------------------------===//
1409 // InsertOp
1410 //===----------------------------------------------------------------------===//
1411 
1412 void InsertOp::getAsmResultNames(
1413     function_ref<void(Value, StringRef)> setNameFn) {
1414   setNameFn(getResult(), "inserted");
1415 }
1416 
1417 LogicalResult InsertOp::verify() {
1418   // Verify the # indices match if we have a ranked type.
1419   auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1420   if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1421     return emitOpError("incorrect number of indices");
1422   return success();
1423 }
1424 
1425 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1426   Attribute scalar = adaptor.getScalar();
1427   Attribute dest = adaptor.getDest();
1428   if (scalar && dest)
1429     if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1430       if (scalar == splatDest.getSplatValue<Attribute>())
1431         return dest;
1432   return {};
1433 }
1434 
1435 //===----------------------------------------------------------------------===//
1436 // GenerateOp
1437 //===----------------------------------------------------------------------===//
1438 
1439 void GenerateOp::getAsmResultNames(
1440     function_ref<void(Value, StringRef)> setNameFn) {
1441   setNameFn(getResult(), "generated");
1442 }
1443 
1444 LogicalResult GenerateOp::reifyResultShapes(
1445     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1446   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1447   int idx = 0;
1448   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1449     if (getType().isDynamicDim(dim)) {
1450       reifiedReturnShapes[0][dim] = getOperand(idx++);
1451     } else {
1452       reifiedReturnShapes[0][dim] =
1453           builder.getIndexAttr(getType().getDimSize(dim));
1454     }
1455   }
1456   return success();
1457 }
1458 
1459 LogicalResult GenerateOp::verify() {
1460   // Ensure that the tensor type has as many dynamic dimensions as are
1461   // specified by the operands.
1462   RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1463   if (getNumOperands() != resultType.getNumDynamicDims())
1464     return emitError("must have as many index operands as dynamic extents "
1465                      "in the result type");
1466   return success();
1467 }
1468 
1469 LogicalResult GenerateOp::verifyRegions() {
1470   RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1471   // Ensure that region arguments span the index space.
1472   if (!llvm::all_of(getBody().getArgumentTypes(),
1473                     [](Type ty) { return ty.isIndex(); }))
1474     return emitError("all body arguments must be index");
1475   if (getBody().getNumArguments() != resultTy.getRank())
1476     return emitError("must have one body argument per input dimension");
1477 
1478   // Ensure that the region yields an element of the right type.
1479   auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1480 
1481   if (yieldOp.getValue().getType() != resultTy.getElementType())
1482     return emitOpError(
1483         "body must be terminated with a `yield` operation of the tensor "
1484         "element type");
1485 
1486   return success();
1487 }
1488 
1489 void GenerateOp::build(
1490     OpBuilder &b, OperationState &result, Type resultTy,
1491     ValueRange dynamicExtents,
1492     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1493   build(b, result, resultTy, dynamicExtents);
1494 
1495   // Build and populate body.
1496   OpBuilder::InsertionGuard guard(b);
1497   Region *bodyRegion = result.regions.front().get();
1498   auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1499   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1500   SmallVector<Location, 2> argumentLocs(rank, result.location);
1501   Block *bodyBlock =
1502       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1503   bodyBuilder(b, result.location, bodyBlock->getArguments());
1504 }
1505 
1506 namespace {
1507 
1508 /// Canonicalizes tensor.generate operations with a constant
1509 /// operand into the equivalent operation with the operand expressed in the
1510 /// result type, instead. We also insert a type cast to make sure that the
1511 /// resulting IR is still well-typed.
1512 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1513   using OpRewritePattern<GenerateOp>::OpRewritePattern;
1514 
1515   LogicalResult matchAndRewrite(GenerateOp generateOp,
1516                                 PatternRewriter &rewriter) const final {
1517     SmallVector<Value> foldedDynamicSizes;
1518     RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1519         generateOp.getType(), generateOp.getDynamicExtents(),
1520         foldedDynamicSizes);
1521 
1522     // Stop here if no dynamic size was promoted to static.
1523     if (foldedTensorType == generateOp.getType())
1524       return failure();
1525 
1526     auto loc = generateOp.getLoc();
1527     auto newOp =
1528         rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1529     rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1530                                 newOp.getBody().begin());
1531     rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1532                                                 generateOp.getType(), newOp);
1533     return success();
1534   }
1535 };
1536 
1537 /// Canonicalizes the pattern of the form
1538 ///
1539 /// %tensor = tensor.generate %x {
1540 ///   ^bb0(%arg0: index):
1541 ///   <computation>
1542 ///   yield %1 : index
1543 /// } : tensor<?xindex>
1544 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1545 ///
1546 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1547 /// tensor.generate operation has no side-effects.
1548 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1549   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1550 
1551   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1552                                 PatternRewriter &rewriter) const final {
1553     auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1554     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1555       return failure();
1556 
1557     IRMapping mapping;
1558     Block *body = &tensorFromElements.getBody().front();
1559     mapping.map(body->getArguments(), extract.getIndices());
1560     for (auto &op : body->without_terminator())
1561       rewriter.clone(op, mapping);
1562 
1563     auto yield = cast<YieldOp>(body->getTerminator());
1564 
1565     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1566     return success();
1567   }
1568 };
1569 
1570 } // namespace
1571 
1572 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1573                                              MLIRContext *context) {
1574   // TODO: Move extract pattern to tensor::ExtractOp.
1575   results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1576 }
1577 
1578 //===----------------------------------------------------------------------===//
1579 // RankOp
1580 //===----------------------------------------------------------------------===//
1581 
1582 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1583   setNameFn(getResult(), "rank");
1584 }
1585 
1586 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1587   // Constant fold rank when the rank of the operand is known.
1588   auto type = getOperand().getType();
1589   auto shapedType = llvm::dyn_cast<ShapedType>(type);
1590   if (shapedType && shapedType.hasRank())
1591     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1592   return IntegerAttr();
1593 }
1594 
1595 //===----------------------------------------------------------------------===//
1596 // ReshapeOp
1597 //===----------------------------------------------------------------------===//
1598 
1599 void ReshapeOp::getAsmResultNames(
1600     function_ref<void(Value, StringRef)> setNameFn) {
1601   setNameFn(getResult(), "reshape");
1602 }
1603 
1604 static int64_t getNumElements(ShapedType type) {
1605   int64_t numElements = 1;
1606   for (auto dim : type.getShape())
1607     numElements *= dim;
1608   return numElements;
1609 }
1610 
1611 LogicalResult ReshapeOp::verify() {
1612   TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1613   TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1614 
1615   if (operandType.getElementType() != resultType.getElementType())
1616     return emitOpError("element types of source and destination tensor "
1617                        "types should be the same");
1618 
1619   int64_t shapeSize =
1620       llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1621   auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1622   auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1623 
1624   if (resultRankedType) {
1625     if (operandRankedType && resultRankedType.hasStaticShape() &&
1626         operandRankedType.hasStaticShape()) {
1627       if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1628         return emitOpError("source and destination tensor should have the "
1629                            "same number of elements");
1630     }
1631     if (ShapedType::isDynamic(shapeSize))
1632       return emitOpError("cannot use shape operand with dynamic length to "
1633                          "reshape to statically-ranked tensor type");
1634     if (shapeSize != resultRankedType.getRank())
1635       return emitOpError(
1636           "length of shape operand differs from the result's tensor rank");
1637   }
1638   return success();
1639 }
1640 
1641 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1642   if (OpFoldResult reshapedSource = reshapeConstantSource(
1643           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1644           getResult().getType()))
1645     return reshapedSource;
1646 
1647   // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1648   // producer's input instead as the original tensor to reshape. This could
1649   // render such producer dead code.
1650   if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1651     getSourceMutable().assign(reshapeOpProducer.getSource());
1652     return getResult();
1653   }
1654 
1655   auto source = getSource();
1656   auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1657   auto resultTy = dyn_cast<RankedTensorType>(getType());
1658   if (!sourceTy || !resultTy || sourceTy != resultTy)
1659     return {};
1660 
1661   // If the source and result are both 1D tensors and have the same type, the
1662   // reshape has no effect, even if the tensor is dynamically shaped.
1663   if (sourceTy.getRank() == 1)
1664     return source;
1665 
1666   if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1667     auto elements = fromElements.getElements();
1668     bool dynamicNoop =
1669         sourceTy.getRank() == static_cast<int64_t>(elements.size());
1670     for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1671       auto element = elements[id];
1672 
1673       if (auto cst = getConstantIntValue(element)) {
1674         dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1675         continue;
1676       }
1677 
1678       if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1679         dynamicNoop &= dimOp.getSource() == source;
1680 
1681         APSInt dim;
1682         auto cst = getConstantIntValue(dimOp.getIndex());
1683         dynamicNoop &=
1684             cst.has_value() && cst.value() == static_cast<int64_t>(id);
1685         continue;
1686       }
1687 
1688       dynamicNoop = false;
1689       break;
1690     }
1691 
1692     if (dynamicNoop)
1693       return source;
1694   }
1695 
1696   return {};
1697 }
1698 
1699 //===----------------------------------------------------------------------===//
1700 // Reassociative reshape ops
1701 //===----------------------------------------------------------------------===//
1702 
1703 void CollapseShapeOp::getAsmResultNames(
1704     function_ref<void(Value, StringRef)> setNameFn) {
1705   setNameFn(getResult(), "collapsed");
1706 }
1707 
1708 void ExpandShapeOp::getAsmResultNames(
1709     function_ref<void(Value, StringRef)> setNameFn) {
1710   setNameFn(getResult(), "expanded");
1711 }
1712 
1713 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1714   assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1715          "invalid resultDim");
1716   for (const auto &it : llvm::enumerate(getReassociationIndices()))
1717     if (llvm::is_contained(it.value(), resultDim))
1718       return it.index();
1719   llvm_unreachable("could not find reassociation group");
1720 }
1721 
1722 FailureOr<SmallVector<OpFoldResult>>
1723 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1724                                 RankedTensorType expandedType,
1725                                 ArrayRef<ReassociationIndices> reassociation,
1726                                 ArrayRef<OpFoldResult> inputShape) {
1727   std::optional<SmallVector<OpFoldResult>> outputShape =
1728       inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1729                                   inputShape);
1730   if (!outputShape)
1731     return failure();
1732   return *outputShape;
1733 }
1734 
1735 SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1736   return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1737 }
1738 
1739 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1740                           Type resultType, Value src,
1741                           ArrayRef<ReassociationIndices> reassociation,
1742                           ArrayRef<OpFoldResult> outputShape) {
1743   auto [staticOutputShape, dynamicOutputShape] =
1744       decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1745   build(builder, result, cast<RankedTensorType>(resultType), src,
1746         getReassociationIndicesAttribute(builder, reassociation),
1747         dynamicOutputShape, staticOutputShape);
1748 }
1749 
1750 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1751                           Type resultType, Value src,
1752                           ArrayRef<ReassociationIndices> reassociation) {
1753   SmallVector<OpFoldResult> inputShape =
1754       getMixedSizes(builder, result.location, src);
1755   auto tensorResultTy = cast<RankedTensorType>(resultType);
1756   FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1757       builder, result.location, tensorResultTy, reassociation, inputShape);
1758   SmallVector<OpFoldResult> outputShapeOrEmpty;
1759   if (succeeded(outputShape)) {
1760     outputShapeOrEmpty = *outputShape;
1761   }
1762   build(builder, result, tensorResultTy, src, reassociation,
1763         outputShapeOrEmpty);
1764 }
1765 
1766 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1767   return getSymbolLessAffineMaps(getReassociationExprs());
1768 }
1769 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1770   return convertReassociationIndicesToExprs(getContext(),
1771                                             getReassociationIndices());
1772 }
1773 
1774 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1775   return getSymbolLessAffineMaps(getReassociationExprs());
1776 }
1777 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1778   return convertReassociationIndicesToExprs(getContext(),
1779                                             getReassociationIndices());
1780 }
1781 
1782 RankedTensorType CollapseShapeOp::inferCollapsedType(
1783     RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1784   return inferCollapsedType(
1785       type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1786                 type.getContext(), reassociation)));
1787 }
1788 
1789 /// Compute the RankedTensorType obtained by applying `reassociation` to
1790 /// `type`.
1791 RankedTensorType
1792 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1793                                     ArrayRef<AffineMap> reassociation) {
1794   auto shape = type.getShape();
1795   SmallVector<int64_t, 4> newShape;
1796   newShape.reserve(reassociation.size());
1797 
1798   // Use the fact that reassociation is valid to simplify the logic: only use
1799   // each map's rank.
1800   assert(isReassociationValid(reassociation) && "invalid reassociation");
1801   unsigned currentDim = 0;
1802   for (AffineMap m : reassociation) {
1803     unsigned dim = m.getNumResults();
1804     auto band = shape.slice(currentDim, dim);
1805     int64_t size = 1;
1806     if (llvm::is_contained(band, ShapedType::kDynamic))
1807       size = ShapedType::kDynamic;
1808     else
1809       for (unsigned d = 0; d < dim; ++d)
1810         size *= shape[currentDim + d];
1811     newShape.push_back(size);
1812     currentDim += dim;
1813   }
1814 
1815   return RankedTensorType::get(newShape, type.getElementType());
1816 }
1817 
1818 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1819                             ArrayRef<ReassociationIndices> reassociation,
1820                             ArrayRef<NamedAttribute> attrs) {
1821   auto resultType = inferCollapsedType(
1822       llvm::cast<RankedTensorType>(src.getType()),
1823       getSymbolLessAffineMaps(
1824           convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1825   result.addAttribute(getReassociationAttrStrName(),
1826                       getReassociationIndicesAttribute(b, reassociation));
1827   build(b, result, resultType, src, attrs);
1828 }
1829 
1830 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1831                                         TensorReshapeOp, ExpandShapeOp>::value>
1832 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1833                                            RankedTensorType expandedType,
1834                                            RankedTensorType collapsedType) {
1835   if (failed(
1836           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1837     return failure();
1838 
1839   auto maps = op.getReassociationMaps();
1840   RankedTensorType expectedType =
1841       CollapseShapeOp::inferCollapsedType(expandedType, maps);
1842   if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1843     return op.emitOpError("expected collapsed type to be ")
1844            << expectedType << ", but got " << collapsedType;
1845   return success();
1846 }
1847 
1848 LogicalResult ExpandShapeOp::verify() {
1849   auto srcType = getSrcType();
1850   auto resultType = getResultType();
1851 
1852   if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1853     return emitOpError("expected number of static shape dims to be equal to "
1854                        "the output rank (")
1855            << resultType.getRank() << ") but found "
1856            << getStaticOutputShape().size() << " inputs instead";
1857 
1858   if ((int64_t)getOutputShape().size() !=
1859       llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1860     return emitOpError("mismatch in dynamic dims in output_shape and "
1861                        "static_output_shape: static_output_shape has ")
1862            << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1863            << " dynamic dims while output_shape has " << getOutputShape().size()
1864            << " values";
1865 
1866   return verifyTensorReshapeOp(*this, resultType, srcType);
1867 }
1868 
1869 LogicalResult CollapseShapeOp::verify() {
1870   return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1871 }
1872 
1873 namespace {
1874 /// Reshape of a splat constant can be replaced with a constant of the result
1875 /// type.
1876 template <typename TensorReshapeOp>
1877 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1878   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1879   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1880                                 PatternRewriter &rewriter) const override {
1881     DenseElementsAttr attr;
1882     if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1883       return failure();
1884     if (!attr || !attr.isSplat())
1885       return failure();
1886     DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
1887         reshapeOp.getResultType(), attr.getRawData());
1888     rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1889     return success();
1890   }
1891 };
1892 
1893 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1894 template <typename TensorReshapeOp>
1895 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1896 public:
1897   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1898 
1899   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1900                                 PatternRewriter &rewriter) const override {
1901     auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1902     if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1903       return failure();
1904 
1905     rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1906         reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1907     return success();
1908   }
1909 };
1910 
1911 /// Reshape of a FromElements can be replaced with a FromElements of the
1912 /// result type
1913 template <typename TensorReshapeOp>
1914 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1915   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1916   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1917                                 PatternRewriter &rewriter) const override {
1918     auto fromElements =
1919         reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1920     if (!fromElements)
1921       return failure();
1922 
1923     auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1924 
1925     if (!shapedTy.hasStaticShape())
1926       return failure();
1927 
1928     rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1929                                                 fromElements.getElements());
1930     return success();
1931   }
1932 };
1933 
1934 // Fold CastOp into CollapseShapeOp when adding static information.
1935 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1936   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1937 
1938   LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1939                                 PatternRewriter &rewriter) const override {
1940     auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1941     if (!tensor::canFoldIntoConsumerOp(castOp))
1942       return failure();
1943 
1944     RankedTensorType srcType =
1945         llvm::cast<RankedTensorType>(castOp.getSource().getType());
1946     RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1947         srcType, collapseShapeOp.getReassociationMaps());
1948 
1949     if (newResultType == collapseShapeOp.getResultType()) {
1950       rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1951         collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1952       });
1953     } else {
1954       auto newOp = rewriter.create<CollapseShapeOp>(
1955           collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1956           collapseShapeOp.getReassociation());
1957       rewriter.replaceOpWithNewOp<tensor::CastOp>(
1958           collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1959     }
1960     return success();
1961   }
1962 };
1963 
1964 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1965   using OpRewritePattern<DimOp>::OpRewritePattern;
1966 
1967   LogicalResult matchAndRewrite(DimOp dimOp,
1968                                 PatternRewriter &rewriter) const override {
1969     auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1970     if (!expandShapeOp)
1971       return failure();
1972 
1973     // Only constant dimension values are supported.
1974     std::optional<int64_t> dim = dimOp.getConstantIndex();
1975     if (!dim.has_value())
1976       return failure();
1977 
1978     // Skip static dims. These are folded to constant ops.
1979     RankedTensorType resultType = expandShapeOp.getResultType();
1980     if (!resultType.isDynamicDim(*dim))
1981       return failure();
1982 
1983     // Find reassociation group that contains this result dimension.
1984     int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1985 
1986     // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1987     // ExpandShapeOp would be ambiguous.)
1988     int64_t product = 1;
1989     ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1990     for (int64_t d : grp) {
1991       if (d != dim) {
1992         assert(!resultType.isDynamicDim(d) && "expected static dim");
1993         product *= resultType.getDimSize(d);
1994       }
1995     }
1996 
1997     // result dim size = src dim size / (product(other dims in reassoc group))
1998     Value srcDimSz =
1999         rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
2000     AffineExpr expr;
2001     bindSymbols(dimOp.getContext(), expr);
2002     rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
2003         dimOp, expr.floorDiv(product), srcDimSz);
2004     return success();
2005   }
2006 };
2007 
2008 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
2009   using OpRewritePattern<DimOp>::OpRewritePattern;
2010 
2011   LogicalResult matchAndRewrite(DimOp dimOp,
2012                                 PatternRewriter &rewriter) const override {
2013     auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
2014     if (!collapseShapeOp)
2015       return failure();
2016 
2017     // Only constant dimension values are supported.
2018     std::optional<int64_t> dim = dimOp.getConstantIndex();
2019     if (!dim.has_value() ||
2020         dim.value() >= collapseShapeOp.getResultType().getRank())
2021       return failure();
2022 
2023     // Skip static dims. These are folded to constant ops.
2024     RankedTensorType resultType = collapseShapeOp.getResultType();
2025     if (!resultType.isDynamicDim(*dim))
2026       return failure();
2027 
2028     // Get reassociation group of the result dimension.
2029     ReassociationIndices group =
2030         collapseShapeOp.getReassociationIndices()[*dim];
2031 
2032     // result dim size = product(dims in reassoc group)
2033     SmallVector<Value> srcDimSizes;
2034     SmallVector<AffineExpr> syms;
2035     AffineExpr product;
2036     for (const auto &it : llvm::enumerate(group)) {
2037       srcDimSizes.push_back(rewriter.create<DimOp>(
2038           dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
2039       syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
2040       product = product ? product * syms.back() : syms.back();
2041     }
2042     rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
2043                                                        srcDimSizes);
2044     return success();
2045   }
2046 };
2047 
2048 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2049 /// matching constant output_shape operands of the expand. This makes the
2050 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2051 /// propagated further.
2052 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2053   using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2054 
2055   LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2056                                 PatternRewriter &rewriter) const override {
2057     auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2058     if (!canFoldIntoConsumerOp(castOp))
2059       return failure();
2060 
2061     ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2062     SmallVector<ReassociationIndices, 4> reassoc =
2063         expandOp.getReassociationIndices();
2064 
2065     SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2066     SmallVector<Value> dynamicOutputShape;
2067     auto outputIt = expandOp.getOutputShape().begin();
2068 
2069     for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2070       for (uint64_t outDim : innerReassoc) {
2071         if (!ShapedType::isDynamic(newOutputShape[outDim]))
2072           continue;
2073 
2074         // If the cast's src type is dynamic, don't infer any of the
2075         // corresponding expanded dimensions. `tensor.expand_shape` requires at
2076         // least one of the expanded dimensions to be dynamic if the input is
2077         // dynamic.
2078         Value val = *outputIt;
2079         ++outputIt;
2080         if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2081           dynamicOutputShape.push_back(val);
2082           continue;
2083         }
2084 
2085         APInt cst;
2086         if (matchPattern(val, m_ConstantInt(&cst))) {
2087           newOutputShape[outDim] = cst.getSExtValue();
2088         } else {
2089           dynamicOutputShape.push_back(val);
2090         }
2091       }
2092     }
2093 
2094     // Couldn't match any values, nothing to change
2095     if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2096       return failure();
2097 
2098     // Calculate the input shape from the output
2099     SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2100     for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2101       for (auto outDim : reassoc[inDim]) {
2102         auto ofr = newOutputShape[outDim];
2103         if (ShapedType::isDynamic(ofr)) {
2104           newInputShape[inDim] = ShapedType::kDynamic;
2105           break;
2106         }
2107         newInputShape[inDim] *= ofr;
2108       }
2109     }
2110 
2111     SmallVector<OpFoldResult> outputOfr =
2112         getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2113     auto inputType = RankedTensorType::get(
2114         newInputShape, expandOp.getSrcType().getElementType());
2115     auto outputType = RankedTensorType::get(
2116         newOutputShape, expandOp.getSrcType().getElementType());
2117     auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2118                                              expandOp.getSrc());
2119     auto newExpand = rewriter.create<ExpandShapeOp>(
2120         expandOp.getLoc(), outputType, inputCast.getResult(),
2121         expandOp.getReassociationIndices(), outputOfr);
2122     rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2123                                         newExpand.getResult());
2124     return success();
2125   }
2126 };
2127 } // namespace
2128 
2129 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2130                                                 MLIRContext *context) {
2131   results.add<
2132       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2133       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2134       ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2135       FoldReshapeWithSplat<ExpandShapeOp>,
2136       FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
2137       FoldDimOfCollapseShape>(context);
2138 }
2139 
2140 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2141                                                   MLIRContext *context) {
2142   results.add<
2143       ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2144       ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2145                                 tensor::DimOp, RankedTensorType>,
2146       FoldReshapeWithConstant<CollapseShapeOp>,
2147       FoldReshapeWithSplat<CollapseShapeOp>,
2148       FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2149       context);
2150 }
2151 
2152 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2153   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2154                                                        adaptor.getOperands());
2155 }
2156 
2157 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2158   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2159                                                        adaptor.getOperands());
2160 }
2161 
2162 //===----------------------------------------------------------------------===//
2163 // ExtractSliceOp
2164 //===----------------------------------------------------------------------===//
2165 
2166 void ExtractSliceOp::getAsmResultNames(
2167     function_ref<void(Value, StringRef)> setNameFn) {
2168   setNameFn(getResult(), "extracted_slice");
2169 }
2170 
2171 /// An extract_slice result type can be inferred, when it is not
2172 /// rank-reduced, from the source type and the static representation of
2173 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2174 RankedTensorType ExtractSliceOp::inferResultType(
2175     RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2176     ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2177   // An extract_slice op may specify only a leading subset of offset/sizes/
2178   // strides in which case we complete with offset=0, sizes from memref type
2179   // and strides=1.
2180   assert(static_cast<int64_t>(staticSizes.size()) ==
2181              sourceTensorType.getRank() &&
2182          "unexpected staticSizes not equal to rank of source");
2183   return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2184                                sourceTensorType.getEncoding());
2185 }
2186 
2187 RankedTensorType ExtractSliceOp::inferResultType(
2188     RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2189     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2190   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2191   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2192   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2193   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2194   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2195   return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2196                                          staticSizes, staticStrides);
2197 }
2198 
2199 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2200 /// number of sizes), drop as many size 1 as needed to produce an inferred
2201 /// type with the desired rank.
2202 ///
2203 /// Note that there may be multiple ways to compute this rank-reduced type:
2204 ///   e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2205 ///
2206 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2207 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2208     unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2209     ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2210     ArrayRef<int64_t> strides) {
2211   // Type inferred in the absence of rank-reducing behavior.
2212   auto inferredType = llvm::cast<RankedTensorType>(
2213       inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2214   int rankDiff = inferredType.getRank() - desiredResultRank;
2215   if (rankDiff > 0) {
2216     auto shape = inferredType.getShape();
2217     llvm::SmallBitVector dimsToProject =
2218         getPositionsOfShapeOne(rankDiff, shape);
2219     SmallVector<int64_t> projectedShape;
2220     // Best effort rank-reducing: drop 1s in order.
2221     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2222       if (!dimsToProject.test(pos))
2223         projectedShape.push_back(shape[pos]);
2224     inferredType =
2225         RankedTensorType::get(projectedShape, inferredType.getElementType());
2226   }
2227   return inferredType;
2228 }
2229 
2230 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2231     unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2232     ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2233     ArrayRef<OpFoldResult> strides) {
2234   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2235   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2236   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2237   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2238   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2239   return ExtractSliceOp::inferCanonicalRankReducedResultType(
2240       desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2241       staticStrides);
2242 }
2243 
2244 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2245 /// result type. If the type passed is nullptr, it is inferred.
2246 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2247                            RankedTensorType resultType, Value source,
2248                            ArrayRef<OpFoldResult> offsets,
2249                            ArrayRef<OpFoldResult> sizes,
2250                            ArrayRef<OpFoldResult> strides,
2251                            ArrayRef<NamedAttribute> attrs) {
2252   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2253   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2254   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2255   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2256   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2257   auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2258   // Structuring implementation this way avoids duplication between builders.
2259   if (!resultType) {
2260     resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2261         sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2262   }
2263   result.addAttributes(attrs);
2264   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2265         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2266         b.getDenseI64ArrayAttr(staticSizes),
2267         b.getDenseI64ArrayAttr(staticStrides));
2268 }
2269 
2270 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2271 /// result type.
2272 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2273                            ArrayRef<OpFoldResult> offsets,
2274                            ArrayRef<OpFoldResult> sizes,
2275                            ArrayRef<OpFoldResult> strides,
2276                            ArrayRef<NamedAttribute> attrs) {
2277   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2278 }
2279 
2280 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2281 /// a Range vector.
2282 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2283                            ArrayRef<Range> ranges,
2284                            ArrayRef<NamedAttribute> attrs) {
2285   auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2286   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2287 }
2288 
2289 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2290 /// the type passed is nullptr, it is inferred.
2291 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2292                            RankedTensorType resultType, Value source,
2293                            ValueRange offsets, ValueRange sizes,
2294                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2295   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2296       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2297   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2298       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2299   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2300       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2301   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2302 }
2303 
2304 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2305 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2306                            ValueRange offsets, ValueRange sizes,
2307                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2308   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2309 }
2310 
2311 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2312                                           Operation *op,
2313                                           RankedTensorType expectedType) {
2314   switch (result) {
2315   case SliceVerificationResult::Success:
2316     return success();
2317   case SliceVerificationResult::RankTooLarge:
2318     return op->emitError("expected rank to be smaller or equal to ")
2319            << "the other rank. ";
2320   case SliceVerificationResult::SizeMismatch:
2321     return op->emitError("expected type to be ")
2322            << expectedType << " or a rank-reduced version. (size mismatch) ";
2323   case SliceVerificationResult::ElemTypeMismatch:
2324     return op->emitError("expected element type to be ")
2325            << expectedType.getElementType();
2326   default:
2327     llvm_unreachable("unexpected extract_slice op verification result");
2328   }
2329 }
2330 
2331 /// Verifier for ExtractSliceOp.
2332 LogicalResult ExtractSliceOp::verify() {
2333   // Verify result type against inferred type.
2334   RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2335       getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2336   SliceVerificationResult result = isRankReducedType(expectedType, getType());
2337   return produceSliceErrorMsg(result, *this, expectedType);
2338 }
2339 
2340 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2341   return ::getDroppedDims(getType().getShape(), getMixedSizes());
2342 }
2343 
2344 FailureOr<Value>
2345 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2346                                    ArrayRef<int64_t> desiredShape) {
2347   auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2348   assert(sourceTensorType && "not a ranked tensor type");
2349   auto sourceShape = sourceTensorType.getShape();
2350   if (sourceShape.equals(desiredShape))
2351     return value;
2352   auto maybeRankReductionMask =
2353       mlir::computeRankReductionMask(sourceShape, desiredShape);
2354   if (!maybeRankReductionMask)
2355     return failure();
2356   return createCanonicalRankReducingExtractSliceOp(
2357       b, loc, value,
2358       RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2359 }
2360 
2361 LogicalResult ExtractSliceOp::reifyResultShapes(
2362     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2363   reifiedReturnShapes.resize(1);
2364   reifiedReturnShapes[0].reserve(getType().getRank());
2365   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2366   llvm::SmallBitVector droppedDims = getDroppedDims();
2367   for (const auto &size : enumerate(mixedSizes)) {
2368     if (droppedDims.test(size.index()))
2369       continue;
2370     reifiedReturnShapes[0].push_back(size.value());
2371   }
2372   return success();
2373 }
2374 
2375 namespace {
2376 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2377 /// This essentially pushes memref_cast past its consuming slice when
2378 /// `canFoldIntoConsumerOp` is true.
2379 ///
2380 /// Example:
2381 /// ```
2382 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2383 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2384 ///   tensor<3x4xf32>
2385 /// ```
2386 /// is rewritten into:
2387 /// ```
2388 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2389 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2390 /// ```
2391 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2392 public:
2393   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2394 
2395   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2396                                 PatternRewriter &rewriter) const override {
2397     // Any constant operand, just return to let the constant folder kick in.
2398     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2399           return matchPattern(operand, matchConstantIndex());
2400         }))
2401       return failure();
2402 
2403     auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2404     if (!castOp)
2405       return failure();
2406 
2407     if (!canFoldIntoConsumerOp(castOp))
2408       return failure();
2409 
2410     // Create folded extract.
2411     Location loc = sliceOp.getLoc();
2412     Value newResult = rewriter.create<ExtractSliceOp>(
2413         loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2414         sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2415         sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2416     if (newResult.getType() != sliceOp.getType())
2417       newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
2418     rewriter.replaceOp(sliceOp, newResult);
2419     return success();
2420   }
2421 };
2422 
2423 /// Slice elements from `values` into `outValues`. `counts` represents the
2424 /// numbers of elements to stride in the original values for each dimension.
2425 /// The output values can be used to construct a DenseElementsAttr.
2426 template <typename IterTy, typename ElemTy>
2427 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2428                           ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2429                           ArrayRef<int64_t> strides,
2430                           llvm::SmallVectorImpl<ElemTy> *outValues) {
2431   assert(offsets.size() == sizes.size());
2432   assert(offsets.size() == strides.size());
2433   if (offsets.empty())
2434     return;
2435 
2436   int64_t offset = offsets.front();
2437   int64_t size = sizes.front();
2438   int64_t stride = strides.front();
2439   if (offsets.size() == 1) {
2440     for (int64_t i = 0; i < size; ++i, offset += stride)
2441       outValues->push_back(*(values + offset));
2442 
2443     return;
2444   }
2445 
2446   for (int64_t i = 0; i < size; ++i, offset += stride) {
2447     auto begin = values + offset * counts.front();
2448     sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2449                                   offsets.drop_front(), sizes.drop_front(),
2450                                   strides.drop_front(), outValues);
2451   }
2452 }
2453 
2454 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2455 /// folded operation might introduce more constant data; Users can control
2456 /// their heuristics by the control function.
2457 class ConstantOpExtractSliceFolder final
2458     : public OpRewritePattern<ExtractSliceOp> {
2459 public:
2460   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2461 
2462   ConstantOpExtractSliceFolder(MLIRContext *context,
2463                                ControlConstantExtractSliceFusionFn controlFn)
2464       : OpRewritePattern<ExtractSliceOp>(context),
2465         controlFn(std::move(controlFn)) {}
2466 
2467   LogicalResult matchAndRewrite(ExtractSliceOp op,
2468                                 PatternRewriter &rewriter) const override {
2469     DenseElementsAttr attr;
2470     if (!matchPattern(op.getSource(), m_Constant(&attr)))
2471       return failure();
2472 
2473     // A constant splat is handled by fold().
2474     if (attr.isSplat())
2475       return failure();
2476 
2477     // Dynamic result shape is not supported.
2478     auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2479     auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2480     if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2481       return failure();
2482 
2483     // Customized control over the folding.
2484     if (!controlFn(op))
2485       return failure();
2486 
2487     int64_t count = sourceType.getNumElements();
2488     if (count == 0)
2489       return failure();
2490 
2491     // Check if there are any dynamic parts, which are not supported.
2492     auto offsets = op.getStaticOffsets();
2493     if (llvm::is_contained(offsets, ShapedType::kDynamic))
2494       return failure();
2495     auto sizes = op.getStaticSizes();
2496     if (llvm::is_contained(sizes, ShapedType::kDynamic))
2497       return failure();
2498     auto strides = op.getStaticStrides();
2499     if (llvm::is_contained(strides, ShapedType::kDynamic))
2500       return failure();
2501 
2502     // Compute the stride for each dimension.
2503     SmallVector<int64_t> counts;
2504     ArrayRef<int64_t> shape = sourceType.getShape();
2505     counts.reserve(shape.size());
2506     for (int64_t v : shape) {
2507       count = count / v;
2508       counts.push_back(count);
2509     }
2510 
2511     // New attribute constructed by the sliced values.
2512     DenseElementsAttr newAttr;
2513 
2514     if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2515       SmallVector<APInt> outValues;
2516       outValues.reserve(sourceType.getNumElements());
2517       sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2518           elems.begin(), counts, offsets, sizes, strides, &outValues);
2519       newAttr = DenseElementsAttr::get(resultType, outValues);
2520     } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2521       SmallVector<APFloat> outValues;
2522       outValues.reserve(sourceType.getNumElements());
2523       sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2524           elems.begin(), counts, offsets, sizes, strides, &outValues);
2525       newAttr = DenseElementsAttr::get(resultType, outValues);
2526     }
2527 
2528     if (newAttr) {
2529       rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2530       return success();
2531     }
2532 
2533     return failure();
2534   }
2535 
2536 private:
2537   /// This additionally controls whether the fold happens or not. Users can
2538   /// impose their heuristics in the function.
2539   ControlConstantExtractSliceFusionFn controlFn;
2540 };
2541 
2542 } // namespace
2543 
2544 void mlir::tensor::populateFoldConstantExtractSlicePatterns(
2545     RewritePatternSet &patterns,
2546     const ControlConstantExtractSliceFusionFn &controlFn) {
2547   patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2548 }
2549 
2550 /// Return the canonical type of the result of an extract_slice op.
2551 struct SliceReturnTypeCanonicalizer {
2552   RankedTensorType operator()(ExtractSliceOp op,
2553                               ArrayRef<OpFoldResult> mixedOffsets,
2554                               ArrayRef<OpFoldResult> mixedSizes,
2555                               ArrayRef<OpFoldResult> mixedStrides) {
2556     return ExtractSliceOp::inferCanonicalRankReducedResultType(
2557         op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2558         mixedStrides);
2559   }
2560 };
2561 
2562 /// A canonicalizer wrapper to replace ExtractSliceOps.
2563 struct SliceCanonicalizer {
2564   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2565                   ExtractSliceOp newOp) {
2566     Value replacement = newOp.getResult();
2567     if (replacement.getType() != op.getType())
2568       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2569                                                     replacement);
2570     rewriter.replaceOp(op, replacement);
2571   }
2572 };
2573 
2574 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2575                                                  MLIRContext *context) {
2576   results.add<
2577       OpWithOffsetSizesAndStridesConstantArgumentFolder<
2578           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2579       ExtractSliceOpCastFolder>(context);
2580 }
2581 
2582 //
2583 static LogicalResult
2584 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2585                                            ShapedType shapedType) {
2586   OpBuilder b(op.getContext());
2587   for (OpFoldResult ofr : op.getMixedOffsets())
2588     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2589       return failure();
2590   // Rank-reducing noops only need to inspect the leading dimensions:
2591   // llvm::zip is appropriate.
2592   auto shape = shapedType.getShape();
2593   for (auto it : llvm::zip(op.getMixedSizes(), shape))
2594     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2595       return failure();
2596   for (OpFoldResult ofr : op.getMixedStrides())
2597     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2598       return failure();
2599   return success();
2600 }
2601 
2602 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2603 /// slice, we can return the InsertSliceOp's source directly.
2604 // TODO: This only checks the immediate producer; extend to go up the
2605 // insert/extract chain if the slices are disjoint.
2606 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2607   auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2608 
2609   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2610   if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2611       insertOp.isSameAs(extractOp, isSame))
2612     return insertOp.getSource();
2613 
2614   return {};
2615 }
2616 
2617 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2618   if (OpFoldResult reshapedSource = reshapeConstantSource(
2619           llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2620           getResult().getType()))
2621     return reshapedSource;
2622   if (getSourceType() == getType() &&
2623       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
2624     return this->getSource();
2625   if (Value slice = foldExtractAfterInsertSlice(*this))
2626     return slice;
2627 
2628   return OpFoldResult();
2629 }
2630 
2631 Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
2632     OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2633   auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2634   unsigned rank = rankedTensorType.getRank();
2635   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2636   SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2637   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2638   return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2639                                                 offsets, sizes, strides);
2640 }
2641 
2642 //===----------------------------------------------------------------------===//
2643 // InsertSliceOp
2644 //===----------------------------------------------------------------------===//
2645 
2646 void InsertSliceOp::getAsmResultNames(
2647     function_ref<void(Value, StringRef)> setNameFn) {
2648   setNameFn(getResult(), "inserted_slice");
2649 }
2650 
2651 // Build a InsertSliceOp with mixed static and dynamic entries.
2652 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2653                           Value dest, ArrayRef<OpFoldResult> offsets,
2654                           ArrayRef<OpFoldResult> sizes,
2655                           ArrayRef<OpFoldResult> strides,
2656                           ArrayRef<NamedAttribute> attrs) {
2657   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2658   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2659   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2660   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2661   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2662   result.addAttributes(attrs);
2663   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2664         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2665         b.getDenseI64ArrayAttr(staticSizes),
2666         b.getDenseI64ArrayAttr(staticStrides));
2667 }
2668 
2669 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2670 /// Range vector.
2671 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2672                           Value dest, ArrayRef<Range> ranges,
2673                           ArrayRef<NamedAttribute> attrs) {
2674   auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2675   build(b, result, source, dest, offsets, sizes, strides, attrs);
2676 }
2677 
2678 // Build a InsertSliceOp with dynamic entries.
2679 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2680                           Value dest, ValueRange offsets, ValueRange sizes,
2681                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2682   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2683       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2684   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2685       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2686   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2687       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2688   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2689 }
2690 
2691 /// Rank-reducing type verification for both InsertSliceOp and
2692 /// ParallelInsertSliceOp.
2693 static SliceVerificationResult verifyInsertSliceOp(
2694     RankedTensorType srcType, RankedTensorType dstType,
2695     ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2696     ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2697   // insert_slice is the inverse of extract_slice, use the same type
2698   // inference.
2699   RankedTensorType expected = ExtractSliceOp::inferResultType(
2700       dstType, staticOffsets, staticSizes, staticStrides);
2701   if (expectedType)
2702     *expectedType = expected;
2703   return isRankReducedType(expected, srcType);
2704 }
2705 
2706 /// Verifier for InsertSliceOp.
2707 LogicalResult InsertSliceOp::verify() {
2708   RankedTensorType expectedType;
2709   SliceVerificationResult result =
2710       verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2711                           getStaticSizes(), getStaticStrides(), &expectedType);
2712   return produceSliceErrorMsg(result, *this, expectedType);
2713 }
2714 
2715 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2716 /// can mutate the second InsertSliceOp's destination to the first one's.
2717 ///
2718 /// Example:
2719 ///
2720 /// ```mlir
2721 ///   %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2722 ///   %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2723 /// ```
2724 ///
2725 /// folds into:
2726 ///
2727 /// ```mlir
2728 ///   %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2729 /// ```
2730 ///
2731 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2732 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2733   auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2734 
2735   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2736   if (!prevInsertOp ||
2737       prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2738       !prevInsertOp.isSameAs(insertOp, isSame))
2739     return failure();
2740 
2741   insertOp.getDestMutable().assign(prevInsertOp.getDest());
2742   return success();
2743 }
2744 
2745 /// Folds round-trip extract/insert slice op pairs.
2746 /// Example:
2747 /// ```mlir
2748 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2749 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2750 /// ```
2751 /// can be folded into %val.
2752 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2753   auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2754 
2755   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2756   if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2757       !extractOp.isSameAs(insertOp, isSame))
2758     return nullptr;
2759 
2760   return extractOp.getSource();
2761 }
2762 
2763 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2764   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2765       getSourceType() == getType() &&
2766       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
2767     return this->getSource();
2768   if (succeeded(foldInsertAfterInsertSlice(*this)))
2769     return getResult();
2770   if (auto result = foldInsertAfterExtractSlice(*this))
2771     return result;
2772   if (llvm::any_of(getMixedSizes(),
2773                    [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2774     return getDest();
2775   return OpFoldResult();
2776 }
2777 
2778 LogicalResult InsertSliceOp::reifyResultShapes(
2779     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2780   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2781   reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2782   return success();
2783 }
2784 
2785 namespace {
2786 /// Pattern to rewrite a insert_slice op with constant arguments.
2787 ///
2788 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2789 template <typename InsertOpTy>
2790 class InsertSliceOpConstantArgumentFolder final
2791     : public OpRewritePattern<InsertOpTy> {
2792 public:
2793   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2794 
2795   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2796                                 PatternRewriter &rewriter) const override {
2797     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2798     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2799     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2800 
2801     // No constant operands were folded, just return;
2802     if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2803         failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2804         failed(foldDynamicStrideList(mixedStrides)))
2805       return failure();
2806 
2807     // Create the new op in canonical form.
2808     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2809         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2810         mixedOffsets, mixedSizes, mixedStrides);
2811     Value toInsert = insertSliceOp.getSource();
2812     if (sourceType != insertSliceOp.getSourceType()) {
2813       OpBuilder::InsertionGuard g(rewriter);
2814       // The only difference between InsertSliceOp and ParallelInsertSliceOp
2815       // is that the insertion point is just before the ParallelCombiningOp in
2816       // the parallel case.
2817       if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2818         rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2819       toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2820                                                  sourceType, toInsert);
2821     }
2822     rewriter.replaceOpWithNewOp<InsertOpTy>(
2823         insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2824         mixedSizes, mixedStrides);
2825     return success();
2826   }
2827 };
2828 
2829 /// Fold tensor_casts with insert_slice operations. If the source or
2830 /// destination tensor is a tensor_cast that removes static type information,
2831 /// the cast is folded into the insert_slice operation. E.g.:
2832 ///
2833 /// ```mlir
2834 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2835 ///   %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2836 /// ```
2837 ///
2838 /// folds into:
2839 ///
2840 /// ```mlir
2841 ///   %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2842 /// ```
2843 ///
2844 /// Note: When folding a cast on the destination tensor, the result of the
2845 /// insert_slice operation is casted to ensure that the type of the result did
2846 /// not change.
2847 ///
2848 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2849 template <typename InsertOpTy>
2850 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2851   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2852 
2853   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2854                                 PatternRewriter &rewriter) const override {
2855     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2856           return matchPattern(operand, matchConstantIndex());
2857         }))
2858       return failure();
2859 
2860     auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2861       auto castOp = v.getDefiningOp<tensor::CastOp>();
2862       if (!castOp || !canFoldIntoConsumerOp(castOp))
2863         return std::nullopt;
2864       return castOp.getSource();
2865     };
2866     std::optional<Value> sourceCastSource =
2867         getSourceOfCastOp(insertSliceOp.getSource());
2868     std::optional<Value> destCastSource =
2869         getSourceOfCastOp(insertSliceOp.getDest());
2870     if (!sourceCastSource && !destCastSource)
2871       return failure();
2872 
2873     auto src =
2874         (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2875     auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2876     auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2877     auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2878     if (!srcType || !dstType)
2879       return failure();
2880 
2881     // The tensor.cast source could have additional static information not seen
2882     // in the insert slice op static sizes, so we ignore dynamic dims when
2883     // computing the rank reduction mask.
2884     SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2885     auto rankReductionMask = computeRankReductionMask(
2886         staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2887     if (!rankReductionMask.has_value())
2888       return failure();
2889     // Replace dimensions in the insert slice op with corresponding static dims
2890     // from the cast source type. If the insert slice sizes have static dims
2891     // that are not static in the tensor.cast source (i.e., when the cast op
2892     // casts a dynamic dim to static), the dim should not be replaced, and the
2893     // pattern will fail later in `verifyInsertSliceOp`.
2894     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2895     int64_t rankReducedIdx = 0;
2896     for (auto [idx, size] : enumerate(staticSizes)) {
2897       if (!rankReductionMask.value().contains(idx) &&
2898           !srcType.isDynamicDim(rankReducedIdx)) {
2899         mixedSizes[idx] = getAsIndexOpFoldResult(
2900             rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2901         size = srcType.getDimSize(rankReducedIdx++);
2902       }
2903     }
2904     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2905                             staticSizes, insertSliceOp.getStaticStrides()) !=
2906         SliceVerificationResult::Success)
2907       return failure();
2908 
2909     Operation *replacement = rewriter.create<InsertOpTy>(
2910         insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2911         mixedSizes, insertSliceOp.getMixedStrides());
2912 
2913     // In the parallel case there is no result and so nothing to cast.
2914     bool isParallelInsert =
2915         std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2916     if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2917       replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2918                                                     insertSliceOp.getDestType(),
2919                                                     replacement->getResult(0));
2920     }
2921     rewriter.replaceOp(insertSliceOp, replacement->getResults());
2922     return success();
2923   }
2924 };
2925 
2926 /// If additional static type information can be deduced from a insert_slice's
2927 /// size operands, insert an explicit cast of the op's source operand. This
2928 /// enables other canonicalization patterns that are matching for tensor_cast
2929 /// ops such as `ForOpTensorCastFolder` in SCF.
2930 ///
2931 /// Example:
2932 ///
2933 /// ```mlir
2934 ///   %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2935 ///       : tensor<?x?xf32> into ...
2936 /// ```
2937 ///
2938 /// folds into:
2939 ///
2940 /// ```mlir
2941 ///   %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2942 ///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2943 ///       : tensor<64x64xf32> into ...
2944 /// ```
2945 ///
2946 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2947 template <typename InsertOpTy>
2948 struct InsertSliceOpSourceCastInserter final
2949     : public OpRewritePattern<InsertOpTy> {
2950   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2951 
2952   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2953                                 PatternRewriter &rewriter) const override {
2954     RankedTensorType srcType = insertSliceOp.getSourceType();
2955     if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2956       return failure();
2957     SmallVector<int64_t> newSrcShape(srcType.getShape());
2958     for (int64_t i = 0; i < srcType.getRank(); ++i) {
2959       if (std::optional<int64_t> constInt =
2960               getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2961         // Bail on invalid IR.
2962         if (*constInt < 0)
2963           return failure();
2964         newSrcShape[i] = *constInt;
2965       }
2966     }
2967     if (!hasValidSizesOffsets(newSrcShape))
2968       return failure();
2969 
2970     RankedTensorType newSrcType = RankedTensorType::get(
2971         newSrcShape, srcType.getElementType(), srcType.getEncoding());
2972     if (srcType == newSrcType ||
2973         !preservesStaticInformation(srcType, newSrcType) ||
2974         !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2975       return failure();
2976 
2977     // newSrcType is:
2978     //   1) Different from srcType.
2979     //   2) "More static" than srcType.
2980     //   3) Cast-compatible with srcType.
2981     // Insert the cast.
2982     OpBuilder::InsertionGuard g(rewriter);
2983     // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2984     // that the insertion point is just before the ParallelCombiningOp in the
2985     // parallel case.
2986     if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2987       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2988     Value cast = rewriter.create<tensor::CastOp>(
2989         insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2990     rewriter.replaceOpWithNewOp<InsertOpTy>(
2991         insertSliceOp, cast, insertSliceOp.getDest(),
2992         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2993         insertSliceOp.getMixedStrides());
2994     return success();
2995   }
2996 };
2997 } // namespace
2998 
2999 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3000   return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3001 }
3002 
3003 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3004                                                 MLIRContext *context) {
3005   results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3006               InsertSliceOpCastFolder<InsertSliceOp>,
3007               InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3008 }
3009 
3010 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
3011                                                              Location loc,
3012                                                              Value tensor,
3013                                                              Value dest) {
3014   auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3015   unsigned rank = rankedTensorType.getRank();
3016   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3017   SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3018   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3019   return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3020                                                sizes, strides);
3021 }
3022 
3023 //===----------------------------------------------------------------------===//
3024 // PadOp
3025 //===----------------------------------------------------------------------===//
3026 
3027 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3028   setNameFn(getResult(), "padded");
3029 }
3030 
3031 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3032 // supports optional types.
3033 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
3034                     Type typeToInfer, Type typeToInferFrom) {}
3035 
3036 ParseResult
3037 parseInferType(OpAsmParser &parser,
3038                std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3039                Type &typeToInfer, Type typeToInferFrom) {
3040   if (optOperand)
3041     typeToInfer = typeToInferFrom;
3042   return success();
3043 }
3044 
3045 LogicalResult PadOp::verify() {
3046   auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3047   auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3048   auto expectedType =
3049       PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3050   if (!expectedType) {
3051     return emitError("failed to infer expectedType from sourceType ")
3052            << sourceType << ", specified resultType is " << resultType;
3053   }
3054   if (resultType.getRank() != expectedType.getRank()) {
3055     return emitError("specified type ")
3056            << resultType << " does not match the inferred type "
3057            << expectedType;
3058   }
3059   for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3060     if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3061       continue;
3062     if (expectedType.isDynamicDim(i))
3063       continue;
3064     return emitError("specified type ")
3065            << resultType << " does not match the inferred type "
3066            << expectedType;
3067   }
3068 
3069   return success();
3070 }
3071 
3072 LogicalResult PadOp::verifyRegions() {
3073   auto &region = getRegion();
3074   unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3075   Block &block = region.front();
3076   if (block.getNumArguments() != rank)
3077     return emitError("expected the block to have ") << rank << " arguments";
3078 
3079   // Note: the number and type of yield values are checked in the YieldOp.
3080   for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3081     if (!en.value().isIndex())
3082       return emitOpError("expected block argument ")
3083              << (en.index() + 1) << " to be an index";
3084   }
3085 
3086   // Ensure that the region yields an element of the right type.
3087   auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3088   if (yieldOp.getValue().getType() !=
3089       llvm::cast<ShapedType>(getType()).getElementType())
3090     return emitOpError("expected yield type to match shape element type");
3091 
3092   return success();
3093 }
3094 
3095 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3096                                         ArrayRef<int64_t> staticLow,
3097                                         ArrayRef<int64_t> staticHigh,
3098                                         ArrayRef<int64_t> resultShape) {
3099   unsigned rank = sourceType.getRank();
3100   if (staticLow.size() != rank)
3101     return RankedTensorType();
3102   if (staticHigh.size() != rank)
3103     return RankedTensorType();
3104   if (!resultShape.empty() && resultShape.size() != rank)
3105     return RankedTensorType();
3106 
3107   SmallVector<int64_t, 4> inferredShape;
3108   for (auto i : llvm::seq<unsigned>(0, rank)) {
3109     if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3110         staticHigh[i] == ShapedType::kDynamic) {
3111       inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3112                                                   : resultShape[i]);
3113     } else {
3114       int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3115       assert((resultShape.empty() || size == resultShape[i] ||
3116               resultShape[i] == ShapedType::kDynamic) &&
3117              "mismatch between inferred shape and result shape");
3118       inferredShape.push_back(size);
3119     }
3120   }
3121 
3122   return RankedTensorType::get(inferredShape, sourceType.getElementType());
3123 }
3124 
3125 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3126                   Value source, ArrayRef<int64_t> staticLow,
3127                   ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3128                   bool nofold, ArrayRef<NamedAttribute> attrs) {
3129   auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3130   if (!resultType)
3131     resultType = inferResultType(sourceType, staticLow, staticHigh);
3132   result.addAttributes(attrs);
3133   build(b, result, resultType, source, low, high,
3134         b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3135         nofold ? b.getUnitAttr() : UnitAttr());
3136 }
3137 
3138 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3139                   Value source, ValueRange low, ValueRange high, bool nofold,
3140                   ArrayRef<NamedAttribute> attrs) {
3141   auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3142   unsigned rank = sourceType.getRank();
3143   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3144   build(b, result, resultType, source, staticVector, staticVector, low, high,
3145         nofold, attrs);
3146 }
3147 
3148 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3149                   Value source, ArrayRef<OpFoldResult> low,
3150                   ArrayRef<OpFoldResult> high, bool nofold,
3151                   ArrayRef<NamedAttribute> attrs) {
3152   auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3153   SmallVector<Value, 4> dynamicLow, dynamicHigh;
3154   SmallVector<int64_t, 4> staticLow, staticHigh;
3155   // staticLow and staticHigh have full information of the padding config.
3156   // This will grow staticLow and staticHigh with 1 value. If the config is
3157   // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3158   // value as well.
3159   dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3160   dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3161   if (!resultType) {
3162     resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3163   }
3164   assert(llvm::isa<RankedTensorType>(resultType));
3165   result.addAttributes(attrs);
3166   build(b, result, resultType, source, dynamicLow, dynamicHigh,
3167         b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3168         nofold ? b.getUnitAttr() : UnitAttr());
3169 }
3170 
3171 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3172                   Value source, ArrayRef<OpFoldResult> low,
3173                   ArrayRef<OpFoldResult> high, Value constantPadValue,
3174                   bool nofold, ArrayRef<NamedAttribute> attrs) {
3175   build(b, result, resultType, source, low, high, nofold, attrs);
3176 
3177   // Add a region and a block to yield the pad value.
3178   Region *region = result.regions[0].get();
3179   int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3180   SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3181   SmallVector<Location> blockArgLocs(sourceRank, result.location);
3182 
3183   // `builder.createBlock` changes the insertion point within the block. Create
3184   // a guard to reset the insertion point of the builder after it is destroyed.
3185   OpBuilder::InsertionGuard guard(b);
3186   b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3187   b.create<tensor::YieldOp>(result.location, constantPadValue);
3188 }
3189 
3190 llvm::SmallBitVector PadOp::getPaddedDims() {
3191   llvm::SmallBitVector paddedDims(getSourceType().getRank());
3192   auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3193     for (const auto &en : enumerate(paddingWidths))
3194       if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3195         paddedDims.set(en.index());
3196   };
3197   extractPaddedDims(getMixedLowPad());
3198   extractPaddedDims(getMixedHighPad());
3199   return paddedDims;
3200 }
3201 
3202 namespace {
3203 // Folds tensor.pad when padding is static zeros and the attribute
3204 // doesn't request otherwise.
3205 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3206   using OpRewritePattern<PadOp>::OpRewritePattern;
3207 
3208   LogicalResult matchAndRewrite(PadOp padTensorOp,
3209                                 PatternRewriter &rewriter) const override {
3210     if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3211       return failure();
3212     if (padTensorOp.getNofold())
3213       return failure();
3214     rewriter.replaceOpWithNewOp<tensor::CastOp>(
3215         padTensorOp, padTensorOp.getResult().getType(),
3216         padTensorOp.getSource());
3217     return success();
3218   }
3219 };
3220 
3221 // Fold CastOp into PadOp when adding static information.
3222 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3223   using OpRewritePattern<PadOp>::OpRewritePattern;
3224 
3225   LogicalResult matchAndRewrite(PadOp padTensorOp,
3226                                 PatternRewriter &rewriter) const override {
3227     auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3228     if (!tensor::canFoldIntoConsumerOp(castOp))
3229       return failure();
3230 
3231     auto newResultType = PadOp::inferResultType(
3232         llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3233         padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3234         padTensorOp.getResultType().getShape());
3235 
3236     if (newResultType == padTensorOp.getResultType()) {
3237       rewriter.modifyOpInPlace(padTensorOp, [&]() {
3238         padTensorOp.getSourceMutable().assign(castOp.getSource());
3239       });
3240     } else {
3241       auto newOp = rewriter.create<PadOp>(
3242           padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3243           padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3244           padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3245           getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3246       IRMapping mapper;
3247       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3248 
3249       rewriter.replaceOpWithNewOp<tensor::CastOp>(
3250           padTensorOp, padTensorOp.getResultType(), newOp);
3251     }
3252     return success();
3253   }
3254 };
3255 
3256 // Fold CastOp using the result of PadOp back into the latter if it adds
3257 // static information.
3258 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3259   using OpRewritePattern<PadOp>::OpRewritePattern;
3260 
3261   LogicalResult matchAndRewrite(PadOp padTensorOp,
3262                                 PatternRewriter &rewriter) const override {
3263     if (!padTensorOp.getResult().hasOneUse())
3264       return failure();
3265     auto tensorCastOp =
3266         dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3267     if (!tensorCastOp)
3268       return failure();
3269     if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3270                                             tensorCastOp.getDest().getType()))
3271       return failure();
3272 
3273     auto replacementOp = rewriter.create<PadOp>(
3274         padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3275         padTensorOp.getSource(), padTensorOp.getStaticLow(),
3276         padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3277         padTensorOp.getHigh(), padTensorOp.getNofold(),
3278         getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3279     replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3280 
3281     rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3282     rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3283     return success();
3284   }
3285 };
3286 
3287 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3288 /// different dimensions. The pattern applies if the following preconditions
3289 /// hold:
3290 ///   1) the tensor::ExtractSliceOps are not rank-reducing,
3291 ///   2) the tensor::ExtractSliceOps have only unit-strides,
3292 ///   3) the tensor::PadOps perform only high-padding,
3293 ///   4) the tensor::PadOps have the same constant padding value,
3294 ///   5) the tensor::PadOps do not have common padding dimensions,
3295 ///   6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3296 ///      zero-offset for every dimension.
3297 ///   7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3298 ///   the
3299 ///      padded source dimensions.
3300 ///
3301 /// Example:
3302 ///
3303 /// ```mlir
3304 ///   %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3305 ///       : tensor<64x64xf32> to tensor<?x64xf32>
3306 ///   %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3307 ///     } : tensor<?x64xf32> to tensor<8x64xf32>
3308 ///   %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3309 ///        : tensor<8x64xf32> to tensor<8x?xf32>
3310 ///   %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3311 ///     } : tensor<8x?xf32> to tensor<8x4xf32>
3312 /// ```
3313 ///
3314 /// folds into:
3315 ///
3316 /// ```mlir
3317 ///   %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3318 ///        : tensor<64x64xf32> to tensor<?x?xf32>
3319 ///   %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3320 ///     } : tensor<?x?xf32> to tensor<8x4xf32>
3321 /// ```
3322 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3323   using OpRewritePattern<PadOp>::OpRewritePattern;
3324 
3325   LogicalResult matchAndRewrite(PadOp padOp,
3326                                 PatternRewriter &rewriter) const override {
3327     auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3328     if (!innerSliceOp)
3329       return failure();
3330     auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3331     if (!outerPadOp || outerPadOp.getNofold())
3332       return failure();
3333     auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3334     if (!outerSliceOp)
3335       return failure();
3336 
3337     // 1) Fail if the chain is rank-reducing.
3338     int64_t rank = padOp.getSourceType().getRank();
3339     if (outerSliceOp.getSourceType().getRank() != rank) {
3340       return rewriter.notifyMatchFailure(padOp,
3341                                          "cannot fold rank-reducing chain");
3342     }
3343 
3344     // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3345     if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3346       return rewriter.notifyMatchFailure(
3347           padOp, "cannot fold non-unit stride ExtractSliceOps");
3348     }
3349 
3350     // 3) Fail if the tensor::PadOps have non-zero low padding.
3351     if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3352       return rewriter.notifyMatchFailure(padOp,
3353                                          "cannot fold PadOps with low padding");
3354     }
3355 
3356     // 4) Fail if the tensor::PadOps padding values do not match.
3357     Attribute innerAttr, outerAttr;
3358     Value innerValue = padOp.getConstantPaddingValue();
3359     Value outerValue = outerPadOp.getConstantPaddingValue();
3360     if (!innerValue || !outerValue ||
3361         !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3362         !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3363         innerAttr != outerAttr) {
3364       return rewriter.notifyMatchFailure(
3365           padOp, "cannot fold PadOps with different padding values");
3366     }
3367 
3368     // 5) Fail if a dimension is padded by both tensor::PadOps.
3369     llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3370     llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3371     if (innerDims.anyCommon(outerDims)) {
3372       return rewriter.notifyMatchFailure(
3373           padOp, "cannot fold PadOps with common padding dimensions");
3374     }
3375 
3376     // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3377     // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3378     // for every dimension, and use the offset the other pair. Fail if no
3379     // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3380     // exists.
3381     SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3382     for (auto en : enumerate(newOffsets)) {
3383       OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3384       OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3385       if (!innerDims.test(en.index()) &&
3386           (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3387         en.value() = outerOffset;
3388         continue;
3389       }
3390       if (!outerDims.test(en.index()) &&
3391           (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3392         en.value() = innerOffset;
3393         continue;
3394       }
3395       return rewriter.notifyMatchFailure(
3396           padOp, "cannot find zero-offset and zero-padding pair");
3397     }
3398 
3399     // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3400     // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3401     // outer tensor::PadOp and fail if the size of the inner
3402     // tensor::ExtractSliceOp does not match the size of the padded dimension.
3403     // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3404     SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3405     for (auto en : enumerate(newSizes)) {
3406       if (!outerDims.test(en.index()))
3407         continue;
3408       OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3409       int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3410       assert(!ShapedType::isDynamic(sourceSize) &&
3411              "expected padded dimension to have a static size");
3412       if (getConstantIntValue(sliceSize) != sourceSize) {
3413         return rewriter.notifyMatchFailure(
3414             padOp, "cannot fold since the inner ExtractSliceOp size does not "
3415                    "match the size of the outer padding");
3416       }
3417       en.value() = outerSliceOp.getMixedSizes()[en.index()];
3418     }
3419 
3420     // Combine the high paddings of the two tensor::PadOps.
3421     SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3422     for (auto en : enumerate(newHighPad)) {
3423       if (innerDims.test(en.index()))
3424         newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3425       if (outerDims.test(en.index()))
3426         newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3427     }
3428 
3429     // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3430     // the two paddings in one step.
3431     auto newSliceOp = rewriter.create<ExtractSliceOp>(
3432         padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3433         innerSliceOp.getMixedStrides());
3434     auto newPadOp = rewriter.create<PadOp>(
3435         padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3436         padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3437         getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3438     rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3439                                 newPadOp.getRegion().begin());
3440     rewriter.replaceOp(padOp, newPadOp.getResult());
3441     return success();
3442   }
3443 };
3444 
3445 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3446   using OpRewritePattern<PadOp>::OpRewritePattern;
3447 
3448   LogicalResult matchAndRewrite(PadOp padTensorOp,
3449                                 PatternRewriter &rewriter) const override {
3450     Value input = padTensorOp.getSource();
3451     if (!llvm::isa<RankedTensorType>(input.getType()))
3452       return failure();
3453     auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3454     auto inputRank = inputDims.size();
3455 
3456     auto oldResultType =
3457         dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3458     if (!oldResultType)
3459       return failure();
3460 
3461     auto outputDims = oldResultType.getShape();
3462 
3463     // Extract the static info from the high and low operands.
3464     SmallVector<int64_t> constOperandsLow;
3465     SmallVector<Value> newLows;
3466     for (auto operand : padTensorOp.getLow()) {
3467       APSInt intOp;
3468       if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3469         constOperandsLow.push_back(ShapedType::kDynamic);
3470         newLows.push_back(operand);
3471         continue;
3472       }
3473       constOperandsLow.push_back(intOp.getExtValue());
3474     }
3475     SmallVector<int64_t> constOperandsHigh;
3476     SmallVector<Value> newHighs;
3477     for (auto operand : padTensorOp.getHigh()) {
3478       APSInt intOp;
3479       if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3480         constOperandsHigh.push_back(ShapedType::kDynamic);
3481         newHighs.push_back(operand);
3482         continue;
3483       }
3484       constOperandsHigh.push_back(intOp.getExtValue());
3485     }
3486 
3487     SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3488     SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3489 
3490     // Verify the op is well-formed.
3491     if (inputDims.size() != outputDims.size() ||
3492         inputDims.size() != constLow.size() ||
3493         inputDims.size() != constHigh.size())
3494       return failure();
3495 
3496     auto lowCount = 0;
3497     auto highCount = 0;
3498     for (size_t i = 0; i < inputRank; i++) {
3499       if (constLow[i] == ShapedType::kDynamic)
3500         constLow[i] = constOperandsLow[lowCount++];
3501       if (constHigh[i] == ShapedType::kDynamic)
3502         constHigh[i] = constOperandsHigh[highCount++];
3503     }
3504 
3505     auto staticLow = ArrayRef<int64_t>(constLow);
3506     auto staticHigh = ArrayRef<int64_t>(constHigh);
3507 
3508     // Calculate the output sizes with the static information.
3509     SmallVector<int64_t> newOutDims;
3510     for (size_t i = 0; i < inputRank; i++) {
3511       if (outputDims[i] == ShapedType::kDynamic) {
3512         newOutDims.push_back(
3513             (staticLow[i] == ShapedType::kDynamic ||
3514                      staticHigh[i] == ShapedType::kDynamic ||
3515                      inputDims[i] == ShapedType::kDynamic
3516                  ? ShapedType::kDynamic
3517                  : inputDims[i] + staticLow[i] + staticHigh[i]));
3518       } else {
3519         newOutDims.push_back(outputDims[i]);
3520       }
3521     }
3522 
3523     if (SmallVector<int64_t>(outputDims) == newOutDims ||
3524         llvm::all_of(newOutDims,
3525                      [&](int64_t x) { return x == ShapedType::kDynamic; }))
3526       return failure();
3527 
3528     // Rewrite the op using the new static type.
3529     auto newResultType = RankedTensorType::get(
3530         newOutDims, padTensorOp.getType().getElementType());
3531     auto newOp = rewriter.create<PadOp>(
3532         padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3533         newLows, newHighs, padTensorOp.getNofold(),
3534         getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3535 
3536     IRMapping mapper;
3537     padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3538     rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3539                                                 newOp);
3540 
3541     return success();
3542   }
3543 };
3544 
3545 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3546 ///
3547 /// Example:
3548 ///
3549 /// ```mlir
3550 ///   %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3551 ///       tensor.yield %val
3552 ///     } : tensor<1x2xf32> to tensor<2x5xf32>
3553 ///   %res = tensor.pad %1 low[0, 2] high[3, 0] {
3554 ///       tensor.yield %val
3555 ///     } : tensor<1x5xf32> to tensor<5x7xf32>
3556 /// ```
3557 ///
3558 /// folds into:
3559 ///
3560 /// ```mlir
3561 ///   %res = tensor.pad %0 low[0, 3] high[3, 2] {
3562 ///       tensor.yield %val
3563 ///     } : tensor<1x2xf32> to tensor<5x7xf32>
3564 /// ```
3565 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3566   using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3567 
3568   LogicalResult matchAndRewrite(tensor::PadOp padOp,
3569                                 PatternRewriter &rewriter) const override {
3570     if (padOp.getNofold()) {
3571       return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3572     }
3573 
3574     auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3575     if (!producerPad || producerPad.getNofold()) {
3576       return rewriter.notifyMatchFailure(
3577           padOp, "producer is not a foldable tensor.pad op");
3578     }
3579 
3580     // Fail if the tensor::PadOps padding values do not match.
3581     Value consumerPadValue = padOp.getConstantPaddingValue();
3582     Value producerPadValue = producerPad.getConstantPaddingValue();
3583     if (!consumerPadValue || !producerPadValue ||
3584         consumerPadValue != producerPadValue) {
3585       return rewriter.notifyMatchFailure(
3586           padOp,
3587           "cannot fold PadOps with different or non-constant padding values");
3588     }
3589 
3590     Location loc = padOp.getLoc();
3591     AffineExpr d0, d1;
3592     bindDims(rewriter.getContext(), d0, d1);
3593 
3594     // Combine the low/high paddings of the two tensor::PadOps.
3595     auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3596                            ArrayRef<OpFoldResult> producerPaddings) {
3597       SmallVector<OpFoldResult> sumPaddings;
3598       for (auto [consumerIndex, producerIndex] :
3599            llvm::zip_equal(consumerPaddings, producerPaddings)) {
3600         sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3601             rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3602       }
3603       return sumPaddings;
3604     };
3605 
3606     SmallVector<OpFoldResult> newHighPad =
3607         addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3608     SmallVector<OpFoldResult> newLowPad =
3609         addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3610 
3611     auto newPadOp = rewriter.create<tensor::PadOp>(
3612         padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3613         newLowPad, newHighPad, padOp.getNofold(),
3614         getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3615     rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3616                                 newPadOp.getRegion().begin());
3617     rewriter.replaceOp(padOp, newPadOp.getResult());
3618     return success();
3619   }
3620 };
3621 
3622 } // namespace
3623 
3624 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3625                                         MLIRContext *context) {
3626   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3627               FoldOrthogonalPaddings, FoldStaticPadding,
3628               FoldConsecutiveConstantPadding>(context);
3629 }
3630 
3631 /// Return the padding value of the PadOp if it constant. In this context,
3632 /// "constant" means an actual constant or "defined outside of the block".
3633 ///
3634 /// Values are considered constant in three cases:
3635 ///  - A ConstantLike value.
3636 ///  - A basic block argument from a different block.
3637 ///  - A value defined outside of the block.
3638 ///
3639 /// If the padding value is not constant, an empty Value is returned.
3640 Value PadOp::getConstantPaddingValue() {
3641   auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3642   if (!yieldOp)
3643     return {};
3644   Value padValue = yieldOp.getValue();
3645   // Check if yield value is a constant.
3646   if (matchPattern(padValue, m_Constant()))
3647     return padValue;
3648   // Check if yield value is defined inside the PadOp block.
3649   if (padValue.getParentBlock() == &getRegion().front())
3650     return {};
3651   // Else: Yield value defined outside of the PadOp block.
3652   return padValue;
3653 }
3654 
3655 OpFoldResult PadOp::fold(FoldAdaptor) {
3656   if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3657       !getNofold())
3658     return getSource();
3659   return {};
3660 }
3661 
3662 //===----------------------------------------------------------------------===//
3663 // ParallelInsertSliceOp
3664 //===----------------------------------------------------------------------===//
3665 
3666 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3667   ParallelCombiningOpInterface parallelCombiningParent =
3668       getParallelCombiningParent();
3669   for (const auto &it :
3670        llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3671     Operation &nextOp = it.value();
3672     if (&nextOp == getOperation())
3673       return parallelCombiningParent.getParentResult(it.index());
3674   }
3675   llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3676 }
3677 
3678 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3679 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3680                                   Value source, Value dest,
3681                                   ArrayRef<OpFoldResult> offsets,
3682                                   ArrayRef<OpFoldResult> sizes,
3683                                   ArrayRef<OpFoldResult> strides,
3684                                   ArrayRef<NamedAttribute> attrs) {
3685   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3686   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3687   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3688   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3689   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3690   result.addAttributes(attrs);
3691   build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3692         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3693         b.getDenseI64ArrayAttr(staticSizes),
3694         b.getDenseI64ArrayAttr(staticStrides));
3695 }
3696 
3697 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3698 /// packed into a Range vector.
3699 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3700                                   Value source, Value dest,
3701                                   ArrayRef<Range> ranges,
3702                                   ArrayRef<NamedAttribute> attrs) {
3703   auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3704   build(b, result, source, dest, offsets, sizes, strides, attrs);
3705 }
3706 
3707 // Build a ParallelInsertSliceOp with dynamic entries.
3708 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3709                                   Value source, Value dest, ValueRange offsets,
3710                                   ValueRange sizes, ValueRange strides,
3711                                   ArrayRef<NamedAttribute> attrs) {
3712   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3713       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3714   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3715       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3716   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3717       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3718   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3719 }
3720 
3721 LogicalResult ParallelInsertSliceOp::verify() {
3722   if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3723     return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3724            << *(getOperation()->getParentOp());
3725 
3726   RankedTensorType expectedType;
3727   SliceVerificationResult result =
3728       verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3729                           getStaticSizes(), getStaticStrides(), &expectedType);
3730   return produceSliceErrorMsg(result, *this, expectedType);
3731 }
3732 
3733 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3734     RewritePatternSet &results, MLIRContext *context) {
3735   results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3736               InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3737               InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3738 }
3739 
3740 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3741   return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3742 }
3743 
3744 //===----------------------------------------------------------------------===//
3745 // ScatterOp
3746 //===----------------------------------------------------------------------===//
3747 
3748 void ScatterOp::getAsmResultNames(
3749     function_ref<void(Value, StringRef)> setNameFn) {
3750   setNameFn(getResult(), "scatter");
3751 }
3752 
3753 LogicalResult ScatterOp::verify() {
3754   int64_t destRank = getDestType().getRank();
3755   ArrayRef<int64_t> scatterDims = getScatterDims();
3756   if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3757                                        getIndicesType().getShape(), destRank,
3758                                        "scatter", "dest")))
3759     return failure();
3760 
3761   if (!getUnique())
3762     return emitOpError("requires 'unique' attribute to be set");
3763   // TODO: we could also check statically that there are fewer leading index
3764   // tensor dims than the dest dims. If this is not the case, the unique
3765   // attribute cannot be true.
3766 
3767   // Use the GatherOp::inferResultType on the `dest` type and verify the
3768   // expected type matches the source type.
3769   RankedTensorType expectedSourceType = GatherOp::inferResultType(
3770       getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3771   RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3772       getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3773   if (getSourceType() != expectedSourceType &&
3774       getSourceType() != expectedRankReducedSourceType) {
3775     return emitOpError("source type "
3776                        "mismatch: "
3777                        "expected ")
3778            << expectedSourceType << " or its rank-reduced variant "
3779            << expectedRankReducedSourceType << " (got: " << getSourceType()
3780            << ")";
3781   }
3782 
3783   return success();
3784 }
3785 
3786 //===----------------------------------------------------------------------===//
3787 // SplatOp
3788 //===----------------------------------------------------------------------===//
3789 
3790 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3791                     Type aggregateType, ValueRange dynamicSizes) {
3792   build(builder, result, aggregateType, element, dynamicSizes);
3793 }
3794 
3795 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3796                     ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3797   auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3798   build(builder, result, aggregateType, element, dynamicSizes);
3799 }
3800 
3801 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3802                     ArrayRef<OpFoldResult> sizes) {
3803   SmallVector<int64_t> staticShape;
3804   SmallVector<Value> dynamicSizes;
3805   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3806   build(builder, result, element, staticShape, dynamicSizes);
3807 }
3808 
3809 void SplatOp::getAsmResultNames(
3810     function_ref<void(Value, StringRef)> setNameFn) {
3811   setNameFn(getResult(), "splat");
3812 }
3813 
3814 LogicalResult SplatOp::verify() {
3815   if (getType().getNumDynamicDims() != getDynamicSizes().size())
3816     return emitOpError("incorrect number of dynamic sizes, has ")
3817            << getDynamicSizes().size() << ", expected "
3818            << getType().getNumDynamicDims();
3819   return success();
3820 }
3821 
3822 LogicalResult
3823 SplatOp::reifyResultShapes(OpBuilder &builder,
3824                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3825   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3826   unsigned ctr = 0;
3827   for (int64_t i = 0; i < getType().getRank(); ++i) {
3828     if (getType().isDynamicDim(i)) {
3829       reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3830     } else {
3831       reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3832     }
3833   }
3834   return success();
3835 }
3836 
3837 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3838   auto constOperand = adaptor.getInput();
3839   if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3840     return {};
3841 
3842   // Do not fold if the splat is not statically shaped
3843   if (!getType().hasStaticShape())
3844     return {};
3845 
3846   // SplatElementsAttr::get treats single value for second arg as being a
3847   // splat.
3848   return SplatElementsAttr::get(getType(), {constOperand});
3849 }
3850 
3851 //===----------------------------------------------------------------------===//
3852 // PackOp/UnPackOp Common
3853 //===----------------------------------------------------------------------===//
3854 
3855 template <typename OpTy>
3856 static LogicalResult
3857 reifyResultShapesImpl(OpTy op, OpBuilder &builder,
3858                       ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3859   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3860                 "applies to only pack or unpack operations");
3861   int64_t destRank = op.getDestRank();
3862   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3863   reifiedReturnShapes[0] =
3864       tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3865   return success();
3866 }
3867 
3868 template <typename OpTy>
3869 static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) {
3870   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3871                 "applies to only pack or unpack operations");
3872   DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3873   ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3874   SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3875   assert(tiles.size() == dimsToTile.size() &&
3876          "tiles must match indices of dimension to block");
3877   // bind the dimension `i` with the tile factor.
3878   for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3879     dimAndTileMapping[dimsToTile[i]] = tiles[i];
3880   return dimAndTileMapping;
3881 }
3882 
3883 template <typename OpTy>
3884 static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
3885   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3886                 "applies to only pack or unpack operations");
3887   Builder builder(op);
3888   SmallVector<OpFoldResult> mixedInnerTiles;
3889   unsigned dynamicValIndex = 0;
3890   for (int64_t staticTile : op.getStaticInnerTiles()) {
3891     if (!ShapedType::isDynamic(staticTile))
3892       mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3893     else
3894       mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3895   }
3896   return mixedInnerTiles;
3897 }
3898 
3899 template <typename OpTy>
3900 static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
3901   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3902                 "applies to only pack or unpack operations");
3903   SmallVector<Value> dynamicTiles;
3904   SmallVector<int64_t> staticTiles;
3905   dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3906   return staticTiles;
3907 }
3908 
3909 /// Returns true if `dimsPos` is invalid. It is invalid when:
3910 /// a) It contains duplicate.
3911 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3912 /// c) The number of elements in `dimsPos` is > than `rank`.
3913 static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
3914                                              size_t rank) {
3915   size_t dimsPosSize = dimsPos.size();
3916   if (dimsPosSize > rank)
3917     return true;
3918   DenseSet<int64_t> uniqued;
3919   for (int64_t dim : dimsPos)
3920     uniqued.insert(dim);
3921   if (dimsPosSize != uniqued.size())
3922     return true;
3923   return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3924     return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3925   });
3926 }
3927 
3928 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3929 /// of the `limitShape`.
3930 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3931                           ArrayRef<int64_t> limitShape) {
3932   assert(
3933       sourceShape.size() == limitShape.size() &&
3934       "expected source shape rank, and limit of the shape to have same rank");
3935   return llvm::all_of(
3936       llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3937         int64_t sourceExtent = std::get<0>(it);
3938         int64_t limit = std::get<1>(it);
3939         return ShapedType::isDynamic(sourceExtent) ||
3940                ShapedType::isDynamic(limit) || sourceExtent <= limit;
3941       });
3942 }
3943 
3944 template <typename OpTy>
3945 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
3946   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3947                 "applies to only pack or unpack operations");
3948   Operation *op = packOrUnPack.getOperation();
3949 
3950   // Return true if we have a zero-value tile.
3951   auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3952     return llvm::any_of(
3953         tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3954   };
3955 
3956   // Verify tiles. Do not allow zero tiles.
3957   SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3958   if (hasZeros(mixedTiles))
3959     return op->emitError("invalid zero tile factor");
3960 
3961   // Verify inner_dims_pos and outer_dims_perm.
3962   RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3963                                       ? packOrUnPack.getSourceType()
3964                                       : packOrUnPack.getDestType();
3965   size_t unpackedRank = unpackedType.getRank();
3966   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3967   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3968   if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3969     return op->emitError("invalid inner_dims_pos vector");
3970   if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3971     return op->emitError("invalid outer_dims_perm vector");
3972   if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3973     return op->emitError("outer_dims_perm must be a permutation or empty");
3974 
3975   // Tiling factors must be less than or equal to the input rank for pack (or
3976   // output rank for unpack), and must match the number of `inner_dims_pos`.
3977   if (mixedTiles.size() > unpackedRank) {
3978     return op->emitError("tiling factors must be less than or equal to the "
3979                          "input rank for pack or output rank for unpack");
3980   }
3981   if (mixedTiles.size() != innerDimsPos.size()) {
3982     return op->emitError(
3983         "tiling factors must equal the number of dimensions to tile");
3984   }
3985 
3986   ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3987                               ? packOrUnPack.getDestType()
3988                               : packOrUnPack.getSourceType();
3989   size_t packedRank = packedType.getRank();
3990   // Require output rank to match input rank + number of blocking factors.
3991   size_t expectedPackedRank = unpackedRank + mixedTiles.size();
3992   if (expectedPackedRank != packedRank) {
3993     return op->emitError(
3994                "packed rank != (unpacked rank + num tiling factors), got ")
3995            << packedRank << " != " << expectedPackedRank;
3996   }
3997 
3998   // Verify result shape is greater than the minimum expected
3999   // by the pack operation, and that the output shape
4000   // represents full tiles.
4001   RankedTensorType expectedPackedType = PackOp::inferPackedType(
4002       unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4003   if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4004     return op->emitError("the shape of output is not large enough to hold the "
4005                          "packed data. Expected at least ")
4006            << expectedPackedType << ", got " << packedType;
4007   }
4008   if (!llvm::all_of(
4009           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4010                     mixedTiles),
4011           [](std::tuple<int64_t, OpFoldResult> it) {
4012             int64_t shape = std::get<0>(it);
4013             if (Attribute attr =
4014                     llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4015               IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4016               int64_t staticTileSize = intAttr.getValue().getSExtValue();
4017               return shape == staticTileSize;
4018             }
4019             return ShapedType::isDynamic(shape);
4020           })) {
4021     return op->emitError("mismatch in inner tile sizes specified and shaped of "
4022                          "tiled dimension in the packed type");
4023   }
4024   return success();
4025 }
4026 
4027 namespace {
4028 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4029 /// various permutations to the op.
4030 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4031 // these. These may or may not become true foldings / canonicalizations
4032 // depending on how aggressive we want to be in automatically folding
4033 // transposes.
4034 struct PackOrUnPackTransposeResult {
4035   SmallVector<int64_t> innerDimsPos;
4036   SmallVector<OpFoldResult> innerTiles;
4037   SmallVector<int64_t> outerDimsPerm;
4038 };
4039 } // namespace
4040 
4041 template <typename OpTy>
4042 static PackOrUnPackTransposeResult
4043 commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
4044                                    ArrayRef<int64_t> innerPermutation,
4045                                    ArrayRef<int64_t> outerPermutation) {
4046   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4047                 "applies to only pack or unpack operations");
4048   assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4049          "some permutation must be non-empty");
4050   PackOrUnPackTransposeResult metadata;
4051   metadata.innerDimsPos =
4052       SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4053   metadata.innerTiles =
4054       SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4055   int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4056                              ? packOrUnPackOp.getSourceRank()
4057                              : packOrUnPackOp.getDestRank();
4058   metadata.outerDimsPerm =
4059       packOrUnPackOp.getOuterDimsPerm().empty()
4060           ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4061           : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4062   if (!innerPermutation.empty()) {
4063     assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4064            isPermutationVector(innerPermutation) &&
4065            "invalid inner permutation");
4066     applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4067     applyPermutationToVector(metadata.innerTiles, innerPermutation);
4068   }
4069   if (!outerPermutation.empty()) {
4070     assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4071            isPermutationVector(outerPermutation) &&
4072            "invalid outer permutation");
4073     applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4074   }
4075   return metadata;
4076 }
4077 
4078 //===----------------------------------------------------------------------===//
4079 // PackOp
4080 //===----------------------------------------------------------------------===//
4081 
4082 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4083   setNameFn(getResult(), "pack");
4084 }
4085 
4086 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4087                    Value dest, ArrayRef<int64_t> innerDimsPos,
4088                    ArrayRef<OpFoldResult> innerTiles,
4089                    std::optional<Value> paddingValue,
4090                    ArrayRef<int64_t> outerDimsPerm) {
4091   assert(innerDimsPos.size() == innerTiles.size() &&
4092          "number of tile sizes specified must match the specified number of "
4093          "original dimensions to be tiled");
4094   SmallVector<int64_t> staticTileSizes;
4095   SmallVector<Value> dynamicTileSizes;
4096   dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4097   build(builder, state, dest.getType(), source, dest,
4098         paddingValue ? *paddingValue : nullptr,
4099         outerDimsPerm.empty() ? nullptr
4100                               : builder.getDenseI64ArrayAttr(outerDimsPerm),
4101         builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4102         builder.getDenseI64ArrayAttr(staticTileSizes));
4103 }
4104 
4105 LogicalResult
4106 PackOp::reifyResultShapes(OpBuilder &builder,
4107                           ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4108   return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4109 }
4110 
4111 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4112   return getDimAndTileMappingImpl(*this);
4113 }
4114 
4115 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4116   return getMixedTilesImpl(*this);
4117 }
4118 
4119 SmallVector<int64_t> PackOp::getStaticTiles() {
4120   return getStaticTilesImpl(*this);
4121 }
4122 
4123 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4124   ShapedType inputType = getSourceType();
4125   int64_t inputRank = inputType.getRank();
4126   return getDestType().getShape().take_front(inputRank);
4127 }
4128 
4129 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4130   auto innerDimsPos = getInnerDimsPos();
4131   auto packedShape = getDestType().getShape();
4132   SmallVector<int64_t> res;
4133 
4134   for (auto index : innerDimsPos)
4135     res.push_back(packedShape[index]);
4136 
4137   return res;
4138 }
4139 
4140 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4141                                  ArrayRef<int64_t> innerDimsPos,
4142                                  ArrayRef<int64_t> outputShape,
4143                                  ArrayRef<int64_t> outerDimsPerm,
4144                                  ArrayRef<OpFoldResult> innerTiles) {
4145   SmallVector<int64_t> outputTileSizes(
4146       outputShape.take_front(inputShape.size()));
4147   if (!outerDimsPerm.empty()) {
4148     assert(outerDimsPerm.size() == outputTileSizes.size() &&
4149            "expected output and outer_dims_perm to have same size");
4150     applyPermutationToVector(outputTileSizes,
4151                              invertPermutationVector(outerDimsPerm));
4152   }
4153   for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4154     if (ShapedType::isDynamic(inputShape[pos]))
4155       continue;
4156     std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4157 
4158     if (!constantTile) {
4159       if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4160           (inputShape[pos] % outputTileSizes[pos] != 0))
4161         return true;
4162     } else if (inputShape[pos] % (*constantTile) != 0) {
4163       return true;
4164     }
4165   }
4166   return false;
4167 }
4168 
4169 LogicalResult PackOp::verify() {
4170   if (failed(commonVerifierPackAndUnPackOp(*this)))
4171     return failure();
4172 
4173   // Verify padding value, and bail out if the tile does not divide the
4174   // dimension fully. In the case of dynamic tile factors or dimensions, having
4175   // a partial tile is undefined behavior.
4176   auto paddingValue = getPaddingValue();
4177   if (paddingValue &&
4178       paddingValue.getType() != getSourceType().getElementType()) {
4179     return emitOpError("expected padding_value has ")
4180            << getSourceType().getElementType()
4181            << " but got: " << paddingValue.getType();
4182   }
4183 
4184   if (!paddingValue &&
4185       requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4186                           getDestType().getShape(), getOuterDimsPerm(),
4187                           getMixedTiles())) {
4188     return emitOpError(
4189         "invalid tile factor or output size provided. Only full tiles are "
4190         "supported when padding_value is not set");
4191   }
4192   return success();
4193 }
4194 
4195 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4196 /// Value's to kDynamic, even if they are arith.constant values.
4197 static SmallVector<int64_t>
4198 asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
4199   SmallVector<int64_t> result;
4200   for (auto o : ofrs) {
4201     // Have to do this first, as getConstantIntValue special-cases constants.
4202     if (llvm::dyn_cast_if_present<Value>(o))
4203       result.push_back(ShapedType::kDynamic);
4204     else
4205       result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4206   }
4207   return result;
4208 }
4209 
4210 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4211 /// the packed type. Having a shared helper helps implement these two methods in
4212 /// a way that ensures that they agree on which dimensions are dynamic.
4213 static SmallVector<int64_t> getPackOpResultTypeShape(
4214     ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4215     ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4216   SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4217   for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4218     if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4219       continue;
4220     if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4221       resultShape[tiledDim.value()] = ShapedType::kDynamic;
4222       continue;
4223     }
4224     resultShape[tiledDim.value()] = divideCeilSigned(
4225         resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4226   }
4227 
4228   // Swap tile loops if outer_dims_perm is available.
4229   if (!outerDimsPerm.empty())
4230     applyPermutationToVector(resultShape, outerDimsPerm);
4231 
4232   // Append the inner tile dimensions.
4233   resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4234   return resultShape;
4235 }
4236 
4237 SmallVector<OpFoldResult> PackOp::getResultShape(
4238     OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4239     ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
4240     ArrayRef<int64_t> outerDimsPerm) {
4241   SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4242 
4243   AffineExpr s0, s1;
4244   bindSymbols(builder.getContext(), s0, s1);
4245   AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4246   for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4247     resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4248         builder, loc, ceilDivExpr,
4249         {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4250   }
4251   if (!outerDimsPerm.empty())
4252     applyPermutationToVector(resultDims, outerDimsPerm);
4253   resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4254 
4255   SmallVector<int64_t> resultTypeShape =
4256       getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
4257                                asShapeWithAnyValueAsDynamic(innerTileSizes),
4258                                innerDimsPos, outerDimsPerm);
4259 
4260   // Fix-up `resultDims` to ensure that they are Value's if and only if the
4261   // result type shape says it's a dynamic dim. This is needed as callers may
4262   // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4263   // dynamic dims returned by that.
4264   for (unsigned i = 0; i < resultDims.size(); ++i) {
4265     if (!ShapedType::isDynamic(resultTypeShape[i]))
4266       continue;
4267     resultDims[i] =
4268         getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4269   }
4270 
4271   return resultDims;
4272 }
4273 
4274 /// Get the expected packed type based on source type, tile factors, position of
4275 /// the inner tiles and permutation of the outer tiled loop.
4276 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4277                                          ArrayRef<int64_t> innerTileSizes,
4278                                          ArrayRef<int64_t> innerDimsPos,
4279                                          ArrayRef<int64_t> outerDimsPerm) {
4280   SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
4281       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4282   return RankedTensorType::get(resultShape, sourceType.getElementType());
4283 }
4284 
4285 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4286                                       ArrayRef<OpFoldResult> innerTileSizes,
4287                                       ArrayRef<int64_t> innerDimsPos,
4288                                       ArrayRef<int64_t> outerDimsPerm) {
4289   AffineExpr dim0, dim1;
4290   bindDims(b.getContext(), dim0, dim1);
4291   auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4292     return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4293                                                  {v1, v2});
4294   };
4295 
4296   SmallVector<OpFoldResult> mixedSizes;
4297   for (auto [index, value] : llvm::enumerate(
4298            llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4299     if (ShapedType::isDynamic(value))
4300       mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
4301     else
4302       mixedSizes.push_back(b.getIndexAttr(value));
4303   }
4304   for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4305     int64_t dimPos = std::get<0>(it);
4306     OpFoldResult tileSize = std::get<1>(it);
4307     mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4308   }
4309   if (!outerDimsPerm.empty())
4310     applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4311 
4312   mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4313   auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4314   return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4315 }
4316 
4317 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4318                                      ArrayRef<int64_t> innerPermutation,
4319                                      ArrayRef<int64_t> outerPermutation) {
4320   PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4321       *this, innerPermutation, outerPermutation);
4322   Value transposedDest =
4323       createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4324                               metadata.innerDimsPos, metadata.outerDimsPerm);
4325   return b.create<PackOp>(loc, getSource(), transposedDest,
4326                           metadata.innerDimsPos, metadata.innerTiles,
4327                           getPaddingValue(), metadata.outerDimsPerm);
4328 }
4329 
4330 /// Returns true if the tiles and the tiled dims are constant.
4331 template <typename OpTy>
4332 bool areTilesAndTiledDimsAllConstant(OpTy op) {
4333   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4334                 "applies to only pack or unpack operations");
4335   ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4336                               ? op.getDestType()
4337                               : op.getSourceType();
4338   SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4339   for (auto [dimDest, tile] : llvm::zip(
4340            packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4341     std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4342     if (!constTileSize || ShapedType::isDynamic(dimDest))
4343       return false;
4344   }
4345   return true;
4346 }
4347 
4348 Speculation::Speculatability PackOp::getSpeculatability() {
4349   if (getPaddingValue())
4350     return Speculation::Speculatable;
4351 
4352   // The verifier rejects already operations if we can statically prove that the
4353   // sizes of the tiles do not divide perfectly the dimension; thus, check only
4354   // to have constant tiles and tiled inner dimensions.
4355   if (!areTilesAndTiledDimsAllConstant(*this))
4356     return Speculation::NotSpeculatable;
4357 
4358   return Speculation::Speculatable;
4359 }
4360 
4361 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4362 // dimensions for pack and unpack.
4363 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4364   if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4365     return false;
4366   if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4367     return true;
4368   // Outer dims permutation is optional.
4369   // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4370   // identity permutation.
4371   return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4372          isIdentityPermutation(unPackOp.getOuterDimsPerm());
4373 }
4374 
4375 // Return true if pack and unpack have the same tiles.
4376 // Same SSA values or same integer constants.
4377 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4378   auto packTiles = packOp.getMixedTiles();
4379   auto unPackTiles = unPackOp.getMixedTiles();
4380   if (packTiles.size() != unPackTiles.size())
4381     return false;
4382   for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4383     if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4384       return false;
4385   }
4386   return true;
4387 }
4388 
4389 /// Returns true if the pack op does not need a padding value.
4390 static bool paddingIsNotNeeded(PackOp op) {
4391   auto srcType = op.getSourceType();
4392   if (llvm::any_of(op.getInnerDimsPos(),
4393                    [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4394     return false;
4395   if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4396     return false;
4397   return !PackOp::requirePaddingValue(
4398       srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4399       op.getOuterDimsPerm(), op.getMixedTiles());
4400 }
4401 
4402 /// Returns true if the `srcShape` or `destShape` is different from the one in
4403 /// `packOp` and populates each with the inferred static shape.
4404 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4405                              SmallVectorImpl<int64_t> &destShape) {
4406   bool changeNeeded = false;
4407   srcShape.assign(packOp.getSourceType().getShape().begin(),
4408                   packOp.getSourceType().getShape().end());
4409   destShape.assign(packOp.getDestType().getShape().begin(),
4410                    packOp.getDestType().getShape().end());
4411   llvm::SmallSetVector<int64_t, 4> innerDims;
4412   innerDims.insert(packOp.getInnerDimsPos().begin(),
4413                    packOp.getInnerDimsPos().end());
4414   SmallVector<int64_t> inverseOuterDimsPerm;
4415   if (!packOp.getOuterDimsPerm().empty())
4416     inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4417   int srcRank = packOp.getSourceRank();
4418   for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4419     if (innerDims.contains(i))
4420       continue;
4421     int64_t srcPos = i;
4422     int64_t destPos = i;
4423     if (!inverseOuterDimsPerm.empty())
4424       destPos = inverseOuterDimsPerm[srcPos];
4425     if (ShapedType::isDynamic(srcShape[srcPos]) ==
4426         ShapedType::isDynamic(destShape[destPos])) {
4427       continue;
4428     }
4429     int64_t size = srcShape[srcPos];
4430     if (ShapedType::isDynamic(size))
4431       size = destShape[destPos];
4432     srcShape[srcPos] = size;
4433     destShape[destPos] = size;
4434     changeNeeded = true;
4435   }
4436   return changeNeeded;
4437 }
4438 
4439 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4440   // Fold an pack(unpack(x)) to x.
4441   if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4442     if (unPackOp.getSourceType() != packOp.getDestType())
4443       return failure();
4444     if (packOp.getPaddingValue() ||
4445         !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4446         !haveSameTiles(packOp, unPackOp))
4447       return failure();
4448     rewriter.replaceOp(packOp, unPackOp.getSource());
4449     return success();
4450   }
4451 
4452   // Fold optional PaddingValue operand away if padding is not needed.
4453   if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4454     rewriter.startOpModification(packOp);
4455     packOp.getPaddingValueMutable().clear();
4456     rewriter.finalizeOpModification(packOp);
4457     return success();
4458   }
4459 
4460   // Insert tensor.cast ops if static shape inference is available..
4461   SmallVector<int64_t> srcShape, destShape;
4462   if (inferStaticShape(packOp, srcShape, destShape)) {
4463     Location loc = packOp.getLoc();
4464     Value source = packOp.getSource();
4465     if (srcShape != packOp.getSourceType().getShape()) {
4466       auto newSrcType = packOp.getSourceType().clone(srcShape);
4467       source =
4468           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4469     }
4470     Value dest = packOp.getDest();
4471     RankedTensorType originalResultType = packOp.getDestType();
4472     bool needUpdateDestType = (destShape != originalResultType.getShape());
4473     if (needUpdateDestType) {
4474       auto newDestType = packOp.getDestType().clone(destShape);
4475       dest =
4476           rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4477     }
4478     rewriter.modifyOpInPlace(packOp, [&] {
4479       packOp.getSourceMutable().assign(source);
4480       packOp.getDestMutable().assign(dest);
4481       packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4482     });
4483     // Insert a cast if needed
4484     if (needUpdateDestType) {
4485       rewriter.setInsertionPointAfter(packOp);
4486       auto castOp =
4487           rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4488       rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
4489     }
4490     return success();
4491   }
4492 
4493   return failure();
4494 }
4495 
4496 template <typename PackOrUnpackOp>
4497 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4498                            RankedTensorType packedTensorType) {
4499   static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4500                     std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4501                 "Function meant for pack/unpack");
4502   // This is a pad if packing only adds ones and we don't transpose dimensions.
4503 
4504   // Check that we are not transposing any dimensions.
4505   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4506   int64_t numPackedDims = innerDimsPos.size();
4507   auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4508   if (orderedDims != innerDimsPos) {
4509     // Dimensions don't happen in order.
4510     return false;
4511   }
4512 
4513   ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4514   int64_t packedRank = packedTensorType.getRank();
4515   // At this point we know that we are taking numPackedDims outer
4516   // dimensions and pushing them all the way as the inner most dimensions.
4517   // What's left on the outer most dimensions is, in this order:
4518   // - the factor of the packed dimensions, then
4519   // - the untouched dimensions
4520   // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4521   // if all the dimensions that bubble outerward are ones.
4522   // Therefore check that all the dimensions but the numPackedDims inner most
4523   // ones are ones.
4524   return llvm::all_of(
4525       llvm::seq<int64_t>(0, packedRank - numPackedDims),
4526       [&packedShape](int64_t i) { return packedShape[i] == 1; });
4527 }
4528 
4529 bool PackOp::isLikePad() {
4530   auto packedTensorType =
4531       llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4532   return isLikePadUnPad(*this, packedTensorType);
4533 }
4534 
4535 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4536   std::optional<Attribute> paddingValue;
4537   if (auto pad = adaptor.getPaddingValue())
4538     paddingValue = pad;
4539   if (OpFoldResult reshapedSource = reshapeConstantSource(
4540           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4541           getDestType(), paddingValue))
4542     return reshapedSource;
4543   return {};
4544 }
4545 
4546 //===----------------------------------------------------------------------===//
4547 // UnPackOp
4548 //===----------------------------------------------------------------------===//
4549 
4550 void UnPackOp::getAsmResultNames(
4551     function_ref<void(Value, StringRef)> setNameFn) {
4552   setNameFn(getResult(), "unpack");
4553 }
4554 
4555 LogicalResult
4556 UnPackOp::reifyResultShapes(OpBuilder &builder,
4557                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4558   return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4559 }
4560 
4561 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4562   return getDimAndTileMappingImpl(*this);
4563 }
4564 
4565 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4566   return getMixedTilesImpl(*this);
4567 }
4568 
4569 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4570   return getStaticTilesImpl(*this);
4571 }
4572 
4573 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
4574   ShapedType destType = getDestType();
4575   int64_t destRank = destType.getRank();
4576   return getSourceType().getShape().take_front(destRank);
4577 }
4578 
4579 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
4580   auto innerDimsPos = getInnerDimsPos();
4581   auto packedShape = getSourceType().getShape();
4582   SmallVector<int64_t> res;
4583 
4584   for (auto index : innerDimsPos)
4585     res.push_back(packedShape[index]);
4586 
4587   return res;
4588 }
4589 
4590 LogicalResult UnPackOp::verify() {
4591   return commonVerifierPackAndUnPackOp(*this);
4592 }
4593 
4594 Speculation::Speculatability UnPackOp::getSpeculatability() {
4595   // See PackOp::getSpeculatability.
4596   if (!areTilesAndTiledDimsAllConstant(*this))
4597     return Speculation::NotSpeculatable;
4598 
4599   return Speculation::Speculatable;
4600 }
4601 
4602 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4603                      Value dest, ArrayRef<int64_t> innerDimsPos,
4604                      ArrayRef<OpFoldResult> innerTiles,
4605                      ArrayRef<int64_t> outerDimsPerm) {
4606   assert(innerDimsPos.size() == innerTiles.size() &&
4607          "number of tile sizes specified must match the specified number of "
4608          "original dimensions to be tiled");
4609   SmallVector<int64_t> staticTileSizes;
4610   SmallVector<Value> dynamicTileSizes;
4611   dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4612   build(builder, state, dest.getType(), source, dest,
4613         outerDimsPerm.empty() ? nullptr
4614                               : builder.getDenseI64ArrayAttr(outerDimsPerm),
4615         builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4616         builder.getDenseI64ArrayAttr(staticTileSizes));
4617 }
4618 
4619 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4620                                         Value source,
4621                                         ArrayRef<OpFoldResult> innerTileSizes,
4622                                         ArrayRef<int64_t> innerDimsPos,
4623                                         ArrayRef<int64_t> outerDimsPerm) {
4624   AffineExpr sym0, sym1;
4625   bindSymbols(b.getContext(), sym0, sym1);
4626   auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4627     return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4628   };
4629 
4630   SmallVector<OpFoldResult> mixedSizes;
4631   auto srcType = llvm::cast<RankedTensorType>(source.getType());
4632   for (auto i :
4633        llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4634     if (srcType.isDynamicDim(i))
4635       mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
4636     else
4637       mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4638   }
4639   if (!outerDimsPerm.empty()) {
4640     applyPermutationToVector<OpFoldResult>(
4641         mixedSizes, invertPermutationVector(outerDimsPerm));
4642   }
4643 
4644   for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4645     mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4646 
4647   auto elemType = srcType.getElementType();
4648   return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4649 }
4650 
4651 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4652                                          Value transposedSource,
4653                                          ArrayRef<int64_t> innerPermutation,
4654                                          ArrayRef<int64_t> outerPermutation) {
4655   PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4656       *this, innerPermutation, outerPermutation);
4657   return b.create<UnPackOp>(loc, transposedSource, getDest(),
4658                             metadata.innerDimsPos, metadata.innerTiles,
4659                             metadata.outerDimsPerm);
4660 }
4661 
4662 /// Returns true if the `srcShape` or `destShape` is different from the one in
4663 /// `op` and populates each with the inferred static shape.
4664 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4665                              SmallVectorImpl<int64_t> &destShape) {
4666   bool changeNeeded = false;
4667   srcShape.assign(op.getSourceType().getShape().begin(),
4668                   op.getSourceType().getShape().end());
4669   destShape.assign(op.getDestType().getShape().begin(),
4670                    op.getDestType().getShape().end());
4671   llvm::SmallSetVector<int64_t, 4> innerDims;
4672   innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4673   SmallVector<int64_t> inverseOuterDimsPerm;
4674   if (!op.getOuterDimsPerm().empty())
4675     inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
4676   int destRank = op.getDestRank();
4677   for (auto i : llvm::seq<int64_t>(0, destRank)) {
4678     if (innerDims.contains(i))
4679       continue;
4680     int64_t srcPos = i;
4681     int64_t destPos = i;
4682     if (!inverseOuterDimsPerm.empty())
4683       srcPos = inverseOuterDimsPerm[destPos];
4684     if (ShapedType::isDynamic(srcShape[srcPos]) ==
4685         ShapedType::isDynamic(destShape[destPos])) {
4686       continue;
4687     }
4688     int64_t size = srcShape[srcPos];
4689     if (ShapedType::isDynamic(size))
4690       size = destShape[destPos];
4691     srcShape[srcPos] = size;
4692     destShape[destPos] = size;
4693     changeNeeded = true;
4694   }
4695   return changeNeeded;
4696 }
4697 
4698 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4699                                      PatternRewriter &rewriter) {
4700   /// unpack(pack(x)) -> x
4701   if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4702     if (packOp.getSourceType() != unPackOp.getDestType())
4703       return failure();
4704     if (packOp.getPaddingValue() ||
4705         !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4706         !haveSameTiles(packOp, unPackOp))
4707       return failure();
4708     rewriter.replaceOp(unPackOp, packOp.getSource());
4709     return success();
4710   }
4711   /// unpack(destinationStyleOp(x)) -> unpack(x)
4712   if (auto dstStyleOp =
4713           unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4714     auto destValue = cast<OpResult>(unPackOp.getDest());
4715     Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4716     rewriter.modifyOpInPlace(unPackOp,
4717                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4718     return success();
4719   }
4720 
4721   // Insert tensor.cast ops if static shape inference is available..
4722   SmallVector<int64_t> srcShape, destShape;
4723   if (inferStaticShape(unPackOp, srcShape, destShape)) {
4724     Location loc = unPackOp.getLoc();
4725     Value source = unPackOp.getSource();
4726     if (srcShape != unPackOp.getSourceType().getShape()) {
4727       auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4728       source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4729                                                unPackOp.getSource());
4730     }
4731     Value dest = unPackOp.getDest();
4732     if (destShape != unPackOp.getDestType().getShape()) {
4733       auto newDestType = unPackOp.getDestType().clone(destShape);
4734       dest =
4735           rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4736     }
4737     Value newOp = rewriter.create<tensor::UnPackOp>(
4738         loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4739         unPackOp.getOuterDimsPerm());
4740     rewriter.replaceOpWithNewOp<tensor::CastOp>(
4741         unPackOp, unPackOp.getResult().getType(), newOp);
4742     return success();
4743   }
4744 
4745   return failure();
4746 }
4747 
4748 bool UnPackOp::isLikeUnPad() {
4749   RankedTensorType packedTensorType = getSourceType();
4750   return isLikePadUnPad(*this, packedTensorType);
4751 }
4752 
4753 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4754   if (OpFoldResult reshapedSource = reshapeConstantSource(
4755           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4756           getResult().getType()))
4757     return reshapedSource;
4758   return {};
4759 }
4760 
4761 //===----------------------------------------------------------------------===//
4762 // Common Canonicalizers and Folders.
4763 //===----------------------------------------------------------------------===//
4764 bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4765   // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4766   // 2. Exclude DPS ops that are also LoopLike from this interface as they
4767   // might need special handling of attached regions.
4768   if (isa<InsertSliceOp>(op.getOperation()) ||
4769       isa<LoopLikeOpInterface>(op.getOperation()))
4770     return false;
4771 
4772   // If no operand comes from a tensor::CastOp and can be folded then fail.
4773   bool hasTensorCastOperand =
4774       llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4775         if (llvm::isa<BlockArgument>(opOperand.get()))
4776           return false;
4777         auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4778         return castOp && canFoldIntoConsumerOp(castOp);
4779       });
4780 
4781   return hasTensorCastOperand;
4782 }
4783 
4784 static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
4785                                          SmallVector<Type> &newResTy) {
4786   SmallVector<Value> newOperands;
4787   newOperands.reserve(op->getNumOperands());
4788 
4789   // Assumes that the result has dpsInits followed by nonDpsInits.
4790   int64_t dpsInitIdx = 0;
4791   for (OpOperand &opOperand : op->getOpOperands()) {
4792     auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4793     bool fold = canFoldIntoConsumerOp(tensorCastOp);
4794     newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4795     if (op.isDpsInit(&opOperand) &&
4796         !llvm::isa<MemRefType>(newOperands.back().getType()))
4797       newResTy[dpsInitIdx++] = newOperands.back().getType();
4798   }
4799   return newOperands;
4800 }
4801 
4802 // Given the (potentially) updated packed type, `newPackedTy`, generates an
4803 // updated mixed-tile-sizes attribute. A tile size is updated only
4804 // when:
4805 //  * a dim from newPackedTy is static, and
4806 //  * the corresponding size from mixedTiles is still dynamic.
4807 // Otherwise, the original tile size is preserved.
4808 // Note - packed-type-dim and mixed-tile-size should always match!
4809 static SmallVector<OpFoldResult>
4810 getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
4811                      SmallVector<OpFoldResult> mixedTiles) {
4812   SmallVector<OpFoldResult> newMixedTileSizes;
4813   for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4814                                .getShape()
4815                                .take_back(mixedTiles.size()),
4816                            mixedTiles)) {
4817     int64_t shape = std::get<0>(it);
4818     if (shape == ShapedType::kDynamic) {
4819       newMixedTileSizes.push_back(std::get<1>(it));
4820       continue;
4821     }
4822 
4823     // If the current result dim is static, update the dynamic mixed-size
4824     // (provided the original value is dynamic).
4825     OpFoldResult tile = std::get<1>(it);
4826     if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4827       // Already a constant
4828       newMixedTileSizes.push_back(tile);
4829     } else {
4830       assert(getConstantIntValue(tile).value() == shape &&
4831              "tile size and dim size don't match!");
4832       newMixedTileSizes.push_back(
4833           (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4834     }
4835   }
4836 
4837   return newMixedTileSizes;
4838 }
4839 
4840 /// Folds a tensor.cast op into a consuming tensor::PackOp op if the
4841 /// `tensor.cast` has source that is more static than the consuming op.
4842 ///
4843 /// Example:
4844 /// ```mlir
4845 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4846 ///   %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4847 /// ```
4848 ///
4849 /// folds into:
4850 ///
4851 /// ```mlir
4852 ///   %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4853 /// ```
4854 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4855   using OpRewritePattern<PackOp>::OpRewritePattern;
4856 
4857   LogicalResult matchAndRewrite(PackOp op,
4858                                 PatternRewriter &rewriter) const override {
4859     if (!foldTensorCastPrecondition(op))
4860       return failure();
4861 
4862     SmallVector<Type> newResultTypes(op->getResultTypes());
4863     SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4864 
4865     // Get the updated mixed-tile-sizes attribute.
4866     SmallVector<OpFoldResult> newMixedTileSizes =
4867         getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
4868 
4869     // Clone op.
4870     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4871     // this point. However, in practice, we use them for things that we'd like
4872     // to preserve. Implement a better abstraction.
4873     PackOp newOp = rewriter.create<PackOp>(
4874         op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
4875         newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
4876     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
4877 
4878     // Replace op.
4879     Value oldResult = op.getResult();
4880     Value newResult = newOp.getResult();
4881     Value replacement = (newResult.getType() != oldResult.getType())
4882                             ? rewriter.create<tensor::CastOp>(
4883                                   op->getLoc(), oldResult.getType(), newResult)
4884                             : newResult;
4885 
4886     rewriter.replaceOp(op, {replacement});
4887 
4888     return success();
4889   }
4890 };
4891 
4892 /// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
4893 /// `tensor.cast` has source that is more static than the consuming op.
4894 ///
4895 /// Example:
4896 /// ```mlir
4897 ///   %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4898 ///   %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
4899 /// ```
4900 ///
4901 /// folds into:
4902 ///
4903 /// ```mlir
4904 ///   %2 = tensor.unpack %0  ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4905 /// ```
4906 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
4907   using OpRewritePattern<UnPackOp>::OpRewritePattern;
4908 
4909   LogicalResult matchAndRewrite(UnPackOp op,
4910                                 PatternRewriter &rewriter) const override {
4911     if (!foldTensorCastPrecondition(op))
4912       return failure();
4913 
4914     SmallVector<Type> newResultTypes(op->getResultTypes());
4915     SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4916     Value sourceTensor = newOperands[0];
4917 
4918     // Get the updated mixed-tile-sizes attribute.
4919     SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
4920         rewriter, sourceTensor.getType(), op.getMixedTiles());
4921 
4922     // Clone op.
4923     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4924     // this point. However, in practice, we use them for things that we'd like
4925     // to preserve. Implement a better abstraction.
4926     UnPackOp newOp = rewriter.create<UnPackOp>(
4927         op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
4928         newMixedTileSizes, op.getOuterDimsPerm());
4929     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
4930 
4931     // Replace op.
4932     Value oldResult = op.getResult();
4933     Value newResult = newOp.getResult();
4934     Value replacement = (newResult.getType() != oldResult.getType())
4935                             ? rewriter.create<tensor::CastOp>(
4936                                   op->getLoc(), oldResult.getType(), newResult)
4937                             : newResult;
4938 
4939     rewriter.replaceOp(op, {replacement});
4940 
4941     return success();
4942   }
4943 };
4944 
4945 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4946 /// the `tensor.cast` has source that is more static than the consuming op.
4947 ///
4948 /// Example:
4949 /// ```mlir
4950 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4951 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
4952 /// ```
4953 ///
4954 /// folds into:
4955 ///
4956 /// ```mlir
4957 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
4958 /// ```
4959 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4960 /// can add the pattern to their canonicalizers.
4961 struct FoldTensorCastProducerOp
4962     : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4963   using OpInterfaceRewritePattern<
4964       DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4965 
4966   LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4967                                 PatternRewriter &rewriter) const override {
4968 
4969     // Reject tensor::PackOp - there's dedicated pattern for that instead.
4970     if (!foldTensorCastPrecondition(op) ||
4971         isa<tensor::PackOp, tensor::UnPackOp>(*op))
4972       return failure();
4973 
4974     SmallVector<Type> newResultTypes(op->getResultTypes());
4975     SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4976 
4977     // Clone op
4978     auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4979 
4980     SmallVector<Value, 4> replacements;
4981     replacements.reserve(newOp->getNumResults());
4982     for (auto [oldResult, newResult] :
4983          llvm::zip(op->getResults(), newOp->getResults())) {
4984       if (newResult.getType() != oldResult.getType()) {
4985         replacements.push_back(rewriter.create<tensor::CastOp>(
4986             op->getLoc(), oldResult.getType(), newResult));
4987       } else {
4988         replacements.push_back(newResult);
4989       }
4990     }
4991     rewriter.replaceOp(op, replacements);
4992 
4993     return success();
4994   }
4995 };
4996 
4997 //===----------------------------------------------------------------------===//
4998 // TensorDialect
4999 //===----------------------------------------------------------------------===//
5000 
5001 void TensorDialect::getCanonicalizationPatterns(
5002     RewritePatternSet &results) const {
5003   results.add<FoldTensorCastPackOp>(getContext());
5004   results.add<FoldTensorCastUnPackOp>(getContext());
5005   results.add<FoldTensorCastProducerOp>(getContext());
5006 }
5007 
5008 //===----------------------------------------------------------------------===//
5009 // TableGen'd op method definitions
5010 //===----------------------------------------------------------------------===//
5011 
5012 #define GET_OP_CLASSES
5013 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
5014