1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::vector;
21
22 namespace mlir {
23 namespace vector {
24 namespace {
25
26 /// Bufferization of vector.transfer_read. Replaced with a new
27 /// vector.transfer_read that operates on a memref.
28 struct TransferReadOpInterface
29 : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30 vector::TransferReadOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::TransferReadOpInterface31 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
32 const AnalysisState &state) const {
33 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
34 "only tensor types expected");
35 return true;
36 }
37
bufferizesToMemoryWritemlir::vector::__anonf19189740111::TransferReadOpInterface38 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
39 const AnalysisState &state) const {
40 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
41 "only tensor types expected");
42 return false;
43 }
44
getAliasingValuesmlir::vector::__anonf19189740111::TransferReadOpInterface45 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
46 const AnalysisState &state) const {
47 return {};
48 }
49
bufferizemlir::vector::__anonf19189740111::TransferReadOpInterface50 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51 const BufferizationOptions &options) const {
52 auto readOp = cast<vector::TransferReadOp>(op);
53 assert(isa<TensorType>(readOp.getShapedType()) &&
54 "only tensor types expected");
55 FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
56 if (failed(buffer))
57 return failure();
58 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
59 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
60 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
61 readOp.getInBoundsAttr());
62 return success();
63 }
64 };
65
66 /// Bufferization of vector.transfer_write. Replace with a new
67 /// vector.transfer_write that operates on a memref.
68 ///
69 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
70 /// implementations for DestinationStyle ops.
71 struct TransferWriteOpInterface
72 : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
73 vector::TransferWriteOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::TransferWriteOpInterface74 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75 const AnalysisState &state) const {
76 auto writeOp = cast<vector::TransferWriteOp>(op);
77
78 // Does not bufferize to a memory read if the vector completely overwrites
79 // the buffer.
80
81 // Destination must have static shape.
82 if (!writeOp.getShapedType().hasStaticShape())
83 return true;
84
85 // All offsets must be 0.
86 for (Value offset : writeOp.getIndices()) {
87 if (getConstantIntValue(offset) != 0)
88 return true;
89 }
90
91 // There is no mask.
92 if (writeOp.isMasked())
93 return true;
94
95 // Must write at least the full dimension size.
96 for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
97 writeOp.getVectorType().getShape())) {
98 if (d0 > d1)
99 return true;
100 }
101
102 return false;
103 }
104
bufferizemlir::vector::__anonf19189740111::TransferWriteOpInterface105 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
106 const BufferizationOptions &options) const {
107 auto writeOp = cast<vector::TransferWriteOp>(op);
108 assert(isa<TensorType>(writeOp.getShapedType()) &&
109 "only tensor types expected");
110
111 // Create a new transfer_write on buffer that doesn't have a return value.
112 FailureOr<Value> resultBuffer =
113 getBuffer(rewriter, writeOp.getSource(), options);
114 if (failed(resultBuffer))
115 return failure();
116 rewriter.create<vector::TransferWriteOp>(
117 writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
118 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
119 writeOp.getMask(), writeOp.getInBoundsAttr());
120 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
121
122 return success();
123 }
124 };
125
126 /// Bufferization of vector.gather. Replaced with a new vector.gather that
127 /// operates on a memref.
128 struct GatherOpInterface
129 : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
130 vector::GatherOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::GatherOpInterface131 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
132 const AnalysisState &state) const {
133 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
134 "only tensor types expected");
135 return true;
136 }
137
bufferizesToMemoryWritemlir::vector::__anonf19189740111::GatherOpInterface138 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
139 const AnalysisState &state) const {
140 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
141 "only tensor types expected");
142 return false;
143 }
144
getAliasingValuesmlir::vector::__anonf19189740111::GatherOpInterface145 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
146 const AnalysisState &state) const {
147 return {};
148 }
149
bufferizemlir::vector::__anonf19189740111::GatherOpInterface150 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
151 const BufferizationOptions &options) const {
152 auto gatherOp = cast<vector::GatherOp>(op);
153 assert(isa<TensorType>(gatherOp.getBaseType()) &&
154 "only tensor types expected");
155 FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
156 if (failed(buffer))
157 return failure();
158 replaceOpWithNewBufferizedOp<vector::GatherOp>(
159 rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
160 gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
161 gatherOp.getPassThru());
162 return success();
163 }
164 };
165
166 /// Bufferization of vector.mask. Replaced with a new vector.mask that
167 /// operates on a memref.
168 struct MaskOpInterface
169 : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
170 vector::MaskOp> {
171 AliasingOpOperandList
getAliasingOpOperandsmlir::vector::__anonf19189740111::MaskOpInterface172 getAliasingOpOperands(Operation *op, Value value,
173 const AnalysisState &state) const {
174 // MaskOps do not have tensor OpOperands. The yielded values are the result
175 // of the wrapped op.
176 auto maskOp = cast<vector::MaskOp>(op);
177 size_t resultNum = std::distance(op->getOpResults().begin(),
178 llvm::find(op->getOpResults(), value));
179 auto yieldOp =
180 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
181 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
182 }
183
resolveConflictsmlir::vector::__anonf19189740111::MaskOpInterface184 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
185 const AnalysisState &state) const {
186 auto bufferizableOp = cast<BufferizableOpInterface>(op);
187 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
188 return failure();
189
190 // TODO: Remove this function when vector.mask bodies can bufferize
191 // out-of-place. This is currently not supported because yielding allocs
192 // from a block leads to a memory leak and because vector.mask supports only
193 // a single op in its body.
194 auto maskOp = cast<vector::MaskOp>(op);
195 if (!maskOp.getMaskRegion()
196 .front()
197 .getOps<bufferization::AllocTensorOp>()
198 .empty())
199 return op->emitOpError("body must bufferize in-place");
200
201 return success();
202 }
203
bufferizemlir::vector::__anonf19189740111::MaskOpInterface204 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
205 const BufferizationOptions &options) const {
206 auto maskOp = cast<vector::MaskOp>(op);
207
208 // Do not bufferize if the masked op is not bufferizable.
209 Operation *maskedOp = maskOp.getMaskableOp();
210 if (!options.dynCastBufferizableOp(maskedOp))
211 return success();
212
213 // Update the terminator: Drop all operands that are not results of the
214 // masked op.
215 auto yieldOp =
216 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
217 SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
218 SmallVector<Value> newYieldedValues;
219 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
220 if (llvm::is_contained(maskedOp->getOpResults(), it.value())) {
221 newYieldedValues.push_back(it.value());
222 } else {
223 // This used to be a tensor result of the masked op, but is now a memref
224 // that is defined outside of the vector.mask op.
225 newReturnValues[it.index()] = it.value();
226 }
227 }
228 rewriter.modifyOpInPlace(yieldOp, [&]() {
229 yieldOp.getOperandsMutable().assign(newYieldedValues);
230 });
231
232 // Create a new vector.mask op.
233 ValueRange newYieldedValuesRange(newYieldedValues);
234 TypeRange newResultTypes(newYieldedValuesRange);
235 auto newOp = rewriter.create<vector::MaskOp>(
236 op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
237 /*maskableOp=*/nullptr,
238 /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
239 newOp.getRegion().takeBody(maskOp.getMaskRegion());
240
241 // Replace all uses of the old vector.mask op.
242 int idx = 0;
243 for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
244 if (!newReturnValues[i])
245 newReturnValues[i] = newOp->getResult(idx++);
246 }
247 replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
248 return success();
249 }
250 };
251
252 /// Bufferization of vector.yield. Replaced with a new vector.yield that
253 /// operates on a memref.
254 struct YieldOpInterface
255 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
256 vector::YieldOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::YieldOpInterface257 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
258 const AnalysisState &state) const {
259 return true;
260 }
261
bufferizesToMemoryWritemlir::vector::__anonf19189740111::YieldOpInterface262 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
263 const AnalysisState &state) const {
264 return false;
265 }
266
getAliasingValuesmlir::vector::__anonf19189740111::YieldOpInterface267 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
268 const AnalysisState &state) const {
269 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
270 BufferRelation::Equivalent}};
271 }
272
mustBufferizeInPlacemlir::vector::__anonf19189740111::YieldOpInterface273 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
274 const AnalysisState &state) const {
275 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
276 // may be generated inside the block. We should not return/yield allocations
277 // when possible.
278 return true;
279 }
280
bufferizemlir::vector::__anonf19189740111::YieldOpInterface281 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
282 const BufferizationOptions &options) const {
283 auto yieldOp = cast<vector::YieldOp>(op);
284
285 // Only supported as a vector.mask terminator.
286 auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
287 if (!maskOp)
288 return yieldOp->emitError("unsupported vector::YieldOp parent");
289
290 // Do not bufferize if the masked op is not bufferizable.
291 Operation *maskedOp = &maskOp.getMaskRegion().front().front();
292 if (!options.dynCastBufferizableOp(maskedOp))
293 return success();
294
295 // Create a new terminator with the same number of operands. Some of these
296 // may get dropped during the bufferization of vector.mask.
297 SmallVector<Value> newResults;
298 for (Value value : yieldOp.getOperands()) {
299 if (isa<TensorType>(value.getType())) {
300 FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
301 if (failed(maybeBuffer))
302 return failure();
303 newResults.push_back(*maybeBuffer);
304 } else {
305 newResults.push_back(value);
306 }
307 }
308
309 replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
310 return success();
311 }
312 };
313
314 } // namespace
315 } // namespace vector
316 } // namespace mlir
317
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)318 void mlir::vector::registerBufferizableOpInterfaceExternalModels(
319 DialectRegistry ®istry) {
320 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
321 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
322 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
323 GatherOp::attachInterface<GatherOpInterface>(*ctx);
324 MaskOp::attachInterface<MaskOpInterface>(*ctx);
325 YieldOp::attachInterface<YieldOpInterface>(*ctx);
326 });
327 }
328