1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a test pass for the match reduction utility.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Interfaces/FunctionInterfaces.h"
15 #include "mlir/Pass/Pass.h"
16
17 using namespace mlir;
18
19 namespace {
20
printReductionResult(Operation * redRegionOp,unsigned numOutput,Value reducedValue,ArrayRef<Operation * > combinerOps)21 void printReductionResult(Operation *redRegionOp, unsigned numOutput,
22 Value reducedValue,
23 ArrayRef<Operation *> combinerOps) {
24 if (reducedValue) {
25 redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!";
26 redRegionOp->emitRemark("Reduced Value: ") << reducedValue;
27 for (Operation *combOp : combinerOps)
28 redRegionOp->emitRemark("Combiner Op: ") << *combOp;
29
30 return;
31 }
32
33 redRegionOp->emitRemark("Reduction NOT found in output #")
34 << numOutput << "!";
35 }
36
37 struct TestMatchReductionPass
38 : public PassWrapper<TestMatchReductionPass,
39 InterfacePass<FunctionOpInterface>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon69ee29380111::TestMatchReductionPass40 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchReductionPass)
41
42 StringRef getArgument() const final { return "test-match-reduction"; }
getDescription__anon69ee29380111::TestMatchReductionPass43 StringRef getDescription() const final {
44 return "Test the match reduction utility.";
45 }
46
runOnOperation__anon69ee29380111::TestMatchReductionPass47 void runOnOperation() override {
48 FunctionOpInterface func = getOperation();
49 func->emitRemark("Testing function");
50
51 func.walk<WalkOrder::PreOrder>([](Operation *op) {
52 if (isa<FunctionOpInterface>(op))
53 return;
54
55 // Limit testing to ops with only one region.
56 if (op->getNumRegions() != 1)
57 return;
58
59 Region ®ion = op->getRegion(0);
60 if (!region.hasOneBlock())
61 return;
62
63 // We expect all the tested region ops to have 1 input by default. The
64 // remaining arguments are assumed to be outputs/reductions and there must
65 // be at least one.
66 // TODO: Extend it to support more generic cases.
67 Block ®ionEntry = region.front();
68 auto args = regionEntry.getArguments();
69 if (args.size() < 2)
70 return;
71
72 auto outputs = args.drop_front();
73 for (int i = 0, size = outputs.size(); i < size; ++i) {
74 SmallVector<Operation *, 4> combinerOps;
75 Value reducedValue = matchReduction(outputs, i, combinerOps);
76 printReductionResult(op, i, reducedValue, combinerOps);
77 }
78 });
79 }
80 };
81
82 } // namespace
83
84 namespace mlir {
85 namespace test {
registerTestMatchReductionPass()86 void registerTestMatchReductionPass() {
87 PassRegistration<TestMatchReductionPass>();
88 }
89 } // namespace test
90 } // namespace mlir
91