1461dafd2SMatthias Springer //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2461dafd2SMatthias Springer // 3461dafd2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4461dafd2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5461dafd2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6461dafd2SMatthias Springer // 7461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 8461dafd2SMatthias Springer 9461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10461dafd2SMatthias Springer 11461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 1455e38579SXiaolei Shi #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 151ccd8cd6SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 1688f4292aSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h" 17461dafd2SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 18e7d0cc76SLorenzo Chelini #include "mlir/Dialect/Tensor/IR/Tensor.h" 19461dafd2SMatthias Springer #include "mlir/Dialect/Transform/IR/TransformDialect.h" 2034a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h" 21461dafd2SMatthias Springer 22461dafd2SMatthias Springer using namespace mlir; 23461dafd2SMatthias Springer using namespace mlir::bufferization; 24461dafd2SMatthias Springer using namespace mlir::transform; 25461dafd2SMatthias Springer 26461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 2798770ecdSMatthias Springer // BufferLoopHoistingOp 2898770ecdSMatthias Springer //===----------------------------------------------------------------------===// 2998770ecdSMatthias Springer 3098770ecdSMatthias Springer DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne( 3198770ecdSMatthias Springer TransformRewriter &rewriter, Operation *target, 3298770ecdSMatthias Springer ApplyToEachResultList &results, TransformState &state) { 3398770ecdSMatthias Springer bufferization::hoistBuffersFromLoops(target); 3498770ecdSMatthias Springer return DiagnosedSilenceableFailure::success(); 3598770ecdSMatthias Springer } 3698770ecdSMatthias Springer 3798770ecdSMatthias Springer void transform::BufferLoopHoistingOp::getEffects( 3898770ecdSMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 392c1ae801Sdonald chen onlyReadsHandle(getTargetMutable(), effects); 4098770ecdSMatthias Springer modifiesPayload(effects); 4198770ecdSMatthias Springer } 4298770ecdSMatthias Springer 4398770ecdSMatthias Springer //===----------------------------------------------------------------------===// 44461dafd2SMatthias Springer // OneShotBufferizeOp 45461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 46461dafd2SMatthias Springer 4788f4292aSMatthias Springer LogicalResult transform::OneShotBufferizeOp::verify() { 4888f4292aSMatthias Springer if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy") 4988f4292aSMatthias Springer return emitOpError() << "unsupported memcpy op"; 505958043eSMatthias Springer if (getPrintConflicts() && !getTestAnalysisOnly()) 515958043eSMatthias Springer return emitOpError() << "'print_conflicts' requires 'test_analysis_only'"; 525958043eSMatthias Springer if (getDumpAliasSets() && !getTestAnalysisOnly()) 535958043eSMatthias Springer return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'"; 5488f4292aSMatthias Springer return success(); 5588f4292aSMatthias Springer } 5688f4292aSMatthias Springer 571d45282aSAlex Zinenko DiagnosedSilenceableFailure 58c63d2b2cSMatthias Springer transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, 59c63d2b2cSMatthias Springer TransformResults &transformResults, 60461dafd2SMatthias Springer TransformState &state) { 61461dafd2SMatthias Springer OneShotBufferizationOptions options; 626bf043e7SMartin Erhart options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops(); 63461dafd2SMatthias Springer options.allowUnknownOps = getAllowUnknownOps(); 64461dafd2SMatthias Springer options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 655958043eSMatthias Springer options.dumpAliasSets = getDumpAliasSets(); 66461dafd2SMatthias Springer options.testAnalysisOnly = getTestAnalysisOnly(); 67461dafd2SMatthias Springer options.printConflicts = getPrintConflicts(); 68c780184aSLorenzo Chelini if (getFunctionBoundaryTypeConversion().has_value()) 6975ef84bfSOleg Shyshkov options.setFunctionBoundaryTypeConversion( 7075ef84bfSOleg Shyshkov *getFunctionBoundaryTypeConversion()); 7188f4292aSMatthias Springer if (getMemcpyOp() == "memref.copy") { 7288f4292aSMatthias Springer options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { 7388f4292aSMatthias Springer b.create<memref::CopyOp>(loc, from, to); 7488f4292aSMatthias Springer return success(); 7588f4292aSMatthias Springer }; 7688f4292aSMatthias Springer } else if (getMemcpyOp() == "linalg.copy") { 7788f4292aSMatthias Springer options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { 7888f4292aSMatthias Springer b.create<linalg::CopyOp>(loc, from, to); 7988f4292aSMatthias Springer return success(); 8088f4292aSMatthias Springer }; 8188f4292aSMatthias Springer } else { 8288f4292aSMatthias Springer llvm_unreachable("invalid copy op"); 8388f4292aSMatthias Springer } 84461dafd2SMatthias Springer 850e37ef08SMatthias Springer auto payloadOps = state.getPayloadOps(getTarget()); 86461dafd2SMatthias Springer for (Operation *target : payloadOps) { 873f7959eaSMatthias Springer if (!isa<ModuleOp, FunctionOpInterface>(target)) 883f7959eaSMatthias Springer return emitSilenceableError() << "expected module or function target"; 89461dafd2SMatthias Springer auto moduleOp = dyn_cast<ModuleOp>(target); 90461dafd2SMatthias Springer if (options.bufferizeFunctionBoundaries) { 91461dafd2SMatthias Springer if (!moduleOp) 923f7959eaSMatthias Springer return emitSilenceableError() << "expected module target"; 93461dafd2SMatthias Springer if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 941d45282aSAlex Zinenko return emitSilenceableError() << "bufferization failed"; 95461dafd2SMatthias Springer } else { 96461dafd2SMatthias Springer if (failed(bufferization::runOneShotBufferize(target, options))) 971d45282aSAlex Zinenko return emitSilenceableError() << "bufferization failed"; 98461dafd2SMatthias Springer } 99461dafd2SMatthias Springer } 100461dafd2SMatthias Springer 1013f7959eaSMatthias Springer // This transform op is currently restricted to ModuleOps and function ops. 1023f7959eaSMatthias Springer // Such ops are modified in-place. 1035550c821STres Popp transformResults.set(cast<OpResult>(getTransformed()), payloadOps); 1041d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 105461dafd2SMatthias Springer } 106461dafd2SMatthias Springer 107e7d0cc76SLorenzo Chelini //===----------------------------------------------------------------------===// 1081ccd8cd6SMatthias Springer // EliminateEmptyTensorsOp 1091ccd8cd6SMatthias Springer //===----------------------------------------------------------------------===// 1101ccd8cd6SMatthias Springer 1111ccd8cd6SMatthias Springer void transform::EliminateEmptyTensorsOp::getEffects( 1121ccd8cd6SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1132c1ae801Sdonald chen onlyReadsHandle(getTargetMutable(), effects); 1141ccd8cd6SMatthias Springer modifiesPayload(effects); 1151ccd8cd6SMatthias Springer } 1161ccd8cd6SMatthias Springer 117c63d2b2cSMatthias Springer DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( 118c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, TransformResults &transformResults, 1191ccd8cd6SMatthias Springer TransformState &state) { 1200e37ef08SMatthias Springer for (Operation *target : state.getPayloadOps(getTarget())) { 12187633432SMatthias Springer if (failed(bufferization::eliminateEmptyTensors(rewriter, target))) 1221ccd8cd6SMatthias Springer return mlir::emitSilenceableFailure(target->getLoc()) 12387633432SMatthias Springer << "empty tensor elimination failed"; 1241ccd8cd6SMatthias Springer } 1251ccd8cd6SMatthias Springer return DiagnosedSilenceableFailure::success(); 1261ccd8cd6SMatthias Springer } 1271ccd8cd6SMatthias Springer 1281ccd8cd6SMatthias Springer //===----------------------------------------------------------------------===// 129e7d0cc76SLorenzo Chelini // EmptyTensorToAllocTensorOp 130e7d0cc76SLorenzo Chelini //===----------------------------------------------------------------------===// 131e7d0cc76SLorenzo Chelini 132c63d2b2cSMatthias Springer DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( 133c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, tensor::EmptyOp target, 134c63d2b2cSMatthias Springer ApplyToEachResultList &results, transform::TransformState &state) { 135e7d0cc76SLorenzo Chelini rewriter.setInsertionPoint(target); 136e7d0cc76SLorenzo Chelini auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( 137e7d0cc76SLorenzo Chelini target, target.getType(), target.getDynamicSizes()); 138e7d0cc76SLorenzo Chelini results.push_back(alloc); 139e7d0cc76SLorenzo Chelini return DiagnosedSilenceableFailure::success(); 140e7d0cc76SLorenzo Chelini } 141e7d0cc76SLorenzo Chelini 142461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 143461dafd2SMatthias Springer // Transform op registration 144461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 145461dafd2SMatthias Springer 146461dafd2SMatthias Springer namespace { 147461dafd2SMatthias Springer /// Registers new ops and declares PDL as dependent dialect since the additional 148461dafd2SMatthias Springer /// ops are using PDL types for operands and results. 149461dafd2SMatthias Springer class BufferizationTransformDialectExtension 150461dafd2SMatthias Springer : public transform::TransformDialectExtension< 151461dafd2SMatthias Springer BufferizationTransformDialectExtension> { 152461dafd2SMatthias Springer public: 153*84cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 154*84cc1865SNikhil Kalra BufferizationTransformDialectExtension) 155*84cc1865SNikhil Kalra 156333ee218SAlex Zinenko using Base::Base; 157333ee218SAlex Zinenko 158333ee218SAlex Zinenko void init() { 159333ee218SAlex Zinenko declareGeneratedDialect<bufferization::BufferizationDialect>(); 160333ee218SAlex Zinenko declareGeneratedDialect<memref::MemRefDialect>(); 161333ee218SAlex Zinenko 162461dafd2SMatthias Springer registerTransformOps< 163461dafd2SMatthias Springer #define GET_OP_LIST 164461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 165461dafd2SMatthias Springer >(); 166461dafd2SMatthias Springer } 167461dafd2SMatthias Springer }; 168461dafd2SMatthias Springer } // namespace 169461dafd2SMatthias Springer 170461dafd2SMatthias Springer #define GET_OP_CLASSES 171461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 172461dafd2SMatthias Springer 173c780184aSLorenzo Chelini #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" 174c780184aSLorenzo Chelini 175461dafd2SMatthias Springer void mlir::bufferization::registerTransformDialectExtension( 176461dafd2SMatthias Springer DialectRegistry ®istry) { 177461dafd2SMatthias Springer registry.addExtensions<BufferizationTransformDialectExtension>(); 178461dafd2SMatthias Springer } 179