xref: /llvm-project/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Tensor transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19 #include "mlir/Dialect/Transform/IR/TransformOps.h"
20 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 struct TestTensorTransforms
28     : public PassWrapper<TestTensorTransforms, OperationPass<>> {
29   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms)
30 
31   TestTensorTransforms() = default;
32   TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
33 
34   void getDependentDialects(DialectRegistry &registry) const override {
35     registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
36                     transform::TransformDialect>();
37   }
38 
39   StringRef getArgument() const final {
40     return "test-tensor-transform-patterns";
41   }
42   StringRef getDescription() const final {
43     return "Test Tensor transformation patterns by applying them greedily.";
44   }
45 
46   void runOnOperation() override;
47 
48   Option<bool> testFoldConstantExtractSlice{
49       *this, "test-fold-constant-extract-slice",
50       llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
51       llvm::cl::init(false)};
52 
53   Option<bool> testFoldConsecutiveInsertExtractSlice{
54       *this, "test-fold-consecutive-insert-extract-slice",
55       llvm::cl::desc(
56           "Test folding consecutive tensor.insert_slice/tensor.extract_slice"),
57       llvm::cl::init(false)};
58 
59   Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
60       *this, "test-rewrite-extract-slice-from-collapse-shape",
61       llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
62                      "with loop nest"),
63       llvm::cl::init(false)};
64 
65   Option<bool> testDropRedundantInsertSliceRankExpansion{
66       *this, "test-drop-redundant-insert-slice-rank-expansion",
67       llvm::cl::desc("Test dropping redundant insert_slice rank expansions"),
68       llvm::cl::init(false)};
69 
70   Option<bool> testReassociativeReshapeFolding{
71       *this, "test-reassociative-reshape-folding",
72       llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
73       llvm::cl::init(false)};
74 
75   Option<bool> testBubbleUpExpandShapePatterns{
76       *this, "test-expand-shape-bubbling",
77       llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
78       llvm::cl::init(false)};
79 
80   Option<bool> testFoldIntoPackAndUnpack{
81       *this, "test-fold-into-pack-and-unpack",
82       llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
83       llvm::cl::init(false)};
84 
85   Option<bool> useForeach{
86       *this, "use-foreach",
87       llvm::cl::desc(
88           "Use the scf.forall operation when generating loop nests for "
89           "the extract_slice of collapse_shape pattern"),
90       llvm::cl::init(false)};
91 
92   Option<bool> testSimplifyPackUnpackPatterns{
93       *this, "test-simplify-pack-unpack-patterns",
94       llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack"),
95       llvm::cl::init(false)};
96 
97   Option<bool> testTrackingListener{
98       *this, "test-tracking-listener",
99       llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
100       llvm::cl::init(false)};
101 };
102 } // namespace
103 
104 static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
105   RewritePatternSet patterns(rootOp->getContext());
106   tensor::populateReassociativeReshapeFoldingPatterns(patterns);
107   (void)applyPatternsGreedily(rootOp, std::move(patterns));
108 }
109 
110 static void applyBubbleUpExpandShapePatterns(Operation *rootOp) {
111   RewritePatternSet patterns(rootOp->getContext());
112   tensor::populateBubbleUpExpandShapePatterns(patterns);
113   (void)applyPatternsGreedily(rootOp, std::move(patterns));
114 }
115 
116 static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
117   RewritePatternSet patterns(rootOp->getContext());
118   tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
119   (void)applyPatternsGreedily(rootOp, std::move(patterns));
120 }
121 
122 static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
123   RewritePatternSet patterns(rootOp->getContext());
124   tensor::ControlConstantExtractSliceFusionFn controlFn =
125       [](tensor::ExtractSliceOp op) {
126         if (!op.getSource().hasOneUse())
127           return false;
128 
129         auto resultType = cast<ShapedType>(op.getResult().getType());
130         constexpr int64_t kConstantFoldingMaxNumElements = 1024;
131         return resultType.getNumElements() <= kConstantFoldingMaxNumElements;
132       };
133 
134   tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
135   (void)applyPatternsGreedily(rootOp, std::move(patterns));
136 }
137 
138 static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
139   RewritePatternSet patterns(rootOp->getContext());
140   tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
141   (void)applyPatternsGreedily(rootOp, std::move(patterns));
142 }
143 
144 static void
145 applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
146   RewritePatternSet patterns(rootOp->getContext());
147   tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
148   (void)applyPatternsGreedily(rootOp, std::move(patterns));
149 }
150 
151 static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
152   RewritePatternSet patterns(rootOp->getContext());
153   tensor::populateSimplifyPackAndUnpackPatterns(patterns);
154   (void)applyPatternsGreedily(rootOp, std::move(patterns));
155 }
156 
157 namespace {
158 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
159 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
160 /// stitches together the desired tile from slices of the source of the collapse
161 /// shape op.
162 struct RewriteExtractSliceFromCollapseShapeBase
163     : public OpRewritePattern<tensor::ExtractSliceOp> {
164   RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context)
165       : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {}
166 
167   /// Emit a loop or gather operation that uses `helper` to take each point in
168   /// the parallel iteration space bounds, extract a slice from the source
169   /// tensor and insert it into `dest`. For examples, see below for `scf.for`
170   /// and `scf.foreach`.
171   virtual LogicalResult
172   emitReplacement(tensor::ExtractSliceOp op, Value dest,
173                   tensor::ExtractSliceFromCollapseHelper &helper,
174                   PatternRewriter &rewriter) const = 0;
175 
176   LogicalResult matchAndRewrite(tensor::ExtractSliceOp op,
177                                 PatternRewriter &rewriter) const override {
178     auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>();
179     if (!collapseOp)
180       return rewriter.notifyMatchFailure(
181           op, "producer is not a tensor.collapse_shape op");
182 
183     // Try to simplify the collapse shape using a rank-reducing slice, if
184     // possible.
185     FailureOr<Operation *> simplifiedCollapseShapeResult =
186         tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp,
187                                                                   rewriter);
188     if (succeeded(simplifiedCollapseShapeResult)) {
189       auto newCollapseOp =
190           dyn_cast<tensor::CollapseShapeOp>(*simplifiedCollapseShapeResult);
191       // The collapse shape op might have been simplified away, so we can just
192       // return.
193       if (!newCollapseOp)
194         return success();
195       collapseOp = newCollapseOp;
196     }
197 
198     // Materialize the output shape values of the slice operation.
199     ReifiedRankedShapedTypeDims reifiedShapes;
200     if (failed(reifyResultShapes(rewriter, op, reifiedShapes)))
201       return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
202 
203     // Create the destination tensor using the above values.
204     Type elementType = op.getSourceType().getElementType();
205     SmallVector<OpFoldResult> outputShape = reifiedShapes[0];
206     Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape,
207                                                   elementType);
208 
209     // Calculate the parameters for the tile loop nest.
210     FailureOr<tensor::ExtractSliceFromCollapseHelper> params =
211         tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp,
212                                                        op);
213     if (failed(params))
214       return rewriter.notifyMatchFailure(
215           op, "could not calculate tiling parameters");
216     return emitReplacement(op, dest, *params, rewriter);
217   }
218 };
219 
220 struct RewriteExtractSliceFromCollapseShapeUsingScfFor
221     : public RewriteExtractSliceFromCollapseShapeBase {
222   RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context)
223       : RewriteExtractSliceFromCollapseShapeBase(context) {}
224   LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
225                                 tensor::ExtractSliceFromCollapseHelper &helper,
226                                 PatternRewriter &rewriter) const override {
227     Location loc = op.getLoc();
228     const unsigned numTiledDims = helper.getIterationSpaceSizes().size();
229     auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
230     auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
231     SmallVector<Value> lbs(numTiledDims, zero);
232     SmallVector<Value> steps(numTiledDims, one);
233 
234     scf::LoopNest nest = scf::buildLoopNest(
235         rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest,
236         [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
237             ValueRange iterArgs) -> scf::ValueVector {
238           auto [tile, insertParams] =
239               helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
240 
241           // Insert the slice into the destination.
242           return {nestedBuilder.create<tensor::InsertSliceOp>(
243               loc, tile, iterArgs[0], insertParams)};
244         });
245     rewriter.replaceOp(op, nest.results);
246 
247     return success();
248   }
249 };
250 
251 struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
252     : public RewriteExtractSliceFromCollapseShapeBase {
253   RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context)
254       : RewriteExtractSliceFromCollapseShapeBase(context) {}
255   LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
256                                 tensor::ExtractSliceFromCollapseHelper &helper,
257                                 PatternRewriter &rewriter) const override {
258     Location loc = op.getLoc();
259     auto forallOp = rewriter.create<scf::ForallOp>(
260         loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()),
261         /*outputs=*/dest,
262         /*mapping=*/std::nullopt,
263         [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
264           unsigned numThreadIdRegionArgs =
265               helper.getIterationSpaceSizes().size();
266           unsigned numOutputRegionArgs =
267               regionArgs.size() - numThreadIdRegionArgs;
268           ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs);
269           ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs);
270           assert(outputArgs.size() == 1 &&
271                  "there should only be one output region argument");
272           auto [tile, insertParams] =
273               helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
274           // Insert the slice into the destination.
275           auto term = nestedBuilder.create<scf::InParallelOp>(loc);
276           nestedBuilder.setInsertionPointToStart(term.getBody());
277           nestedBuilder.create<tensor::ParallelInsertSliceOp>(
278               loc, tile, outputArgs[0], insertParams);
279         });
280     rewriter.replaceOp(op, forallOp->getResult(0));
281     return success();
282   }
283 };
284 } // namespace
285 
286 static LogicalResult
287 applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
288                                              bool useForeach) {
289   RewritePatternSet patterns(rootOp->getContext());
290   if (useForeach)
291     patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>(
292         rootOp->getContext());
293   else
294     patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>(
295         rootOp->getContext());
296   return applyPatternsGreedily(rootOp, std::move(patterns));
297 }
298 
299 namespace {
300 class DummyTrackingListener : public transform::TrackingListener {
301 public:
302   using transform::TrackingListener::TrackingListener;
303 
304   // Expose `findReplacementOp` as a public function, so that it can be tested.
305   Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
306     Operation *replacementOp;
307     if (!findReplacementOp(replacementOp, op, newValues).succeeded())
308       return nullptr;
309     return replacementOp;
310   }
311 };
312 } // namespace
313 
314 static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
315   // Find replaced op.
316   Operation *replaced = nullptr;
317   WalkResult status = rootOp->walk([&](Operation *op) {
318     if (op->hasAttr("replaced")) {
319       if (replaced) {
320         op->emitError("only one 'replaced' op is allowed per test case");
321         replaced->emitRemark("other 'replaced' op");
322         return WalkResult::interrupt();
323       }
324       replaced = op;
325     }
326     return WalkResult::advance();
327   });
328   if (status.wasInterrupted())
329     return failure();
330   if (!replaced) {
331     rootOp->emitError("could not find 'replaced' op");
332     return failure();
333   }
334 
335   // Find replacements.
336   SmallVector<Value> replacements(replaced->getNumResults(), Value());
337   status = rootOp->walk([&](Operation *op) {
338     for (int64_t i = 0; i < replaced->getNumResults(); ++i) {
339       if (auto attr = op->getAttrOfType<IntegerAttr>("replacement_" +
340                                                      std::to_string(i))) {
341         if (replacements[i]) {
342           op->emitError("only one 'replacement_" + std::to_string(i) +
343                         "' is allowed per test case");
344           replacements[i].getDefiningOp()->emitRemark("other 'replacement_" +
345                                                       std::to_string(i) + "'");
346           return WalkResult::interrupt();
347         }
348         replacements[i] = op->getResult(attr.getInt());
349       }
350     }
351     return WalkResult::advance();
352   });
353   if (status.wasInterrupted())
354     return failure();
355 
356   if (!llvm::all_of(replacements,
357                     [](Value v) { return static_cast<bool>(v); })) {
358     replaced->emitError("insufficient replacement values");
359     return failure();
360   }
361 
362   // Find the replacement op (if any) and emit a remark/error.
363   transform::TransformState transformState =
364       transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
365                                                       /*payloadRoot=*/nullptr);
366   MLIRContext *context = rootOp->getContext();
367   OpBuilder builder(context);
368   OwningOpRef<transform::NamedSequenceOp> transformOp =
369       builder.create<transform::NamedSequenceOp>(
370           rootOp->getLoc(),
371           /*sym_name=*/"test_sequence",
372           /*function_type=*/
373           TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
374           /*sym_visibility*/ StringAttr::get(context, "public"),
375           /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
376           /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
377   DummyTrackingListener listener(transformState, transformOp.get());
378   Operation *replacement = listener.getReplacementOp(replaced, replacements);
379   if (!replacement) {
380     replaced->emitError("listener could not find replacement op");
381     return failure();
382   }
383 
384   replacement->emitRemark("replacement found");
385   return success();
386 }
387 
388 void TestTensorTransforms::runOnOperation() {
389   Operation *rootOp = getOperation();
390   if (testSimplifyPackUnpackPatterns)
391     applySimplifyPackUnpackPatterns(rootOp);
392   if (testFoldConstantExtractSlice)
393     applyFoldConstantExtractSlicePatterns(rootOp);
394   if (testFoldConsecutiveInsertExtractSlice)
395     applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
396   if (testDropRedundantInsertSliceRankExpansion)
397     applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
398   if (testReassociativeReshapeFolding)
399     applyReassociativeReshapeFoldingPatterns(rootOp);
400   if (testBubbleUpExpandShapePatterns)
401     applyBubbleUpExpandShapePatterns(rootOp);
402   if (testFoldIntoPackAndUnpack)
403     applyFoldIntoPackAndUnpackPatterns(rootOp);
404   if (testRewriteExtractSliceWithTiledCollapseShape) {
405     if (failed(
406             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
407       return signalPassFailure();
408   }
409   if (testTrackingListener)
410     if (failed(testTrackingListenerReplacements(rootOp)))
411       return signalPassFailure();
412 }
413 
414 namespace mlir {
415 namespace test {
416 void registerTestTensorTransforms() {
417   PassRegistration<TestTensorTransforms>();
418 }
419 } // namespace test
420 } // namespace mlir
421