xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (revision 5fcf907b34355980f77d7665a175b05fea7a6b7b)
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::vector;
21 
22 namespace mlir {
23 namespace vector {
24 namespace {
25 
26 /// Bufferization of vector.transfer_read. Replaced with a new
27 /// vector.transfer_read that operates on a memref.
28 struct TransferReadOpInterface
29     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30                                                     vector::TransferReadOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::TransferReadOpInterface31   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
32                               const AnalysisState &state) const {
33     assert(isa<RankedTensorType>(opOperand.get().getType()) &&
34            "only tensor types expected");
35     return true;
36   }
37 
bufferizesToMemoryWritemlir::vector::__anonf19189740111::TransferReadOpInterface38   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
39                                const AnalysisState &state) const {
40     assert(isa<RankedTensorType>(opOperand.get().getType()) &&
41            "only tensor types expected");
42     return false;
43   }
44 
getAliasingValuesmlir::vector::__anonf19189740111::TransferReadOpInterface45   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
46                                       const AnalysisState &state) const {
47     return {};
48   }
49 
bufferizemlir::vector::__anonf19189740111::TransferReadOpInterface50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto readOp = cast<vector::TransferReadOp>(op);
53     assert(isa<TensorType>(readOp.getShapedType()) &&
54            "only tensor types expected");
55     FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
56     if (failed(buffer))
57       return failure();
58     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
59         rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
60         readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
61         readOp.getInBoundsAttr());
62     return success();
63   }
64 };
65 
66 /// Bufferization of vector.transfer_write. Replace with a new
67 /// vector.transfer_write that operates on a memref.
68 ///
69 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
70 /// implementations for DestinationStyle ops.
71 struct TransferWriteOpInterface
72     : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
73                                                      vector::TransferWriteOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::TransferWriteOpInterface74   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75                               const AnalysisState &state) const {
76     auto writeOp = cast<vector::TransferWriteOp>(op);
77 
78     // Does not bufferize to a memory read if the vector completely overwrites
79     // the buffer.
80 
81     // Destination must have static shape.
82     if (!writeOp.getShapedType().hasStaticShape())
83       return true;
84 
85     // All offsets must be 0.
86     for (Value offset : writeOp.getIndices()) {
87       if (getConstantIntValue(offset) != 0)
88         return true;
89     }
90 
91     // There is no mask.
92     if (writeOp.isMasked())
93       return true;
94 
95     // Must write at least the full dimension size.
96     for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
97                                    writeOp.getVectorType().getShape())) {
98       if (d0 > d1)
99         return true;
100     }
101 
102     return false;
103   }
104 
bufferizemlir::vector::__anonf19189740111::TransferWriteOpInterface105   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
106                           const BufferizationOptions &options) const {
107     auto writeOp = cast<vector::TransferWriteOp>(op);
108     assert(isa<TensorType>(writeOp.getShapedType()) &&
109            "only tensor types expected");
110 
111     // Create a new transfer_write on buffer that doesn't have a return value.
112     FailureOr<Value> resultBuffer =
113         getBuffer(rewriter, writeOp.getSource(), options);
114     if (failed(resultBuffer))
115       return failure();
116     rewriter.create<vector::TransferWriteOp>(
117         writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
118         writeOp.getIndices(), writeOp.getPermutationMapAttr(),
119         writeOp.getMask(), writeOp.getInBoundsAttr());
120     replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
121 
122     return success();
123   }
124 };
125 
126 /// Bufferization of vector.gather. Replaced with a new vector.gather that
127 /// operates on a memref.
128 struct GatherOpInterface
129     : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
130                                                     vector::GatherOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::GatherOpInterface131   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
132                               const AnalysisState &state) const {
133     assert(isa<RankedTensorType>(opOperand.get().getType()) &&
134            "only tensor types expected");
135     return true;
136   }
137 
bufferizesToMemoryWritemlir::vector::__anonf19189740111::GatherOpInterface138   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
139                                const AnalysisState &state) const {
140     assert(isa<RankedTensorType>(opOperand.get().getType()) &&
141            "only tensor types expected");
142     return false;
143   }
144 
getAliasingValuesmlir::vector::__anonf19189740111::GatherOpInterface145   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
146                                       const AnalysisState &state) const {
147     return {};
148   }
149 
bufferizemlir::vector::__anonf19189740111::GatherOpInterface150   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
151                           const BufferizationOptions &options) const {
152     auto gatherOp = cast<vector::GatherOp>(op);
153     assert(isa<TensorType>(gatherOp.getBaseType()) &&
154            "only tensor types expected");
155     FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
156     if (failed(buffer))
157       return failure();
158     replaceOpWithNewBufferizedOp<vector::GatherOp>(
159         rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
160         gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
161         gatherOp.getPassThru());
162     return success();
163   }
164 };
165 
166 /// Bufferization of vector.mask. Replaced with a new vector.mask that
167 /// operates on a memref.
168 struct MaskOpInterface
169     : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
170                                                     vector::MaskOp> {
171   AliasingOpOperandList
getAliasingOpOperandsmlir::vector::__anonf19189740111::MaskOpInterface172   getAliasingOpOperands(Operation *op, Value value,
173                         const AnalysisState &state) const {
174     // MaskOps do not have tensor OpOperands. The yielded values are the result
175     // of the wrapped op.
176     auto maskOp = cast<vector::MaskOp>(op);
177     size_t resultNum = std::distance(op->getOpResults().begin(),
178                                      llvm::find(op->getOpResults(), value));
179     auto yieldOp =
180         cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
181     return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
182   }
183 
resolveConflictsmlir::vector::__anonf19189740111::MaskOpInterface184   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
185                                  const AnalysisState &state) const {
186     auto bufferizableOp = cast<BufferizableOpInterface>(op);
187     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
188       return failure();
189 
190     // TODO: Remove this function when vector.mask bodies can bufferize
191     // out-of-place. This is currently not supported because yielding allocs
192     // from a block leads to a memory leak and because vector.mask supports only
193     // a single op in its body.
194     auto maskOp = cast<vector::MaskOp>(op);
195     if (!maskOp.getMaskRegion()
196              .front()
197              .getOps<bufferization::AllocTensorOp>()
198              .empty())
199       return op->emitOpError("body must bufferize in-place");
200 
201     return success();
202   }
203 
bufferizemlir::vector::__anonf19189740111::MaskOpInterface204   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
205                           const BufferizationOptions &options) const {
206     auto maskOp = cast<vector::MaskOp>(op);
207 
208     // Do not bufferize if the masked op is not bufferizable.
209     Operation *maskedOp = maskOp.getMaskableOp();
210     if (!options.dynCastBufferizableOp(maskedOp))
211       return success();
212 
213     // Update the terminator: Drop all operands that are not results of the
214     // masked op.
215     auto yieldOp =
216         cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
217     SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
218     SmallVector<Value> newYieldedValues;
219     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
220       if (llvm::is_contained(maskedOp->getOpResults(), it.value())) {
221         newYieldedValues.push_back(it.value());
222       } else {
223         // This used to be a tensor result of the masked op, but is now a memref
224         // that is defined outside of the vector.mask op.
225         newReturnValues[it.index()] = it.value();
226       }
227     }
228     rewriter.modifyOpInPlace(yieldOp, [&]() {
229       yieldOp.getOperandsMutable().assign(newYieldedValues);
230     });
231 
232     // Create a new vector.mask op.
233     ValueRange newYieldedValuesRange(newYieldedValues);
234     TypeRange newResultTypes(newYieldedValuesRange);
235     auto newOp = rewriter.create<vector::MaskOp>(
236         op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
237         /*maskableOp=*/nullptr,
238         /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
239     newOp.getRegion().takeBody(maskOp.getMaskRegion());
240 
241     // Replace all uses of the old vector.mask op.
242     int idx = 0;
243     for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
244       if (!newReturnValues[i])
245         newReturnValues[i] = newOp->getResult(idx++);
246     }
247     replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
248     return success();
249   }
250 };
251 
252 /// Bufferization of vector.yield. Replaced with a new vector.yield that
253 /// operates on a memref.
254 struct YieldOpInterface
255     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
256                                                     vector::YieldOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::YieldOpInterface257   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
258                               const AnalysisState &state) const {
259     return true;
260   }
261 
bufferizesToMemoryWritemlir::vector::__anonf19189740111::YieldOpInterface262   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
263                                const AnalysisState &state) const {
264     return false;
265   }
266 
getAliasingValuesmlir::vector::__anonf19189740111::YieldOpInterface267   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
268                                       const AnalysisState &state) const {
269     return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
270              BufferRelation::Equivalent}};
271   }
272 
mustBufferizeInPlacemlir::vector::__anonf19189740111::YieldOpInterface273   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
274                             const AnalysisState &state) const {
275     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
276     // may be generated inside the block. We should not return/yield allocations
277     // when possible.
278     return true;
279   }
280 
bufferizemlir::vector::__anonf19189740111::YieldOpInterface281   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
282                           const BufferizationOptions &options) const {
283     auto yieldOp = cast<vector::YieldOp>(op);
284 
285     // Only supported as a vector.mask terminator.
286     auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
287     if (!maskOp)
288       return yieldOp->emitError("unsupported vector::YieldOp parent");
289 
290     // Do not bufferize if the masked op is not bufferizable.
291     Operation *maskedOp = &maskOp.getMaskRegion().front().front();
292     if (!options.dynCastBufferizableOp(maskedOp))
293       return success();
294 
295     // Create a new terminator with the same number of operands. Some of these
296     // may get dropped during the bufferization of vector.mask.
297     SmallVector<Value> newResults;
298     for (Value value : yieldOp.getOperands()) {
299       if (isa<TensorType>(value.getType())) {
300         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
301         if (failed(maybeBuffer))
302           return failure();
303         newResults.push_back(*maybeBuffer);
304       } else {
305         newResults.push_back(value);
306       }
307     }
308 
309     replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
310     return success();
311   }
312 };
313 
314 } // namespace
315 } // namespace vector
316 } // namespace mlir
317 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)318 void mlir::vector::registerBufferizableOpInterfaceExternalModels(
319     DialectRegistry &registry) {
320   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
321     TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
322     TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
323     GatherOp::attachInterface<GatherOpInterface>(*ctx);
324     MaskOp::attachInterface<MaskOpInterface>(*ctx);
325     YieldOp::attachInterface<YieldOpInterface>(*ctx);
326   });
327 }
328