1 //===- TestClone.cpp - Pass to test operation cloning --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestOps.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 17 struct DumpNotifications : public OpBuilder::Listener { notifyOperationInserted__anon7794c1660111::DumpNotifications18 void notifyOperationInserted(Operation *op, 19 OpBuilder::InsertPoint previous) override { 20 llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n"; 21 } 22 }; 23 24 /// This is a test pass which clones the body of a function. Specifically 25 /// this pass replaces f(x) to instead return f(f(x)) in which the cloned body 26 /// takes the result of the first operation return as an input. 27 struct ClonePass 28 : public PassWrapper<ClonePass, InterfacePass<FunctionOpInterface>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon7794c1660111::ClonePass29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ClonePass) 30 31 StringRef getArgument() const final { return "test-clone"; } getDescription__anon7794c1660111::ClonePass32 StringRef getDescription() const final { return "Test clone of op"; } runOnOperation__anon7794c1660111::ClonePass33 void runOnOperation() override { 34 FunctionOpInterface op = getOperation(); 35 36 // Limit testing to ops with only one region. 37 if (op->getNumRegions() != 1) 38 return; 39 40 Region ®ion = op->getRegion(0); 41 if (!region.hasOneBlock()) 42 return; 43 44 Block ®ionEntry = region.front(); 45 Operation *terminator = regionEntry.getTerminator(); 46 47 // Only handle functions whose returns match the inputs. 48 if (terminator->getNumOperands() != regionEntry.getNumArguments()) 49 return; 50 51 IRMapping map; 52 for (auto tup : 53 llvm::zip(terminator->getOperands(), regionEntry.getArguments())) { 54 if (std::get<0>(tup).getType() != std::get<1>(tup).getType()) 55 return; 56 map.map(std::get<1>(tup), std::get<0>(tup)); 57 } 58 59 OpBuilder builder(op->getContext()); 60 DumpNotifications dumpNotifications; 61 builder.setListener(&dumpNotifications); 62 builder.setInsertionPointToEnd(®ionEntry); 63 SmallVector<Operation *> toClone; 64 for (Operation &inst : regionEntry) 65 toClone.push_back(&inst); 66 for (Operation *inst : toClone) 67 builder.clone(*inst, map); 68 terminator->erase(); 69 } 70 }; 71 } // namespace 72 73 namespace mlir { registerCloneTestPasses()74void registerCloneTestPasses() { PassRegistration<ClonePass>(); } 75 } // namespace mlir 76