1 //===- InferEffects.cpp - Infer memory effects for named symbols ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 10 #include "mlir/Dialect/Transform/Transforms/Passes.h" 11 12 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 13 #include "mlir/IR/Visitors.h" 14 #include "mlir/Interfaces/FunctionInterfaces.h" 15 #include "mlir/Interfaces/SideEffectInterfaces.h" 16 #include "llvm/ADT/DenseSet.h" 17 18 using namespace mlir; 19 20 namespace mlir { 21 namespace transform { 22 #define GEN_PASS_DEF_INFEREFFECTSPASS 23 #include "mlir/Dialect/Transform/Transforms/Passes.h.inc" 24 } // namespace transform 25 } // namespace mlir 26 inferSideEffectAnnotations(Operation * op)27static LogicalResult inferSideEffectAnnotations(Operation *op) { 28 if (!isa<transform::TransformOpInterface>(op)) 29 return success(); 30 31 auto func = dyn_cast<FunctionOpInterface>(op); 32 if (!func || func.isExternal()) 33 return success(); 34 35 if (!func.getFunctionBody().hasOneBlock()) { 36 return op->emitError() 37 << "only single-block operations are currently supported"; 38 } 39 40 // Note that there can't be an inclusion of an unannotated symbol because it 41 // wouldn't have passed the verifier, so recursion isn't necessary here. 42 llvm::SmallDenseSet<unsigned> consumedArguments; 43 transform::getConsumedBlockArguments(func.getFunctionBody().front(), 44 consumedArguments); 45 46 for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) { 47 func.setArgAttr(i, 48 consumedArguments.contains(i) 49 ? transform::TransformDialect::kArgConsumedAttrName 50 : transform::TransformDialect::kArgReadOnlyAttrName, 51 UnitAttr::get(op->getContext())); 52 } 53 return success(); 54 } 55 56 namespace { 57 class InferEffectsPass 58 : public transform::impl::InferEffectsPassBase<InferEffectsPass> { 59 public: runOnOperation()60 void runOnOperation() override { 61 WalkResult result = getOperation()->walk([](Operation *op) { 62 return failed(inferSideEffectAnnotations(op)) ? WalkResult::interrupt() 63 : WalkResult::advance(); 64 }); 65 if (result.wasInterrupted()) 66 return signalPassFailure(); 67 } 68 }; 69 } // namespace 70