xref: /llvm-project/mlir/test/lib/IR/TestClone.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
19a8bb4bcSWilliam S. Moses //===- TestClone.cpp - Pass to test operation cloning  --------------------===//
2ed499ddcSWilliam S. Moses //
3ed499ddcSWilliam S. Moses // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ed499ddcSWilliam S. Moses // See https://llvm.org/LICENSE.txt for license information.
5ed499ddcSWilliam S. Moses // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ed499ddcSWilliam S. Moses //
7ed499ddcSWilliam S. Moses //===----------------------------------------------------------------------===//
8ed499ddcSWilliam S. Moses 
9*e95e94adSJeff Niu #include "TestOps.h"
10ed499ddcSWilliam S. Moses #include "mlir/IR/BuiltinOps.h"
11ed499ddcSWilliam S. Moses #include "mlir/Pass/Pass.h"
12ed499ddcSWilliam S. Moses 
13ed499ddcSWilliam S. Moses using namespace mlir;
14ed499ddcSWilliam S. Moses 
15ed499ddcSWilliam S. Moses namespace {
16ed499ddcSWilliam S. Moses 
175a4ca51aSjeanPerier struct DumpNotifications : public OpBuilder::Listener {
notifyOperationInserted__anon7794c1660111::DumpNotifications185cc0f76dSMatthias Springer   void notifyOperationInserted(Operation *op,
195cc0f76dSMatthias Springer                                OpBuilder::InsertPoint previous) override {
205a4ca51aSjeanPerier     llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
215a4ca51aSjeanPerier   }
225a4ca51aSjeanPerier };
235a4ca51aSjeanPerier 
24ed499ddcSWilliam S. Moses /// This is a test pass which clones the body of a function. Specifically
25ed499ddcSWilliam S. Moses /// this pass replaces f(x) to instead return f(f(x)) in which the cloned body
26ed499ddcSWilliam S. Moses /// takes the result of the first operation return as an input.
27ed499ddcSWilliam S. Moses struct ClonePass
28ed499ddcSWilliam S. Moses     : public PassWrapper<ClonePass, InterfacePass<FunctionOpInterface>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon7794c1660111::ClonePass290df963e8SWilliam S. Moses   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ClonePass)
30ac860240SRiver Riddle 
31ed499ddcSWilliam S. Moses   StringRef getArgument() const final { return "test-clone"; }
getDescription__anon7794c1660111::ClonePass32ed499ddcSWilliam S. Moses   StringRef getDescription() const final { return "Test clone of op"; }
runOnOperation__anon7794c1660111::ClonePass33ed499ddcSWilliam S. Moses   void runOnOperation() override {
34ed499ddcSWilliam S. Moses     FunctionOpInterface op = getOperation();
35ed499ddcSWilliam S. Moses 
36ed499ddcSWilliam S. Moses     // Limit testing to ops with only one region.
37ed499ddcSWilliam S. Moses     if (op->getNumRegions() != 1)
38ed499ddcSWilliam S. Moses       return;
39ed499ddcSWilliam S. Moses 
40ed499ddcSWilliam S. Moses     Region &region = op->getRegion(0);
41ed499ddcSWilliam S. Moses     if (!region.hasOneBlock())
42ed499ddcSWilliam S. Moses       return;
43ed499ddcSWilliam S. Moses 
44ed499ddcSWilliam S. Moses     Block &regionEntry = region.front();
45ac860240SRiver Riddle     Operation *terminator = regionEntry.getTerminator();
46ed499ddcSWilliam S. Moses 
47ed499ddcSWilliam S. Moses     // Only handle functions whose returns match the inputs.
48ed499ddcSWilliam S. Moses     if (terminator->getNumOperands() != regionEntry.getNumArguments())
49ed499ddcSWilliam S. Moses       return;
50ed499ddcSWilliam S. Moses 
514d67b278SJeff Niu     IRMapping map;
52ed499ddcSWilliam S. Moses     for (auto tup :
53ed499ddcSWilliam S. Moses          llvm::zip(terminator->getOperands(), regionEntry.getArguments())) {
54ed499ddcSWilliam S. Moses       if (std::get<0>(tup).getType() != std::get<1>(tup).getType())
55ed499ddcSWilliam S. Moses         return;
56ed499ddcSWilliam S. Moses       map.map(std::get<1>(tup), std::get<0>(tup));
57ed499ddcSWilliam S. Moses     }
58ed499ddcSWilliam S. Moses 
59ac860240SRiver Riddle     OpBuilder builder(op->getContext());
605a4ca51aSjeanPerier     DumpNotifications dumpNotifications;
615a4ca51aSjeanPerier     builder.setListener(&dumpNotifications);
62ac860240SRiver Riddle     builder.setInsertionPointToEnd(&regionEntry);
63ed499ddcSWilliam S. Moses     SmallVector<Operation *> toClone;
64ed499ddcSWilliam S. Moses     for (Operation &inst : regionEntry)
65ed499ddcSWilliam S. Moses       toClone.push_back(&inst);
66ed499ddcSWilliam S. Moses     for (Operation *inst : toClone)
67ac860240SRiver Riddle       builder.clone(*inst, map);
68ed499ddcSWilliam S. Moses     terminator->erase();
69ed499ddcSWilliam S. Moses   }
70ed499ddcSWilliam S. Moses };
71ed499ddcSWilliam S. Moses } // namespace
72ed499ddcSWilliam S. Moses 
73ed499ddcSWilliam S. Moses namespace mlir {
registerCloneTestPasses()74ed499ddcSWilliam S. Moses void registerCloneTestPasses() { PassRegistration<ClonePass>(); }
75ed499ddcSWilliam S. Moses } // namespace mlir
76