xref: /llvm-project/mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1 //===- InferEffects.cpp - Infer memory effects for named symbols ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
10 #include "mlir/Dialect/Transform/Transforms/Passes.h"
11 
12 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13 #include "mlir/IR/Visitors.h"
14 #include "mlir/Interfaces/FunctionInterfaces.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
16 #include "llvm/ADT/DenseSet.h"
17 
18 using namespace mlir;
19 
20 namespace mlir {
21 namespace transform {
22 #define GEN_PASS_DEF_INFEREFFECTSPASS
23 #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
24 } // namespace transform
25 } // namespace mlir
26 
inferSideEffectAnnotations(Operation * op)27 static LogicalResult inferSideEffectAnnotations(Operation *op) {
28   if (!isa<transform::TransformOpInterface>(op))
29     return success();
30 
31   auto func = dyn_cast<FunctionOpInterface>(op);
32   if (!func || func.isExternal())
33     return success();
34 
35   if (!func.getFunctionBody().hasOneBlock()) {
36     return op->emitError()
37            << "only single-block operations are currently supported";
38   }
39 
40   // Note that there can't be an inclusion of an unannotated symbol because it
41   // wouldn't have passed the verifier, so recursion isn't necessary here.
42   llvm::SmallDenseSet<unsigned> consumedArguments;
43   transform::getConsumedBlockArguments(func.getFunctionBody().front(),
44                                        consumedArguments);
45 
46   for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
47     func.setArgAttr(i,
48                     consumedArguments.contains(i)
49                         ? transform::TransformDialect::kArgConsumedAttrName
50                         : transform::TransformDialect::kArgReadOnlyAttrName,
51                     UnitAttr::get(op->getContext()));
52   }
53   return success();
54 }
55 
56 namespace {
57 class InferEffectsPass
58     : public transform::impl::InferEffectsPassBase<InferEffectsPass> {
59 public:
runOnOperation()60   void runOnOperation() override {
61     WalkResult result = getOperation()->walk([](Operation *op) {
62       return failed(inferSideEffectAnnotations(op)) ? WalkResult::interrupt()
63                                                     : WalkResult::advance();
64     });
65     if (result.wasInterrupted())
66       return signalPassFailure();
67   }
68 };
69 } // namespace
70