xref: /llvm-project/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1461dafd2SMatthias Springer //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2461dafd2SMatthias Springer //
3461dafd2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4461dafd2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5461dafd2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6461dafd2SMatthias Springer //
7461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
8461dafd2SMatthias Springer 
9461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10461dafd2SMatthias Springer 
11461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
1455e38579SXiaolei Shi #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
151ccd8cd6SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1688f4292aSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
17461dafd2SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
18e7d0cc76SLorenzo Chelini #include "mlir/Dialect/Tensor/IR/Tensor.h"
19461dafd2SMatthias Springer #include "mlir/Dialect/Transform/IR/TransformDialect.h"
2034a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
21461dafd2SMatthias Springer 
22461dafd2SMatthias Springer using namespace mlir;
23461dafd2SMatthias Springer using namespace mlir::bufferization;
24461dafd2SMatthias Springer using namespace mlir::transform;
25461dafd2SMatthias Springer 
26461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
2798770ecdSMatthias Springer // BufferLoopHoistingOp
2898770ecdSMatthias Springer //===----------------------------------------------------------------------===//
2998770ecdSMatthias Springer 
3098770ecdSMatthias Springer DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne(
3198770ecdSMatthias Springer     TransformRewriter &rewriter, Operation *target,
3298770ecdSMatthias Springer     ApplyToEachResultList &results, TransformState &state) {
3398770ecdSMatthias Springer   bufferization::hoistBuffersFromLoops(target);
3498770ecdSMatthias Springer   return DiagnosedSilenceableFailure::success();
3598770ecdSMatthias Springer }
3698770ecdSMatthias Springer 
3798770ecdSMatthias Springer void transform::BufferLoopHoistingOp::getEffects(
3898770ecdSMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
392c1ae801Sdonald chen   onlyReadsHandle(getTargetMutable(), effects);
4098770ecdSMatthias Springer   modifiesPayload(effects);
4198770ecdSMatthias Springer }
4298770ecdSMatthias Springer 
4398770ecdSMatthias Springer //===----------------------------------------------------------------------===//
44461dafd2SMatthias Springer // OneShotBufferizeOp
45461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
46461dafd2SMatthias Springer 
4788f4292aSMatthias Springer LogicalResult transform::OneShotBufferizeOp::verify() {
4888f4292aSMatthias Springer   if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
4988f4292aSMatthias Springer     return emitOpError() << "unsupported memcpy op";
505958043eSMatthias Springer   if (getPrintConflicts() && !getTestAnalysisOnly())
515958043eSMatthias Springer     return emitOpError() << "'print_conflicts' requires 'test_analysis_only'";
525958043eSMatthias Springer   if (getDumpAliasSets() && !getTestAnalysisOnly())
535958043eSMatthias Springer     return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'";
5488f4292aSMatthias Springer   return success();
5588f4292aSMatthias Springer }
5688f4292aSMatthias Springer 
571d45282aSAlex Zinenko DiagnosedSilenceableFailure
58c63d2b2cSMatthias Springer transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
59c63d2b2cSMatthias Springer                                      TransformResults &transformResults,
60461dafd2SMatthias Springer                                      TransformState &state) {
61461dafd2SMatthias Springer   OneShotBufferizationOptions options;
626bf043e7SMartin Erhart   options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops();
63461dafd2SMatthias Springer   options.allowUnknownOps = getAllowUnknownOps();
64461dafd2SMatthias Springer   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
655958043eSMatthias Springer   options.dumpAliasSets = getDumpAliasSets();
66461dafd2SMatthias Springer   options.testAnalysisOnly = getTestAnalysisOnly();
67461dafd2SMatthias Springer   options.printConflicts = getPrintConflicts();
68c780184aSLorenzo Chelini   if (getFunctionBoundaryTypeConversion().has_value())
6975ef84bfSOleg Shyshkov     options.setFunctionBoundaryTypeConversion(
7075ef84bfSOleg Shyshkov         *getFunctionBoundaryTypeConversion());
7188f4292aSMatthias Springer   if (getMemcpyOp() == "memref.copy") {
7288f4292aSMatthias Springer     options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
7388f4292aSMatthias Springer       b.create<memref::CopyOp>(loc, from, to);
7488f4292aSMatthias Springer       return success();
7588f4292aSMatthias Springer     };
7688f4292aSMatthias Springer   } else if (getMemcpyOp() == "linalg.copy") {
7788f4292aSMatthias Springer     options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
7888f4292aSMatthias Springer       b.create<linalg::CopyOp>(loc, from, to);
7988f4292aSMatthias Springer       return success();
8088f4292aSMatthias Springer     };
8188f4292aSMatthias Springer   } else {
8288f4292aSMatthias Springer     llvm_unreachable("invalid copy op");
8388f4292aSMatthias Springer   }
84461dafd2SMatthias Springer 
850e37ef08SMatthias Springer   auto payloadOps = state.getPayloadOps(getTarget());
86461dafd2SMatthias Springer   for (Operation *target : payloadOps) {
873f7959eaSMatthias Springer     if (!isa<ModuleOp, FunctionOpInterface>(target))
883f7959eaSMatthias Springer       return emitSilenceableError() << "expected module or function target";
89461dafd2SMatthias Springer     auto moduleOp = dyn_cast<ModuleOp>(target);
90461dafd2SMatthias Springer     if (options.bufferizeFunctionBoundaries) {
91461dafd2SMatthias Springer       if (!moduleOp)
923f7959eaSMatthias Springer         return emitSilenceableError() << "expected module target";
93461dafd2SMatthias Springer       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
941d45282aSAlex Zinenko         return emitSilenceableError() << "bufferization failed";
95461dafd2SMatthias Springer     } else {
96461dafd2SMatthias Springer       if (failed(bufferization::runOneShotBufferize(target, options)))
971d45282aSAlex Zinenko         return emitSilenceableError() << "bufferization failed";
98461dafd2SMatthias Springer     }
99461dafd2SMatthias Springer   }
100461dafd2SMatthias Springer 
1013f7959eaSMatthias Springer   // This transform op is currently restricted to ModuleOps and function ops.
1023f7959eaSMatthias Springer   // Such ops are modified in-place.
1035550c821STres Popp   transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
1041d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
105461dafd2SMatthias Springer }
106461dafd2SMatthias Springer 
107e7d0cc76SLorenzo Chelini //===----------------------------------------------------------------------===//
1081ccd8cd6SMatthias Springer // EliminateEmptyTensorsOp
1091ccd8cd6SMatthias Springer //===----------------------------------------------------------------------===//
1101ccd8cd6SMatthias Springer 
1111ccd8cd6SMatthias Springer void transform::EliminateEmptyTensorsOp::getEffects(
1121ccd8cd6SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1132c1ae801Sdonald chen   onlyReadsHandle(getTargetMutable(), effects);
1141ccd8cd6SMatthias Springer   modifiesPayload(effects);
1151ccd8cd6SMatthias Springer }
1161ccd8cd6SMatthias Springer 
117c63d2b2cSMatthias Springer DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
118c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, TransformResults &transformResults,
1191ccd8cd6SMatthias Springer     TransformState &state) {
1200e37ef08SMatthias Springer   for (Operation *target : state.getPayloadOps(getTarget())) {
12187633432SMatthias Springer     if (failed(bufferization::eliminateEmptyTensors(rewriter, target)))
1221ccd8cd6SMatthias Springer       return mlir::emitSilenceableFailure(target->getLoc())
12387633432SMatthias Springer              << "empty tensor elimination failed";
1241ccd8cd6SMatthias Springer   }
1251ccd8cd6SMatthias Springer   return DiagnosedSilenceableFailure::success();
1261ccd8cd6SMatthias Springer }
1271ccd8cd6SMatthias Springer 
1281ccd8cd6SMatthias Springer //===----------------------------------------------------------------------===//
129e7d0cc76SLorenzo Chelini // EmptyTensorToAllocTensorOp
130e7d0cc76SLorenzo Chelini //===----------------------------------------------------------------------===//
131e7d0cc76SLorenzo Chelini 
132c63d2b2cSMatthias Springer DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(
133c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, tensor::EmptyOp target,
134c63d2b2cSMatthias Springer     ApplyToEachResultList &results, transform::TransformState &state) {
135e7d0cc76SLorenzo Chelini   rewriter.setInsertionPoint(target);
136e7d0cc76SLorenzo Chelini   auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
137e7d0cc76SLorenzo Chelini       target, target.getType(), target.getDynamicSizes());
138e7d0cc76SLorenzo Chelini   results.push_back(alloc);
139e7d0cc76SLorenzo Chelini   return DiagnosedSilenceableFailure::success();
140e7d0cc76SLorenzo Chelini }
141e7d0cc76SLorenzo Chelini 
142461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
143461dafd2SMatthias Springer // Transform op registration
144461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
145461dafd2SMatthias Springer 
146461dafd2SMatthias Springer namespace {
147461dafd2SMatthias Springer /// Registers new ops and declares PDL as dependent dialect since the additional
148461dafd2SMatthias Springer /// ops are using PDL types for operands and results.
149461dafd2SMatthias Springer class BufferizationTransformDialectExtension
150461dafd2SMatthias Springer     : public transform::TransformDialectExtension<
151461dafd2SMatthias Springer           BufferizationTransformDialectExtension> {
152461dafd2SMatthias Springer public:
153*84cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
154*84cc1865SNikhil Kalra       BufferizationTransformDialectExtension)
155*84cc1865SNikhil Kalra 
156333ee218SAlex Zinenko   using Base::Base;
157333ee218SAlex Zinenko 
158333ee218SAlex Zinenko   void init() {
159333ee218SAlex Zinenko     declareGeneratedDialect<bufferization::BufferizationDialect>();
160333ee218SAlex Zinenko     declareGeneratedDialect<memref::MemRefDialect>();
161333ee218SAlex Zinenko 
162461dafd2SMatthias Springer     registerTransformOps<
163461dafd2SMatthias Springer #define GET_OP_LIST
164461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
165461dafd2SMatthias Springer         >();
166461dafd2SMatthias Springer   }
167461dafd2SMatthias Springer };
168461dafd2SMatthias Springer } // namespace
169461dafd2SMatthias Springer 
170461dafd2SMatthias Springer #define GET_OP_CLASSES
171461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
172461dafd2SMatthias Springer 
173c780184aSLorenzo Chelini #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
174c780184aSLorenzo Chelini 
175461dafd2SMatthias Springer void mlir::bufferization::registerTransformDialectExtension(
176461dafd2SMatthias Springer     DialectRegistry &registry) {
177461dafd2SMatthias Springer   registry.addExtensions<BufferizationTransformDialectExtension>();
178461dafd2SMatthias Springer }
179