xref: /llvm-project/mlir/test/lib/Analysis/TestMatchReduction.cpp (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a test pass for the match reduction utility.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Interfaces/FunctionInterfaces.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
printReductionResult(Operation * redRegionOp,unsigned numOutput,Value reducedValue,ArrayRef<Operation * > combinerOps)21 void printReductionResult(Operation *redRegionOp, unsigned numOutput,
22                           Value reducedValue,
23                           ArrayRef<Operation *> combinerOps) {
24   if (reducedValue) {
25     redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!";
26     redRegionOp->emitRemark("Reduced Value: ") << reducedValue;
27     for (Operation *combOp : combinerOps)
28       redRegionOp->emitRemark("Combiner Op: ") << *combOp;
29 
30     return;
31   }
32 
33   redRegionOp->emitRemark("Reduction NOT found in output #")
34       << numOutput << "!";
35 }
36 
37 struct TestMatchReductionPass
38     : public PassWrapper<TestMatchReductionPass,
39                          InterfacePass<FunctionOpInterface>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon69ee29380111::TestMatchReductionPass40   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchReductionPass)
41 
42   StringRef getArgument() const final { return "test-match-reduction"; }
getDescription__anon69ee29380111::TestMatchReductionPass43   StringRef getDescription() const final {
44     return "Test the match reduction utility.";
45   }
46 
runOnOperation__anon69ee29380111::TestMatchReductionPass47   void runOnOperation() override {
48     FunctionOpInterface func = getOperation();
49     func->emitRemark("Testing function");
50 
51     func.walk<WalkOrder::PreOrder>([](Operation *op) {
52       if (isa<FunctionOpInterface>(op))
53         return;
54 
55       // Limit testing to ops with only one region.
56       if (op->getNumRegions() != 1)
57         return;
58 
59       Region &region = op->getRegion(0);
60       if (!region.hasOneBlock())
61         return;
62 
63       // We expect all the tested region ops to have 1 input by default. The
64       // remaining arguments are assumed to be outputs/reductions and there must
65       // be at least one.
66       // TODO: Extend it to support more generic cases.
67       Block &regionEntry = region.front();
68       auto args = regionEntry.getArguments();
69       if (args.size() < 2)
70         return;
71 
72       auto outputs = args.drop_front();
73       for (int i = 0, size = outputs.size(); i < size; ++i) {
74         SmallVector<Operation *, 4> combinerOps;
75         Value reducedValue = matchReduction(outputs, i, combinerOps);
76         printReductionResult(op, i, reducedValue, combinerOps);
77       }
78     });
79   }
80 };
81 
82 } // namespace
83 
84 namespace mlir {
85 namespace test {
registerTestMatchReductionPass()86 void registerTestMatchReductionPass() {
87   PassRegistration<TestMatchReductionPass>();
88 }
89 } // namespace test
90 } // namespace mlir
91