xref: /llvm-project/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
20 #include "mlir/Interfaces/FunctionInterfaces.h"
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 using namespace mlir::transform;
25 
26 //===----------------------------------------------------------------------===//
27 // BufferLoopHoistingOp
28 //===----------------------------------------------------------------------===//
29 
30 DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne(
31     TransformRewriter &rewriter, Operation *target,
32     ApplyToEachResultList &results, TransformState &state) {
33   bufferization::hoistBuffersFromLoops(target);
34   return DiagnosedSilenceableFailure::success();
35 }
36 
37 void transform::BufferLoopHoistingOp::getEffects(
38     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
39   onlyReadsHandle(getTargetMutable(), effects);
40   modifiesPayload(effects);
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // OneShotBufferizeOp
45 //===----------------------------------------------------------------------===//
46 
47 LogicalResult transform::OneShotBufferizeOp::verify() {
48   if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
49     return emitOpError() << "unsupported memcpy op";
50   if (getPrintConflicts() && !getTestAnalysisOnly())
51     return emitOpError() << "'print_conflicts' requires 'test_analysis_only'";
52   if (getDumpAliasSets() && !getTestAnalysisOnly())
53     return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'";
54   return success();
55 }
56 
57 DiagnosedSilenceableFailure
58 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
59                                      TransformResults &transformResults,
60                                      TransformState &state) {
61   OneShotBufferizationOptions options;
62   options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops();
63   options.allowUnknownOps = getAllowUnknownOps();
64   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
65   options.dumpAliasSets = getDumpAliasSets();
66   options.testAnalysisOnly = getTestAnalysisOnly();
67   options.printConflicts = getPrintConflicts();
68   if (getFunctionBoundaryTypeConversion().has_value())
69     options.setFunctionBoundaryTypeConversion(
70         *getFunctionBoundaryTypeConversion());
71   if (getMemcpyOp() == "memref.copy") {
72     options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
73       b.create<memref::CopyOp>(loc, from, to);
74       return success();
75     };
76   } else if (getMemcpyOp() == "linalg.copy") {
77     options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
78       b.create<linalg::CopyOp>(loc, from, to);
79       return success();
80     };
81   } else {
82     llvm_unreachable("invalid copy op");
83   }
84 
85   auto payloadOps = state.getPayloadOps(getTarget());
86   for (Operation *target : payloadOps) {
87     if (!isa<ModuleOp, FunctionOpInterface>(target))
88       return emitSilenceableError() << "expected module or function target";
89     auto moduleOp = dyn_cast<ModuleOp>(target);
90     if (options.bufferizeFunctionBoundaries) {
91       if (!moduleOp)
92         return emitSilenceableError() << "expected module target";
93       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
94         return emitSilenceableError() << "bufferization failed";
95     } else {
96       if (failed(bufferization::runOneShotBufferize(target, options)))
97         return emitSilenceableError() << "bufferization failed";
98     }
99   }
100 
101   // This transform op is currently restricted to ModuleOps and function ops.
102   // Such ops are modified in-place.
103   transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
104   return DiagnosedSilenceableFailure::success();
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // EliminateEmptyTensorsOp
109 //===----------------------------------------------------------------------===//
110 
111 void transform::EliminateEmptyTensorsOp::getEffects(
112     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
113   onlyReadsHandle(getTargetMutable(), effects);
114   modifiesPayload(effects);
115 }
116 
117 DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
118     transform::TransformRewriter &rewriter, TransformResults &transformResults,
119     TransformState &state) {
120   for (Operation *target : state.getPayloadOps(getTarget())) {
121     if (failed(bufferization::eliminateEmptyTensors(rewriter, target)))
122       return mlir::emitSilenceableFailure(target->getLoc())
123              << "empty tensor elimination failed";
124   }
125   return DiagnosedSilenceableFailure::success();
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // EmptyTensorToAllocTensorOp
130 //===----------------------------------------------------------------------===//
131 
132 DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(
133     transform::TransformRewriter &rewriter, tensor::EmptyOp target,
134     ApplyToEachResultList &results, transform::TransformState &state) {
135   rewriter.setInsertionPoint(target);
136   auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
137       target, target.getType(), target.getDynamicSizes());
138   results.push_back(alloc);
139   return DiagnosedSilenceableFailure::success();
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Transform op registration
144 //===----------------------------------------------------------------------===//
145 
146 namespace {
147 /// Registers new ops and declares PDL as dependent dialect since the additional
148 /// ops are using PDL types for operands and results.
149 class BufferizationTransformDialectExtension
150     : public transform::TransformDialectExtension<
151           BufferizationTransformDialectExtension> {
152 public:
153   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
154       BufferizationTransformDialectExtension)
155 
156   using Base::Base;
157 
158   void init() {
159     declareGeneratedDialect<bufferization::BufferizationDialect>();
160     declareGeneratedDialect<memref::MemRefDialect>();
161 
162     registerTransformOps<
163 #define GET_OP_LIST
164 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
165         >();
166   }
167 };
168 } // namespace
169 
170 #define GET_OP_CLASSES
171 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
172 
173 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
174 
175 void mlir::bufferization::registerTransformDialectExtension(
176     DialectRegistry &registry) {
177   registry.addExtensions<BufferizationTransformDialectExtension>();
178 }
179