xref: /llvm-project/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (revision 4b56345895729fda3bc3c094bc3f237ba3a49686)
1 //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines transform dialect operations used for testing
10 // TilingInterface
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Index/IR/IndexDialect.h"
16 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/Interfaces/TilingInterface.h"
24 
25 #define GET_OP_CLASSES
26 #include "TestTilingInterfaceTransformOps.h.inc"
27 
28 using namespace mlir;
29 using namespace mlir::transform;
30 
31 //===----------------------------------------------------------------------===//
32 // TestFuseAndYieldOp
33 //===----------------------------------------------------------------------===//
34 
35 static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
36   SmallVector<Operation *> worklist;
37   llvm::SmallDenseSet<Operation *> producers;
38   worklist.push_back(op);
39   producers.insert(op);
40   while (!worklist.empty()) {
41     Operation *current = worklist.pop_back_val();
42     for (OpOperand &operand : current->getOpOperands()) {
43       Operation *producer = operand.get().getDefiningOp();
44       if (!producer || !isa<TilingInterface>(producer) ||
45           producers.contains(producer))
46         continue;
47       worklist.push_back(producer);
48       producers.insert(producer);
49     }
50   }
51   return producers;
52 }
53 
54 /// Apply a tile and fuse transformation to all payload ops and store both the
55 /// tiled operation as well as the created tile loops.
56 template <typename Range>
57 static LogicalResult
58 applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
59                       Range &&payloadOps, unsigned numLoops,
60                       ArrayRef<OpFoldResult> tileSizes,
61                       ArrayRef<int64_t> interchange, bool useForall,
62                       TransformResults &transformResults) {
63   SmallVector<Operation *> tiledOps;
64   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
65 
66   for (Operation *target : payloadOps) {
67     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
68     if (!tilingInterfaceOp)
69       return transformOp->emitError("only TilingInterface ops are supported");
70     DominanceInfo dominanceInfo(tilingInterfaceOp);
71 
72     llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
73         collectTiledAndFusedOps(tilingInterfaceOp);
74     llvm::DenseSet<Operation *> yieldReplacementsFor;
75     for (auto op : tiledAndFusedOps) {
76       if (llvm::any_of(op->getUsers(), [&](Operation *user) {
77             return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
78           })) {
79         yieldReplacementsFor.insert(op);
80       }
81     }
82 
83     scf::SCFTilingOptions tilingOptions;
84     tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
85     if (useForall) {
86       tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
87     }
88 
89     scf::SCFTileAndFuseOptions tileAndFuseOptions;
90     tileAndFuseOptions.setTilingOptions(tilingOptions);
91 
92     scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
93         [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
94             bool isDestinationOperand)
95         -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
96       Operation *owner = originalProducer.getOwner();
97       bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
98       return scf::SCFTileAndFuseOptions::ControlFnResult{
99           yieldProducerReplacement};
100     };
101     tileAndFuseOptions.setFusionControlFn(controlFn);
102 
103     rewriter.setInsertionPoint(target);
104     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
105         scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
106                                                   tileAndFuseOptions);
107     if (failed(tiledResults))
108       return failure();
109 
110     // Perform the replacement of tiled and fused values.
111     SmallVector<Operation *> opsToReplace{target};
112     llvm::append_range(opsToReplace, tiledResults->fusedProducers);
113     for (Operation *toReplace : opsToReplace) {
114       for (OpResult res : toReplace->getResults())
115         if (auto replacement = tiledResults->replacements.lookup(res)) {
116           Operation *replacementOp = replacement.getDefiningOp();
117           rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
118             Operation *user = use.getOwner();
119             return dominanceInfo.properlyDominates(replacementOp, user) &&
120                    user->getParentOp() == replacementOp->getParentOp();
121           });
122         }
123 
124       if (toReplace->use_empty()) {
125         rewriter.eraseOp(toReplace);
126       }
127     }
128 
129     // Report back the relevant handles to the transform op.
130     tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
131     assert(tiledResults->loops.size() == numLoops &&
132            "Mismatched number of loops, tile and fuse transform should have "
133            "failed");
134     for (unsigned int i = 0; i < numLoops; ++i)
135       loopOps[i].push_back(tiledResults->loops[i]);
136   }
137 
138   transformResults.set(transformOp->getOpResult(0), tiledOps);
139   for (unsigned int i = 0; i < numLoops; ++i)
140     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
141 
142   return success();
143 }
144 
145 DiagnosedSilenceableFailure
146 transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
147                                      TransformResults &transformResults,
148                                      TransformState &state) {
149   SmallVector<int64_t> tileSizes =
150       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
151   SmallVector<int64_t> tileInterchange =
152       extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
153 
154   SmallVector<OpFoldResult> tileSizesOfr =
155       getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
156 
157   LogicalResult result = applyTileAndFuseToAll(
158       rewriter, getOperation(), state.getPayloadOps(getTarget()),
159       tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
160       tileInterchange, getUseForall(), transformResults);
161   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
162                         : DiagnosedSilenceableFailure::success();
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // TestFuseConsumerOp
167 //===----------------------------------------------------------------------===//
168 
169 /// Apply fusing of consumer transformation to all payload ops and store both
170 /// the original consumer operation as well as the fused consumer operation.
171 template <typename Range>
172 static LogicalResult
173 applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
174                   Range &&payloadOps, uint32_t numConsumerToFuse,
175                   TransformResults &transformResults) {
176   SmallVector<Operation *> originalConsumerOps;
177   SmallVector<Operation *> fusedConsumerOps;
178 
179   for (Operation *target : payloadOps) {
180     rewriter.setInsertionPoint(target);
181 
182     while (numConsumerToFuse--) {
183       FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
184           scf::tileAndFuseConsumerOfSlice(rewriter, target);
185 
186       if (failed(fuseConsumerResults))
187         return failure();
188 
189       // Report back the relevant handles to the transform op.
190       originalConsumerOps.push_back(
191           fuseConsumerResults->origConsumerOperand->getOwner());
192       fusedConsumerOps.push_back(
193           fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
194     }
195   }
196 
197   transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
198   transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
199   return success();
200 }
201 
202 DiagnosedSilenceableFailure
203 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
204                                      TransformResults &transformResults,
205                                      TransformState &state) {
206   LogicalResult result = applyFuseConsumer(
207       rewriter, getOperation(), state.getPayloadOps(getTarget()),
208       getNumConsumerToFuse(), transformResults);
209   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
210                         : DiagnosedSilenceableFailure::success();
211 }
212 
213 void transform::TestFuseConsumerOp::getEffects(
214     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
215   consumesHandle(getTargetMutable(), effects);
216   producesHandle(getOperation()->getOpResults(), effects);
217   modifiesPayload(effects);
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // TestTileUsingForallOp
222 //===----------------------------------------------------------------------===//
223 
224 /// Apply a tiling transformation to all payload ops and store both the
225 /// tiled operation as well as the created tile loops.
226 template <typename Range>
227 static LogicalResult
228 applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
229                Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
230                ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
231                TransformResults &transformResults) {
232   SmallVector<Operation *> tiledOps;
233   SmallVector<Operation *> loopOps;
234 
235   for (Operation *target : payloadOps) {
236     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
237     if (!tilingInterfaceOp)
238       return transformOp->emitError("only TilingInterface ops are supported");
239     scf::SCFTilingOptions tilingOptions;
240     tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
241     if (mapping) {
242       tilingOptions.setMapping(mapping.value().getValue());
243     }
244     tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
245 
246     rewriter.setInsertionPoint(target);
247     FailureOr<scf::SCFTilingResult> tiledResults =
248         scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
249     if (failed(tiledResults))
250       return failure();
251 
252     // Perform the replacement of tiled and fused values.
253     rewriter.replaceOp(tilingInterfaceOp,
254                        tiledResults->mergeResult.replacements);
255 
256     // Report back the relevant handles to the transform op.
257     tiledOps.push_back(tiledResults->tiledOps.front());
258     for (Operation *loop : tiledResults->loops)
259       loopOps.push_back(loop);
260   }
261 
262   transformResults.set(transformOp->getOpResult(0), tiledOps);
263   for (auto [index, loop] : llvm::enumerate(loopOps))
264     transformResults.set(transformOp->getOpResult(index + 1), {loop});
265 
266   return success();
267 }
268 
269 DiagnosedSilenceableFailure
270 transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
271                                         TransformResults &transformResults,
272                                         TransformState &state) {
273   SmallVector<int64_t> tileSizes =
274       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
275   SmallVector<int64_t> interchange =
276       extractFromIntegerArrayAttr<int64_t>(getInterchange());
277   SmallVector<OpFoldResult> tileSizesOfr =
278       getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
279 
280   LogicalResult result =
281       applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
282                      tileSizesOfr, interchange, getMapping(), transformResults);
283   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
284                         : DiagnosedSilenceableFailure::success();
285 }
286 
287 void transform::TestTileUsingForallOp::getEffects(
288     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
289   consumesHandle(getTargetMutable(), effects);
290   producesHandle(getOperation()->getOpResults(), effects);
291   modifiesPayload(effects);
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // TestFuseUsingForallOp
296 //===----------------------------------------------------------------------===//
297 
298 /// Apply a tiling transformation to all payload ops and store both the
299 /// tiled operation as well as the created tile loops.
300 template <typename Range>
301 static LogicalResult applyTilingToAll(
302     RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
303     unsigned numLoops, TransformResults &transformResults,
304     function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
305         applyFn) {
306   SmallVector<Operation *> tiledLinalgOps;
307   SmallVector<SmallVector<Operation *>> loopOps(1);
308 
309   for (Operation *target : payloadOps) {
310     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
311     if (!tilingInterfaceOp)
312       return transformOp->emitError("only TilingInterface ops are supported");
313 
314     rewriter.setInsertionPoint(target);
315     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
316         applyFn(tilingInterfaceOp);
317     if (failed(tiledResults))
318       return failure();
319 
320     // Perform the replacement of tiled and fused values.
321     SmallVector<Operation *> opsToReplace{target};
322     llvm::append_range(opsToReplace, tiledResults->fusedProducers);
323     for (Operation *toReplace : opsToReplace) {
324       for (OpResult res : toReplace->getResults())
325         if (auto replacement = tiledResults->replacements.lookup(res))
326           rewriter.replaceAllUsesWith(res, replacement);
327       if (toReplace->use_empty())
328         rewriter.eraseOp(toReplace);
329     }
330 
331     // Report back the relevant handles to the transform op.
332     tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
333     assert(tiledResults->loops.size() == 1 &&
334            cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
335            "Mismatched number of loops, tile and fuse transform should have "
336            "failed");
337     loopOps[0] = {tiledResults->loops[0]};
338   }
339 
340   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
341   if (!loopOps.empty())
342     transformResults.set(transformOp->getOpResult(1), loopOps[0]);
343 
344   return success();
345 }
346 
347 DiagnosedSilenceableFailure
348 transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
349                                         TransformResults &transformResults,
350                                         TransformState &state) {
351   SmallVector<int64_t> tileSizes =
352       extractFromIntegerArrayAttr<int64_t>(getTileSizes());
353   SmallVector<int64_t> tileInterchange =
354       extractFromIntegerArrayAttr<int64_t>(getInterchange());
355 
356   scf::SCFTilingOptions tilingOptions;
357   tilingOptions.interchangeVector = tileInterchange;
358   SmallVector<OpFoldResult> tileSizesOfr =
359       getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
360   tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
361   tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
362   scf::SCFTileAndFuseOptions tileAndFuseOptions;
363   tileAndFuseOptions.tilingOptions = tilingOptions;
364   LogicalResult result = applyTilingToAll(
365       rewriter, getOperation(), state.getPayloadOps(getRootOp()),
366       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
367       [&](TilingInterface tilingInterfaceOp)
368           -> FailureOr<scf::SCFTileAndFuseResult> {
369         return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
370                                                     tileAndFuseOptions);
371       });
372   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
373                         : DiagnosedSilenceableFailure::success();
374 }
375 
376 void transform::TestFuseUsingForallOp::getEffects(
377     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
378   consumesHandle(getRootOpMutable(), effects);
379   producesHandle(getOperation()->getOpResults(), effects);
380   modifiesPayload(effects);
381 }
382 
383 #define GET_OP_CLASSES
384 #include "TestTilingInterfaceTransformOps.cpp.inc"
385 
386 namespace {
387 class TestTilingInterfaceDialectExtension
388     : public transform::TransformDialectExtension<
389           TestTilingInterfaceDialectExtension> {
390 public:
391   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
392       TestTilingInterfaceDialectExtension)
393 
394   using Base::Base;
395 
396   void init() {
397     declareDependentDialect<affine::AffineDialect>();
398     declareDependentDialect<index::IndexDialect>();
399     declareDependentDialect<scf::SCFDialect>();
400     declareDependentDialect<tensor::TensorDialect>();
401 
402     registerTransformOps<
403 #define GET_OP_LIST
404 #include "TestTilingInterfaceTransformOps.cpp.inc"
405         >();
406   }
407 };
408 } // namespace
409 
410 namespace test {
411 void registerTestTilingInterfaceTransformDialectExtension(
412     DialectRegistry &registry) {
413   registry.addExtensions<TestTilingInterfaceDialectExtension>();
414 }
415 } // namespace test
416