xref: /llvm-project/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (revision 2f3ac28cb2f7fc24c6ff742af571b58419c0adaa)
1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/IR/FunctionInterfaces.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 using namespace mlir::transform;
23 
24 //===----------------------------------------------------------------------===//
25 // OneShotBufferizeOp
26 //===----------------------------------------------------------------------===//
27 
28 DiagnosedSilenceableFailure
29 transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
30                                      TransformState &state) {
31   OneShotBufferizationOptions options;
32   options.allowReturnAllocs = getAllowReturnAllocs();
33   options.allowUnknownOps = getAllowUnknownOps();
34   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
35   options.createDeallocs = getCreateDeallocs();
36   options.testAnalysisOnly = getTestAnalysisOnly();
37   options.printConflicts = getPrintConflicts();
38   if (getFunctionBoundaryTypeConversion().has_value())
39     options.setFunctionBoundaryTypeConversion(
40         *getFunctionBoundaryTypeConversion());
41 
42   auto payloadOps = state.getPayloadOps(getTarget());
43   for (Operation *target : payloadOps) {
44     if (!isa<ModuleOp, FunctionOpInterface>(target))
45       return emitSilenceableError() << "expected module or function target";
46     auto moduleOp = dyn_cast<ModuleOp>(target);
47     if (options.bufferizeFunctionBoundaries) {
48       if (!moduleOp)
49         return emitSilenceableError() << "expected module target";
50       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
51         return emitSilenceableError() << "bufferization failed";
52     } else {
53       if (failed(bufferization::runOneShotBufferize(target, options)))
54         return emitSilenceableError() << "bufferization failed";
55     }
56   }
57 
58   // This transform op is currently restricted to ModuleOps and function ops.
59   // Such ops are modified in-place.
60   transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
61   return DiagnosedSilenceableFailure::success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // EliminateEmptyTensorsOp
66 //===----------------------------------------------------------------------===//
67 
68 void transform::EliminateEmptyTensorsOp::getEffects(
69     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
70   onlyReadsHandle(getTarget(), effects);
71   modifiesPayload(effects);
72 }
73 
74 DiagnosedSilenceableFailure
75 transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults,
76                                           TransformState &state) {
77   IRRewriter rewriter(getContext());
78   OneShotBufferizationOptions options;
79   options.allowReturnAllocs = true;
80 
81   for (Operation *target : state.getPayloadOps(getTarget())) {
82     OneShotAnalysisState state(target, options);
83     if (failed(analyzeOp(target, state)))
84       return mlir::emitSilenceableFailure(target->getLoc())
85              << "failed to analyze op";
86     if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
87             rewriter, target, state)))
88       return mlir::emitSilenceableFailure(target->getLoc())
89              << "failed to eliminate insert_slice anchored tensor.empty ops";
90   }
91   return DiagnosedSilenceableFailure::success();
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // EmptyTensorToAllocTensorOp
96 //===----------------------------------------------------------------------===//
97 
98 DiagnosedSilenceableFailure
99 EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target,
100                                        ApplyToEachResultList &results,
101                                        transform::TransformState &state) {
102   IRRewriter rewriter(target->getContext());
103   rewriter.setInsertionPoint(target);
104   auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
105       target, target.getType(), target.getDynamicSizes());
106   results.push_back(alloc);
107   return DiagnosedSilenceableFailure::success();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Transform op registration
112 //===----------------------------------------------------------------------===//
113 
114 namespace {
115 /// Registers new ops and declares PDL as dependent dialect since the additional
116 /// ops are using PDL types for operands and results.
117 class BufferizationTransformDialectExtension
118     : public transform::TransformDialectExtension<
119           BufferizationTransformDialectExtension> {
120 public:
121   using Base::Base;
122 
123   void init() {
124     declareGeneratedDialect<bufferization::BufferizationDialect>();
125     declareGeneratedDialect<memref::MemRefDialect>();
126 
127     registerTransformOps<
128 #define GET_OP_LIST
129 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
130         >();
131   }
132 };
133 } // namespace
134 
135 #define GET_OP_CLASSES
136 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
137 
138 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
139 
140 void mlir::bufferization::registerTransformDialectExtension(
141     DialectRegistry &registry) {
142   registry.addExtensions<BufferizationTransformDialectExtension>();
143 }
144