xref: /llvm-project/mlir/test/lib/IR/TestSlicing.cpp (revision d97bc388fd9ef8bc38353f93ff42d894ddc4a271)
1 //===- TestSlicing.cpp - Testing slice functionality ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple testing pass for slicing.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/IRMapping.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Support/LLVM.h"
21 
22 using namespace mlir;
23 
24 /// Create a function with the same signature as the parent function of `op`
25 /// with name being the function name and a `suffix`.
26 static LogicalResult createBackwardSliceFunction(Operation *op,
27                                                  StringRef suffix,
28                                                  bool omitBlockArguments) {
29   func::FuncOp parentFuncOp = op->getParentOfType<func::FuncOp>();
30   OpBuilder builder(parentFuncOp);
31   Location loc = op->getLoc();
32   std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
33   func::FuncOp clonedFuncOp = builder.create<func::FuncOp>(
34       loc, clonedFuncOpName, parentFuncOp.getFunctionType());
35   IRMapping mapper;
36   builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
37   for (const auto &arg : enumerate(parentFuncOp.getArguments()))
38     mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
39   SetVector<Operation *> slice;
40   BackwardSliceOptions options;
41   options.omitBlockArguments = omitBlockArguments;
42   // TODO: Make this default.
43   options.omitUsesFromAbove = false;
44   getBackwardSlice(op, &slice, options);
45   for (Operation *slicedOp : slice)
46     builder.clone(*slicedOp, mapper);
47   builder.create<func::ReturnOp>(loc);
48   return success();
49 }
50 
51 namespace {
52 /// Pass to test slice generated from slice analysis.
53 struct SliceAnalysisTestPass
54     : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
55   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SliceAnalysisTestPass)
56 
57   StringRef getArgument() const final { return "slice-analysis-test"; }
58   StringRef getDescription() const final {
59     return "Test Slice analysis functionality.";
60   }
61 
62   Option<bool> omitBlockArguments{
63       *this, "omit-block-arguments",
64       llvm::cl::desc("Test Slice analysis with multiple blocks but slice "
65                      "omiting block arguments"),
66       llvm::cl::init(true)};
67 
68   void runOnOperation() override;
69   SliceAnalysisTestPass() = default;
70   SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
71 };
72 } // namespace
73 
74 void SliceAnalysisTestPass::runOnOperation() {
75   ModuleOp module = getOperation();
76   auto funcOps = module.getOps<func::FuncOp>();
77   unsigned opNum = 0;
78   for (auto funcOp : funcOps) {
79     // TODO: For now this is just looking for Linalg ops. It can be generalized
80     // to look for other ops using flags.
81     funcOp.walk([&](Operation *op) {
82       if (!isa<linalg::LinalgOp>(op))
83         return WalkResult::advance();
84       std::string append =
85           std::string("__backward_slice__") + std::to_string(opNum);
86       (void)createBackwardSliceFunction(op, append, omitBlockArguments);
87       opNum++;
88       return WalkResult::advance();
89     });
90   }
91 }
92 
93 namespace mlir {
94 void registerSliceAnalysisTestPass() {
95   PassRegistration<SliceAnalysisTestPass>();
96 }
97 } // namespace mlir
98