19a8bb4bcSWilliam S. Moses //===- TestClone.cpp - Pass to test operation cloning --------------------===// 2ed499ddcSWilliam S. Moses // 3ed499ddcSWilliam S. Moses // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ed499ddcSWilliam S. Moses // See https://llvm.org/LICENSE.txt for license information. 5ed499ddcSWilliam S. Moses // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ed499ddcSWilliam S. Moses // 7ed499ddcSWilliam S. Moses //===----------------------------------------------------------------------===// 8ed499ddcSWilliam S. Moses 9*e95e94adSJeff Niu #include "TestOps.h" 10ed499ddcSWilliam S. Moses #include "mlir/IR/BuiltinOps.h" 11ed499ddcSWilliam S. Moses #include "mlir/Pass/Pass.h" 12ed499ddcSWilliam S. Moses 13ed499ddcSWilliam S. Moses using namespace mlir; 14ed499ddcSWilliam S. Moses 15ed499ddcSWilliam S. Moses namespace { 16ed499ddcSWilliam S. Moses 175a4ca51aSjeanPerier struct DumpNotifications : public OpBuilder::Listener { notifyOperationInserted__anon7794c1660111::DumpNotifications185cc0f76dSMatthias Springer void notifyOperationInserted(Operation *op, 195cc0f76dSMatthias Springer OpBuilder::InsertPoint previous) override { 205a4ca51aSjeanPerier llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n"; 215a4ca51aSjeanPerier } 225a4ca51aSjeanPerier }; 235a4ca51aSjeanPerier 24ed499ddcSWilliam S. Moses /// This is a test pass which clones the body of a function. Specifically 25ed499ddcSWilliam S. Moses /// this pass replaces f(x) to instead return f(f(x)) in which the cloned body 26ed499ddcSWilliam S. Moses /// takes the result of the first operation return as an input. 27ed499ddcSWilliam S. Moses struct ClonePass 28ed499ddcSWilliam S. Moses : public PassWrapper<ClonePass, InterfacePass<FunctionOpInterface>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon7794c1660111::ClonePass290df963e8SWilliam S. Moses MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ClonePass) 30ac860240SRiver Riddle 31ed499ddcSWilliam S. Moses StringRef getArgument() const final { return "test-clone"; } getDescription__anon7794c1660111::ClonePass32ed499ddcSWilliam S. Moses StringRef getDescription() const final { return "Test clone of op"; } runOnOperation__anon7794c1660111::ClonePass33ed499ddcSWilliam S. Moses void runOnOperation() override { 34ed499ddcSWilliam S. Moses FunctionOpInterface op = getOperation(); 35ed499ddcSWilliam S. Moses 36ed499ddcSWilliam S. Moses // Limit testing to ops with only one region. 37ed499ddcSWilliam S. Moses if (op->getNumRegions() != 1) 38ed499ddcSWilliam S. Moses return; 39ed499ddcSWilliam S. Moses 40ed499ddcSWilliam S. Moses Region ®ion = op->getRegion(0); 41ed499ddcSWilliam S. Moses if (!region.hasOneBlock()) 42ed499ddcSWilliam S. Moses return; 43ed499ddcSWilliam S. Moses 44ed499ddcSWilliam S. Moses Block ®ionEntry = region.front(); 45ac860240SRiver Riddle Operation *terminator = regionEntry.getTerminator(); 46ed499ddcSWilliam S. Moses 47ed499ddcSWilliam S. Moses // Only handle functions whose returns match the inputs. 48ed499ddcSWilliam S. Moses if (terminator->getNumOperands() != regionEntry.getNumArguments()) 49ed499ddcSWilliam S. Moses return; 50ed499ddcSWilliam S. Moses 514d67b278SJeff Niu IRMapping map; 52ed499ddcSWilliam S. Moses for (auto tup : 53ed499ddcSWilliam S. Moses llvm::zip(terminator->getOperands(), regionEntry.getArguments())) { 54ed499ddcSWilliam S. Moses if (std::get<0>(tup).getType() != std::get<1>(tup).getType()) 55ed499ddcSWilliam S. Moses return; 56ed499ddcSWilliam S. Moses map.map(std::get<1>(tup), std::get<0>(tup)); 57ed499ddcSWilliam S. Moses } 58ed499ddcSWilliam S. Moses 59ac860240SRiver Riddle OpBuilder builder(op->getContext()); 605a4ca51aSjeanPerier DumpNotifications dumpNotifications; 615a4ca51aSjeanPerier builder.setListener(&dumpNotifications); 62ac860240SRiver Riddle builder.setInsertionPointToEnd(®ionEntry); 63ed499ddcSWilliam S. Moses SmallVector<Operation *> toClone; 64ed499ddcSWilliam S. Moses for (Operation &inst : regionEntry) 65ed499ddcSWilliam S. Moses toClone.push_back(&inst); 66ed499ddcSWilliam S. Moses for (Operation *inst : toClone) 67ac860240SRiver Riddle builder.clone(*inst, map); 68ed499ddcSWilliam S. Moses terminator->erase(); 69ed499ddcSWilliam S. Moses } 70ed499ddcSWilliam S. Moses }; 71ed499ddcSWilliam S. Moses } // namespace 72ed499ddcSWilliam S. Moses 73ed499ddcSWilliam S. Moses namespace mlir { registerCloneTestPasses()74ed499ddcSWilliam S. Mosesvoid registerCloneTestPasses() { PassRegistration<ClonePass>(); } 75ed499ddcSWilliam S. Moses } // namespace mlir 76