xref: /llvm-project/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (revision 565f3bd641dfdfefd9cf932cf94cc3fbd0b30d33)
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/AsmParser/AsmParser.h"
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
18 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
21 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
22 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
23 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24 #include "mlir/Dialect/Linalg/Utils/Utils.h"
25 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
29 #include "mlir/Dialect/Transform/IR/TransformOps.h"
30 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
31 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
32 #include "mlir/Dialect/Transform/Utils/Utils.h"
33 #include "mlir/Dialect/Utils/IndexingUtils.h"
34 #include "mlir/Dialect/Utils/StaticValueUtils.h"
35 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
36 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
37 #include "mlir/IR/BuiltinTypeInterfaces.h"
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Interfaces/TilingInterface.h"
41 #include "mlir/Support/LLVM.h"
42 #include "mlir/Support/TypeID.h"
43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
49 
50 using namespace mlir;
51 using namespace mlir::linalg;
52 using namespace mlir::transform;
53 
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
58 
59 /// Attempts to apply the pattern specified as template argument to the given
60 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
61 /// function that returns the "main" result or failure. Returns failure if the
62 /// pattern failed to apply. Extra arguments are forwarded to the pattern
63 /// constructor.
64 template <typename PatternTy, typename... Args>
65 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
66   // Check if the given operation has the type expected by the pattern.
67   using OpTy = typename llvm::function_traits<
68       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69   auto op = dyn_cast<OpTy>(operation);
70   if (!op)
71     return failure();
72 
73   // Apply the pattern directly to the op.
74   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
75   // We want to discourage direct use of PatternRewriter in APIs but In this
76   // very specific case, an IRRewriter is not enough.
77   struct TrivialPatternRewriter : public PatternRewriter {
78   public:
79     explicit TrivialPatternRewriter(MLIRContext *context)
80         : PatternRewriter(context) {}
81   };
82   TrivialPatternRewriter rewriter(operation->getContext());
83   rewriter.setInsertionPoint(operation);
84   auto result = pattern.returningMatchAndRewrite(op, rewriter);
85   if (failed(result))
86     return failure();
87   return cast<LinalgOp>(result->getOperation());
88 }
89 
90 /// Assuming that `ofr` is an index attr or a param of index type
91 /// or a transform dialect handle mapped to exactly one op
92 /// with one index result, return that value.
93 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
94     transform::TransformState &state, TransformOpInterface transformOp,
95     SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
96   for (OpFoldResult ofr : ofrs) {
97     if (auto attr = dyn_cast<Attribute>(ofr)) {
98       if (!isa<IntegerAttr>(attr))
99         return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
100       result.push_back(ofr);
101       continue;
102     }
103 
104     Value transformValue = cast<Value>(ofr);
105     if (isa<TransformParamTypeInterface>(transformValue.getType())) {
106       ArrayRef<Attribute> params = state.getParams(transformValue);
107       if (params.size() != 1)
108         return transformOp.emitDefiniteFailure()
109                << "requires exactly one parameter associated";
110       result.push_back(params[0]);
111       continue;
112     }
113 
114     auto payloadOps = state.getPayloadOps(transformValue);
115     if (!llvm::hasSingleElement(payloadOps)) {
116       DiagnosedSilenceableFailure diag =
117           transformOp.emitSilenceableError()
118           << "handle must be mapped to exactly one payload op";
119       diag.attachNote(transformValue.getLoc())
120           << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
121       return diag;
122     }
123 
124     Operation *op = *payloadOps.begin();
125     if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
126       DiagnosedSilenceableFailure diag =
127           transformOp.emitSilenceableError()
128           << "payload op must have exactly 1 index result";
129       diag.attachNote(op->getLoc())
130           << "has " << op->getNumResults() << " results";
131       return diag;
132     }
133     result.push_back(op->getResult(0));
134   }
135 
136   return DiagnosedSilenceableFailure::success();
137 }
138 
139 // Given a list of params that are index attrs or a list of OpFoldResults
140 // that are either index attrs or op handles, return a list of OpFoldResults
141 // of index attrs or a list of OpFoldResults where all op handles are
142 // replaced with the first (and only) OpResult of that payload op.
143 // (There must be exactly one parameter associated with the AnyParamType or
144 // one mapped payload op which must have exactly one index result.)
145 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
146     transform::TransformState &state, TransformOpInterface transformOp,
147     SmallVector<OpFoldResult> &result, Value packedHandle) {
148   if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
149     ArrayRef<Attribute> params = state.getParams(packedHandle);
150     for (auto param : params) {
151       if (!isa<IntegerAttr>(param))
152         return transformOp.emitDefiniteFailure()
153                << "expected the parameter to be associated with an integer "
154                   "attribute";
155       result.push_back(param);
156     }
157     return DiagnosedSilenceableFailure::success();
158   }
159 
160   for (Operation *op : state.getPayloadOps(packedHandle)) {
161     if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
162       DiagnosedSilenceableFailure diag =
163           transformOp.emitSilenceableError()
164           << "payload op must have exactly 1 index result";
165       diag.attachNote(op->getLoc())
166           << "has " << op->getNumResults() << " results";
167       return diag;
168     }
169     result.push_back(op->getResult(0));
170   }
171 
172   return DiagnosedSilenceableFailure::success();
173 }
174 
175 /// When possible, converts each `OpFoldResult` in `mixedResult` to
176 /// an integer if the value can be statically inferred.  If a result
177 /// is a `Value` then it must be either a `ParamType` or a handle
178 /// to an a constant like op.
179 static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
180     TransformState &state, TransformOpInterface &transformOp,
181     ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
182   for (OpFoldResult paramOrHandle : mixedResults) {
183     if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
184       reified.push_back(cast<IntegerAttr>(attr).getInt());
185       continue;
186     } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
187       ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
188       if (params.size() != 1)
189         return transformOp.emitSilenceableError() << "expected a single param";
190       reified.push_back(
191           cast<IntegerAttr>(params.front()).getValue().getSExtValue());
192       continue;
193     }
194 
195     Value handle = cast<Value>(paramOrHandle);
196     if (!isa<TransformHandleTypeInterface>(handle.getType()))
197       return transformOp.emitSilenceableError() << "unexpected value handle";
198     auto payload = state.getPayloadOps(handle);
199     if (!llvm::hasSingleElement(payload))
200       return transformOp.emitSilenceableError()
201              << "requires param or handle that is mapped to 1 payload op";
202 
203     Operation *paramOrHandlePayloadOp = *payload.begin();
204     if (paramOrHandlePayloadOp->getNumResults() != 1 ||
205         !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
206       return transformOp.emitSilenceableError()
207              << "requires param or handle to be result of op with 1 index "
208                 "result";
209     }
210 
211     IntegerAttr attr;
212     if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
213       return transformOp.emitSilenceableError()
214              << "requires param or handle to be the result of a constant like "
215                 "op";
216 
217     reified.push_back(attr.getInt());
218   }
219   return DiagnosedSilenceableFailure::success();
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Apply...PatternsOp
224 //===----------------------------------------------------------------------===//
225 
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
227     RewritePatternSet &patterns) {
228   linalg::populateEraseUnnecessaryInputsPatterns(patterns);
229 }
230 
231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
232     RewritePatternSet &patterns) {
233   linalg::populateDecomposePackUnpackPatterns(patterns);
234 }
235 
236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
237     RewritePatternSet &patterns) {
238   linalg::populateDecomposePadPatterns(patterns);
239 }
240 
241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
242     RewritePatternSet &patterns) {
243   linalg::ControlDropUnitDims options;
244   linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
245 }
246 
247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
248     RewritePatternSet &patterns) {
249   linalg::ControlDropUnitDims options;
250   options.rankReductionStrategy =
251       linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
252   linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
253 }
254 
255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
256     RewritePatternSet &patterns) {
257   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
258 }
259 
260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
261     RewritePatternSet &patterns) {
262   linalg::populateFoldAddIntoDestPatterns(patterns);
263 }
264 
265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
266     RewritePatternSet &patterns) {
267   linalg::populatePadOpVectorizationPatterns(patterns);
268   linalg::populateInsertSliceVectorizationPatterns(patterns);
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // BufferizeToAllocationOp
273 //===----------------------------------------------------------------------===//
274 
275 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
276                                                OperationState &result,
277                                                Value target,
278                                                Attribute memorySpace) {
279   SmallVector<Type> resultTypes;
280   resultTypes.push_back(b.getType<transform::AnyValueType>());
281   resultTypes.push_back(b.getType<transform::AnyOpType>());
282   return build(b, result,
283                /*resultTypes=*/resultTypes,
284                /*target=*/target,
285                /*memorySpace=*/memorySpace);
286 }
287 
288 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
289                                                OperationState &result,
290                                                Value target,
291                                                int64_t memorySpace) {
292   SmallVector<Type> resultTypes;
293   resultTypes.push_back(b.getType<transform::AnyValueType>());
294   resultTypes.push_back(b.getType<transform::AnyOpType>());
295   return build(b, result,
296                /*resultTypes=*/resultTypes,
297                /*target=*/target,
298                /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
299 }
300 
301 namespace {
302 class NewOpsListener : public RewriterBase::ForwardingListener {
303 public:
304   using RewriterBase::ForwardingListener::ForwardingListener;
305 
306   SmallVector<Operation *> getNewOps() const {
307     return SmallVector<Operation *>(newOps.begin(), newOps.end());
308   }
309 
310 private:
311   void notifyOperationInserted(Operation *op,
312                                OpBuilder::InsertPoint previous) override {
313     ForwardingListener::notifyOperationInserted(op, previous);
314     // We only care about newly created ops.
315     if (previous.isSet())
316       return;
317     auto inserted = newOps.insert(op);
318     (void)inserted;
319     assert(inserted.second && "expected newly created op");
320   }
321 
322   void notifyOperationErased(Operation *op) override {
323     ForwardingListener::notifyOperationErased(op);
324     op->walk([&](Operation *op) { newOps.erase(op); });
325   }
326 
327   DenseSet<Operation *> newOps;
328 };
329 } // namespace
330 
331 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
332     transform::TransformRewriter &rewriter,
333     transform::TransformResults &results, transform::TransformState &state) {
334   // Attach listener to keep track of newly created ops.
335   OpBuilder::Listener *previousListener = rewriter.getListener();
336   auto resetListener =
337       llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
338   NewOpsListener newOpsListener(previousListener);
339   rewriter.setListener(&newOpsListener);
340 
341   linalg::BufferizeToAllocationOptions options;
342   if (getMemcpyOp() == "bufferization.materialize_in_destination") {
343     options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
344         MaterializeInDestination;
345   } else if (getMemcpyOp() == "memref.copy") {
346     options.memcpyOp =
347         linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy;
348   } else if (getMemcpyOp() == "linalg.copy") {
349     options.memcpyOp =
350         linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy;
351   } else {
352     llvm_unreachable("invalid memcpy op");
353   }
354   if (getAllocOp() == "memref.alloc") {
355     options.allocOp =
356         linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc;
357   } else if (getAllocOp() == "memref.alloca") {
358     options.allocOp =
359         linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca;
360   } else {
361     llvm_unreachable("invalid alloc op");
362   }
363   options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
364   options.emitDealloc = getEmitDealloc();
365 
366   // Bufferize ops.
367   Attribute memorySpace =
368       getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
369   SmallVector<Value> allocatedBuffers;
370   for (Operation *op : state.getPayloadOps(getTarget())) {
371     Value buffer =
372         linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
373     if (!buffer) {
374       DiagnosedSilenceableFailure diag = emitSilenceableError()
375                                          << "failed to bufferize operation";
376       diag.attachNote(op->getLoc()) << "target payload op";
377       return diag;
378     }
379     allocatedBuffers.push_back(buffer);
380   }
381 
382   // Set results.
383   results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
384   results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
385   return DiagnosedSilenceableFailure::success();
386 }
387 
388 void transform::BufferizeToAllocationOp::getEffects(
389     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
390   if (getBufferizeDestinationOnly()) {
391     // The destination is replaced with a newly allocated buffer, but the op
392     // itself remains in place.
393     onlyReadsHandle(getTargetMutable(), effects);
394   } else {
395     consumesHandle(getTargetMutable(), effects);
396   }
397   producesHandle(getOperation()->getOpResults(), effects);
398   modifiesPayload(effects);
399 }
400 
401 LogicalResult transform::BufferizeToAllocationOp::verify() {
402   if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
403       getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
404     return emitOpError() << "unsupported memcpy op";
405   if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
406     return emitOpError() << "unsupported alloc op";
407   return success();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // DecomposeOp
412 //===----------------------------------------------------------------------===//
413 
414 DiagnosedSilenceableFailure
415 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
416                                    LinalgOp target,
417                                    transform::ApplyToEachResultList &results,
418                                    transform::TransformState &state) {
419 #define DOWNSCALE(trans)                                                       \
420   {                                                                            \
421     FailureOr<LinalgOp> res = tryApply<trans>(target);                         \
422     if (succeeded(res)) {                                                      \
423       results.push_back(*res);                                                 \
424       return DiagnosedSilenceableFailure::success();                           \
425     }                                                                          \
426   }
427 
428 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
429 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
430 
431   DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
432   DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
433   DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
434   DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
435   DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
436   DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
437   DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
438   DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
439   DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
440   DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
441   DOWNSCALE(DownscaleConv2DOp)
442 #undef DOWNSCALE_NORMAL
443 #undef DOWNSCALE_CALL
444 #undef DOWNSCALE
445   return emitDefaultSilenceableFailure(target);
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // DecomposeInterfaceOp
450 //===----------------------------------------------------------------------===//
451 
452 // Decompose the target operation if it implements the AggregatedOpInterface.
453 // Push the decomposed operations (the ones that replaces the values produced by
454 // \p target) in the `results`.
455 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
456     transform::TransformRewriter &rewriter, Operation *target,
457     transform::ApplyToEachResultList &results,
458     transform::TransformState &state) {
459   auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
460   if (!decomposableOp) {
461     failed(rewriter.notifyMatchFailure(target,
462                                        "payload is not a decomposable op"));
463     return emitDefaultSilenceableFailure(target);
464   }
465 
466   FailureOr<SmallVector<Value>> maybeNewResults =
467       decomposableOp.decomposeOperation(rewriter);
468   if (failed(maybeNewResults))
469     return emitDefaultSilenceableFailure(target);
470 
471   rewriter.replaceOp(decomposableOp, *maybeNewResults);
472   for (Value val : *maybeNewResults) {
473     Operation *definition = val.getDefiningOp();
474     if (definition)
475       results.push_back(definition);
476   }
477   return DiagnosedSilenceableFailure::success();
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // EliminateLinalgOpAnchoredEmptyTensorsOp
482 //===----------------------------------------------------------------------===//
483 
484 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
485     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
486   onlyReadsHandle(getTargetMutable(), effects);
487   modifiesPayload(effects);
488 }
489 
490 DiagnosedSilenceableFailure
491 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
492     transform::TransformRewriter &rewriter, TransformResults &transformResults,
493     TransformState &state) {
494   bufferization::OneShotBufferizationOptions options;
495   options.allowReturnAllocsFromLoops = true;
496 
497   for (Operation *target : state.getPayloadOps(getTarget())) {
498     bufferization::OneShotAnalysisState state(target, options);
499     if (failed(analyzeOp(target, state)))
500       return mlir::emitSilenceableFailure(target->getLoc())
501              << "failed to analyze op";
502     if (failed(linalg::linalgOpAnchoredEmptyTensorEliminationStep(
503             rewriter, target, state)))
504       return mlir::emitSilenceableFailure(target->getLoc())
505              << "failed to eliminate LinalgOp anchored tensor.empty ops";
506   }
507   return DiagnosedSilenceableFailure::success();
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // FuseOp
512 //===----------------------------------------------------------------------===//
513 
514 /// Apply a tiling transformation to all payload ops and store both the
515 /// tiled operation as well as the created tile loops.
516 template <typename Range>
517 static LogicalResult applyTilingToAll(
518     RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
519     unsigned numLoops, transform::TransformResults &transformResults,
520     function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
521         applyFn) {
522   SmallVector<Operation *> tiledLinalgOps;
523   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
524 
525   for (Operation *target : payloadOps) {
526     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
527     if (!tilingInterfaceOp)
528       return transformOp->emitError("only TilingInterface ops are supported");
529 
530     rewriter.setInsertionPoint(target);
531     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
532         applyFn(tilingInterfaceOp);
533     if (failed(tiledResults))
534       return failure();
535 
536     // Perform the replacement of tiled and fused values.
537     SmallVector<Operation *> opsToReplace{target};
538     llvm::append_range(opsToReplace, tiledResults->fusedProducers);
539     for (Operation *toReplace : opsToReplace) {
540       for (OpResult res : toReplace->getResults())
541         if (auto replacement = tiledResults->replacements.lookup(res))
542           rewriter.replaceAllUsesWith(res, replacement);
543       if (toReplace->use_empty()) {
544         rewriter.eraseOp(toReplace);
545       }
546     }
547 
548     // Report back the relevant handles to the transform op.
549     tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
550     assert(tiledResults->loops.size() == numLoops &&
551            "Mismatched number of loops, tile and fuse transform should have "
552            "failed");
553     for (unsigned int i = 0; i < numLoops; ++i)
554       loopOps[i].push_back(tiledResults->loops[i]);
555   }
556 
557   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
558   for (unsigned int i = 0; i < numLoops; ++i)
559     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
560 
561   return success();
562 }
563 
564 DiagnosedSilenceableFailure
565 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
566                          mlir::transform::TransformResults &transformResults,
567                          mlir::transform::TransformState &state) {
568   SmallVector<int64_t> tileSizes =
569       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
570   SmallVector<int64_t> tileInterchange =
571       extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
572 
573   scf::SCFTilingOptions tilingOptions;
574   tilingOptions.interchangeVector = tileInterchange;
575   SmallVector<OpFoldResult> tileSizesOfr =
576       getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
577   tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
578   scf::SCFTileAndFuseOptions tileAndFuseOptions;
579   tileAndFuseOptions.tilingOptions = tilingOptions;
580 
581   if (getApplyCleanup()) {
582     MLIRContext *context = rewriter.getContext();
583     RewritePatternSet patterns(context);
584     tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
585     tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
586     tileAndFuseOptions.cleanupPatterns = std::move(patterns);
587   }
588 
589   LogicalResult result = applyTilingToAll(
590       rewriter, getOperation(), state.getPayloadOps(getTarget()),
591       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
592       [&](TilingInterface tilingInterfaceOp)
593           -> FailureOr<scf::SCFTileAndFuseResult> {
594         return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
595                                                     tileAndFuseOptions);
596       });
597   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
598                         : DiagnosedSilenceableFailure::success();
599 }
600 
601 LogicalResult transform::FuseOp::verify() {
602   SmallVector<int64_t> permutation =
603       extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
604   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
605   if (!std::is_permutation(sequence.begin(), sequence.end(),
606                            permutation.begin(), permutation.end())) {
607     return emitOpError() << "expects interchange to be a permutation, found "
608                          << getTileInterchange();
609   }
610 
611   SmallVector<int64_t> sizes =
612       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
613   size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
614   if (numExpectedLoops != getNumResults() - 1)
615     return emitOpError() << "expects " << numExpectedLoops << " loop results";
616 
617   return success();
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // FuseIntoContainingOp
622 //===----------------------------------------------------------------------===//
623 
624 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
625                                             OperationState &result,
626                                             Value producerOp,
627                                             Value containingOp) {
628   result.addOperands({producerOp, containingOp});
629   auto resultType = transform::AnyOpType::get(builder.getContext());
630   result.addTypes({resultType, resultType});
631 }
632 
633 /// Add new operands to the forall op for users of the producerOp
634 /// that are dominated by the containing scf.forall op.
635 static Operation *replaceForAllWithNewSignature(
636     RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
637     Operation *containingOp, TilingResult &tileAndFuseResult,
638     int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
639     SmallVector<OpFoldResult> &sizes) {
640 
641   // Count number of users not including the containing op
642   SetVector<Operation *> dominatedUsers;
643   DominanceInfo domInfo(containingOp);
644   for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
645     if (!containingOp->isAncestor(user) &&
646         (domInfo.dominates(containingOp, user))) {
647       dominatedUsers.insert(user);
648     }
649   }
650   if (dominatedUsers.empty())
651     return nullptr;
652 
653   // Create new scf.forall op
654   auto forallOp = cast<scf::ForallOp>(containingOp);
655   OpBuilder::InsertionGuard g(rewriter);
656   rewriter.setInsertionPoint(forallOp);
657 
658   // Get new output
659   Location loc = forallOp.getLoc();
660   auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
661   if (!genericOp)
662     return nullptr;
663   SmallVector<Value> outputs = genericOp.getOutputs();
664   SmallVector<Value> newOuts(forallOp.getOutputs());
665   newOuts.push_back(outputs[resultNumber]);
666 
667   // Create new scf.forall op
668   auto newforallOp = rewriter.create<scf::ForallOp>(
669       loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
670       forallOp.getMixedStep(), newOuts, forallOp.getMapping());
671   rewriter.eraseBlock(newforallOp.getBody());
672   newforallOp.getRegion().takeBody(forallOp.getRegion());
673 
674   // Add additional block argument for new value being returned
675   // and replaces all uses of the new output with corresponding bbArg
676   // inside the scf.forall to enable fusion into this new scf.forall.
677   newforallOp.getBody()->addArgument(newOuts.back().getType(),
678                                      newOuts.back().getLoc());
679   auto bbArgs = newforallOp.getBody()->getArguments();
680   rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
681                              [&](OpOperand &use) {
682                                Operation *op = use.getOwner();
683                                return newforallOp->isProperAncestor(op);
684                              });
685 
686   // Fix terminator
687   scf::InParallelOp terminatorOp = newforallOp.getTerminator();
688   SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
689       terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
690   Operation *firstYieldOp = yieldingOps.front();
691   rewriter.setInsertionPoint(firstYieldOp);
692   Value src = tileAndFuseResult.tiledValues[0];
693   Value dst = newforallOp.getRegionIterArgs().back();
694   SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
695   rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
696                                                  dst, offsets, sizes, strides);
697 
698   for (auto result : llvm::enumerate(forallOp.getResults())) {
699     rewriter.replaceAllUsesWith(result.value(),
700                                 newforallOp->getResult(result.index()));
701   }
702   rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
703                              newforallOp->getResults().back(),
704                              [&](OpOperand &use) {
705                                Operation *user = use.getOwner();
706                                return dominatedUsers.contains(user);
707                              });
708   return newforallOp;
709 }
710 
711 /// Find the first "extract" user of `producerOp` and tile it right before its
712 /// use. The tiled op is fused under the `containingOp`.
713 /// Return this fused op on success or nullptr if anything fails.
714 /// If tiled op has uses that are dominated by `containingOp`, return
715 /// a new `containingOp` with results of the fused op appended to
716 /// results of the `containingOp` or nullptr if there are no dominated uses.
717 static std::tuple<SmallVector<Operation *>, Operation *>
718 tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
719                            Operation *producerOp, Operation *containingOp) {
720   LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
721   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
722   if (!tileableProducer) {
723     diag.attachNote(producerOp->getLoc())
724         << "producer is not a TileableInterface: " << *producerOp;
725     return {};
726   }
727 
728   // Search the producer slices accessed within the containing operation.
729   // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
730   // evolve into an interface.
731   auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
732     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
733     return sliceOp && containingOp->isProperAncestor(sliceOp);
734   });
735 
736   // Find a fusion opportunity.
737   if (it == tileableProducer->getUsers().end()) {
738     diag.attachNote(tileableProducer->getLoc())
739         << "could not find fusion opportunity for: " << *tileableProducer;
740     return {};
741   }
742   auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
743 
744   // Try to fuse the producer in-place.
745   OpBuilder::InsertionGuard guard(rewriter);
746   rewriter.setInsertionPoint(sliceOpToTile);
747 
748   // Tile the producer.
749   int64_t resultNumber =
750       cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
751   LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
752 
753   SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
754   SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
755 
756   FailureOr<TilingResult> tileAndFuseResult =
757       tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
758                                                sizes);
759 
760   if (failed(tileAndFuseResult)) {
761     diag.attachNote(tileableProducer->getLoc())
762         << "failed to tile producer op: " << *tileableProducer;
763     return {};
764   }
765 
766 #ifndef NDEBUG
767   for (auto *tiledOp : tileAndFuseResult->tiledOps) {
768     LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
769   }
770 #endif
771 
772   // Replace the extract op.
773   auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
774       rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
775       cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
776   if (failed(maybeRankReduced)) {
777     diag.attachNote(producerOp->getLoc())
778         << "shape types don't match (missing canonicalization?):\nTiledOp: "
779         << tileAndFuseResult->tiledValues[0]
780         << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
781     return {};
782   }
783   rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
784 
785   // Add new outputs to containing op, if required
786   Operation *newContainingOp = replaceForAllWithNewSignature(
787       rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
788       resultNumber, offsets, sizes);
789 
790   return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
791 }
792 
793 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
794 /// it is exactly the `containingOp`, otherwise bail.
795 /// Then, find the first "extract" user of the tied block argument and tile it
796 /// right before its "extract" use. The tiled op is fused under the
797 /// `containingOp`.
798 /// Return this fused op on success or nullptr if anything fails.
799 static SmallVector<Operation *>
800 tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
801     RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
802     Operation *containingOp) {
803   LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
804 
805   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
806   if (!tileableProducer) {
807     diag.attachNote(producerOp->getLoc())
808         << "producer is not a TileableInterface: " << *producerOp;
809     return {};
810   }
811 
812   // Search the first use by a "scf::ForallOp" user.
813   scf::ForallOp forallOp;
814   auto itProducerUses =
815       llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
816         forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
817         return forallOp;
818       });
819   // If it's not from the containing op, return.
820   if (!forallOp || forallOp != containingOp) {
821     diag.attachNote(tileableProducer->getLoc())
822         << "could not find a use by the containing op: " << *tileableProducer;
823     return {};
824   }
825 
826   // Search the producer slices accessed within the containing
827   // operation.
828   // TODO: Generalize to more extract/insert/parallel_insert triples.
829   //   Maybe evolve into an interface.
830   OpOperand *pUse = &(*itProducerUses);
831   BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
832 
833   // Search the producer slices accessed within the containing operation.
834   // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
835   // evolve into an interface.
836   auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
837     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
838     return sliceOp && containingOp->isProperAncestor(sliceOp);
839   });
840 
841   // Find a fusion opportunity.
842   if (itBBArgUsers == bbArg.getUsers().end()) {
843     diag.attachNote(containingOp->getLoc())
844         << "could not find fusion opportunity for bbArg: " << bbArg;
845     return {};
846   }
847   auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
848 
849   // Try to fuse the producer in-place.
850   OpBuilder::InsertionGuard guard(rewriter);
851   rewriter.setInsertionPoint(sliceOpToTile);
852 
853   // Replace the use in the tileableProducer before tiling: clone, replace and
854   // then tile.
855   int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
856   LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
857 
858   // Gather destination tensors.
859   SmallVector<Value> destinationTensors;
860   if (failed(tensor::getOrCreateDestinations(
861           rewriter, tileableProducer->getLoc(), tileableProducer,
862           destinationTensors))) {
863     diag.attachNote(tileableProducer->getLoc())
864         << "failed to get destination tensors for: " << *tileableProducer;
865     return {};
866   }
867 
868   IRMapping bvm;
869   bvm.map(destinationTensors[resultNumber], bbArg);
870   auto tileableProducerClone =
871       cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
872   auto scopeGuard =
873       llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
874 
875   // Tile the producer.
876   FailureOr<TilingResult> tileAndFuseResult =
877       tileableProducerClone.generateResultTileValue(
878           rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
879           sliceOpToTile.getMixedSizes());
880   if (failed(tileAndFuseResult)) {
881     diag.attachNote(tileableProducer->getLoc())
882         << "failed to tile producer op: " << *tileableProducer;
883     return {};
884   }
885 
886   // Replace the extract op.
887   auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
888       rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
889       cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
890   assert(succeeded(maybeRankReduced) && "unexpected shape");
891   rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
892 
893   // Replace the use in containingOp.
894   rewriter.modifyOpInPlace(containingOp, [&]() {
895     containingOp->setOperand(pUse->getOperandNumber(),
896                              destinationTensors.front());
897   });
898 
899   return tileAndFuseResult->tiledOps;
900 }
901 
902 static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
903                                        Operation *producerOp,
904                                        Operation *containingOp) {
905   LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
906 
907   // Gather all uses inside the containing op.
908   SmallVector<OpOperand *> uses;
909   for (OpResult result : producerOp->getOpResults()) {
910     for (OpOperand &use : result.getUses()) {
911       if (containingOp->isProperAncestor(use.getOwner())) {
912         uses.push_back(&use);
913         continue;
914       }
915       // Cannot clone and fuse if the use is by the containing op itself: fail
916       // immediately.
917       if (containingOp == use.getOwner()) {
918         diag.attachNote(producerOp->getLoc())
919             << "producer op use by containing op cannot be fused by cloning";
920         return nullptr;
921       }
922     }
923   }
924 
925   // Check for a non-empty list of fusion opportunities.
926   if (uses.empty()) {
927     diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
928     return nullptr;
929   }
930 
931   // Clone and fuse inside the containing op.
932   Operation *fusedOp = nullptr;
933   OpOperand *use = uses.front();
934   // Parallel insert slice is not a valid clone destination.
935   // TODO: Generalize to other type of ops.
936   assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
937          "Parallel insert slice is not a valid clone destination");
938   unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
939   LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
940 
941   OpBuilder::InsertionGuard guard(rewriter);
942   rewriter.setInsertionPoint(use->getOwner());
943   fusedOp = rewriter.clone(*producerOp);
944   rewriter.modifyOpInPlace(
945       use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
946 
947   return fusedOp;
948 }
949 
950 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
951   // Allow repeated handles since we are fusing everything anyway.
952   return true;
953 }
954 
955 DiagnosedSilenceableFailure
956 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
957                                        transform::TransformResults &results,
958                                        transform::TransformState &state) {
959   SmallVector<Operation *> fusedOps;
960   auto producerOps = state.getPayloadOps(getProducerOp());
961   auto containingOps = state.getPayloadOps(getContainingOp());
962   if (!llvm::hasSingleElement(containingOps)) {
963     return emitDefiniteFailure()
964            << "requires exactly one containing_op handle (got "
965            << llvm::range_size(containingOps) << ")";
966   }
967   Operation *containingOp = *containingOps.begin();
968 
969   // If nothing to fuse, propagate success.
970   if (std::empty(producerOps)) {
971     results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
972     results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
973     return DiagnosedSilenceableFailure::success();
974   }
975 
976   // Helper function to find the next producer that should be fused. Take any
977   // producer that has a use inside the containing op.
978   SetVector<Operation *> remainingProducers(producerOps.begin(),
979                                             producerOps.end());
980   auto getNextProducer = [&]() -> FailureOr<Operation *> {
981     for (const auto &it : enumerate(remainingProducers)) {
982       Operation *producerOp = it.value();
983       // The containing op may be a user of producerOp: use isAncestor.
984       int64_t numUsesInContainingOp =
985           llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
986             return containingOp->isAncestor(op);
987           });
988       // TODO: When resolving the TODO below (no duplicate ops), take an op
989       // that has no use among the remaining producers. This is a topological
990       // sorting.
991       if (numUsesInContainingOp > 0) {
992         if (numUsesInContainingOp == 1)
993           remainingProducers.erase(remainingProducers.begin() + it.index());
994         return producerOp;
995       }
996     }
997     return failure();
998   };
999 
1000   while (!remainingProducers.empty()) {
1001     auto nextProducer = getNextProducer();
1002     if (failed(nextProducer)) {
1003       auto diag = mlir::emitSilenceableFailure(getLoc())
1004                   << "could not find next producer to fuse into container";
1005       diag.attachNote(containingOp->getLoc()) << "containing op";
1006       return diag;
1007     }
1008 
1009     Operation *producerOp = *nextProducer;
1010 
1011     // Default diagnostic, to be complemented with more failure information.
1012     Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
1013     diag << "could not fuse " << *producerOp << " into " << *containingOp;
1014 
1015     // TODO: If there are multiple uses of the producer in the containing op,
1016     // we currently tile/clone the op multiple times (once per use). In some
1017     // cases, we can tile/clone once and reuse the value for each use.
1018     // Futhermore, producers should then be traversed according to a
1019     // topological sorting.
1020     auto [tiledOps, newContainingOp] =
1021         tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1022     if (!tiledOps.empty()) {
1023       LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1024       fusedOps.append(tiledOps);
1025       if (newContainingOp) {
1026         // Update handles associated with the containing op so we don't need to
1027         // invalidate them. This is a hack to support better composability
1028         // between tiling and fusion while a proper mechanism is being
1029         // investigated.
1030         //
1031         // DO NOT replicate this elsewhere unless you understand what you are
1032         // doing.
1033         LogicalResult replacementStatus =
1034             rewriter.notifyPayloadOperationReplaced(containingOp,
1035                                                     newContainingOp);
1036         (void)replacementStatus;
1037         assert(succeeded(replacementStatus) &&
1038                "unable to update transform state mapping");
1039         rewriter.eraseOp(containingOp);
1040         containingOp = newContainingOp;
1041       }
1042       continue;
1043     }
1044 
1045     SmallVector<Operation *> tiledContainingOpOperand =
1046         tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
1047             rewriter, diag, producerOp, containingOp);
1048     if (!tiledContainingOpOperand.empty()) {
1049       LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1050                         << *containingOp);
1051       fusedOps.append(tiledContainingOpOperand);
1052       continue;
1053     }
1054 
1055     Operation *cloned =
1056         cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1057     if (cloned) {
1058       LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1059       fusedOps.push_back(cloned);
1060       continue;
1061     }
1062     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
1063   }
1064 
1065   results.set(cast<OpResult>(getFusedOp()), fusedOps);
1066   results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1067   return DiagnosedSilenceableFailure::success();
1068 }
1069 
1070 void transform::FuseIntoContainingOp::getEffects(
1071     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1072   consumesHandle(getProducerOpMutable(), effects);
1073   onlyReadsHandle(getContainingOpMutable(), effects);
1074   producesHandle(getOperation()->getOpResults(), effects);
1075   modifiesPayload(effects);
1076 }
1077 
1078 //===----------------------------------------------------------------------===//
1079 // GeneralizeOp
1080 //===----------------------------------------------------------------------===//
1081 
1082 DiagnosedSilenceableFailure
1083 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1084                                     LinalgOp target,
1085                                     transform::ApplyToEachResultList &results,
1086                                     transform::TransformState &state) {
1087   // Exit early if no transformation is needed.
1088   if (isa<GenericOp>(target)) {
1089     results.push_back(target);
1090     return DiagnosedSilenceableFailure::success();
1091   }
1092   rewriter.setInsertionPoint(target);
1093   FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1094   if (succeeded(generic)) {
1095     results.push_back(generic->getOperation());
1096     return DiagnosedSilenceableFailure::success();
1097   }
1098   return emitDefaultSilenceableFailure(target);
1099 }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // SpecializeOp
1103 //===----------------------------------------------------------------------===/
1104 
1105 DiagnosedSilenceableFailure
1106 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1107                                     LinalgOp target,
1108                                     transform::ApplyToEachResultList &results,
1109                                     transform::TransformState &state) {
1110   // Exit early if the operation is not a generic.
1111   if (!isa<GenericOp>(target)) {
1112     results.push_back(target);
1113     return DiagnosedSilenceableFailure::success();
1114   }
1115   rewriter.setInsertionPoint(target);
1116   FailureOr<LinalgOp> named =
1117       specializeGenericOp(rewriter, cast<GenericOp>(target));
1118   if (succeeded(named)) {
1119     results.push_back(named->getOperation());
1120     return DiagnosedSilenceableFailure::success();
1121   }
1122   return emitDefaultSilenceableFailure(target);
1123 }
1124 
1125 //===----------------------------------------------------------------------===//
1126 // InterchangeOp
1127 //===----------------------------------------------------------------------===//
1128 
1129 DiagnosedSilenceableFailure
1130 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1131                                      GenericOp target,
1132                                      transform::ApplyToEachResultList &results,
1133                                      transform::TransformState &state) {
1134   ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1135   // Exit early if no transformation is needed.
1136   if (interchangeVector.empty()) {
1137     results.push_back(target);
1138     return DiagnosedSilenceableFailure::success();
1139   }
1140 
1141   unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1142   if (interchangeVector.size() != numLoops) {
1143     return emitSilenceableError()
1144            << getIteratorInterchangeAttrName() << " has length ("
1145            << interchangeVector.size()
1146            << ") different from the number of loops in the target operation ("
1147            << numLoops << ")";
1148   }
1149   FailureOr<GenericOp> res = interchangeGenericOp(
1150       rewriter, target, SmallVector<unsigned>(interchangeVector));
1151   if (failed(res))
1152     return emitDefiniteFailure() << "failed to apply";
1153   results.push_back(res->getOperation());
1154   return DiagnosedSilenceableFailure::success();
1155 }
1156 
1157 LogicalResult transform::InterchangeOp::verify() {
1158   ArrayRef<int64_t> permutation = getIteratorInterchange();
1159   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1160   if (!std::is_permutation(sequence.begin(), sequence.end(),
1161                            permutation.begin(), permutation.end())) {
1162     return emitOpError()
1163            << "expects iterator_interchange to be a permutation, found "
1164            << getIteratorInterchange();
1165   }
1166   return success();
1167 }
1168 
1169 //===----------------------------------------------------------------------===//
1170 // LowerPackOp
1171 //===----------------------------------------------------------------------===//
1172 
1173 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1174     transform::TransformRewriter &rewriter, tensor::PackOp target,
1175     transform::ApplyToEachResultList &transformResults,
1176     transform::TransformState &state) {
1177   rewriter.setInsertionPoint(target);
1178   bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1179   FailureOr<LowerPackResult> res =
1180       lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1181   if (failed(res)) {
1182     return mlir::emitSilenceableFailure(target->getLoc())
1183            << "cannot lower to pad + expand + transpose";
1184   }
1185   transformResults.push_back(res->padOp);
1186   transformResults.push_back(res->expandShapeOp);
1187   transformResults.push_back(res->transposeOp);
1188   return DiagnosedSilenceableFailure::success();
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // LowerUnPackOp
1193 //===----------------------------------------------------------------------===//
1194 
1195 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1196     transform::TransformRewriter &rewriter, tensor::UnPackOp target,
1197     transform::ApplyToEachResultList &transformResults,
1198     transform::TransformState &state) {
1199   rewriter.setInsertionPoint(target);
1200   bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1201   FailureOr<LowerUnPackOpResult> res =
1202       lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1203   if (failed(res)) {
1204     DiagnosedSilenceableFailure diag =
1205         emitSilenceableError()
1206         << "cannot lower to transpose + collapse + extract";
1207     diag.attachNote(target->getLoc()) << "target payload op";
1208     return diag;
1209   }
1210   transformResults.push_back(res->emptyOp);
1211   transformResults.push_back(res->transposeOp);
1212   transformResults.push_back(res->collapseShapeOp);
1213   transformResults.push_back(res->extractSliceOp);
1214   return DiagnosedSilenceableFailure::success();
1215 }
1216 
1217 //===---------------------------------------------------------------------===//
1218 // MatchOp
1219 //===---------------------------------------------------------------------===//
1220 
1221 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1222                                Value target, ArrayRef<StringRef> opNames) {
1223   result.addOperands(target);
1224   result.addAttribute(MatchOp::getOpsAttrName(result.name),
1225                       builder.getStrArrayAttr(opNames));
1226   result.addTypes(transform::AnyOpType::get(builder.getContext()));
1227 }
1228 
1229 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1230                                TypeRange resultTypes, Value target,
1231                                ArrayRef<StringRef> opNames) {
1232   result.addOperands(target);
1233   result.addAttribute(MatchOp::getOpsAttrName(result.name),
1234                       builder.getStrArrayAttr(opNames));
1235   result.addTypes(resultTypes);
1236 }
1237 
1238 DiagnosedSilenceableFailure
1239 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1240                           transform::TransformResults &results,
1241                           transform::TransformState &state) {
1242   llvm::StringSet<> strs;
1243   if (getOps().has_value())
1244     strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1245                 getOps()->getAsValueRange<StringAttr>().end());
1246 
1247   auto payloadOps = state.getPayloadOps(getTarget());
1248   if (!llvm::hasSingleElement(payloadOps)) {
1249     return emitDefiniteFailure("requires exactly one target handle");
1250   }
1251 
1252   SmallVector<Operation *> res;
1253   bool incorrectNumOperandTypes = false;
1254   auto matchFun = [&](Operation *op) {
1255     if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1256       return;
1257 
1258     // Interfaces cannot be matched by name, just by ID.
1259     // So we specifically encode the interfaces we care about for this op.
1260     if (getInterface().has_value()) {
1261       auto iface = getInterface().value();
1262       if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1263           !isa<LinalgOp>(op))
1264         return;
1265       if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1266           !isa<TilingInterface>(op))
1267         return;
1268       if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1269           !isa<LoopLikeOpInterface>(op))
1270         return;
1271     }
1272 
1273     // Check if all specified attributes match.
1274     if (getOpAttrs().has_value()) {
1275       DictionaryAttr opAttrs = getOpAttrs().value();
1276       for (NamedAttribute attr : opAttrs) {
1277         if (attr.getName() == getInterfaceAttrName() ||
1278             attr.getName() == getOpsAttrName())
1279           continue;
1280         if (!op->hasAttr(attr.getName()))
1281           return;
1282         if (op->getAttr(attr.getName()) != attr.getValue())
1283           return;
1284       }
1285     }
1286 
1287     if (getFilterResultType().has_value()) {
1288       Type t = getFilterResultType().value();
1289       if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1290         return;
1291     }
1292 
1293     if (getFilterOperandTypes().has_value()) {
1294       mlir::ArrayAttr types = getFilterOperandTypes().value();
1295       auto operandTypes = op->getOperandTypes();
1296 
1297       if (types.size() == 1) {
1298         // All the operands must must be equal to the specified type
1299         auto typeattr =
1300             dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1301         Type t = cast<::mlir::Type>(typeattr.getValue());
1302         if (!llvm::all_of(op->getOperandTypes(),
1303                           [&](Type operandType) { return operandType == t; }))
1304           return;
1305       } else {
1306         // The operand types must match all the types in the list (in the same
1307         // order in with they are specified)
1308         if (types.size() != operandTypes.size()) {
1309           incorrectNumOperandTypes = true;
1310           return;
1311         }
1312 
1313         for (auto [attr, operandType] :
1314              llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1315           auto typeattr = cast<mlir::TypeAttr>(attr);
1316           Type type = cast<::mlir::Type>(typeattr.getValue());
1317 
1318           if (type != operandType)
1319             return;
1320         }
1321       }
1322     }
1323 
1324     // All constraints are satisfied.
1325     res.push_back(op);
1326     return;
1327   };
1328 
1329   (*payloadOps.begin())->walk(matchFun);
1330   if (incorrectNumOperandTypes)
1331     return emitDefiniteFailure("If filter_operand_types contains more than a "
1332                                "type, then it must contain as much types as "
1333                                "the number of operands in the target ops");
1334   results.set(cast<OpResult>(getResult()), res);
1335   return DiagnosedSilenceableFailure::success();
1336 }
1337 
1338 //===---------------------------------------------------------------------===//
1339 // MultiTileSizesOp
1340 //===---------------------------------------------------------------------===//
1341 
1342 static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op,
1343                                      Type targetType, Type lowSizeType, Type,
1344                                      Type) {
1345   printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1346 }
1347 
1348 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1349                                             Type &targetType, Type &lowSizeType,
1350                                             Type &highSizeType,
1351                                             Type &splitPointType) {
1352   FunctionType funcType;
1353   llvm::SMLoc typeLoc = parser.getCurrentLocation();
1354   if (failed(parser.parseType<FunctionType>(funcType)))
1355     return failure();
1356 
1357   if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1358     parser.emitError(typeLoc) << "expects a trailing functional type with one "
1359                                  "argument and one result";
1360   }
1361   targetType = funcType.getInput(0);
1362   lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1363 
1364   return success();
1365 }
1366 
1367 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1368     transform::TransformRewriter &rewriter, LinalgOp target,
1369     transform::ApplyToEachResultList &results, TransformState &state) {
1370   if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1371     if (target.hasDynamicShape()) {
1372       auto diag = emitSilenceableError()
1373                   << "cannot compute parametric tile sizes for dynamically "
1374                      "shaped payload op";
1375       diag.attachNote(target->getLoc()) << "payload op";
1376       return diag;
1377     }
1378 
1379     FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1380         target, getDimension(), getTargetSize(), getDivisor());
1381     if (failed(spec)) {
1382       return emitSilenceableError()
1383              << "failed to compute multi-size tiling sizes";
1384     }
1385 
1386     Builder builder(target.getContext());
1387     results.assign(llvm::map_range(
1388         ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1389                            spec->lowTileSize * spec->lowTripCount}),
1390         [&builder, this](int64_t value) {
1391           return builder.getIntegerAttr(
1392               cast<ParamType>(getLowSize().getType()).getType(), value);
1393         }));
1394     return DiagnosedSilenceableFailure::success();
1395   }
1396 
1397   OpBuilder builder(target.getContext());
1398   builder.setInsertionPoint(target);
1399   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1400   OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1401   FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1402       builder, target, getDimension(), targetSize, divisor);
1403   if (failed(spec)) {
1404     return emitSilenceableError() << "could not generate tile size computation";
1405   }
1406 
1407   AffineExpr s0 = builder.getAffineSymbolExpr(0);
1408   AffineExpr s1 = builder.getAffineSymbolExpr(1);
1409   Operation *splitPoint =
1410       affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1411                                       {spec->lowTileSize, spec->lowTripCount});
1412   Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1413   Operation *highTileSize = spec->highTileSize.getDefiningOp();
1414   assert(lowTileSize && highTileSize && splitPoint &&
1415          "tile sizes are not produced by operations");
1416   results.reserve(results.size() + 3);
1417   results.push_back(lowTileSize);
1418   results.push_back(highTileSize);
1419   results.push_back(splitPoint);
1420   return DiagnosedSilenceableFailure::success();
1421 }
1422 
1423 void transform::MultiTileSizesOp::getEffects(
1424     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1425   onlyReadsHandle(getTargetMutable(), effects);
1426   producesHandle(getOperation()->getOpResults(), effects);
1427   if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1428     onlyReadsPayload(effects);
1429   else
1430     modifiesPayload(effects);
1431 }
1432 
1433 LogicalResult transform::MultiTileSizesOp::verify() {
1434   if (getLowSize().getType() != getHighSize().getType() ||
1435       getLowSize().getType() != getSplitPoint().getType()) {
1436     return emitOpError() << "expects all results type to be the same";
1437   }
1438   return success();
1439 }
1440 
1441 //===---------------------------------------------------------------------===//
1442 // PackOp
1443 //===---------------------------------------------------------------------===//
1444 
1445 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1446                               Value target,
1447                               ArrayRef<OpFoldResult> mixedPackedSizes) {
1448   SmallVector<int64_t> staticPackedSizes;
1449   SmallVector<Value> dynamicPackedSizes;
1450   dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1451                              staticPackedSizes);
1452   // Call the default builder which sets up the proper operands segment sizes
1453   // attributes for multiple variadic operands. In the absence of this, horrible
1454   // bugs ensue.
1455   Type linalgOpHType = transform::OperationType::get(
1456       builder.getContext(), GenericOp::getOperationName());
1457   build(builder, result,
1458         /*resultType=*/linalgOpHType,
1459         /*target=*/target,
1460         /*dynamic_sizes=*/dynamicPackedSizes,
1461         /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1462 }
1463 
1464 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1465   Builder b(getContext());
1466   return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1467 }
1468 
1469 DiagnosedSilenceableFailure
1470 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1471                          transform::TransformResults &transformResults,
1472                          transform::TransformState &state) {
1473   auto targetOps = state.getPayloadOps(getTarget());
1474   // If nothing to pack, propagate success.
1475   if (std::empty(targetOps)) {
1476     transformResults.set(cast<OpResult>(getPackedOp()),
1477                          ArrayRef<Operation *>({}));
1478     return DiagnosedSilenceableFailure::success();
1479   }
1480   // Fail on multi-op handles.
1481   auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1482   if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1483     return emitSilenceableError()
1484            << "requires target to map to exactly 1 LinalgOp (got "
1485            << llvm::range_size(targetOps) << ")";
1486   }
1487   // Fail on mismatched number of pack sizes.
1488   if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1489     return emitSilenceableError()
1490            << "requires number of packed sizes match the number of loops ("
1491            << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1492            << ")";
1493   }
1494 
1495   // Unpack handles to constants or actual SSA index values.
1496   SmallVector<OpFoldResult> packedSizes;
1497   DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations(
1498       state, *this, packedSizes, getMixedPackedSizes());
1499 
1500   rewriter.setInsertionPoint(linalgOp);
1501   FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1502   if (failed(maybeResult))
1503     return emitDefiniteFailure("data tiling failed");
1504 
1505   transformResults.set(cast<OpResult>(getPackedOp()),
1506                        {maybeResult->packedLinalgOp.getOperation()});
1507   return DiagnosedSilenceableFailure::success();
1508 }
1509 
1510 void transform::PackOp::getEffects(
1511     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1512   transform::consumesHandle(getTargetMutable(), effects);
1513   transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1514   transform::producesHandle(getOperation()->getOpResults(), effects);
1515   transform::modifiesPayload(effects);
1516 }
1517 
1518 //===---------------------------------------------------------------------===//
1519 // PackGreedilyOp.
1520 //===---------------------------------------------------------------------===//
1521 
1522 LogicalResult transform::PackGreedilyOp::verify() {
1523   if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1524     return emitOpError() << getMatmulInnerDimsOrderAttrName()
1525                          << " is not a valid permutation";
1526   }
1527   // TODO: relax to allow empty once we have another strategy than just matmul.
1528   if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1529     for (auto [s, nmo] :
1530          llvm::zip_equal(getMixedMatmulPackedSizes(),
1531                          getMatmulPaddedSizesNextMultipleOf())) {
1532       std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1533       if (nmo != 0 &&
1534           (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1535         return emitOpError() << "at most one of the packed_size and the "
1536                                 "padded_sizes_next_multiple_of can be nonzero "
1537                                 "for the matmul strategy";
1538       }
1539     }
1540   }
1541   return success();
1542 }
1543 
1544 DiagnosedSilenceableFailure
1545 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1546                       transform::TransformResults &transformResults,
1547                       transform::TransformState &state) {
1548   SmallVector<Operation *> results;
1549   for (Operation *op : state.getPayloadOps(getTarget())) {
1550     auto linalgOp = dyn_cast<LinalgOp>(op);
1551     if (!linalgOp)
1552       continue;
1553     // linalgOp will be replaced and the insertion point may be invalidated if
1554     // we set it before -> set it after.
1555     rewriter.setInsertionPointAfter(linalgOp);
1556     // Failing to pack greedily is perfectly fine.
1557     // In the future we will want to order packings according to some metric.
1558     FailureOr<PackResult> packResult = packMatmulGreedily(
1559         /*rewriter=*/rewriter,
1560         /*linalgOp=*/linalgOp,
1561         /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1562         /*mnkPaddedSizesNextMultipleOf=*/
1563         getMatmulPaddedSizesNextMultipleOf(),
1564         /*mnkOrder=*/getMatmulInnerDimsOrder());
1565     if (succeeded(packResult)) {
1566       results.push_back(packResult->packedLinalgOp);
1567       continue;
1568     }
1569     results.push_back(linalgOp);
1570   }
1571   transformResults.set(cast<OpResult>(getPackedOp()), results);
1572   return DiagnosedSilenceableFailure::success();
1573 }
1574 
1575 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1576   Builder b(getContext());
1577   return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1578                         b);
1579 }
1580 
1581 void transform::PackGreedilyOp::getEffects(
1582     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1583   transform::consumesHandle(getTargetMutable(), effects);
1584   transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1585   transform::producesHandle(getOperation()->getOpResults(), effects);
1586   transform::modifiesPayload(effects);
1587 }
1588 
1589 //===---------------------------------------------------------------------===//
1590 // PackTransposeOp
1591 //===---------------------------------------------------------------------===//
1592 
1593 LogicalResult transform::PackTransposeOp::verify() {
1594   if (!isPermutationVector(getInnerPerm())) {
1595     return emitOpError() << getInnerPermAttrName()
1596                          << " is not a valid permutation";
1597   }
1598   if (!isPermutationVector(getOuterPerm())) {
1599     return emitOpError() << getOuterPermAttrName()
1600                          << " is not a valid permutation";
1601   }
1602   if (getInnerPerm().empty() && getOuterPerm().empty()) {
1603     return emitOpError() << " at least one of " << getInnerPermAttrName()
1604                          << " or " << getOuterPermAttrName()
1605                          << " must be specified";
1606   }
1607   return success();
1608 }
1609 
1610 namespace {
1611 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1612 } // namespace
1613 
1614 /// Return true if `permutation` is a valid permutation of the
1615 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1616 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1617 /// This is the case when the `permutation` rank matches the rank expected by
1618 /// `op` and `permutation` is itself a permutation vector.
1619 /// Return true if either `op` or `permutation` are empty to allow a simpler
1620 /// polymorphic implementation.
1621 template <typename RelayoutOpTy>
1622 bool isValidPackingPermutation(
1623     RelayoutOpTy op, ArrayRef<int64_t> permutation,
1624     OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1625   static_assert(
1626       llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1627       "applies to only pack or unpack operations");
1628   if (!op || permutation.empty())
1629     return true;
1630   size_t innerRank = op.getInnerDimsPos().size();
1631   if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1632     return permutation.size() == innerRank && isPermutationVector(permutation);
1633   // op.getOuterDimsPerm() may be empty, in which case it is identity.
1634   // Don't rely on it.
1635   if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1636     return permutation.size() == op.getSourceRank() &&
1637            isPermutationVector(permutation);
1638   }
1639   return permutation.size() == op.getDestRank() &&
1640          isPermutationVector(permutation);
1641 }
1642 
1643 DiagnosedSilenceableFailure
1644 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1645                                   transform::TransformResults &transformResults,
1646                                   transform::TransformState &state) {
1647   auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1648   auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1649   // Step 1. If nothing to pack, propagate success.
1650   if (std::empty(packOrUnpackOps)) {
1651     transformResults.set(cast<OpResult>(getPackedOp()), {});
1652     transformResults.set(cast<OpResult>(getPackOp()), {});
1653     transformResults.set(cast<OpResult>(getUnPackOp()), {});
1654     return DiagnosedSilenceableFailure::success();
1655   }
1656 
1657   // Step 2. Bunch of runtime sanity check and error messages.
1658   // Step 2.1. Fail on multi-op handles.
1659   if (!llvm::hasSingleElement(packOrUnpackOps) ||
1660       !llvm::hasSingleElement(linalgOps)) {
1661     return emitSilenceableError()
1662            << "requires target to map to exactly 1 "
1663               "packing op and 1 packed op ("
1664            << "got " << llvm::range_size(packOrUnpackOps) << " and "
1665            << llvm::range_size(linalgOps) << ")";
1666   }
1667 
1668   // Step 2.2. Fail on wrong type.
1669   auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1670   auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1671   if ((!packOp && !unPackOp)) {
1672     return emitSilenceableError() << "requires target to map to a "
1673                                      "tensor.pack or tensor.unpack";
1674   }
1675   LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1676   if (!linalgOpTarget)
1677     return emitSilenceableError() << "requires a LinalgOp target";
1678 
1679   // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1680   LinalgOp linalgOp;
1681   if (packOp && packOp.getResult().hasOneUse())
1682     linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1683   else if (unPackOp)
1684     linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1685   if (linalgOp != linalgOpTarget) {
1686     auto errorMsg =
1687         packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1688                : StringLiteral{"not produced by the LinalgOp target"};
1689     return emitSilenceableError() << errorMsg;
1690   }
1691 
1692   // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1693   // PackOp.
1694   if (unPackOp) {
1695     assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1696     OpOperand *packUse = linalgOp.getDpsInitOperand(
1697         cast<OpResult>(unPackOp.getSource()).getResultNumber());
1698     packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
1699     if (!packOp || !packOp.getResult().hasOneUse())
1700       return emitSilenceableError() << "could not find matching pack op";
1701   }
1702 
1703   // Step 2.5. Fail if any permutation does not validate.
1704   for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1705     ArrayRef<int64_t> perm =
1706         (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1707     auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1708                         ? StringLiteral{"invalid outer_perm"}
1709                         : StringLiteral{"invalid inner_perm"};
1710     if (!isValidPackingPermutation(packOp, perm, permType) ||
1711         !isValidPackingPermutation(unPackOp, perm, permType)) {
1712       Operation *packOrUnpackOp =
1713           unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1714       return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1715     }
1716   }
1717 
1718   // From here on, packOp and linalgOp are always present, unPackOp may or may
1719   // not be present.
1720   assert(packOp && linalgOp && "unexpected null op");
1721 
1722   // Step 3. Actually transpose the ops.
1723   FailureOr<PackTransposeResult> res = packTranspose(
1724       rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1725   // Preconditions have been checked, it is an error to fail here.
1726   assert(succeeded(res) && "unexpected packTranspose failure");
1727 
1728   // Step 4. Return results.
1729   transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1730   transformResults.set(cast<OpResult>(getPackedOp()),
1731                        {res->transposedLinalgOp});
1732   if (unPackOp) {
1733     transformResults.set(cast<OpResult>(getUnPackOp()),
1734                          {res->transposedUnPackOp});
1735   } else {
1736     transformResults.set(cast<OpResult>(getUnPackOp()), {});
1737   }
1738 
1739   return DiagnosedSilenceableFailure::success();
1740 }
1741 
1742 //===---------------------------------------------------------------------===//
1743 // PadOp
1744 //===---------------------------------------------------------------------===//
1745 
1746 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1747                              ArrayRef<int64_t> paddingDimensions,
1748                              ArrayRef<int64_t> padToMultipleOf,
1749                              ArrayRef<int64_t> nofoldFlags,
1750                              ArrayRef<Attribute> transposePaddings,
1751                              StringRef copyBackOp) {
1752   auto resultType = transform::AnyOpType::get(b.getContext());
1753   return build(/*builder=*/b,
1754                /*result=*/result,
1755                /*types=*/TypeRange{resultType, resultType},
1756                /*target=*/target,
1757                /*paddingValues=*/ArrayAttr(), // let inference handle this
1758                /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1759                /*padToMultipleOf=*/ValueRange{},
1760                /*padToMultipleOf=*/
1761                (padToMultipleOf.empty()
1762                     ? DenseI64ArrayAttr()
1763                     : b.getDenseI64ArrayAttr(padToMultipleOf)),
1764                /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1765                /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1766                /*copyBackOp=*/b.getStringAttr(copyBackOp));
1767 }
1768 
1769 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1770                              ArrayRef<int64_t> paddingDimensions,
1771                              ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1772                              ArrayRef<int64_t> nofoldFlags,
1773                              ArrayRef<Attribute> transposePaddings,
1774                              StringRef copyBackOp) {
1775   auto resultType = transform::AnyOpType::get(b.getContext());
1776   SmallVector<int64_t> staticPadToMultipleOf;
1777   SmallVector<Value> dynamicPadToMultipleOf;
1778   dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1779                              staticPadToMultipleOf);
1780   return build(/*builder=*/b,
1781                /*result=*/result,
1782                /*types=*/TypeRange{resultType, resultType},
1783                /*target=*/target,
1784                /*paddingValues=*/ArrayAttr(), // let inference handle this
1785                /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1786                /*padToMultipleOf=*/dynamicPadToMultipleOf,
1787                /*padToMultipleOf=*/staticPadToMultipleOf,
1788                /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1789                /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1790                /*copyBackOp=*/b.getStringAttr(copyBackOp));
1791 }
1792 
1793 void PadOp::getEffects(
1794     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1795   consumesHandle(getTargetMutable(), effects);
1796   onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1797   producesHandle(getOperation()->getOpResults(), effects);
1798   modifiesPayload(effects);
1799 }
1800 
1801 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1802   Builder b(getContext());
1803   return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1804 }
1805 
1806 DiagnosedSilenceableFailure
1807 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1808                         transform::TransformResults &results,
1809                         transform::TransformState &state) {
1810   auto transformOp = cast<TransformOpInterface>(getOperation());
1811   SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1812 
1813   for (Operation *target : state.getPayloadOps(getTarget())) {
1814     auto linalgTarget = dyn_cast<LinalgOp>(target);
1815     if (!linalgTarget) {
1816       auto diag = emitSilenceableError() << "expected LinalgOp target";
1817       diag.attachNote(target->getLoc()) << "target op";
1818       return diag;
1819     }
1820 
1821     // Convert the integer packing flags to booleans.
1822     SmallVector<bool> nofoldFlags;
1823     for (int64_t packPadding :
1824          extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1825       nofoldFlags.push_back(static_cast<bool>(packPadding));
1826 
1827     // Convert the padding values to attributes.
1828     SmallVector<Attribute> paddingValues;
1829     for (auto const &it :
1830          llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1831       auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1832       if (!attr) {
1833         emitOpError("expects padding values to be typed attributes");
1834         return DiagnosedSilenceableFailure::definiteFailure();
1835       }
1836       Type elementType = getElementTypeOrSelf(std::get<1>(it));
1837       // Try to parse string attributes to obtain an attribute of element type.
1838       if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1839         auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1840             stringAttr, getContext(), elementType,
1841             /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1842         if (!parsedAttr || parsedAttr.getType() != elementType) {
1843           auto diag = this->emitOpError("expects a padding that parses to ")
1844                       << elementType << ", got " << std::get<0>(it);
1845           diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1846           return DiagnosedSilenceableFailure::definiteFailure();
1847         }
1848         paddingValues.push_back(parsedAttr);
1849         continue;
1850       }
1851       // Otherwise, add the attribute directly.
1852       if (attr.getType() != elementType) {
1853         auto diag = this->emitOpError("expects a padding value of type ")
1854                     << elementType << ", got " << attr;
1855         diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1856         return DiagnosedSilenceableFailure::definiteFailure();
1857       }
1858       paddingValues.push_back(attr);
1859     }
1860 
1861     // Extract the transpose vectors.
1862     SmallVector<SmallVector<int64_t>> transposePaddings;
1863     for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1864       transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1865           cast<ArrayAttr>(transposeVector)));
1866 
1867     LinalgOp paddedOp;
1868     LinalgPaddingOptions options;
1869     options.paddingDimensions =
1870         extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1871 
1872     SmallVector<int64_t> padToMultipleOf;
1873     DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
1874         state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1875     if (!status.succeeded())
1876       return status;
1877     if (padToMultipleOf.empty())
1878       padToMultipleOf =
1879           SmallVector<int64_t>(options.paddingDimensions.size(), 1);
1880 
1881     options.padToMultipleOf = padToMultipleOf;
1882     options.paddingValues = paddingValues;
1883     options.nofoldFlags = nofoldFlags;
1884     if (getCopyBackOp() ==
1885         bufferization::MaterializeInDestinationOp::getOperationName()) {
1886       options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
1887           BufferizationMaterializeInDestination;
1888     } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1889       options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
1890     } else if (getCopyBackOp() == kCopyOpNone) {
1891       options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
1892     } else {
1893       llvm_unreachable("unsupported copy_back op");
1894     }
1895 
1896     SmallVector<Value> replacements;
1897     SmallVector<tensor::PadOp> newPadOps;
1898     if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1899                                  replacements, newPadOps))) {
1900       auto diag = emitSilenceableError() << "failed to pad op";
1901       diag.attachNote(target->getLoc()) << "target op";
1902       return diag;
1903     }
1904 
1905     // We need to perform our own replacement here because this API is still
1906     // used in patterns that "pad and hoist", for which the replacement values
1907     // need to be different.
1908     // TODO: clean this up and stop "pad and hoist" behavior more globally now
1909     // that we have more composable abstractions.
1910     rewriter.replaceOp(linalgTarget, replacements);
1911     paddedOps.push_back(paddedOp);
1912     padOps.append(newPadOps.begin(), newPadOps.end());
1913     if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1914       for (Value v : replacements) {
1915         Operation *copyBackOp = v.getDefiningOp();
1916         if (!llvm::is_contained(copyBackOps, copyBackOp))
1917           copyBackOps.push_back(copyBackOp);
1918       }
1919     }
1920   }
1921 
1922   results.set(cast<OpResult>(getPadded()), paddedOps);
1923   results.set(cast<OpResult>(getPad()), padOps);
1924   results.set(cast<OpResult>(getCopy()), copyBackOps);
1925   return DiagnosedSilenceableFailure::success();
1926 }
1927 
1928 LogicalResult transform::PadOp::verify() {
1929   SmallVector<int64_t> nofoldFlags =
1930       extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
1931   if (any_of(nofoldFlags, [](int64_t packPadding) {
1932         return packPadding != 0 && packPadding != 1;
1933       })) {
1934     return emitOpError()
1935            << "expects nofold_flags to contain booleans (0/1), found "
1936            << getNofoldFlags();
1937   }
1938 
1939   SmallVector<int64_t> paddingDimensions =
1940       extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1941   if (any_of(paddingDimensions,
1942              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1943     return emitOpError() << "expects padding_dimensions to contain positive "
1944                             "integers, found "
1945                          << getPaddingDimensions();
1946   }
1947   if (!getMixedPadToMultipleOf().empty()) {
1948     if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
1949       return emitOpError() << "expects as many multiples as padding_dimensions";
1950     }
1951   }
1952   ArrayAttr transposes = getTransposePaddings();
1953   for (Attribute attr : transposes) {
1954     SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1955     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1956     if (!std::is_permutation(sequence.begin(), sequence.end(),
1957                              transpose.begin(), transpose.end())) {
1958       return emitOpError()
1959              << "expects transpose_paddings to be a permutation, found "
1960              << attr;
1961     }
1962   }
1963   if (getCopyBackOp() !=
1964           bufferization::MaterializeInDestinationOp::getOperationName() &&
1965       getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1966       getCopyBackOp() != kCopyOpNone)
1967     return emitOpError() << "invalid copy_back_op";
1968   return success();
1969 }
1970 
1971 //===---------------------------------------------------------------------===//
1972 // HoistPadOp
1973 //===---------------------------------------------------------------------===//
1974 
1975 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1976     transform::TransformRewriter &rewriter,
1977     transform::TransformResults &transformResults,
1978     transform::TransformState &state) {
1979   auto targetOps = state.getPayloadOps(getTarget());
1980   auto loopOps = state.getPayloadOps(getLoop());
1981   if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1982     return emitDefiniteFailure()
1983            << "requires exactly one target and one loop handle (got "
1984            << llvm::range_size(targetOps) << " and "
1985            << llvm::range_size(loopOps) << ")";
1986   }
1987 
1988   auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1989   auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1990   if (!padOp || !loopOp)
1991     return emitDefiniteFailure() << "requires exactly 2 non-null handles";
1992 
1993   FailureOr<linalg::detail::PackingResult> result =
1994       linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
1995                                            getTranspose());
1996   if (failed(result))
1997     return emitDefiniteFailure() << "could not build packing loop nest";
1998 
1999   if (result->clonedLoopIvs.empty()) {
2000     transformResults.set(cast<OpResult>(getPackingLoop()),
2001                          {result->hoistedPadOp.getOperation()});
2002     return DiagnosedSilenceableFailure::success();
2003   }
2004   auto outerPackedLoop =
2005       scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2006   transformResults.set(cast<OpResult>(getPackingLoop()),
2007                        {outerPackedLoop.getOperation()});
2008   return DiagnosedSilenceableFailure::success();
2009 }
2010 
2011 LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2012   ArrayRef<int64_t> transpose = getTranspose();
2013   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2014   if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2015                            transpose.end())) {
2016     return emitOpError() << "expects transpose to be a permutation, found "
2017                          << getTranspose();
2018   }
2019   return success();
2020 }
2021 
2022 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2023     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2024   transform::onlyReadsHandle(getTargetMutable(), effects);
2025   transform::onlyReadsHandle(getLoopMutable(), effects);
2026   transform::producesHandle(getOperation()->getOpResults(), effects);
2027   transform::modifiesPayload(effects);
2028 }
2029 
2030 DiagnosedSilenceableFailure
2031 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2032                                   tensor::PadOp target,
2033                                   transform::ApplyToEachResultList &results,
2034                                   transform::TransformState &state) {
2035   tensor::PadOp hoistedPadOp;
2036   SmallVector<TransposeOp> transposeOps;
2037   FailureOr<Value> result =
2038       hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2039                             hoistedPadOp, transposeOps);
2040   if (succeeded(result)) {
2041     // We need to perform our own replacement here because this API is still
2042     // used in patterns that "pad and hoist", for which the replacement values
2043     // need to be different.
2044     // TODO: clean this up and stop "pad and hoist" behavior more globally now
2045     // that we have more composable abstractions.
2046     rewriter.replaceOp(target, *result);
2047     results.push_back(hoistedPadOp);
2048     return DiagnosedSilenceableFailure::success();
2049   }
2050   return emitDefaultSilenceableFailure(target);
2051 }
2052 
2053 LogicalResult transform::HoistPadOp::verify() {
2054   ArrayRef<int64_t> transpose = getTranspose();
2055   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2056   if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2057                            transpose.end())) {
2058     return emitOpError() << "expects transpose to be a permutation, found "
2059                          << getTranspose();
2060   }
2061   return success();
2062 }
2063 
2064 //===----------------------------------------------------------------------===//
2065 // PromoteOp
2066 //===----------------------------------------------------------------------===//
2067 
2068 DiagnosedSilenceableFailure
2069 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2070                                  LinalgOp target,
2071                                  transform::ApplyToEachResultList &results,
2072                                  transform::TransformState &state) {
2073   LinalgPromotionOptions promotionOptions;
2074   if (!getOperandsToPromote().empty())
2075     promotionOptions = promotionOptions.setOperandsToPromote(
2076         extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2077   if (getUseFullTilesByDefault())
2078     promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2079         getUseFullTilesByDefault());
2080   if (getUseAlloca())
2081     promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2082   if (!getUseFullTileBuffers().empty())
2083     promotionOptions = promotionOptions.setUseFullTileBuffers(
2084         llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2085   if (getAlignment().has_value())
2086     promotionOptions = promotionOptions.setAlignment(*getAlignment());
2087   if (getMemorySpace().has_value())
2088     promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2089 
2090   if (getMapping().has_value()) {
2091     // The mapping should only contain an element
2092     auto mapping = *getMapping();
2093     if (mapping.size() > 1)
2094       return emitDefaultDefiniteFailure(target);
2095 
2096     auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2097 
2098     if (addressSpace.getAddressSpace() ==
2099         mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2100       promotionOptions =
2101           promotionOptions
2102               .setAllocationDeallocationFns(allocateWorkgroupMemory,
2103                                             deallocateWorkgroupMemory)
2104               .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
2105               .setUseFullTileBuffers({false, false});
2106     } else if (addressSpace.getAddressSpace() ==
2107                mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2108       promotionOptions =
2109           promotionOptions
2110               .setAllocationDeallocationFns(allocateGPUPrivateMemory,
2111                                             deallocateGPUPrivateMemory)
2112               .setCopyInOutFns(copyToGPUPrivateMemory, copyToGPUPrivateMemory)
2113               .setUseFullTileBuffers({false, false});
2114     } else {
2115       return emitDefaultDefiniteFailure(target);
2116     }
2117   }
2118 
2119   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2120     return emitDefaultDefiniteFailure(target);
2121 
2122   rewriter.setInsertionPoint(target);
2123   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2124   if (failed(res))
2125     return emitDefaultDefiniteFailure(target);
2126   results.push_back(target);
2127   return DiagnosedSilenceableFailure::success();
2128 }
2129 
2130 //===----------------------------------------------------------------------===//
2131 // ReplaceOp
2132 //===----------------------------------------------------------------------===//
2133 
2134 DiagnosedSilenceableFailure
2135 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2136                             TransformResults &transformResults,
2137                             TransformState &state) {
2138   auto payload = state.getPayloadOps(getTarget());
2139 
2140   // Check for invalid targets.
2141   for (Operation *target : payload) {
2142     if (target->getNumOperands() > 0)
2143       return emitDefiniteFailure() << "expected target without operands";
2144     if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2145         target->getNumRegions() > 0)
2146       return emitDefiniteFailure()
2147              << "expected target that is isolated from above";
2148   }
2149 
2150   // Clone and replace.
2151   Operation *pattern = &getBodyRegion().front().front();
2152   SmallVector<Operation *> replacements;
2153   for (Operation *target : payload) {
2154     if (getOperation()->isAncestor(target))
2155       continue;
2156     rewriter.setInsertionPoint(target);
2157     Operation *replacement = rewriter.clone(*pattern);
2158     rewriter.replaceOp(target, replacement->getResults());
2159     replacements.push_back(replacement);
2160   }
2161   transformResults.set(cast<OpResult>(getReplacement()), replacements);
2162   return DiagnosedSilenceableFailure::success();
2163 }
2164 
2165 void transform::ReplaceOp::getEffects(
2166     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2167   consumesHandle(getTargetMutable(), effects);
2168   producesHandle(getOperation()->getOpResults(), effects);
2169   modifiesPayload(effects);
2170 }
2171 
2172 LogicalResult transform::ReplaceOp::verify() {
2173   if (!getBodyRegion().hasOneBlock())
2174     return emitOpError() << "expected one block";
2175   if (std::distance(getBodyRegion().front().begin(),
2176                     getBodyRegion().front().end()) != 1)
2177     return emitOpError() << "expected one operation in block";
2178   Operation *replacement = &getBodyRegion().front().front();
2179   if (replacement->getNumOperands() > 0)
2180     return replacement->emitOpError()
2181            << "expected replacement without operands";
2182   if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2183       replacement->getNumRegions() > 0)
2184     return replacement->emitOpError()
2185            << "expect op that is isolated from above";
2186   return success();
2187 }
2188 
2189 //===----------------------------------------------------------------------===//
2190 // ScalarizeOp
2191 //===----------------------------------------------------------------------===//
2192 
2193 DiagnosedSilenceableFailure
2194 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2195                                    LinalgOp target,
2196                                    transform::ApplyToEachResultList &results,
2197                                    transform::TransformState &state) {
2198   scf::SCFTilingOptions tilingOptions;
2199   tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2200     SmallVector<OpFoldResult> tileSizes;
2201     Location loc = target.getLoc();
2202     SmallVector<OpFoldResult> allShapeSizes =
2203         target.createFlatListOfOperandDims(b, loc);
2204     AffineMap map = target.getShapesToLoopsMap();
2205     if (!map)
2206       return tileSizes;
2207     SmallVector<OpFoldResult> shapeSizes =
2208         affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
2209                                                          allShapeSizes);
2210     // If the shape size is dynamic, tile by 1.
2211     // Otherwise, do not tile (i.e. tile size 0).
2212     for (OpFoldResult shapeSize : shapeSizes) {
2213       tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2214                                                          : b.getIndexAttr(1));
2215     }
2216     return tileSizes;
2217   });
2218   SmallVector<int64_t> emptyTileSizes;
2219   rewriter.setInsertionPoint(target);
2220   FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2221       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2222   if (failed(maybeTilingResult))
2223     return emitDefaultDefiniteFailure(target);
2224 
2225   if (target->getNumResults())
2226     rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
2227   else
2228     rewriter.eraseOp(target);
2229 
2230   results.reserve(maybeTilingResult->tiledOps.size());
2231   for (Operation *tiled : maybeTilingResult->tiledOps)
2232     results.push_back(tiled);
2233   return DiagnosedSilenceableFailure::success();
2234 }
2235 
2236 //===----------------------------------------------------------------------===//
2237 // ConvertToLoopsOp
2238 //===----------------------------------------------------------------------===//
2239 
2240 DiagnosedSilenceableFailure
2241 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2242                                    transform::TransformResults &results,
2243                                    transform::TransformState &state) {
2244   SmallVector<Operation *> loops;
2245   for (Operation *target : state.getPayloadOps(getTarget())) {
2246     auto tilingOp = dyn_cast<TilingInterface>(*target);
2247     if (!tilingOp) {
2248       DiagnosedSilenceableFailure diag =
2249           emitSilenceableError()
2250           << "expected the payload to implement TilingInterface";
2251       diag.attachNote(target->getLoc()) << "payload op";
2252       return diag;
2253     }
2254     rewriter.setInsertionPoint(target);
2255     FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2256         scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2257     if (failed(generatedLoops))
2258       return emitDefaultDefiniteFailure(target);
2259     for (scf::ForOp &loop : *generatedLoops) {
2260       loops.push_back(loop.getOperation());
2261     }
2262     rewriter.eraseOp(target);
2263   }
2264   results.set(cast<OpResult>(getResult()), loops);
2265   return DiagnosedSilenceableFailure::success();
2266 }
2267 
2268 //===----------------------------------------------------------------------===//
2269 // RewriteInDestinationPassingStyleOp
2270 //===----------------------------------------------------------------------===//
2271 
2272 DiagnosedSilenceableFailure
2273 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2274     transform::TransformRewriter &rewriter, Operation *target,
2275     transform::ApplyToEachResultList &results,
2276     transform::TransformState &state) {
2277   SmallVector<Operation *> res;
2278   rewriter.setInsertionPoint(target);
2279   FailureOr<Operation *> maybeResult =
2280       TypeSwitch<Operation *, FailureOr<Operation *>>(target)
2281           .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2282               [&rewriter](auto op) {
2283                 return rewriteInDestinationPassingStyle(rewriter, op);
2284               });
2285   if (failed(maybeResult))
2286     return emitDefaultSilenceableFailure(target);
2287   results.push_back(*maybeResult);
2288   return DiagnosedSilenceableFailure::success();
2289 }
2290 
2291 //===----------------------------------------------------------------------===//
2292 // SplitOp
2293 //===----------------------------------------------------------------------===//
2294 
2295 DiagnosedSilenceableFailure
2296 SplitOp::apply(transform::TransformRewriter &rewriter,
2297                TransformResults &results, TransformState &state) {
2298   // Collect the dynamic split points if provided.
2299   SmallVector<Operation *> payload =
2300       llvm::to_vector(state.getPayloadOps(getTarget()));
2301 
2302   bool isMultiwaySplit = getMultiway();
2303 
2304   if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2305     return mlir::emitSilenceableFailure(getLoc())
2306            << "requires exactly one target when "
2307               "multiway split is enabled (got "
2308            << llvm::range_size(payload) << ")";
2309   }
2310 
2311   SmallVector<OpFoldResult> chunkSizes;
2312 
2313   if (!isMultiwaySplit)
2314     chunkSizes.reserve(payload.size());
2315 
2316   if (getDynamicChunkSizes()) {
2317     auto diag = DiagnosedSilenceableFailure::success();
2318     if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2319       chunkSizes = llvm::to_vector(llvm::map_range(
2320           state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2321             if (op->getNumResults() != 1 ||
2322                 !op->getResult(0).getType().isIndex()) {
2323               diag = emitSilenceableError()
2324                      << "expected dynamic split point handle to point to a "
2325                         "single-result index-typed op";
2326               diag.attachNote(op->getLoc()) << "dynamic split point";
2327             }
2328             return OpFoldResult(op->getResult(0));
2329           }));
2330     } else {
2331       chunkSizes = llvm::to_vector(
2332           llvm::map_range(state.getParams(getDynamicChunkSizes()),
2333                           [](Attribute attr) { return OpFoldResult(attr); }));
2334     }
2335     if (diag.isSilenceableFailure())
2336       return diag;
2337 
2338     // For multiway split, a single payload is expected to have multiple
2339     // split points.
2340     if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2341       return emitDefiniteFailure()
2342              << "expected the dynamic split point handle to point to as "
2343                 "many operations ("
2344              << chunkSizes.size() << ") as the target handle ("
2345              << payload.size() << ")";
2346     }
2347   } else {
2348     chunkSizes.resize(payload.size(),
2349                       rewriter.getIndexAttr(getStaticChunkSizes()));
2350   }
2351 
2352   auto checkStructuredOpAndDimensions =
2353       [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2354     if (!linalgOp) {
2355       auto diag = emitSilenceableError() << "only applies to structured ops";
2356       diag.attachNote(loc) << "target op";
2357       return diag;
2358     }
2359 
2360     if (getDimension() >= linalgOp.getNumLoops()) {
2361       auto diag = emitSilenceableError() << "dimension " << getDimension()
2362                                          << " does not exist in target op";
2363       diag.attachNote(loc) << "target op";
2364       return diag;
2365     }
2366     return DiagnosedSilenceableFailure::success();
2367   };
2368 
2369   auto checkFailureInSplitting =
2370       [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2371     if (hasFailed) {
2372       auto diag = emitDefiniteFailure() << "internal failure in splitting";
2373       diag.attachNote(loc) << "target op";
2374       return diag;
2375     }
2376     return DiagnosedSilenceableFailure::success();
2377   };
2378 
2379   SmallVector<Operation *> opList;
2380   if (isMultiwaySplit) {
2381 
2382     // Split a single target operation at multiple points.
2383     TilingInterface head, tail;
2384     Operation *target = payload.front();
2385 
2386     LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2387 
2388     // Check that the target is a valid LinalgOp with correct dimensions.
2389     DiagnosedSilenceableFailure diag =
2390         checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2391     if (diag.isSilenceableFailure())
2392       return diag;
2393 
2394     for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2395 
2396       if (idx > 0)
2397         target = tail.getOperation();
2398 
2399       if (!target)
2400         break;
2401 
2402       linalgOp = cast<LinalgOp>(target);
2403       Location loc = target->getLoc();
2404 
2405       rewriter.setInsertionPoint(linalgOp);
2406       std::tie(head, tail) = linalg::splitOp(
2407           rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2408           getDimension(), chunkSize);
2409 
2410       // Propagate errors.
2411       DiagnosedSilenceableFailure diag =
2412           checkFailureInSplitting(!head && !tail, loc);
2413       if (diag.isDefiniteFailure())
2414         return diag;
2415 
2416       opList.push_back(head.getOperation());
2417     }
2418 
2419     // Append any leftover parts to the end of the result list.
2420     if (tail)
2421       opList.push_back(tail.getOperation());
2422 
2423   } else {
2424     // Split each target operation.
2425     SmallVector<Operation *> first, second;
2426     Operation *noSecondPart = nullptr;
2427     for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2428       Operation *target = std::get<0>(pair);
2429       Location loc = target->getLoc();
2430       LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2431       DiagnosedSilenceableFailure diag =
2432           checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2433 
2434       if (diag.isSilenceableFailure())
2435         return diag;
2436 
2437       rewriter.setInsertionPoint(linalgOp);
2438       std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2439           rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2440           getDimension(), std::get<1>(pair));
2441 
2442       // Propagate errors.
2443       DiagnosedSilenceableFailure diagSplit =
2444           checkFailureInSplitting(!first.back() && !second.back(), loc);
2445       if (diagSplit.isDefiniteFailure())
2446         return diag;
2447 
2448       // Do not add null second parts.
2449       if (!second.back()) {
2450         noSecondPart = target;
2451         second.pop_back();
2452       }
2453     }
2454 
2455     if (second.size() != first.size() && !second.empty()) {
2456       auto diag = emitSilenceableError()
2457                   << "splitting does not produce the second part for a subset "
2458                      "of targets";
2459       diag.attachNote()
2460           << "expected splitting to produce the second part of all "
2461              "or none of the targets";
2462       diag.attachNote(noSecondPart->getLoc())
2463           << "first target with no second part";
2464       return diag;
2465     }
2466 
2467     opList.append(first);
2468     if (second.size())
2469       opList.append(second);
2470   }
2471   results.set(cast<OpResult>(getSplitList()), opList);
2472   return DiagnosedSilenceableFailure::success();
2473 }
2474 
2475 void SplitOp::getEffects(
2476     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2477   consumesHandle(getTargetMutable(), effects);
2478   if (getDynamicChunkSizes())
2479     onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2480   producesHandle(getOperation()->getOpResults(), effects);
2481   modifiesPayload(effects);
2482 }
2483 
2484 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2485   OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2486   IntegerAttr staticChunkSizes;
2487   if (parser.parseOperand(target) || parser.parseKeyword("after"))
2488     return failure();
2489 
2490   OptionalParseResult dynamicPointParseResult =
2491       parser.parseOptionalOperand(dynamicChunkSizes);
2492   if (!dynamicPointParseResult.has_value()) {
2493     int64_t staticChunkSizesValue;
2494     if (failed(parser.parseInteger(staticChunkSizesValue)))
2495       return failure();
2496 
2497     staticChunkSizes =
2498         parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2499   }
2500 
2501   Type targetType;
2502   if (parser.parseOptionalAttrDict(result.attributes) ||
2503       parser.parseColonType(targetType) ||
2504       parser.resolveOperand(target, targetType, result.operands)) {
2505     return failure();
2506   }
2507   if (dynamicPointParseResult.has_value()) {
2508     Type ChunkSizesType;
2509     if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2510         parser.parseType(ChunkSizesType) ||
2511         parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2512                               result.operands)) {
2513       return failure();
2514     }
2515 
2516     staticChunkSizes =
2517         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2518   }
2519 
2520   result.addAttribute(
2521       SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2522       staticChunkSizes);
2523   result.addTypes(targetType);
2524   return success();
2525 }
2526 
2527 void SplitOp::print(OpAsmPrinter &printer) {
2528   printer << " " << getTarget() << " after ";
2529   int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2530   if (staticChunkSize != ShapedType::kDynamic)
2531     printer << staticChunkSize;
2532   else
2533     printer << getDynamicChunkSizes();
2534   printer << " ";
2535   printer.printOptionalAttrDict(getOperation()->getAttrs(),
2536                                 {getStaticChunkSizesAttrName()});
2537   printer << " : " << getTarget().getType();
2538   if (staticChunkSize == ShapedType::kDynamic)
2539     printer << ", " << getDynamicChunkSizes().getType();
2540 }
2541 
2542 LogicalResult SplitOp::verify() {
2543   if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2544       (getDynamicChunkSizes() == nullptr)) {
2545     return emitOpError() << "expects either a dynamic or a static split "
2546                             "point to be provided";
2547   }
2548   return success();
2549 }
2550 
2551 //===----------------------------------------------------------------------===//
2552 // SplitReductionOp
2553 //===----------------------------------------------------------------------===//
2554 
2555 void transform::SplitReductionOp::build(
2556     OpBuilder &builder, OperationState &result, Value target,
2557     int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2558     bool useScalingAlgorithm, bool useAlloc) {
2559   MLIRContext *ctx = builder.getContext();
2560   result.addOperands(target);
2561   result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2562                       builder.getI64IntegerAttr(splitFactor));
2563   result.addAttribute(
2564       SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2565       builder.getI64IntegerAttr(insertSplitDimension));
2566   if (innerParallel) {
2567     result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2568                         builder.getUnitAttr());
2569   }
2570   if (useScalingAlgorithm) {
2571     result.addAttribute(
2572         SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2573         builder.getUnitAttr());
2574   }
2575   if (useAlloc) {
2576     result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2577                         builder.getUnitAttr());
2578   }
2579   auto resultType = transform::AnyOpType::get(ctx);
2580   result.addTypes({resultType, resultType, resultType, resultType});
2581 }
2582 
2583 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2584     transform::TransformRewriter &rewriter, LinalgOp target,
2585     transform::ApplyToEachResultList &results,
2586     transform::TransformState &state) {
2587   ControlSplitReductionFn splitFn = [&](LinalgOp) {
2588     return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2589                                          unsigned(getInsertSplitDimension()),
2590                                          bool(getInnerParallel())};
2591   };
2592   rewriter.setInsertionPoint(target);
2593   FailureOr<SplitReductionResult> splitResult =
2594       (getUseScalingAlgorithm())
2595           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2596           : splitReduction(rewriter, target, splitFn, getUseAlloc());
2597   if (failed(splitResult))
2598     return emitDefaultDefiniteFailure(target);
2599 
2600   results.push_back(splitResult->initOrAlloc);
2601   results.push_back(splitResult->fillOp);
2602   results.push_back(splitResult->splitLinalgOp);
2603   results.push_back(splitResult->resultCombiningLinalgOp);
2604   return DiagnosedSilenceableFailure::success();
2605 }
2606 
2607 //===----------------------------------------------------------------------===//
2608 // TileReductionUsingForOp
2609 //===----------------------------------------------------------------------===//
2610 
2611 void transform::TileReductionUsingForOp::build(
2612     OpBuilder &builder, OperationState &result, Value target,
2613     ArrayRef<int64_t> staticTileSizes) {
2614   // Call the default builder.
2615   // This is future-proof re mixed static-dynamic and setting up the proper
2616   // operands segment sizes attributes for multiple variadic operands.
2617   // In the absence of this, horrible bugs ensue.
2618   // TODO: support mixed static-dynamic (see TileUsingForallOp).
2619   MLIRContext *ctx = builder.getContext();
2620   auto opTy = transform::AnyOpType::get(ctx);
2621   auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2622   build(builder, result,
2623         /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2624         /*target=*/target,
2625         /*tile_sizes=*/staticTileSizesAttr);
2626 }
2627 
2628 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2629     transform::TransformRewriter &rewriter, Operation *target,
2630     transform::ApplyToEachResultList &results,
2631     transform::TransformState &state) {
2632   rewriter.setInsertionPoint(target);
2633 
2634   auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2635   if (!partialReductionOp) {
2636     return emitSilenceableFailure(
2637         target->getLoc(),
2638         "Operation should implement PartialReductionOpInterface");
2639   }
2640   FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2641       rewriter, partialReductionOp,
2642       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2643 
2644   if (failed(result))
2645     return emitDefaultSilenceableFailure(target);
2646   rewriter.replaceOp(target, result->mergeResult.replacements);
2647   for (Value initValue : result->initialValues)
2648     results.push_back(initValue.getDefiningOp());
2649   for (auto parallelTiledOp : result->tiledOps)
2650     results.push_back(parallelTiledOp);
2651   for (auto mergeOp : result->mergeResult.mergeOps)
2652     results.push_back(mergeOp);
2653   results.push_back(result->loops.front());
2654   return DiagnosedSilenceableFailure::success();
2655 }
2656 
2657 //===----------------------------------------------------------------------===//
2658 // TileReductionUsingForallOp
2659 //===----------------------------------------------------------------------===//
2660 
2661 void transform::TileReductionUsingForallOp::build(
2662     OpBuilder &builder, OperationState &result, Value target,
2663     ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2664     ArrayAttr mapping) {
2665   // Call the default builder.
2666   // This is future-proof re mixed static-dynamic and setting up the proper
2667   // operands segment sizes attributes for multiple variadic operands.
2668   // In the absence of this, horrible bugs ensue.
2669   // TODO: support mixed static-dynamic (see TileUsingForallOp).
2670   MLIRContext *ctx = builder.getContext();
2671   auto opTy = transform::AnyOpType::get(ctx);
2672   auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2673   auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2674   build(builder, result,
2675         /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2676         /*target=*/target,
2677         /*num_threads=*/staticNumThreadsAttr,
2678         /*tile_sizes=*/staticTileSizesAttr,
2679         /*mapping=*/mapping);
2680 }
2681 
2682 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2683     transform::TransformRewriter &rewriter, LinalgOp target,
2684     transform::ApplyToEachResultList &results,
2685     transform::TransformState &state) {
2686   rewriter.setInsertionPoint(target);
2687   SmallVector<OpFoldResult> numThreads =
2688       getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2689   SmallVector<OpFoldResult> tileSizes =
2690       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2691   FailureOr<linalg::ForallReductionTilingResult> result =
2692       linalg::tileReductionUsingForall(
2693           rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2694           numThreads, tileSizes, getMapping());
2695 
2696   if (failed(result)) {
2697     auto diag = emitSilenceableError() << "could not tile reduction";
2698     diag.attachNote(target.getLoc()) << "target operation";
2699     return diag;
2700   }
2701   for (Value initValue : result->initialValues)
2702     results.push_back(initValue.getDefiningOp());
2703   for (auto parallelTiledOp : result->parallelTiledOps)
2704     results.push_back(parallelTiledOp);
2705   for (auto mergeOp : result->mergeOps)
2706     results.push_back(mergeOp);
2707   results.push_back(result->loops);
2708   return DiagnosedSilenceableFailure::success();
2709 }
2710 
2711 //===----------------------------------------------------------------------===//
2712 // ContinuousTileSizesOp
2713 //===----------------------------------------------------------------------===//
2714 
2715 DiagnosedSilenceableFailure
2716 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2717                                         TransformResults &transformResults,
2718                                         TransformState &state) {
2719 
2720   SmallVector<Operation *> targetOps =
2721       llvm::to_vector(state.getPayloadOps(getTarget()));
2722 
2723   if (!llvm::hasSingleElement(targetOps)) {
2724     return mlir::emitSilenceableFailure(getLoc())
2725            << "requires exactly one target (got " << llvm::range_size(targetOps)
2726            << ")";
2727   }
2728 
2729   Operation *target = *targetOps.begin();
2730   auto linalgOp = dyn_cast<LinalgOp>(target);
2731   auto tileableOp = dyn_cast<TilingInterface>(target);
2732 
2733   if (!linalgOp)
2734     return emitDefiniteFailure() << "expected Linalg Op";
2735 
2736   OpBuilder builder(linalgOp.getContext());
2737 
2738   if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
2739     if (linalgOp.hasDynamicShape()) {
2740       auto diag = emitSilenceableError()
2741                   << "cannot compute parametric tile sizes for dynamically "
2742                      "shaped payload op";
2743       diag.attachNote(linalgOp->getLoc()) << "payload op";
2744       return diag;
2745     }
2746 
2747     FailureOr<StaticContinuousTileSizeSpecification> spec =
2748         computeStaticContinuousTileSizes(linalgOp, getDimension(),
2749                                          getTargetSize());
2750     if (failed(spec)) {
2751       return emitSilenceableError()
2752              << "failed to compute multi-size tiling sizes";
2753     }
2754 
2755     SmallVector<int64_t> chunkSizes;
2756 
2757     for (auto &&[tileSize, tripCount] :
2758          llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2759       chunkSizes.push_back(tileSize * tripCount);
2760 
2761     auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2762       return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2763         return builder.getI64IntegerAttr(value);
2764       });
2765     };
2766     transformResults.setParams(cast<OpResult>(getTileSizes()),
2767                                getI64AttrsFromI64(spec->tileSizes));
2768     transformResults.setParams(cast<OpResult>(getChunkSizes()),
2769                                getI64AttrsFromI64(chunkSizes));
2770 
2771     return DiagnosedSilenceableFailure::success();
2772   }
2773 
2774   builder.setInsertionPoint(linalgOp);
2775 
2776   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2777   unsigned dimension = getDimension();
2778 
2779   FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
2780       builder, tileableOp, dimension, targetSize, true);
2781   if (failed(spec)) {
2782     return emitSilenceableError() << "could not generate tile size computation";
2783   }
2784 
2785   AffineExpr s0 = builder.getAffineSymbolExpr(0);
2786   AffineExpr s1 = builder.getAffineSymbolExpr(1);
2787   auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2788     return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
2789                                            ofrs);
2790   };
2791 
2792   SmallVector<Value> chunkSizes;
2793   Value splitPoint;
2794   for (auto &&[tileSize, tripCount] :
2795        llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2796     splitPoint = apply(s0 * s1, {tileSize, tripCount});
2797     chunkSizes.push_back(splitPoint);
2798   }
2799 
2800   auto getDefiningOps = [&](ArrayRef<Value> values) {
2801     return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2802       return value.getDefiningOp();
2803     });
2804   };
2805 
2806   transformResults.set(cast<OpResult>(getTileSizes()),
2807                        getDefiningOps(spec->tileSizes));
2808   transformResults.set(cast<OpResult>(getChunkSizes()),
2809                        getDefiningOps(chunkSizes));
2810 
2811   return DiagnosedSilenceableFailure::success();
2812 }
2813 
2814 LogicalResult transform::ContinuousTileSizesOp::verify() {
2815 
2816   if (getTileSizes().getType() != getChunkSizes().getType()) {
2817     return emitOpError() << "expects all results type to be the same";
2818   }
2819 
2820   return success();
2821 }
2822 
2823 void transform::ContinuousTileSizesOp::getEffects(
2824     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2825   if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2826     onlyReadsPayload(effects);
2827   else
2828     modifiesPayload(effects);
2829   onlyReadsHandle(getTargetMutable(), effects);
2830   producesHandle(getOperation()->getOpResults(), effects);
2831 }
2832 
2833 static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
2834                                          Type targetType, Type tile_sizes,
2835                                          Type) {
2836   printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
2837 }
2838 
2839 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2840                                                 Type &targetType,
2841                                                 Type &tileSizesType,
2842                                                 Type &chunkSizesType) {
2843   FunctionType funcType;
2844   llvm::SMLoc typeLoc = parser.getCurrentLocation();
2845   if (failed(parser.parseType<FunctionType>(funcType)))
2846     return failure();
2847 
2848   if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2849     parser.emitError(typeLoc) << "expects a trailing functional type with one "
2850                                  "argument and one result";
2851   }
2852   targetType = funcType.getInput(0);
2853   tileSizesType = chunkSizesType = funcType.getResult(0);
2854 
2855   return success();
2856 }
2857 
2858 //===----------------------------------------------------------------------===//
2859 // TileUsingForOp
2860 //===----------------------------------------------------------------------===//
2861 
2862 void transform::TileUsingForOp::build(
2863     OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2864     Value target, ArrayRef<int64_t> staticTileSizes,
2865     ArrayRef<int64_t> interchange,
2866     std::optional<ArrayRef<bool>> scalableSizes) {
2867   return build(builder, result, loopTypes,
2868                /*target=*/target,
2869                /*mixedTileSizes=*/
2870                getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2871                interchange, scalableSizes);
2872 }
2873 
2874 void transform::TileUsingForOp::build(
2875     OpBuilder &builder, OperationState &result, Value target,
2876     ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2877     std::optional<ArrayRef<bool>> scalableSizes) {
2878   build(builder, result, target,
2879         getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2880         interchange, scalableSizes);
2881 }
2882 
2883 void transform::TileUsingForOp::build(
2884     OpBuilder &builder, OperationState &result, Value target,
2885     ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2886     std::optional<ArrayRef<bool>> scalableSizes) {
2887   // Loop types are automaticaly splat by the callee, setting up one is
2888   // enough.
2889   SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2890   build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2891         scalableSizes);
2892 }
2893 
2894 void transform::TileUsingForOp::build(
2895     OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2896     Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2897     ArrayRef<int64_t> interchange,
2898     std::optional<ArrayRef<bool>> scalableSizes) {
2899   SmallVector<int64_t> staticTileSizes;
2900   SmallVector<Value> dynamicTileSizes;
2901   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2902   // Call the default builder which sets up the proper operands segment sizes
2903   // attributes for multiple variadic operands. In the absence of this,
2904   // horrible bugs ensue.
2905   auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2906   unsigned numExpectedLoops =
2907       staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2908   SmallVector<Type> resultTypes;
2909   resultTypes.reserve(numExpectedLoops);
2910   assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2911          "expected one loop type or as many as loops");
2912   if (loopTypes.size() == 1)
2913     resultTypes.append(numExpectedLoops, loopTypes[0]);
2914   else
2915     llvm::append_range(resultTypes, loopTypes);
2916   SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2917   if (scalableSizes.has_value())
2918     expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2919   build(builder, result, /*tiled_linalg_op=*/target.getType(),
2920         /*loops=*/resultTypes,
2921         /*target=*/target,
2922         /*dynamic_sizes=*/dynamicTileSizes,
2923         /*static_sizes=*/staticTileSizesAttr,
2924         /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2925         /*scalable_sizes=*/expandedScalableSizes);
2926 }
2927 
2928 LogicalResult transform::TileUsingForOp::verify() {
2929   if (getMixedSizes().size() != getScalableSizes().size())
2930     return emitOpError("expected same number of sizes (")
2931            << getMixedSizes().size() << ") and scalable sizes ("
2932            << getScalableSizes().size() << ")";
2933   ArrayRef<int64_t> staticSizes = getStaticSizes();
2934   unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
2935   if (getLoops().size() != numExpectedLoops)
2936     return emitOpError("expected number of loops to tile (")
2937            << numExpectedLoops << ") to match number of `loops` results ("
2938            << getLoops().size() << ")";
2939   return success();
2940 }
2941 
2942 DiagnosedSilenceableFailure
2943 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2944                                  TransformResults &transformResults,
2945                                  TransformState &state) {
2946   ArrayRef<int64_t> tileSizes = getStaticSizes();
2947 
2948   SmallVector<Operation *> targets =
2949       llvm::to_vector(state.getPayloadOps(getTarget()));
2950   SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2951   SmallVector<SmallVector<int64_t>> paramSizes;
2952   dynamicSizeProducers.reserve(getDynamicSizes().size());
2953   paramSizes.reserve(getDynamicSizes().size());
2954   for (Value transformValue : getDynamicSizes()) {
2955     if (isa<ParamType>(transformValue.getType())) {
2956       dynamicSizeProducers.push_back({});
2957       ArrayRef<Attribute> params = state.getParams(transformValue);
2958       paramSizes.push_back(
2959           llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2960             return cast<IntegerAttr>(attr).getValue().getSExtValue();
2961           })));
2962 
2963       if (paramSizes.back().size() != targets.size()) {
2964         DiagnosedSilenceableFailure diag =
2965             emitSilenceableError()
2966             << "expected as many parameter values ("
2967             << dynamicSizeProducers.back().size() << ") as target ops ("
2968             << targets.size() << ")";
2969         diag.attachNote(transformValue.getLoc()) << "for this parameter";
2970         return diag;
2971       }
2972 
2973       continue;
2974     }
2975     paramSizes.push_back({});
2976     dynamicSizeProducers.push_back(
2977         llvm::to_vector(state.getPayloadOps(transformValue)));
2978 
2979     if (dynamicSizeProducers.back().size() != targets.size()) {
2980       DiagnosedSilenceableFailure diag =
2981           emitSilenceableError()
2982           << "expected as many dynamic size-producing operations ("
2983           << dynamicSizeProducers.back().size() << ") as target ops ("
2984           << targets.size() << ")";
2985       diag.attachNote(transformValue.getLoc()) << "for this handle";
2986       return diag;
2987     }
2988 
2989     for (Operation *op : dynamicSizeProducers.back()) {
2990       if (op->getNumResults() == 1 &&
2991           isa<IndexType>(op->getResult(0).getType())) {
2992         continue;
2993       }
2994 
2995       DiagnosedSilenceableFailure diag =
2996           emitSilenceableError() << "expected sizes to be produced by ops "
2997                                     "with a single index-type result";
2998       diag.attachNote(op->getLoc()) << "size producer op";
2999       diag.attachNote(transformValue.getLoc()) << "for this handle";
3000       return diag;
3001     }
3002   }
3003 
3004   SmallVector<Operation *> tiled;
3005   SmallVector<SmallVector<Operation *, 4>, 4> loops;
3006   loops.resize(getLoops().size());
3007   auto scalableSizes = getScalableSizes();
3008   for (auto [i, op] : llvm::enumerate(targets)) {
3009     auto tilingInterface = dyn_cast<TilingInterface>(op);
3010     if (!tilingInterface) {
3011       DiagnosedSilenceableFailure diag =
3012           emitSilenceableError()
3013           << "only ops implementing TilingInterface are supported";
3014       diag.attachNote(op->getLoc()) << "target op";
3015       return diag;
3016     }
3017     if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3018       DiagnosedSilenceableFailure diag =
3019           emitSilenceableError()
3020           << "too many tiles provided, expected at most "
3021           << tilingInterface.getLoopIteratorTypes().size() << " found "
3022           << tileSizes.size();
3023       diag.attachNote(op->getLoc()) << "target op";
3024       return diag;
3025     }
3026 
3027     scf::SCFTilingOptions tilingOptions;
3028     if (tileSizes.empty()) {
3029       tilingOptions.setTileSizeComputationFunction(
3030           [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3031             return {};
3032           });
3033     } else {
3034       tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3035                                                                   Operation *) {
3036         SmallVector<OpFoldResult> sizes;
3037         sizes.reserve(tileSizes.size());
3038         unsigned dynamicIdx = 0;
3039 
3040         for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3041           if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3042             if (scalableSizes[ofrIdx]) {
3043               auto val = b.create<arith::ConstantIndexOp>(
3044                   getLoc(), cast<IntegerAttr>(attr).getInt());
3045               Value vscale =
3046                   b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
3047               sizes.push_back(
3048                   b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
3049             } else {
3050               sizes.push_back(attr);
3051             }
3052             continue;
3053           }
3054           ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3055           ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3056           ++dynamicIdx;
3057           assert((dynamicSizes.empty() ^ params.empty()) &&
3058                  "expected either dynamic sizes or parameters");
3059           if (!params.empty()) {
3060             sizes.push_back(b.getIndexAttr(params[index]));
3061           } else {
3062             sizes.push_back(dynamicSizes[index]->getResult(0));
3063           }
3064         }
3065         return sizes;
3066       });
3067     }
3068 
3069     tilingOptions.setInterchange(getInterchange());
3070     FailureOr<scf::SCFTilingResult> maybeTilingResult =
3071         tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3072     if (failed(maybeTilingResult))
3073       return DiagnosedSilenceableFailure::definiteFailure();
3074 
3075     rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
3076 
3077     tiled.append(maybeTilingResult->tiledOps);
3078     for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3079       loops[en2.index()].push_back(en2.value());
3080   }
3081 
3082   transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3083   for (const auto &en : llvm::enumerate(loops))
3084     transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3085 
3086   return DiagnosedSilenceableFailure::success();
3087 }
3088 
3089 SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3090   ValueRange dynamic = getDynamicSizes();
3091   ArrayRef<int64_t> tileSizes = getStaticSizes();
3092   SmallVector<OpFoldResult> results;
3093   results.reserve(tileSizes.size());
3094   unsigned dynamicPos = 0;
3095   Builder builder(getContext());
3096   for (int64_t size : tileSizes) {
3097     if (size == ShapedType::kDynamic) {
3098       results.push_back(dynamic[dynamicPos++]);
3099     } else {
3100       results.push_back(builder.getIndexAttr(size));
3101     }
3102   }
3103   return results;
3104 }
3105 
3106 void transform::TileUsingForOp::getEffects(
3107     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3108   consumesHandle(getTargetMutable(), effects);
3109   onlyReadsHandle(getDynamicSizesMutable(), effects);
3110   producesHandle(getOperation()->getOpResults(), effects);
3111   modifiesPayload(effects);
3112 }
3113 
3114 //===----------------------------------------------------------------------===//
3115 // TileUsingForallOp
3116 //===----------------------------------------------------------------------===//
3117 
3118 void transform::TileUsingForallOp::build(OpBuilder &builder,
3119                                          OperationState &result, Value target,
3120                                          ArrayRef<int64_t> staticTileSizes,
3121                                          transform::TileSizesSpec,
3122                                          ArrayAttr mapping) {
3123   return build(builder, result,
3124                /*target=*/target,
3125                /*mixedTileSizes=*/
3126                getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3127                /*_=*/TileSizesSpec(),
3128                /*mapping=*/mapping);
3129 }
3130 
3131 void transform::TileUsingForallOp::build(OpBuilder &builder,
3132                                          OperationState &result, Value target,
3133                                          ArrayRef<OpFoldResult> mixedTileSizes,
3134                                          transform::TileSizesSpec,
3135                                          ArrayAttr mapping) {
3136   SmallVector<int64_t> staticTileSizes;
3137   SmallVector<Value> dynamicTileSizes;
3138   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3139   // Call the default builder which sets up the proper operands segment sizes
3140   // attributes for multiple variadic operands. In the absence of this,
3141   // horrible bugs ensue.
3142   MLIRContext *ctx = builder.getContext();
3143   auto operationType = transform::AnyOpType::get(ctx);
3144   auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3145   build(builder, result,
3146         /*resultTypes=*/TypeRange{operationType, operationType},
3147         /*target=*/target,
3148         /*num_threads=*/ValueRange{},
3149         /*tile_sizes=*/dynamicTileSizes,
3150         /*packed_num_threads=*/Value(),
3151         /*packed_tile_sizes=*/Value(),
3152         /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3153         /*static_tile_sizes=*/staticTileSizesAttr,
3154         /*mapping=*/mapping);
3155 }
3156 
3157 void transform::TileUsingForallOp::build(OpBuilder &builder,
3158                                          OperationState &result, Value target,
3159                                          ArrayRef<int64_t> staticNumThreads,
3160                                          transform::NumThreadsSpec,
3161                                          ArrayAttr mapping) {
3162   return build(builder, result, target,
3163                getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3164                NumThreadsSpec(), mapping);
3165 }
3166 
3167 void transform::TileUsingForallOp::build(OpBuilder &builder,
3168                                          OperationState &result, Value target,
3169                                          ArrayRef<OpFoldResult> mixedNumThreads,
3170                                          transform::NumThreadsSpec,
3171                                          ArrayAttr mapping) {
3172   SmallVector<int64_t> staticNumThreads;
3173   SmallVector<Value> dynamicNumThreads;
3174   dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3175                              staticNumThreads);
3176   // Call the default builder which sets up the proper operands segment sizes
3177   // attributes for multiple variadic operands. In the absence of this,
3178   // horrible bugs ensue.
3179   MLIRContext *ctx = builder.getContext();
3180   auto operationType = transform::AnyOpType::get(ctx);
3181   auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3182   build(builder, result,
3183         /*resultTypes=*/TypeRange{operationType, operationType},
3184         /*target=*/target,
3185         /*num_threads=*/dynamicNumThreads,
3186         /*tile_sizes=*/ValueRange{},
3187         /*packed_num_threads=*/Value(),
3188         /*packed_tile_sizes=*/Value(),
3189         /*static_num_threads=*/staticNumThreadsAttr,
3190         /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3191         /*mapping=*/mapping);
3192 }
3193 
3194 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3195 /// normalized upper bound.
3196 static SmallVector<OpFoldResult>
3197 normalizeUpperBounds(RewriterBase &rewriter, Location loc,
3198                      ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
3199                      ArrayRef<OpFoldResult> steps) {
3200   AffineExpr s0, s1, s2;
3201   bindSymbols(rewriter.getContext(), s0, s1, s2);
3202   AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3203   SmallVector<OpFoldResult> normalizedUbs;
3204   for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3205     OpFoldResult normalizedUb = affine::makeComposedFoldedAffineApply(
3206         rewriter, loc, normalizedUbExpr, {lb, ub, step});
3207     normalizedUbs.push_back(normalizedUb);
3208   }
3209   return normalizedUbs;
3210 }
3211 
3212 /// When a loop is normalized, the uses of the induction variable within the
3213 /// loop need to replaced with `original_lb + old_iv * original_step`.
3214 static SmallVector<Value> denormalizeIndVar(RewriterBase &rewriter,
3215                                             Location loc, ValueRange ivs,
3216                                             ArrayRef<OpFoldResult> lbs,
3217                                             ArrayRef<OpFoldResult> steps) {
3218   AffineExpr s0, s1;
3219   AffineExpr d0;
3220   bindSymbols(rewriter.getContext(), s0, s1);
3221   bindDims(rewriter.getContext(), d0);
3222   AffineExpr denormExpr = s0 + d0 * s1;
3223   SmallVector<Value> denormalizedIvs;
3224 
3225   for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3226     OpFoldResult denormValue = affine::makeComposedFoldedAffineApply(
3227         rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3228     denormalizedIvs.push_back(
3229         getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3230   }
3231   return denormalizedIvs;
3232 }
3233 
3234 /// Given a `scf.forall` loop return a loop op with the loop bounds
3235 /// normalized.
3236 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3237 /// At the time of writing, this wasnt done since adding this to `scf`
3238 /// dialect would disallow using of `affine.apply` operations due
3239 /// to cyclic dependencies. To avoid churn in lit tests
3240 /// with the change this was added with, defer that to a follow up.
3241 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3242                                            scf::ForallOp loop) {
3243   SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3244   SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3245   SmallVector<OpFoldResult> steps = loop.getMixedStep();
3246 
3247   if (llvm::all_of(
3248           lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3249       llvm::all_of(
3250           steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3251     return loop;
3252   }
3253 
3254   Location loc = loop.getLoc();
3255   SmallVector<OpFoldResult> normalizedUbs =
3256       normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3257   SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3258                                           rewriter.getIndexAttr(0));
3259   SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3260                                             rewriter.getIndexAttr(1));
3261 
3262   auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3263       loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3264       loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
3265 
3266   auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3267   OpBuilder::InsertionGuard g(rewriter);
3268   Block *normalizedLoopBlock = normalizedForallOp.getBody();
3269   rewriter.setInsertionPointToStart(normalizedLoopBlock);
3270 
3271   SmallVector<Value> argValues =
3272       denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3273   argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3274                    normalizedForallOp.getRegionIterArgs().end());
3275   Block *origLoopBlock = loop.getBody();
3276   rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3277 
3278   rewriter.replaceOp(loop, normalizedForallOp);
3279   return normalizedForallOp;
3280 }
3281 
3282 DiagnosedSilenceableFailure transform::tileToForallOpImpl(
3283     RewriterBase &rewriter, transform::TransformState &state,
3284     TransformOpInterface transformOp, Operation *target,
3285     ArrayRef<OpFoldResult> mixedNumThreads,
3286     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3287     scf::SCFTilingResult &tilingResult) {
3288   // Transform all targets one by one.
3289   auto tileableOp = dyn_cast<TilingInterface>(target);
3290   if (!tileableOp) {
3291     DiagnosedSilenceableFailure diag =
3292         transformOp.emitSilenceableError()
3293         << "only TilingInterface ops are supported";
3294     diag.attachNote(target->getLoc()) << "target op";
3295     return diag;
3296   }
3297   rewriter.setInsertionPoint(tileableOp);
3298   scf::SCFTilingOptions options;
3299   options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3300   if (!mixedNumThreads.empty()) {
3301     options.setNumThreads(mixedNumThreads);
3302   } else {
3303     options.setTileSizes(mixedTileSizes);
3304   }
3305   if (mapping) {
3306     options.setMapping(mapping.value().getValue());
3307   }
3308   FailureOr<scf::SCFTilingResult> maybeTilingResult =
3309       scf::tileUsingSCF(rewriter, tileableOp, options);
3310 
3311   if (failed(maybeTilingResult))
3312     return transformOp.emitDefaultSilenceableFailure(tileableOp);
3313 
3314   rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3315 
3316   tilingResult = *maybeTilingResult;
3317 
3318   if (mixedNumThreads.empty()) {
3319     auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3320     OpBuilder::InsertionGuard g(rewriter);
3321     rewriter.setInsertionPoint(generatedForallOp);
3322     scf::ForallOp normalizedForallOp =
3323         normalizeForallLoopOp(rewriter, generatedForallOp);
3324     tilingResult.loops.front() = normalizedForallOp;
3325   }
3326 
3327   return DiagnosedSilenceableFailure::success();
3328 }
3329 
3330 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3331     transform::TransformRewriter &rewriter,
3332     transform::TransformResults &transformResults,
3333     transform::TransformState &state) {
3334   auto transformOp = cast<TransformOpInterface>(getOperation());
3335 
3336   // Result payload ops.
3337   SmallVector<Operation *> tileOps;
3338   SmallVector<Operation *> tiledOps;
3339 
3340   // Unpack handles.
3341   SmallVector<OpFoldResult> mixedNumThreads;
3342   DiagnosedSilenceableFailure status =
3343       getPackedNumThreads()
3344           ? unpackSingleIndexResultPayloadOperations(
3345                 state, transformOp, mixedNumThreads, getPackedNumThreads())
3346           : unpackSingleIndexResultPayloadOperations(
3347                 state, transformOp, mixedNumThreads, getMixedNumThreads());
3348   if (!status.succeeded())
3349     return status;
3350   SmallVector<OpFoldResult> mixedTileSizes;
3351   status = getPackedTileSizes()
3352                ? unpackSingleIndexResultPayloadOperations(
3353                      state, transformOp, mixedTileSizes, getPackedTileSizes())
3354                : unpackSingleIndexResultPayloadOperations(
3355                      state, transformOp, mixedTileSizes, getMixedTileSizes());
3356   if (!status.succeeded())
3357     return status;
3358 
3359   for (Operation *target : state.getPayloadOps(getTarget())) {
3360     scf::SCFTilingResult tilingResult;
3361     DiagnosedSilenceableFailure diag = tileToForallOpImpl(
3362         rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3363         getMapping(), tilingResult);
3364     if (!diag.succeeded())
3365       return diag;
3366     tileOps.push_back(tilingResult.loops.front());
3367     tiledOps.append(tilingResult.tiledOps);
3368   }
3369 
3370   transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3371   transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3372 
3373   return DiagnosedSilenceableFailure::success();
3374 }
3375 
3376 void transform::TileUsingForallOp::getEffects(
3377     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3378   consumesHandle(getTargetMutable(), effects);
3379   onlyReadsHandle(getTileSizesMutable(), effects);
3380   onlyReadsHandle(getNumThreadsMutable(), effects);
3381   onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3382   onlyReadsHandle(getPackedTileSizesMutable(), effects);
3383   producesHandle(getOperation()->getOpResults(), effects);
3384   modifiesPayload(effects);
3385 }
3386 
3387 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3388   Builder b(getContext());
3389   return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3390 }
3391 
3392 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3393   Builder b(getContext());
3394   return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3395 }
3396 
3397 LogicalResult TileUsingForallOp::verify() {
3398   int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3399                        static_cast<int>(getPackedNumThreads() != Value());
3400   if (numThreadsSpec > 1)
3401     return emitOpError(
3402         "num_threads and packed_num_threads are mutually exclusive");
3403   int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3404                       static_cast<int>(getPackedTileSizes() != Value());
3405   if (tileSizesSpec > 1)
3406     return emitOpError(
3407         "tile_sizes and packed_tile_sizes are mutually exclusive");
3408   if (numThreadsSpec == 0 && tileSizesSpec == 0)
3409     return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3410                        "must be specified");
3411   return success();
3412 }
3413 
3414 //===----------------------------------------------------------------------===//
3415 // VectorizeChildrenAndApplyPatternsOp
3416 //===----------------------------------------------------------------------===//
3417 
3418 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3419     OpBuilder &builder, OperationState &result, Value target,
3420     bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3421   result.addOperands(target);
3422   if (vectorizePadding) {
3423     result.addAttribute(
3424         VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3425             result.name),
3426         builder.getUnitAttr());
3427   }
3428   if (vectorizeExtract) {
3429     result.addAttribute(
3430         VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3431             result.name),
3432         builder.getUnitAttr());
3433   }
3434   if (flatten1DDepthwiseConv) {
3435     result.addAttribute(
3436         VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3437             result.name),
3438         builder.getUnitAttr());
3439   }
3440   result.addTypes(transform::AnyOpType::get(builder.getContext()));
3441 }
3442 
3443 namespace {
3444 /// This is an helper only to call vectorize via a pattern inside of
3445 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3446 struct VectorizationPattern : public RewritePattern {
3447   explicit VectorizationPattern(MLIRContext *context,
3448                                 bool vectorizeExtract = false,
3449                                 bool flattenConv = false)
3450       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3451         vectorizeNDExtract(vectorizeExtract),
3452         flatten1DDepthwiseConv(flattenConv) {}
3453   LogicalResult matchAndRewrite(Operation *op,
3454                                 PatternRewriter &rewriter) const override {
3455     if (!linalg::hasVectorizationImpl(op))
3456       return rewriter.notifyMatchFailure(op,
3457                                          "Unsupported Op, cannot vectorize");
3458     return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3459                      /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3460                      flatten1DDepthwiseConv);
3461   }
3462 
3463 private:
3464   /// Controls whether to vectorize `tensor.extract` when the input tensor is
3465   /// rank >= 2.
3466   bool vectorizeNDExtract = false;
3467   /// Controls whether to "flatten" the channel dimension when vectorising 1D
3468   /// depthwise convolutions. This should lead to bette vectorization for
3469   /// tensors with a low number of channel dimensions.
3470   bool flatten1DDepthwiseConv = false;
3471 };
3472 } // namespace
3473 
3474 DiagnosedSilenceableFailure
3475 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3476     transform::TransformRewriter &rewriter, Operation *target,
3477     transform::ApplyToEachResultList &results,
3478     transform::TransformState &state) {
3479   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3480     auto diag = this->emitOpError("requires isolated-from-above targets");
3481     diag.attachNote(target->getLoc()) << "non-isolated target";
3482     return DiagnosedSilenceableFailure::definiteFailure();
3483   }
3484 
3485   MLIRContext *ctx = getContext();
3486   RewritePatternSet patterns(ctx);
3487   patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3488                                      getFlatten_1dDepthwiseConv());
3489 
3490   if (!getDisableTransferPermutationMapLoweringPatterns())
3491     vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
3492 
3493   if (!getDisableMultiReductionToContractPatterns())
3494     vector::populateVectorReductionToContractPatterns(patterns);
3495 
3496   vector::populateSinkVectorOpsPatterns(patterns);
3497 
3498   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
3499                linalg::LinalgCopyVTWForwardingPattern>(ctx,
3500                                                        /*benefit=*/2);
3501   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3502   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3503   tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
3504 
3505   patterns.add<CopyVectorizationPattern>(ctx);
3506 
3507   // Add misc. vectorization patterns (e.g. for tensor.insert_slice)
3508   linalg::populateInsertSliceVectorizationPatterns(patterns);
3509 
3510   if (getVectorizePadding()) {
3511     linalg::populatePadOpVectorizationPatterns(patterns);
3512     // This creates an alternative path for lowering tensor.pad - by
3513     // decomposing it into e.g. linalg.fill.
3514     linalg::populateDecomposePadPatterns(patterns);
3515   }
3516   vector::populateVectorStepLoweringPatterns(patterns);
3517 
3518   TrackingListener listener(state, *this);
3519   GreedyRewriteConfig config;
3520   config.listener = &listener;
3521   if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
3522     return emitDefaultDefiniteFailure(target);
3523 
3524   results.push_back(target);
3525   return DiagnosedSilenceableFailure::success();
3526 }
3527 
3528 //===----------------------------------------------------------------------===//
3529 // VectorizeOp
3530 //===----------------------------------------------------------------------===//
3531 
3532 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3533     transform::TransformRewriter &rewriter,
3534     mlir::transform::TransformResults &transformResults,
3535     mlir::transform::TransformState &state) {
3536   auto targets = state.getPayloadOps(getTarget());
3537   if (std::empty(targets))
3538     return DiagnosedSilenceableFailure::success();
3539   auto transformOp = cast<TransformOpInterface>(getOperation());
3540   SmallVector<int64_t> vectorSizes;
3541   DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
3542       state, transformOp, getMixedVectorSizes(), vectorSizes);
3543   if (!status.succeeded())
3544     return status;
3545 
3546   // TODO: Check that the correct number of vectorSizes was provided.
3547   for (Operation *target : targets) {
3548     if (!linalg::hasVectorizationImpl(target)) {
3549       return mlir::emitSilenceableFailure(target->getLoc())
3550              << "Unsupported Op, cannot vectorize";
3551     }
3552 
3553     if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3554                                  getScalableSizes(),
3555                                  getVectorizeNdExtract().value_or(false)))) {
3556       return mlir::emitSilenceableFailure(target->getLoc())
3557              << "Attempted to vectorize, but failed";
3558     }
3559   }
3560 
3561   return DiagnosedSilenceableFailure::success();
3562 }
3563 
3564 void transform::VectorizeOp::getEffects(
3565     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3566   consumesHandle(getTargetMutable(), effects);
3567   onlyReadsHandle(getVectorSizesMutable(), effects);
3568   modifiesPayload(effects);
3569 }
3570 
3571 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3572   OpBuilder b(getContext());
3573   return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3574 }
3575 
3576 LogicalResult transform::VectorizeOp::verify() {
3577   if (getStaticVectorSizes().size() != getScalableSizes().size())
3578     return emitOpError("expected same number of vector sizes (")
3579            << getStaticVectorSizes().size() << ") and scalable sizes ("
3580            << getScalableSizes().size() << ")";
3581   return success();
3582 }
3583 
3584 //===----------------------------------------------------------------------===//
3585 // HoistRedundantVectorTransfersOp
3586 //===----------------------------------------------------------------------===//
3587 
3588 DiagnosedSilenceableFailure
3589 transform::HoistRedundantVectorTransfersOp::applyToOne(
3590     transform::TransformRewriter &rewriter, func::FuncOp target,
3591     transform::ApplyToEachResultList &results,
3592     transform::TransformState &state) {
3593   // WARNING: This hoisting does not model parallelism and is generally
3594   // incorrect when used on distributed loops with memref semantics!
3595   // TODO: obsolete and should be retired.
3596   linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3597   results.push_back(target);
3598   return DiagnosedSilenceableFailure::success();
3599 }
3600 
3601 //===----------------------------------------------------------------------===//
3602 // HoistRedundantVectorBroadcastsOp
3603 //===----------------------------------------------------------------------===//
3604 
3605 DiagnosedSilenceableFailure
3606 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3607     transform::TransformRewriter &rewriter, mlir::Operation *target,
3608     transform::ApplyToEachResultList &results,
3609     transform::TransformState &state) {
3610   rewriter.setInsertionPoint(target);
3611   linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3612   results.push_back(target);
3613   return DiagnosedSilenceableFailure::success();
3614 }
3615 
3616 //===----------------------------------------------------------------------===//
3617 // ConvertConv2DToImg2ColOp.
3618 //===----------------------------------------------------------------------===//
3619 
3620 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3621     transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3622     transform::ApplyToEachResultList &results,
3623     transform::TransformState &state) {
3624   rewriter.setInsertionPoint(target);
3625   auto maybeTransformed =
3626       TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
3627           target)
3628           .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3629             return rewriteInIm2Col(rewriter, op);
3630           })
3631           .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3632             return rewriteInIm2Col(rewriter, op);
3633           })
3634           .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3635             return rewriteInIm2Col(rewriter, op);
3636           })
3637           .Case([&](linalg::Conv2DNchwFchwOp op) {
3638             return rewriteInIm2Col(rewriter, op);
3639           })
3640           .Default([&](Operation *op) {
3641             return rewriter.notifyMatchFailure(op, "not supported");
3642           });
3643   if (failed(maybeTransformed))
3644     return emitDefaultSilenceableFailure(target);
3645   // Handle to the operation producing the img2col tensor.
3646   results.push_back(maybeTransformed->first);
3647   // Handle to the operation that replaces the original convolution.
3648   results.push_back(maybeTransformed->second);
3649   return DiagnosedSilenceableFailure::success();
3650 }
3651 
3652 //===----------------------------------------------------------------------===//
3653 // FlattenElementwiseLinalgOp.
3654 //===----------------------------------------------------------------------===//
3655 
3656 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3657     transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3658     transform::ApplyToEachResultList &results,
3659     transform::TransformState &state) {
3660   rewriter.setInsertionPoint(target);
3661   if (!isElementwise(target))
3662     return mlir::emitSilenceableFailure(target->getLoc())
3663            << "only elementwise flattening is supported";
3664 
3665   // If rank <= 1, do nothing
3666   if (target.getNumLoops() <= 1) {
3667     results.push_back(target);
3668     return DiagnosedSilenceableFailure::success();
3669   }
3670 
3671   // Attempt to flatten all dims to one.
3672   ReassociationIndices reassociation(target.getNumLoops());
3673   std::iota(reassociation.begin(), reassociation.end(), 0);
3674   auto maybeFlattened =
3675       collapseOpIterationDims(target, reassociation, rewriter);
3676   if (failed(maybeFlattened))
3677     return mlir::emitSilenceableFailure(target->getLoc())
3678            << "attempted to flatten, but failed";
3679   results.push_back(maybeFlattened->collapsedOp);
3680   rewriter.replaceOp(target, maybeFlattened->results);
3681   return DiagnosedSilenceableFailure::success();
3682 }
3683 
3684 //===----------------------------------------------------------------------===//
3685 // TransposeConv2DOp
3686 //===----------------------------------------------------------------------===//
3687 
3688 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3689     transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3690     transform::ApplyToEachResultList &results,
3691     transform::TransformState &state) {
3692   rewriter.setInsertionPoint(target);
3693   auto maybeTransformed =
3694       TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3695           .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3696             return transposeConv2D(rewriter, op);
3697           })
3698           .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3699             return transposeConv2D(rewriter, op);
3700           })
3701           .Default([&](Operation *op) {
3702             return rewriter.notifyMatchFailure(op, "not supported");
3703           });
3704   if (failed(maybeTransformed))
3705     return emitDefaultSilenceableFailure(target);
3706   // Handle to the new Conv2D operation with transposed filters
3707   results.push_back(*maybeTransformed);
3708   return DiagnosedSilenceableFailure::success();
3709 }
3710 
3711 //===----------------------------------------------------------------------===//
3712 // TransposeMatmulOp
3713 //===----------------------------------------------------------------------===//
3714 
3715 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3716     transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3717     transform::ApplyToEachResultList &results,
3718     transform::TransformState &state) {
3719   rewriter.setInsertionPoint(target);
3720   bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3721   auto maybeTransformed =
3722       TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3723           .Case([&](linalg::MatmulOp op) {
3724             return transposeMatmul(rewriter, op, transposeLHS);
3725           })
3726           .Case([&](linalg::BatchMatmulOp op) {
3727             return transposeBatchMatmul(rewriter, op, transposeLHS);
3728           })
3729           .Default([&](Operation *op) { return failure(); });
3730   if (failed(maybeTransformed))
3731     return emitSilenceableFailure(target->getLoc()) << "not supported";
3732   // Handle to the new Matmul operation with transposed filters
3733   results.push_back(*maybeTransformed);
3734   return DiagnosedSilenceableFailure::success();
3735 }
3736 
3737 //===----------------------------------------------------------------------===//
3738 // InsertSliceToCopyOp
3739 //===----------------------------------------------------------------------===//
3740 template <typename OpTy>
3741 DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
3742                                  transform::ApplyToEachResultList &results,
3743                                  transform::TransformState &state) {
3744   static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3745                                 tensor::ParallelInsertSliceOp>() &&
3746                 "wrong op type");
3747 
3748   if (auto copySource =
3749           target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3750     results.push_back(copySource);
3751     return DiagnosedSilenceableFailure::success();
3752   }
3753 
3754   // If we are inside an InParallel region, temporarily set the insertion point
3755   // outside: only tensor.parallel_insert_slice ops are allowed in there.
3756   if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3757     rewriter.setInsertionPoint(
3758         target->template getParentOfType<scf::InParallelOp>());
3759   }
3760 
3761   Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3762       target.getLoc(), target.getDest(), target.getMixedOffsets(),
3763       target.getMixedSizes(), target.getMixedStrides());
3764   Value copied = rewriter
3765                      .create<linalg::CopyOp>(target.getLoc(),
3766                                              target.getSource(), extracted)
3767                      .getResult(0);
3768   // Reset the insertion point.
3769   rewriter.setInsertionPoint(target);
3770   rewriter.replaceOpWithNewOp<OpTy>(
3771       target, copied, target.getDest(), target.getMixedOffsets(),
3772       target.getMixedSizes(), target.getMixedStrides());
3773 
3774   results.push_back(copied.getDefiningOp());
3775   return DiagnosedSilenceableFailure::success();
3776 }
3777 
3778 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3779     transform::TransformRewriter &rewriter, Operation *targetOp,
3780     transform::ApplyToEachResultList &results,
3781     transform::TransformState &state) {
3782 
3783   rewriter.setInsertionPoint(targetOp);
3784   if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3785     return doit(rewriter, target, results, state);
3786   if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3787     return doit(rewriter, target, results, state);
3788 
3789   DiagnosedSilenceableFailure diag =
3790       emitSilenceableError()
3791       << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3792   diag.attachNote(targetOp->getLoc()) << "target op";
3793   return diag;
3794 }
3795 
3796 //===----------------------------------------------------------------------===//
3797 // MapCopyToThreadsOp
3798 //===----------------------------------------------------------------------===//
3799 
3800 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3801     transform::TransformRewriter &rewriter, Operation *target,
3802     transform::ApplyToEachResultList &results,
3803     transform::TransformState &state) {
3804   // Check if the op is supported.
3805   if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3806     DiagnosedSilenceableFailure diag =
3807         emitSilenceableError()
3808         << "only linalg.copy and tensor.pad target ops are supported";
3809     diag.attachNote(target->getLoc()) << "target op";
3810     return diag;
3811   }
3812   assert(target->getNumResults() == 1 && "expected single result");
3813   auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3814   if (!resultShapedType.hasStaticShape()) {
3815     DiagnosedSilenceableFailure diag =
3816         emitSilenceableError()
3817         << "only statically sized ops of rank <= 3 are supported";
3818     diag.attachNote(target->getLoc()) << "target op";
3819     return diag;
3820   }
3821 
3822   // Conservatively set the minimum viable desired bitwidth alignment.
3823   int64_t desiredBitAlignment = getDesiredBitAlignment();
3824   int64_t eltBitwidth =
3825       resultShapedType.getElementType().getIntOrFloatBitWidth();
3826   if (desiredBitAlignment % eltBitwidth != 0) {
3827     desiredBitAlignment = eltBitwidth;
3828   }
3829 
3830   gpu::CopyMappingInfo mapping(
3831       /*ctx=*/getContext(),
3832       /*totalNumThreads=*/getTotalNumThreads(),
3833       /*alignment=*/desiredBitAlignment,
3834       /*sizes=*/resultShapedType.getShape(),
3835       /*favorPredication=*/false,
3836       /*elementalBitwidth=*/
3837       resultShapedType.getElementType().getIntOrFloatBitWidth());
3838   if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3839     DiagnosedSilenceableFailure diag =
3840         emitSilenceableError()
3841         << "too few threads to map copy op to threads on the most minor "
3842            "dimension, given alignment and vector size constraints, try "
3843            "smaller tile size of mapping to more threads";
3844     diag.attachNote(target->getLoc()) << "target op";
3845     return diag;
3846   }
3847 
3848   // OpBuilder only used to compute attributes.
3849   OpBuilder b(getContext());
3850   scf::SCFTilingResult tilingResult;
3851   DiagnosedSilenceableFailure diag = tileToForallOpImpl(
3852       /*rewriter=*/rewriter,
3853       /*state=*/state,
3854       /*transformOp=*/*this,
3855       /*target=*/target,
3856       /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3857       /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3858       /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3859       /*tilingResult=*/tilingResult);
3860   if (!diag.succeeded())
3861     return diag;
3862 
3863   results.push_back(tilingResult.loops.front());
3864   for (auto op : tilingResult.tiledOps)
3865     results.push_back(op);
3866   return DiagnosedSilenceableFailure::success();
3867 }
3868 
3869 //===----------------------------------------------------------------------===//
3870 // WinogradConv2DOp
3871 //===----------------------------------------------------------------------===//
3872 
3873 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3874     transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3875     transform::ApplyToEachResultList &results,
3876     transform::TransformState &state) {
3877   rewriter.setInsertionPoint(target);
3878   FailureOr<Operation *> maybeTransformed = failure();
3879   bool supported = TypeSwitch<Operation *, bool>(target)
3880                        .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3881                          maybeTransformed =
3882                              winogradConv2D(rewriter, op, getM(), getR());
3883                          return true;
3884                        })
3885                        .Default([&](Operation *op) { return false; });
3886 
3887   if (!supported) {
3888     return emitSilenceableError()
3889            << "this operation is not supported to convert to Winograd Conv2D";
3890   }
3891 
3892   if (failed(maybeTransformed)) {
3893     return emitSilenceableError() << "apply Winograd Conv2D failed";
3894   }
3895 
3896   results.push_back(*maybeTransformed);
3897   return DiagnosedSilenceableFailure::success();
3898 }
3899 
3900 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
3901     transform::TransformRewriter &rewriter, Operation *target,
3902     transform::ApplyToEachResultList &results,
3903     transform::TransformState &state) {
3904   rewriter.setInsertionPoint(target);
3905   FailureOr<Operation *> maybeTransformed = failure();
3906   bool supported =
3907       TypeSwitch<Operation *, bool>(target)
3908           .Case([&](linalg::WinogradFilterTransformOp op) {
3909             maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
3910             return true;
3911           })
3912           .Case([&](linalg::WinogradInputTransformOp op) {
3913             maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
3914             return true;
3915           })
3916           .Case([&](linalg::WinogradOutputTransformOp op) {
3917             maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
3918             return true;
3919           })
3920           .Default([&](Operation *op) { return false; });
3921 
3922   if (!supported) {
3923     DiagnosedSilenceableFailure diag =
3924         emitSilenceableError()
3925         << "this operation is not supported to decompose into other operations";
3926     diag.attachNote(target->getLoc()) << "target op";
3927     return diag;
3928   }
3929 
3930   if (failed(maybeTransformed)) {
3931     DiagnosedSilenceableFailure diag =
3932         emitSilenceableError() << "decompose Winograd operations failed";
3933     diag.attachNote(target->getLoc()) << "target op";
3934     return diag;
3935   }
3936 
3937   results.push_back(*maybeTransformed);
3938   return DiagnosedSilenceableFailure::success();
3939 }
3940 
3941 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3942 
3943 #define GET_OP_CLASSES
3944 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
3945