1 //===- TestSlicing.cpp - Testing slice functionality ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a simple testing pass for slicing. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/IRMapping.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Support/LLVM.h" 21 22 using namespace mlir; 23 24 /// Create a function with the same signature as the parent function of `op` 25 /// with name being the function name and a `suffix`. 26 static LogicalResult createBackwardSliceFunction(Operation *op, 27 StringRef suffix, 28 bool omitBlockArguments) { 29 func::FuncOp parentFuncOp = op->getParentOfType<func::FuncOp>(); 30 OpBuilder builder(parentFuncOp); 31 Location loc = op->getLoc(); 32 std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str(); 33 func::FuncOp clonedFuncOp = builder.create<func::FuncOp>( 34 loc, clonedFuncOpName, parentFuncOp.getFunctionType()); 35 IRMapping mapper; 36 builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock()); 37 for (const auto &arg : enumerate(parentFuncOp.getArguments())) 38 mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index())); 39 SetVector<Operation *> slice; 40 BackwardSliceOptions options; 41 options.omitBlockArguments = omitBlockArguments; 42 // TODO: Make this default. 43 options.omitUsesFromAbove = false; 44 getBackwardSlice(op, &slice, options); 45 for (Operation *slicedOp : slice) 46 builder.clone(*slicedOp, mapper); 47 builder.create<func::ReturnOp>(loc); 48 return success(); 49 } 50 51 namespace { 52 /// Pass to test slice generated from slice analysis. 53 struct SliceAnalysisTestPass 54 : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> { 55 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SliceAnalysisTestPass) 56 57 StringRef getArgument() const final { return "slice-analysis-test"; } 58 StringRef getDescription() const final { 59 return "Test Slice analysis functionality."; 60 } 61 62 Option<bool> omitBlockArguments{ 63 *this, "omit-block-arguments", 64 llvm::cl::desc("Test Slice analysis with multiple blocks but slice " 65 "omiting block arguments"), 66 llvm::cl::init(true)}; 67 68 void runOnOperation() override; 69 SliceAnalysisTestPass() = default; 70 SliceAnalysisTestPass(const SliceAnalysisTestPass &) {} 71 }; 72 } // namespace 73 74 void SliceAnalysisTestPass::runOnOperation() { 75 ModuleOp module = getOperation(); 76 auto funcOps = module.getOps<func::FuncOp>(); 77 unsigned opNum = 0; 78 for (auto funcOp : funcOps) { 79 // TODO: For now this is just looking for Linalg ops. It can be generalized 80 // to look for other ops using flags. 81 funcOp.walk([&](Operation *op) { 82 if (!isa<linalg::LinalgOp>(op)) 83 return WalkResult::advance(); 84 std::string append = 85 std::string("__backward_slice__") + std::to_string(opNum); 86 (void)createBackwardSliceFunction(op, append, omitBlockArguments); 87 opNum++; 88 return WalkResult::advance(); 89 }); 90 } 91 } 92 93 namespace mlir { 94 void registerSliceAnalysisTestPass() { 95 PassRegistration<SliceAnalysisTestPass>(); 96 } 97 } // namespace mlir 98