xref: /llvm-project/mlir/test/lib/Analysis/TestMatchReduction.cpp (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
12a876a71SDiego Caballero //===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
22a876a71SDiego Caballero //
32a876a71SDiego Caballero // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42a876a71SDiego Caballero // See https://llvm.org/LICENSE.txt for license information.
52a876a71SDiego Caballero // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62a876a71SDiego Caballero //
72a876a71SDiego Caballero //===----------------------------------------------------------------------===//
82a876a71SDiego Caballero //
92a876a71SDiego Caballero // This file contains a test pass for the match reduction utility.
102a876a71SDiego Caballero //
112a876a71SDiego Caballero //===----------------------------------------------------------------------===//
122a876a71SDiego Caballero 
13755dc07dSRiver Riddle #include "mlir/Analysis/SliceAnalysis.h"
14*34a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
152a876a71SDiego Caballero #include "mlir/Pass/Pass.h"
162a876a71SDiego Caballero 
172a876a71SDiego Caballero using namespace mlir;
182a876a71SDiego Caballero 
192a876a71SDiego Caballero namespace {
202a876a71SDiego Caballero 
printReductionResult(Operation * redRegionOp,unsigned numOutput,Value reducedValue,ArrayRef<Operation * > combinerOps)212a876a71SDiego Caballero void printReductionResult(Operation *redRegionOp, unsigned numOutput,
222a876a71SDiego Caballero                           Value reducedValue,
232a876a71SDiego Caballero                           ArrayRef<Operation *> combinerOps) {
242a876a71SDiego Caballero   if (reducedValue) {
252a876a71SDiego Caballero     redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!";
262a876a71SDiego Caballero     redRegionOp->emitRemark("Reduced Value: ") << reducedValue;
272a876a71SDiego Caballero     for (Operation *combOp : combinerOps)
282a876a71SDiego Caballero       redRegionOp->emitRemark("Combiner Op: ") << *combOp;
292a876a71SDiego Caballero 
302a876a71SDiego Caballero     return;
312a876a71SDiego Caballero   }
322a876a71SDiego Caballero 
332a876a71SDiego Caballero   redRegionOp->emitRemark("Reduction NOT found in output #")
342a876a71SDiego Caballero       << numOutput << "!";
352a876a71SDiego Caballero }
362a876a71SDiego Caballero 
372a876a71SDiego Caballero struct TestMatchReductionPass
3887d6bf37SRiver Riddle     : public PassWrapper<TestMatchReductionPass,
3987d6bf37SRiver Riddle                          InterfacePass<FunctionOpInterface>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon69ee29380111::TestMatchReductionPass405e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchReductionPass)
415e50dd04SRiver Riddle 
422a876a71SDiego Caballero   StringRef getArgument() const final { return "test-match-reduction"; }
getDescription__anon69ee29380111::TestMatchReductionPass432a876a71SDiego Caballero   StringRef getDescription() const final {
442a876a71SDiego Caballero     return "Test the match reduction utility.";
452a876a71SDiego Caballero   }
462a876a71SDiego Caballero 
runOnOperation__anon69ee29380111::TestMatchReductionPass4741574554SRiver Riddle   void runOnOperation() override {
4887d6bf37SRiver Riddle     FunctionOpInterface func = getOperation();
492a876a71SDiego Caballero     func->emitRemark("Testing function");
502a876a71SDiego Caballero 
512a876a71SDiego Caballero     func.walk<WalkOrder::PreOrder>([](Operation *op) {
5287d6bf37SRiver Riddle       if (isa<FunctionOpInterface>(op))
532a876a71SDiego Caballero         return;
542a876a71SDiego Caballero 
552a876a71SDiego Caballero       // Limit testing to ops with only one region.
562a876a71SDiego Caballero       if (op->getNumRegions() != 1)
572a876a71SDiego Caballero         return;
582a876a71SDiego Caballero 
592a876a71SDiego Caballero       Region &region = op->getRegion(0);
602a876a71SDiego Caballero       if (!region.hasOneBlock())
612a876a71SDiego Caballero         return;
622a876a71SDiego Caballero 
632a876a71SDiego Caballero       // We expect all the tested region ops to have 1 input by default. The
642a876a71SDiego Caballero       // remaining arguments are assumed to be outputs/reductions and there must
652a876a71SDiego Caballero       // be at least one.
662a876a71SDiego Caballero       // TODO: Extend it to support more generic cases.
672a876a71SDiego Caballero       Block &regionEntry = region.front();
682a876a71SDiego Caballero       auto args = regionEntry.getArguments();
692a876a71SDiego Caballero       if (args.size() < 2)
702a876a71SDiego Caballero         return;
712a876a71SDiego Caballero 
722a876a71SDiego Caballero       auto outputs = args.drop_front();
732a876a71SDiego Caballero       for (int i = 0, size = outputs.size(); i < size; ++i) {
742a876a71SDiego Caballero         SmallVector<Operation *, 4> combinerOps;
752a876a71SDiego Caballero         Value reducedValue = matchReduction(outputs, i, combinerOps);
762a876a71SDiego Caballero         printReductionResult(op, i, reducedValue, combinerOps);
772a876a71SDiego Caballero       }
782a876a71SDiego Caballero     });
792a876a71SDiego Caballero   }
802a876a71SDiego Caballero };
812a876a71SDiego Caballero 
82be0a7e9fSMehdi Amini } // namespace
832a876a71SDiego Caballero 
842a876a71SDiego Caballero namespace mlir {
852a876a71SDiego Caballero namespace test {
registerTestMatchReductionPass()862a876a71SDiego Caballero void registerTestMatchReductionPass() {
872a876a71SDiego Caballero   PassRegistration<TestMatchReductionPass>();
882a876a71SDiego Caballero }
892a876a71SDiego Caballero } // namespace test
902a876a71SDiego Caballero } // namespace mlir
91