1 //===- TestSlicing.cpp - Testing slice functionality ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a simple testing pass for slicing. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Support/LLVM.h" 21 22 using namespace mlir; 23 24 /// Create a function with the same signature as the parent function of `op` 25 /// with name being the function name and a `suffix`. 26 static LogicalResult createBackwardSliceFunction(Operation *op, 27 StringRef suffix) { 28 FuncOp parentFuncOp = op->getParentOfType<FuncOp>(); 29 OpBuilder builder(parentFuncOp); 30 Location loc = op->getLoc(); 31 std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str(); 32 FuncOp clonedFuncOp = 33 builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType()); 34 BlockAndValueMapping mapper; 35 builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock()); 36 for (auto arg : enumerate(parentFuncOp.getArguments())) 37 mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index())); 38 llvm::SetVector<Operation *> slice; 39 getBackwardSlice(op, &slice); 40 for (Operation *slicedOp : slice) 41 builder.clone(*slicedOp, mapper); 42 builder.create<ReturnOp>(loc); 43 return success(); 44 } 45 46 namespace { 47 /// Pass to test slice generated from slice analysis. 48 struct SliceAnalysisTestPass 49 : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> { 50 void runOnOperation() override; 51 SliceAnalysisTestPass() = default; 52 SliceAnalysisTestPass(const SliceAnalysisTestPass &) {} 53 }; 54 } // namespace 55 56 void SliceAnalysisTestPass::runOnOperation() { 57 ModuleOp module = getOperation(); 58 auto funcOps = module.getOps<FuncOp>(); 59 unsigned opNum = 0; 60 for (auto funcOp : funcOps) { 61 // TODO: For now this is just looking for Linalg ops. It can be generalized 62 // to look for other ops using flags. 63 funcOp.walk([&](Operation *op) { 64 if (!isa<linalg::LinalgOp>(op)) 65 return WalkResult::advance(); 66 std::string append = 67 std::string("__backward_slice__") + std::to_string(opNum); 68 createBackwardSliceFunction(op, append); 69 opNum++; 70 return WalkResult::advance(); 71 }); 72 } 73 } 74 75 namespace mlir { 76 void registerSliceAnalysisTestPass() { 77 PassRegistration<SliceAnalysisTestPass> pass( 78 "slice-analysis-test", "Test Slice analysis functionality."); 79 } 80 } // namespace mlir 81