xref: /llvm-project/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (revision 1ccd8cd6e68c1edfdfc0dbc21f4afb75f3a338e0)
1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
20 #include "mlir/IR/FunctionInterfaces.h"
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 using namespace mlir::transform;
25 
26 //===----------------------------------------------------------------------===//
27 // OneShotBufferizeOp
28 //===----------------------------------------------------------------------===//
29 
30 DiagnosedSilenceableFailure
31 transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
32                                      TransformState &state) {
33   OneShotBufferizationOptions options;
34   options.allowReturnAllocs = getAllowReturnAllocs();
35   options.allowUnknownOps = getAllowUnknownOps();
36   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
37   options.createDeallocs = getCreateDeallocs();
38   options.testAnalysisOnly = getTestAnalysisOnly();
39   options.printConflicts = getPrintConflicts();
40   if (getFunctionBoundaryTypeConversion().has_value())
41     options.functionBoundaryTypeConversion =
42         *getFunctionBoundaryTypeConversion();
43 
44   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
45   for (Operation *target : payloadOps) {
46     if (!isa<ModuleOp, FunctionOpInterface>(target))
47       return emitSilenceableError() << "expected module or function target";
48     auto moduleOp = dyn_cast<ModuleOp>(target);
49     if (options.bufferizeFunctionBoundaries) {
50       if (!moduleOp)
51         return emitSilenceableError() << "expected module target";
52       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
53         return emitSilenceableError() << "bufferization failed";
54     } else {
55       if (failed(bufferization::runOneShotBufferize(target, options)))
56         return emitSilenceableError() << "bufferization failed";
57     }
58   }
59 
60   // This transform op is currently restricted to ModuleOps and function ops.
61   // Such ops are modified in-place.
62   transformResults.set(getTransformed().cast<OpResult>(), payloadOps);
63   return DiagnosedSilenceableFailure::success();
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // EliminateEmptyTensorsOp
68 //===----------------------------------------------------------------------===//
69 
70 void transform::EliminateEmptyTensorsOp::getEffects(
71     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
72   onlyReadsHandle(getTarget(), effects);
73   modifiesPayload(effects);
74 }
75 
76 DiagnosedSilenceableFailure
77 transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults,
78                                           TransformState &state) {
79   IRRewriter rewriter(getContext());
80   OneShotBufferizationOptions options;
81   options.allowReturnAllocs = true;
82 
83   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
84   for (Operation *target : payloadOps) {
85     OneShotAnalysisState state(target, options);
86     if (failed(analyzeOp(target, state)))
87       return mlir::emitSilenceableFailure(target->getLoc())
88              << "failed to analyze op";
89     if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
90             rewriter, target, state)))
91       return mlir::emitSilenceableFailure(target->getLoc())
92              << "failed to eliminate insert_slice anchored tensor.empty ops";
93   }
94   return DiagnosedSilenceableFailure::success();
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // EmptyTensorToAllocTensorOp
99 //===----------------------------------------------------------------------===//
100 
101 DiagnosedSilenceableFailure
102 EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target,
103                                        ApplyToEachResultList &results,
104                                        transform::TransformState &state) {
105   IRRewriter rewriter(target->getContext());
106   rewriter.setInsertionPoint(target);
107   auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
108       target, target.getType(), target.getDynamicSizes());
109   results.push_back(alloc);
110   return DiagnosedSilenceableFailure::success();
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Transform op registration
115 //===----------------------------------------------------------------------===//
116 
117 namespace {
118 /// Registers new ops and declares PDL as dependent dialect since the additional
119 /// ops are using PDL types for operands and results.
120 class BufferizationTransformDialectExtension
121     : public transform::TransformDialectExtension<
122           BufferizationTransformDialectExtension> {
123 public:
124   using Base::Base;
125 
126   void init() {
127     declareDependentDialect<pdl::PDLDialect>();
128 
129     declareGeneratedDialect<bufferization::BufferizationDialect>();
130     declareGeneratedDialect<memref::MemRefDialect>();
131 
132     registerTransformOps<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
135         >();
136   }
137 };
138 } // namespace
139 
140 #define GET_OP_CLASSES
141 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
142 
143 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
144 
145 void mlir::bufferization::registerTransformDialectExtension(
146     DialectRegistry &registry) {
147   registry.addExtensions<BufferizationTransformDialectExtension>();
148 }
149