xref: /llvm-project/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (revision 4b56345895729fda3bc3c094bc3f237ba3a49686)
1aa2a96a2SMaheshRavishankar //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
2aa2a96a2SMaheshRavishankar //
3aa2a96a2SMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4aa2a96a2SMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information.
5aa2a96a2SMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6aa2a96a2SMaheshRavishankar //
7aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===//
8aa2a96a2SMaheshRavishankar //
9aa2a96a2SMaheshRavishankar // This file defines transform dialect operations used for testing
10aa2a96a2SMaheshRavishankar // TilingInterface
11aa2a96a2SMaheshRavishankar //
12aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===//
13aa2a96a2SMaheshRavishankar 
14aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
15aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Index/IR/IndexDialect.h"
16aa2a96a2SMaheshRavishankar #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Transform/IR/TransformDialect.h"
195a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Utils/StaticValueUtils.h"
21aa2a96a2SMaheshRavishankar #include "mlir/IR/Dominance.h"
22aa2a96a2SMaheshRavishankar #include "mlir/IR/OpImplementation.h"
23aa2a96a2SMaheshRavishankar #include "mlir/Interfaces/TilingInterface.h"
24aa2a96a2SMaheshRavishankar 
25aa2a96a2SMaheshRavishankar #define GET_OP_CLASSES
26aa2a96a2SMaheshRavishankar #include "TestTilingInterfaceTransformOps.h.inc"
27aa2a96a2SMaheshRavishankar 
28aa2a96a2SMaheshRavishankar using namespace mlir;
29aa2a96a2SMaheshRavishankar using namespace mlir::transform;
30aa2a96a2SMaheshRavishankar 
31aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===//
32aa2a96a2SMaheshRavishankar // TestFuseAndYieldOp
33aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===//
34aa2a96a2SMaheshRavishankar 
35aa2a96a2SMaheshRavishankar static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
36aa2a96a2SMaheshRavishankar   SmallVector<Operation *> worklist;
37aa2a96a2SMaheshRavishankar   llvm::SmallDenseSet<Operation *> producers;
38aa2a96a2SMaheshRavishankar   worklist.push_back(op);
39aa2a96a2SMaheshRavishankar   producers.insert(op);
40aa2a96a2SMaheshRavishankar   while (!worklist.empty()) {
41aa2a96a2SMaheshRavishankar     Operation *current = worklist.pop_back_val();
42aa2a96a2SMaheshRavishankar     for (OpOperand &operand : current->getOpOperands()) {
43aa2a96a2SMaheshRavishankar       Operation *producer = operand.get().getDefiningOp();
44aa2a96a2SMaheshRavishankar       if (!producer || !isa<TilingInterface>(producer) ||
45aa2a96a2SMaheshRavishankar           producers.contains(producer))
46aa2a96a2SMaheshRavishankar         continue;
47aa2a96a2SMaheshRavishankar       worklist.push_back(producer);
48aa2a96a2SMaheshRavishankar       producers.insert(producer);
49aa2a96a2SMaheshRavishankar     }
50aa2a96a2SMaheshRavishankar   }
51aa2a96a2SMaheshRavishankar   return producers;
52aa2a96a2SMaheshRavishankar }
53aa2a96a2SMaheshRavishankar 
54aa2a96a2SMaheshRavishankar /// Apply a tile and fuse transformation to all payload ops and store both the
55aa2a96a2SMaheshRavishankar /// tiled operation as well as the created tile loops.
56aa2a96a2SMaheshRavishankar template <typename Range>
57aa2a96a2SMaheshRavishankar static LogicalResult
58aa2a96a2SMaheshRavishankar applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
59aa2a96a2SMaheshRavishankar                       Range &&payloadOps, unsigned numLoops,
60aa2a96a2SMaheshRavishankar                       ArrayRef<OpFoldResult> tileSizes,
6176ead96cSMaheshRavishankar                       ArrayRef<int64_t> interchange, bool useForall,
6276ead96cSMaheshRavishankar                       TransformResults &transformResults) {
63aa2a96a2SMaheshRavishankar   SmallVector<Operation *> tiledOps;
64aa2a96a2SMaheshRavishankar   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
65aa2a96a2SMaheshRavishankar 
66aa2a96a2SMaheshRavishankar   for (Operation *target : payloadOps) {
67aa2a96a2SMaheshRavishankar     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
68aa2a96a2SMaheshRavishankar     if (!tilingInterfaceOp)
69aa2a96a2SMaheshRavishankar       return transformOp->emitError("only TilingInterface ops are supported");
70aa2a96a2SMaheshRavishankar     DominanceInfo dominanceInfo(tilingInterfaceOp);
71aa2a96a2SMaheshRavishankar 
72aa2a96a2SMaheshRavishankar     llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
73aa2a96a2SMaheshRavishankar         collectTiledAndFusedOps(tilingInterfaceOp);
74aa2a96a2SMaheshRavishankar     llvm::DenseSet<Operation *> yieldReplacementsFor;
75aa2a96a2SMaheshRavishankar     for (auto op : tiledAndFusedOps) {
76aa2a96a2SMaheshRavishankar       if (llvm::any_of(op->getUsers(), [&](Operation *user) {
77aa2a96a2SMaheshRavishankar             return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
78aa2a96a2SMaheshRavishankar           })) {
79aa2a96a2SMaheshRavishankar         yieldReplacementsFor.insert(op);
80aa2a96a2SMaheshRavishankar       }
81aa2a96a2SMaheshRavishankar     }
82aa2a96a2SMaheshRavishankar 
83aa2a96a2SMaheshRavishankar     scf::SCFTilingOptions tilingOptions;
84aa2a96a2SMaheshRavishankar     tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
8576ead96cSMaheshRavishankar     if (useForall) {
8676ead96cSMaheshRavishankar       tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
8776ead96cSMaheshRavishankar     }
88aa2a96a2SMaheshRavishankar 
89aa2a96a2SMaheshRavishankar     scf::SCFTileAndFuseOptions tileAndFuseOptions;
90aa2a96a2SMaheshRavishankar     tileAndFuseOptions.setTilingOptions(tilingOptions);
91aa2a96a2SMaheshRavishankar 
92aa2a96a2SMaheshRavishankar     scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
93aa2a96a2SMaheshRavishankar         [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
94d5f0969cSMaheshRavishankar             bool isDestinationOperand)
95d5f0969cSMaheshRavishankar         -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
96aa2a96a2SMaheshRavishankar       Operation *owner = originalProducer.getOwner();
97aa2a96a2SMaheshRavishankar       bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
98d5f0969cSMaheshRavishankar       return scf::SCFTileAndFuseOptions::ControlFnResult{
99d5f0969cSMaheshRavishankar           yieldProducerReplacement};
100aa2a96a2SMaheshRavishankar     };
101aa2a96a2SMaheshRavishankar     tileAndFuseOptions.setFusionControlFn(controlFn);
102aa2a96a2SMaheshRavishankar 
103aa2a96a2SMaheshRavishankar     rewriter.setInsertionPoint(target);
104aa2a96a2SMaheshRavishankar     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
10576ead96cSMaheshRavishankar         scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
10676ead96cSMaheshRavishankar                                                   tileAndFuseOptions);
107aa2a96a2SMaheshRavishankar     if (failed(tiledResults))
108aa2a96a2SMaheshRavishankar       return failure();
109aa2a96a2SMaheshRavishankar 
110aa2a96a2SMaheshRavishankar     // Perform the replacement of tiled and fused values.
111aa2a96a2SMaheshRavishankar     SmallVector<Operation *> opsToReplace{target};
112aa2a96a2SMaheshRavishankar     llvm::append_range(opsToReplace, tiledResults->fusedProducers);
113aa2a96a2SMaheshRavishankar     for (Operation *toReplace : opsToReplace) {
114aa2a96a2SMaheshRavishankar       for (OpResult res : toReplace->getResults())
115aa2a96a2SMaheshRavishankar         if (auto replacement = tiledResults->replacements.lookup(res)) {
116aa2a96a2SMaheshRavishankar           Operation *replacementOp = replacement.getDefiningOp();
11776ead96cSMaheshRavishankar           rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
118aa2a96a2SMaheshRavishankar             Operation *user = use.getOwner();
119aa2a96a2SMaheshRavishankar             return dominanceInfo.properlyDominates(replacementOp, user) &&
120aa2a96a2SMaheshRavishankar                    user->getParentOp() == replacementOp->getParentOp();
121aa2a96a2SMaheshRavishankar           });
122aa2a96a2SMaheshRavishankar         }
123aa2a96a2SMaheshRavishankar 
124aa2a96a2SMaheshRavishankar       if (toReplace->use_empty()) {
125aa2a96a2SMaheshRavishankar         rewriter.eraseOp(toReplace);
126aa2a96a2SMaheshRavishankar       }
127aa2a96a2SMaheshRavishankar     }
128aa2a96a2SMaheshRavishankar 
129aa2a96a2SMaheshRavishankar     // Report back the relevant handles to the transform op.
130aa2a96a2SMaheshRavishankar     tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
131aa2a96a2SMaheshRavishankar     assert(tiledResults->loops.size() == numLoops &&
132aa2a96a2SMaheshRavishankar            "Mismatched number of loops, tile and fuse transform should have "
133aa2a96a2SMaheshRavishankar            "failed");
134aa2a96a2SMaheshRavishankar     for (unsigned int i = 0; i < numLoops; ++i)
135aa2a96a2SMaheshRavishankar       loopOps[i].push_back(tiledResults->loops[i]);
136aa2a96a2SMaheshRavishankar   }
137aa2a96a2SMaheshRavishankar 
138aa2a96a2SMaheshRavishankar   transformResults.set(transformOp->getOpResult(0), tiledOps);
139aa2a96a2SMaheshRavishankar   for (unsigned int i = 0; i < numLoops; ++i)
140aa2a96a2SMaheshRavishankar     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
141aa2a96a2SMaheshRavishankar 
142aa2a96a2SMaheshRavishankar   return success();
143aa2a96a2SMaheshRavishankar }
144aa2a96a2SMaheshRavishankar 
14576ead96cSMaheshRavishankar DiagnosedSilenceableFailure
14676ead96cSMaheshRavishankar transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
14776ead96cSMaheshRavishankar                                      TransformResults &transformResults,
14876ead96cSMaheshRavishankar                                      TransformState &state) {
149aa2a96a2SMaheshRavishankar   SmallVector<int64_t> tileSizes =
150aa2a96a2SMaheshRavishankar       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
151aa2a96a2SMaheshRavishankar   SmallVector<int64_t> tileInterchange =
152aa2a96a2SMaheshRavishankar       extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
153aa2a96a2SMaheshRavishankar 
154aa2a96a2SMaheshRavishankar   SmallVector<OpFoldResult> tileSizesOfr =
155aa2a96a2SMaheshRavishankar       getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
156aa2a96a2SMaheshRavishankar 
157aa2a96a2SMaheshRavishankar   LogicalResult result = applyTileAndFuseToAll(
158aa2a96a2SMaheshRavishankar       rewriter, getOperation(), state.getPayloadOps(getTarget()),
159aa2a96a2SMaheshRavishankar       tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
16076ead96cSMaheshRavishankar       tileInterchange, getUseForall(), transformResults);
161aa2a96a2SMaheshRavishankar   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
162aa2a96a2SMaheshRavishankar                         : DiagnosedSilenceableFailure::success();
163aa2a96a2SMaheshRavishankar }
164aa2a96a2SMaheshRavishankar 
165aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===//
1662b2ce50fSAbhishek Varma // TestFuseConsumerOp
1672b2ce50fSAbhishek Varma //===----------------------------------------------------------------------===//
1682b2ce50fSAbhishek Varma 
1692b2ce50fSAbhishek Varma /// Apply fusing of consumer transformation to all payload ops and store both
1702b2ce50fSAbhishek Varma /// the original consumer operation as well as the fused consumer operation.
1712b2ce50fSAbhishek Varma template <typename Range>
1722b2ce50fSAbhishek Varma static LogicalResult
1732b2ce50fSAbhishek Varma applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
1749bc3102bSYun-Fly                   Range &&payloadOps, uint32_t numConsumerToFuse,
1759bc3102bSYun-Fly                   TransformResults &transformResults) {
1762b2ce50fSAbhishek Varma   SmallVector<Operation *> originalConsumerOps;
1772b2ce50fSAbhishek Varma   SmallVector<Operation *> fusedConsumerOps;
1782b2ce50fSAbhishek Varma 
1792b2ce50fSAbhishek Varma   for (Operation *target : payloadOps) {
1802b2ce50fSAbhishek Varma     rewriter.setInsertionPoint(target);
1812b2ce50fSAbhishek Varma 
1829bc3102bSYun-Fly     while (numConsumerToFuse--) {
1832b2ce50fSAbhishek Varma       FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
1842b2ce50fSAbhishek Varma           scf::tileAndFuseConsumerOfSlice(rewriter, target);
1852b2ce50fSAbhishek Varma 
1862b2ce50fSAbhishek Varma       if (failed(fuseConsumerResults))
1872b2ce50fSAbhishek Varma         return failure();
1882b2ce50fSAbhishek Varma 
1892b2ce50fSAbhishek Varma       // Report back the relevant handles to the transform op.
1902b2ce50fSAbhishek Varma       originalConsumerOps.push_back(
1912b2ce50fSAbhishek Varma           fuseConsumerResults->origConsumerOperand->getOwner());
1922b2ce50fSAbhishek Varma       fusedConsumerOps.push_back(
1932b2ce50fSAbhishek Varma           fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
1942b2ce50fSAbhishek Varma     }
1959bc3102bSYun-Fly   }
1962b2ce50fSAbhishek Varma 
1972b2ce50fSAbhishek Varma   transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
1982b2ce50fSAbhishek Varma   transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
1992b2ce50fSAbhishek Varma   return success();
2002b2ce50fSAbhishek Varma }
2012b2ce50fSAbhishek Varma 
2022b2ce50fSAbhishek Varma DiagnosedSilenceableFailure
2032b2ce50fSAbhishek Varma transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
2042b2ce50fSAbhishek Varma                                      TransformResults &transformResults,
2052b2ce50fSAbhishek Varma                                      TransformState &state) {
2069bc3102bSYun-Fly   LogicalResult result = applyFuseConsumer(
2079bc3102bSYun-Fly       rewriter, getOperation(), state.getPayloadOps(getTarget()),
2089bc3102bSYun-Fly       getNumConsumerToFuse(), transformResults);
2092b2ce50fSAbhishek Varma   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
2102b2ce50fSAbhishek Varma                         : DiagnosedSilenceableFailure::success();
2112b2ce50fSAbhishek Varma }
2122b2ce50fSAbhishek Varma 
2132b2ce50fSAbhishek Varma void transform::TestFuseConsumerOp::getEffects(
2142b2ce50fSAbhishek Varma     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2152c1ae801Sdonald chen   consumesHandle(getTargetMutable(), effects);
2162c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
2172b2ce50fSAbhishek Varma   modifiesPayload(effects);
2182b2ce50fSAbhishek Varma }
2192b2ce50fSAbhishek Varma 
2202b2ce50fSAbhishek Varma //===----------------------------------------------------------------------===//
221aa2a96a2SMaheshRavishankar // TestTileUsingForallOp
222aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===//
223aa2a96a2SMaheshRavishankar 
224aa2a96a2SMaheshRavishankar /// Apply a tiling transformation to all payload ops and store both the
225aa2a96a2SMaheshRavishankar /// tiled operation as well as the created tile loops.
226aa2a96a2SMaheshRavishankar template <typename Range>
227aa2a96a2SMaheshRavishankar static LogicalResult
228aa2a96a2SMaheshRavishankar applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
229aa2a96a2SMaheshRavishankar                Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
230aa2a96a2SMaheshRavishankar                ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
23176ead96cSMaheshRavishankar                TransformResults &transformResults) {
232aa2a96a2SMaheshRavishankar   SmallVector<Operation *> tiledOps;
233aa2a96a2SMaheshRavishankar   SmallVector<Operation *> loopOps;
234aa2a96a2SMaheshRavishankar 
235aa2a96a2SMaheshRavishankar   for (Operation *target : payloadOps) {
236aa2a96a2SMaheshRavishankar     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
237aa2a96a2SMaheshRavishankar     if (!tilingInterfaceOp)
238aa2a96a2SMaheshRavishankar       return transformOp->emitError("only TilingInterface ops are supported");
239aa2a96a2SMaheshRavishankar     scf::SCFTilingOptions tilingOptions;
240aa2a96a2SMaheshRavishankar     tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
241aa2a96a2SMaheshRavishankar     if (mapping) {
2426740d701SMaheshRavishankar       tilingOptions.setMapping(mapping.value().getValue());
243aa2a96a2SMaheshRavishankar     }
24476ead96cSMaheshRavishankar     tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
245aa2a96a2SMaheshRavishankar 
246aa2a96a2SMaheshRavishankar     rewriter.setInsertionPoint(target);
247aa2a96a2SMaheshRavishankar     FailureOr<scf::SCFTilingResult> tiledResults =
24876ead96cSMaheshRavishankar         scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
249aa2a96a2SMaheshRavishankar     if (failed(tiledResults))
250aa2a96a2SMaheshRavishankar       return failure();
251aa2a96a2SMaheshRavishankar 
252aa2a96a2SMaheshRavishankar     // Perform the replacement of tiled and fused values.
253*4b563458SKunwar Grover     rewriter.replaceOp(tilingInterfaceOp,
254*4b563458SKunwar Grover                        tiledResults->mergeResult.replacements);
255aa2a96a2SMaheshRavishankar 
256aa2a96a2SMaheshRavishankar     // Report back the relevant handles to the transform op.
257aa2a96a2SMaheshRavishankar     tiledOps.push_back(tiledResults->tiledOps.front());
258aa2a96a2SMaheshRavishankar     for (Operation *loop : tiledResults->loops)
259aa2a96a2SMaheshRavishankar       loopOps.push_back(loop);
260aa2a96a2SMaheshRavishankar   }
261aa2a96a2SMaheshRavishankar 
262aa2a96a2SMaheshRavishankar   transformResults.set(transformOp->getOpResult(0), tiledOps);
263aa2a96a2SMaheshRavishankar   for (auto [index, loop] : llvm::enumerate(loopOps))
264aa2a96a2SMaheshRavishankar     transformResults.set(transformOp->getOpResult(index + 1), {loop});
265aa2a96a2SMaheshRavishankar 
266aa2a96a2SMaheshRavishankar   return success();
267aa2a96a2SMaheshRavishankar }
268aa2a96a2SMaheshRavishankar 
26976ead96cSMaheshRavishankar DiagnosedSilenceableFailure
27076ead96cSMaheshRavishankar transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
27176ead96cSMaheshRavishankar                                         TransformResults &transformResults,
27276ead96cSMaheshRavishankar                                         TransformState &state) {
273aa2a96a2SMaheshRavishankar   SmallVector<int64_t> tileSizes =
274aa2a96a2SMaheshRavishankar       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
275aa2a96a2SMaheshRavishankar   SmallVector<int64_t> interchange =
276aa2a96a2SMaheshRavishankar       extractFromIntegerArrayAttr<int64_t>(getInterchange());
277aa2a96a2SMaheshRavishankar   SmallVector<OpFoldResult> tileSizesOfr =
278aa2a96a2SMaheshRavishankar       getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
279aa2a96a2SMaheshRavishankar 
280aa2a96a2SMaheshRavishankar   LogicalResult result =
281aa2a96a2SMaheshRavishankar       applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
282aa2a96a2SMaheshRavishankar                      tileSizesOfr, interchange, getMapping(), transformResults);
283aa2a96a2SMaheshRavishankar   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
284aa2a96a2SMaheshRavishankar                         : DiagnosedSilenceableFailure::success();
285aa2a96a2SMaheshRavishankar }
286aa2a96a2SMaheshRavishankar 
287aa2a96a2SMaheshRavishankar void transform::TestTileUsingForallOp::getEffects(
288aa2a96a2SMaheshRavishankar     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2892c1ae801Sdonald chen   consumesHandle(getTargetMutable(), effects);
2902c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
291aa2a96a2SMaheshRavishankar   modifiesPayload(effects);
292aa2a96a2SMaheshRavishankar }
293aa2a96a2SMaheshRavishankar 
29476ead96cSMaheshRavishankar //===----------------------------------------------------------------------===//
29576ead96cSMaheshRavishankar // TestFuseUsingForallOp
29676ead96cSMaheshRavishankar //===----------------------------------------------------------------------===//
29776ead96cSMaheshRavishankar 
29876ead96cSMaheshRavishankar /// Apply a tiling transformation to all payload ops and store both the
29976ead96cSMaheshRavishankar /// tiled operation as well as the created tile loops.
30076ead96cSMaheshRavishankar template <typename Range>
30176ead96cSMaheshRavishankar static LogicalResult applyTilingToAll(
30276ead96cSMaheshRavishankar     RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
30376ead96cSMaheshRavishankar     unsigned numLoops, TransformResults &transformResults,
30476ead96cSMaheshRavishankar     function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
30576ead96cSMaheshRavishankar         applyFn) {
30676ead96cSMaheshRavishankar   SmallVector<Operation *> tiledLinalgOps;
30776ead96cSMaheshRavishankar   SmallVector<SmallVector<Operation *>> loopOps(1);
30876ead96cSMaheshRavishankar 
30976ead96cSMaheshRavishankar   for (Operation *target : payloadOps) {
31076ead96cSMaheshRavishankar     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
31176ead96cSMaheshRavishankar     if (!tilingInterfaceOp)
31276ead96cSMaheshRavishankar       return transformOp->emitError("only TilingInterface ops are supported");
31376ead96cSMaheshRavishankar 
31476ead96cSMaheshRavishankar     rewriter.setInsertionPoint(target);
31576ead96cSMaheshRavishankar     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
31676ead96cSMaheshRavishankar         applyFn(tilingInterfaceOp);
31776ead96cSMaheshRavishankar     if (failed(tiledResults))
31876ead96cSMaheshRavishankar       return failure();
31976ead96cSMaheshRavishankar 
32076ead96cSMaheshRavishankar     // Perform the replacement of tiled and fused values.
32176ead96cSMaheshRavishankar     SmallVector<Operation *> opsToReplace{target};
32276ead96cSMaheshRavishankar     llvm::append_range(opsToReplace, tiledResults->fusedProducers);
32376ead96cSMaheshRavishankar     for (Operation *toReplace : opsToReplace) {
32476ead96cSMaheshRavishankar       for (OpResult res : toReplace->getResults())
32576ead96cSMaheshRavishankar         if (auto replacement = tiledResults->replacements.lookup(res))
32676ead96cSMaheshRavishankar           rewriter.replaceAllUsesWith(res, replacement);
32776ead96cSMaheshRavishankar       if (toReplace->use_empty())
32876ead96cSMaheshRavishankar         rewriter.eraseOp(toReplace);
32976ead96cSMaheshRavishankar     }
33076ead96cSMaheshRavishankar 
33176ead96cSMaheshRavishankar     // Report back the relevant handles to the transform op.
33276ead96cSMaheshRavishankar     tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
33376ead96cSMaheshRavishankar     assert(tiledResults->loops.size() == 1 &&
33476ead96cSMaheshRavishankar            cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
33576ead96cSMaheshRavishankar            "Mismatched number of loops, tile and fuse transform should have "
33676ead96cSMaheshRavishankar            "failed");
33776ead96cSMaheshRavishankar     loopOps[0] = {tiledResults->loops[0]};
33876ead96cSMaheshRavishankar   }
33976ead96cSMaheshRavishankar 
34076ead96cSMaheshRavishankar   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
34176ead96cSMaheshRavishankar   if (!loopOps.empty())
34276ead96cSMaheshRavishankar     transformResults.set(transformOp->getOpResult(1), loopOps[0]);
34376ead96cSMaheshRavishankar 
34476ead96cSMaheshRavishankar   return success();
34576ead96cSMaheshRavishankar }
34676ead96cSMaheshRavishankar 
34776ead96cSMaheshRavishankar DiagnosedSilenceableFailure
34876ead96cSMaheshRavishankar transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
34976ead96cSMaheshRavishankar                                         TransformResults &transformResults,
35076ead96cSMaheshRavishankar                                         TransformState &state) {
35176ead96cSMaheshRavishankar   SmallVector<int64_t> tileSizes =
35276ead96cSMaheshRavishankar       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
35376ead96cSMaheshRavishankar   SmallVector<int64_t> tileInterchange =
35476ead96cSMaheshRavishankar       extractFromIntegerArrayAttr<int64_t>(getInterchange());
35576ead96cSMaheshRavishankar 
35676ead96cSMaheshRavishankar   scf::SCFTilingOptions tilingOptions;
35776ead96cSMaheshRavishankar   tilingOptions.interchangeVector = tileInterchange;
35876ead96cSMaheshRavishankar   SmallVector<OpFoldResult> tileSizesOfr =
35976ead96cSMaheshRavishankar       getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
36076ead96cSMaheshRavishankar   tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
36176ead96cSMaheshRavishankar   tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
36276ead96cSMaheshRavishankar   scf::SCFTileAndFuseOptions tileAndFuseOptions;
36376ead96cSMaheshRavishankar   tileAndFuseOptions.tilingOptions = tilingOptions;
36476ead96cSMaheshRavishankar   LogicalResult result = applyTilingToAll(
36576ead96cSMaheshRavishankar       rewriter, getOperation(), state.getPayloadOps(getRootOp()),
36676ead96cSMaheshRavishankar       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
36776ead96cSMaheshRavishankar       [&](TilingInterface tilingInterfaceOp)
36876ead96cSMaheshRavishankar           -> FailureOr<scf::SCFTileAndFuseResult> {
36976ead96cSMaheshRavishankar         return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
37076ead96cSMaheshRavishankar                                                     tileAndFuseOptions);
37176ead96cSMaheshRavishankar       });
37276ead96cSMaheshRavishankar   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
37376ead96cSMaheshRavishankar                         : DiagnosedSilenceableFailure::success();
37476ead96cSMaheshRavishankar }
37576ead96cSMaheshRavishankar 
37676ead96cSMaheshRavishankar void transform::TestFuseUsingForallOp::getEffects(
37776ead96cSMaheshRavishankar     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3782c1ae801Sdonald chen   consumesHandle(getRootOpMutable(), effects);
3792c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
38076ead96cSMaheshRavishankar   modifiesPayload(effects);
38176ead96cSMaheshRavishankar }
38276ead96cSMaheshRavishankar 
383aa2a96a2SMaheshRavishankar #define GET_OP_CLASSES
384aa2a96a2SMaheshRavishankar #include "TestTilingInterfaceTransformOps.cpp.inc"
385aa2a96a2SMaheshRavishankar 
386aa2a96a2SMaheshRavishankar namespace {
387aa2a96a2SMaheshRavishankar class TestTilingInterfaceDialectExtension
388aa2a96a2SMaheshRavishankar     : public transform::TransformDialectExtension<
389aa2a96a2SMaheshRavishankar           TestTilingInterfaceDialectExtension> {
390aa2a96a2SMaheshRavishankar public:
39184cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
39284cc1865SNikhil Kalra       TestTilingInterfaceDialectExtension)
39384cc1865SNikhil Kalra 
394aa2a96a2SMaheshRavishankar   using Base::Base;
395aa2a96a2SMaheshRavishankar 
396aa2a96a2SMaheshRavishankar   void init() {
397aa2a96a2SMaheshRavishankar     declareDependentDialect<affine::AffineDialect>();
398aa2a96a2SMaheshRavishankar     declareDependentDialect<index::IndexDialect>();
399aa2a96a2SMaheshRavishankar     declareDependentDialect<scf::SCFDialect>();
400aa2a96a2SMaheshRavishankar     declareDependentDialect<tensor::TensorDialect>();
401aa2a96a2SMaheshRavishankar 
402aa2a96a2SMaheshRavishankar     registerTransformOps<
403aa2a96a2SMaheshRavishankar #define GET_OP_LIST
404aa2a96a2SMaheshRavishankar #include "TestTilingInterfaceTransformOps.cpp.inc"
405aa2a96a2SMaheshRavishankar         >();
406aa2a96a2SMaheshRavishankar   }
407aa2a96a2SMaheshRavishankar };
408aa2a96a2SMaheshRavishankar } // namespace
409aa2a96a2SMaheshRavishankar 
410aa2a96a2SMaheshRavishankar namespace test {
411aa2a96a2SMaheshRavishankar void registerTestTilingInterfaceTransformDialectExtension(
412aa2a96a2SMaheshRavishankar     DialectRegistry &registry) {
413aa2a96a2SMaheshRavishankar   registry.addExtensions<TestTilingInterfaceDialectExtension>();
414aa2a96a2SMaheshRavishankar }
415aa2a96a2SMaheshRavishankar } // namespace test
416