141109341SAlex Zinenko //===- InferEffects.cpp - Infer memory effects for named symbols ----------===// 241109341SAlex Zinenko // 341109341SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 441109341SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 541109341SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 641109341SAlex Zinenko // 741109341SAlex Zinenko //===----------------------------------------------------------------------===// 841109341SAlex Zinenko 941109341SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h" 1041109341SAlex Zinenko #include "mlir/Dialect/Transform/Transforms/Passes.h" 1141109341SAlex Zinenko 12*5a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 1341109341SAlex Zinenko #include "mlir/IR/Visitors.h" 1434a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h" 1541109341SAlex Zinenko #include "mlir/Interfaces/SideEffectInterfaces.h" 1641109341SAlex Zinenko #include "llvm/ADT/DenseSet.h" 1741109341SAlex Zinenko 1841109341SAlex Zinenko using namespace mlir; 1941109341SAlex Zinenko 2041109341SAlex Zinenko namespace mlir { 2141109341SAlex Zinenko namespace transform { 2241109341SAlex Zinenko #define GEN_PASS_DEF_INFEREFFECTSPASS 2341109341SAlex Zinenko #include "mlir/Dialect/Transform/Transforms/Passes.h.inc" 2441109341SAlex Zinenko } // namespace transform 2541109341SAlex Zinenko } // namespace mlir 2641109341SAlex Zinenko inferSideEffectAnnotations(Operation * op)2741109341SAlex Zinenkostatic LogicalResult inferSideEffectAnnotations(Operation *op) { 2841109341SAlex Zinenko if (!isa<transform::TransformOpInterface>(op)) 2941109341SAlex Zinenko return success(); 3041109341SAlex Zinenko 3141109341SAlex Zinenko auto func = dyn_cast<FunctionOpInterface>(op); 3241109341SAlex Zinenko if (!func || func.isExternal()) 3341109341SAlex Zinenko return success(); 3441109341SAlex Zinenko 3541109341SAlex Zinenko if (!func.getFunctionBody().hasOneBlock()) { 3641109341SAlex Zinenko return op->emitError() 3741109341SAlex Zinenko << "only single-block operations are currently supported"; 3841109341SAlex Zinenko } 3941109341SAlex Zinenko 4041109341SAlex Zinenko // Note that there can't be an inclusion of an unannotated symbol because it 4141109341SAlex Zinenko // wouldn't have passed the verifier, so recursion isn't necessary here. 4241109341SAlex Zinenko llvm::SmallDenseSet<unsigned> consumedArguments; 4341109341SAlex Zinenko transform::getConsumedBlockArguments(func.getFunctionBody().front(), 4441109341SAlex Zinenko consumedArguments); 4541109341SAlex Zinenko 4641109341SAlex Zinenko for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) { 4741109341SAlex Zinenko func.setArgAttr(i, 4841109341SAlex Zinenko consumedArguments.contains(i) 4941109341SAlex Zinenko ? transform::TransformDialect::kArgConsumedAttrName 5041109341SAlex Zinenko : transform::TransformDialect::kArgReadOnlyAttrName, 5141109341SAlex Zinenko UnitAttr::get(op->getContext())); 5241109341SAlex Zinenko } 5341109341SAlex Zinenko return success(); 5441109341SAlex Zinenko } 5541109341SAlex Zinenko 5641109341SAlex Zinenko namespace { 5741109341SAlex Zinenko class InferEffectsPass 5841109341SAlex Zinenko : public transform::impl::InferEffectsPassBase<InferEffectsPass> { 5941109341SAlex Zinenko public: runOnOperation()6041109341SAlex Zinenko void runOnOperation() override { 6141109341SAlex Zinenko WalkResult result = getOperation()->walk([](Operation *op) { 6241109341SAlex Zinenko return failed(inferSideEffectAnnotations(op)) ? WalkResult::interrupt() 6341109341SAlex Zinenko : WalkResult::advance(); 6441109341SAlex Zinenko }); 6541109341SAlex Zinenko if (result.wasInterrupted()) 6641109341SAlex Zinenko return signalPassFailure(); 6741109341SAlex Zinenko } 6841109341SAlex Zinenko }; 6941109341SAlex Zinenko } // namespace 70