xref: /llvm-project/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (revision 88f4292a165cf0b65aca8632840d73e2a094b05f)
1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/IR/FunctionInterfaces.h"
20 
21 using namespace mlir;
22 using namespace mlir::bufferization;
23 using namespace mlir::transform;
24 
25 //===----------------------------------------------------------------------===//
26 // OneShotBufferizeOp
27 //===----------------------------------------------------------------------===//
28 
29 LogicalResult transform::OneShotBufferizeOp::verify() {
30   if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
31     return emitOpError() << "unsupported memcpy op";
32   return success();
33 }
34 
35 DiagnosedSilenceableFailure
36 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
37                                      TransformResults &transformResults,
38                                      TransformState &state) {
39   OneShotBufferizationOptions options;
40   options.allowReturnAllocs = getAllowReturnAllocs();
41   options.allowUnknownOps = getAllowUnknownOps();
42   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
43   options.createDeallocs = getCreateDeallocs();
44   options.testAnalysisOnly = getTestAnalysisOnly();
45   options.printConflicts = getPrintConflicts();
46   if (getFunctionBoundaryTypeConversion().has_value())
47     options.setFunctionBoundaryTypeConversion(
48         *getFunctionBoundaryTypeConversion());
49   if (getMemcpyOp() == "memref.copy") {
50     options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
51       b.create<memref::CopyOp>(loc, from, to);
52       return success();
53     };
54   } else if (getMemcpyOp() == "linalg.copy") {
55     options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
56       b.create<linalg::CopyOp>(loc, from, to);
57       return success();
58     };
59   } else {
60     llvm_unreachable("invalid copy op");
61   }
62 
63   auto payloadOps = state.getPayloadOps(getTarget());
64   for (Operation *target : payloadOps) {
65     if (!isa<ModuleOp, FunctionOpInterface>(target))
66       return emitSilenceableError() << "expected module or function target";
67     auto moduleOp = dyn_cast<ModuleOp>(target);
68     if (options.bufferizeFunctionBoundaries) {
69       if (!moduleOp)
70         return emitSilenceableError() << "expected module target";
71       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
72         return emitSilenceableError() << "bufferization failed";
73     } else {
74       if (failed(bufferization::runOneShotBufferize(target, options)))
75         return emitSilenceableError() << "bufferization failed";
76     }
77   }
78 
79   // This transform op is currently restricted to ModuleOps and function ops.
80   // Such ops are modified in-place.
81   transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
82   return DiagnosedSilenceableFailure::success();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // EliminateEmptyTensorsOp
87 //===----------------------------------------------------------------------===//
88 
89 void transform::EliminateEmptyTensorsOp::getEffects(
90     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
91   onlyReadsHandle(getTarget(), effects);
92   modifiesPayload(effects);
93 }
94 
95 DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
96     transform::TransformRewriter &rewriter, TransformResults &transformResults,
97     TransformState &state) {
98   OneShotBufferizationOptions options;
99   options.allowReturnAllocs = true;
100 
101   for (Operation *target : state.getPayloadOps(getTarget())) {
102     OneShotAnalysisState state(target, options);
103     if (failed(analyzeOp(target, state)))
104       return mlir::emitSilenceableFailure(target->getLoc())
105              << "failed to analyze op";
106     if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
107             rewriter, target, state)))
108       return mlir::emitSilenceableFailure(target->getLoc())
109              << "failed to eliminate insert_slice anchored tensor.empty ops";
110   }
111   return DiagnosedSilenceableFailure::success();
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // EmptyTensorToAllocTensorOp
116 //===----------------------------------------------------------------------===//
117 
118 DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(
119     transform::TransformRewriter &rewriter, tensor::EmptyOp target,
120     ApplyToEachResultList &results, transform::TransformState &state) {
121   rewriter.setInsertionPoint(target);
122   auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
123       target, target.getType(), target.getDynamicSizes());
124   results.push_back(alloc);
125   return DiagnosedSilenceableFailure::success();
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // Transform op registration
130 //===----------------------------------------------------------------------===//
131 
132 namespace {
133 /// Registers new ops and declares PDL as dependent dialect since the additional
134 /// ops are using PDL types for operands and results.
135 class BufferizationTransformDialectExtension
136     : public transform::TransformDialectExtension<
137           BufferizationTransformDialectExtension> {
138 public:
139   using Base::Base;
140 
141   void init() {
142     declareGeneratedDialect<bufferization::BufferizationDialect>();
143     declareGeneratedDialect<memref::MemRefDialect>();
144 
145     registerTransformOps<
146 #define GET_OP_LIST
147 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
148         >();
149   }
150 };
151 } // namespace
152 
153 #define GET_OP_CLASSES
154 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
155 
156 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
157 
158 void mlir::bufferization::registerTransformDialectExtension(
159     DialectRegistry &registry) {
160   registry.addExtensions<BufferizationTransformDialectExtension>();
161 }
162