xref: /llvm-project/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (revision 98770ecd76bb9cb1067c78c502be3899a28e5f92)
1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/IR/FunctionInterfaces.h"
20 
21 using namespace mlir;
22 using namespace mlir::bufferization;
23 using namespace mlir::transform;
24 
25 //===----------------------------------------------------------------------===//
26 // BufferLoopHoistingOp
27 //===----------------------------------------------------------------------===//
28 
29 DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne(
30     TransformRewriter &rewriter, Operation *target,
31     ApplyToEachResultList &results, TransformState &state) {
32   bufferization::hoistBuffersFromLoops(target);
33   return DiagnosedSilenceableFailure::success();
34 }
35 
36 void transform::BufferLoopHoistingOp::getEffects(
37     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
38   onlyReadsHandle(getTarget(), effects);
39   modifiesPayload(effects);
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // OneShotBufferizeOp
44 //===----------------------------------------------------------------------===//
45 
46 LogicalResult transform::OneShotBufferizeOp::verify() {
47   if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
48     return emitOpError() << "unsupported memcpy op";
49   return success();
50 }
51 
52 DiagnosedSilenceableFailure
53 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
54                                      TransformResults &transformResults,
55                                      TransformState &state) {
56   OneShotBufferizationOptions options;
57   options.allowReturnAllocs = getAllowReturnAllocs();
58   options.allowUnknownOps = getAllowUnknownOps();
59   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
60   options.createDeallocs = getCreateDeallocs();
61   options.testAnalysisOnly = getTestAnalysisOnly();
62   options.printConflicts = getPrintConflicts();
63   if (getFunctionBoundaryTypeConversion().has_value())
64     options.setFunctionBoundaryTypeConversion(
65         *getFunctionBoundaryTypeConversion());
66   if (getMemcpyOp() == "memref.copy") {
67     options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
68       b.create<memref::CopyOp>(loc, from, to);
69       return success();
70     };
71   } else if (getMemcpyOp() == "linalg.copy") {
72     options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
73       b.create<linalg::CopyOp>(loc, from, to);
74       return success();
75     };
76   } else {
77     llvm_unreachable("invalid copy op");
78   }
79 
80   auto payloadOps = state.getPayloadOps(getTarget());
81   for (Operation *target : payloadOps) {
82     if (!isa<ModuleOp, FunctionOpInterface>(target))
83       return emitSilenceableError() << "expected module or function target";
84     auto moduleOp = dyn_cast<ModuleOp>(target);
85     if (options.bufferizeFunctionBoundaries) {
86       if (!moduleOp)
87         return emitSilenceableError() << "expected module target";
88       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
89         return emitSilenceableError() << "bufferization failed";
90     } else {
91       if (failed(bufferization::runOneShotBufferize(target, options)))
92         return emitSilenceableError() << "bufferization failed";
93     }
94   }
95 
96   // This transform op is currently restricted to ModuleOps and function ops.
97   // Such ops are modified in-place.
98   transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
99   return DiagnosedSilenceableFailure::success();
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // EliminateEmptyTensorsOp
104 //===----------------------------------------------------------------------===//
105 
106 void transform::EliminateEmptyTensorsOp::getEffects(
107     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
108   onlyReadsHandle(getTarget(), effects);
109   modifiesPayload(effects);
110 }
111 
112 DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
113     transform::TransformRewriter &rewriter, TransformResults &transformResults,
114     TransformState &state) {
115   OneShotBufferizationOptions options;
116   options.allowReturnAllocs = true;
117 
118   for (Operation *target : state.getPayloadOps(getTarget())) {
119     OneShotAnalysisState state(target, options);
120     if (failed(analyzeOp(target, state)))
121       return mlir::emitSilenceableFailure(target->getLoc())
122              << "failed to analyze op";
123     if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
124             rewriter, target, state)))
125       return mlir::emitSilenceableFailure(target->getLoc())
126              << "failed to eliminate insert_slice anchored tensor.empty ops";
127   }
128   return DiagnosedSilenceableFailure::success();
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // EmptyTensorToAllocTensorOp
133 //===----------------------------------------------------------------------===//
134 
135 DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(
136     transform::TransformRewriter &rewriter, tensor::EmptyOp target,
137     ApplyToEachResultList &results, transform::TransformState &state) {
138   rewriter.setInsertionPoint(target);
139   auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
140       target, target.getType(), target.getDynamicSizes());
141   results.push_back(alloc);
142   return DiagnosedSilenceableFailure::success();
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // Transform op registration
147 //===----------------------------------------------------------------------===//
148 
149 namespace {
150 /// Registers new ops and declares PDL as dependent dialect since the additional
151 /// ops are using PDL types for operands and results.
152 class BufferizationTransformDialectExtension
153     : public transform::TransformDialectExtension<
154           BufferizationTransformDialectExtension> {
155 public:
156   using Base::Base;
157 
158   void init() {
159     declareGeneratedDialect<bufferization::BufferizationDialect>();
160     declareGeneratedDialect<memref::MemRefDialect>();
161 
162     registerTransformOps<
163 #define GET_OP_LIST
164 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
165         >();
166   }
167 };
168 } // namespace
169 
170 #define GET_OP_CLASSES
171 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
172 
173 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
174 
175 void mlir::bufferization::registerTransformDialectExtension(
176     DialectRegistry &registry) {
177   registry.addExtensions<BufferizationTransformDialectExtension>();
178 }
179