15523c145SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
25523c145SMatthias Springer //
35523c145SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45523c145SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
55523c145SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65523c145SMatthias Springer //
75523c145SMatthias Springer //===----------------------------------------------------------------------===//
85523c145SMatthias Springer
95523c145SMatthias Springer #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
105523c145SMatthias Springer
115523c145SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12199f368eSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13bf531f28SMatthias Springer #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
145523c145SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
155523c145SMatthias Springer #include "mlir/IR/Dialect.h"
165523c145SMatthias Springer #include "mlir/IR/Operation.h"
175523c145SMatthias Springer
185523c145SMatthias Springer using namespace mlir;
195523c145SMatthias Springer using namespace mlir::bufferization;
205523c145SMatthias Springer using namespace mlir::vector;
215523c145SMatthias Springer
225523c145SMatthias Springer namespace mlir {
235523c145SMatthias Springer namespace vector {
245523c145SMatthias Springer namespace {
255523c145SMatthias Springer
265523c145SMatthias Springer /// Bufferization of vector.transfer_read. Replaced with a new
275523c145SMatthias Springer /// vector.transfer_read that operates on a memref.
285523c145SMatthias Springer struct TransferReadOpInterface
295523c145SMatthias Springer : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
305523c145SMatthias Springer vector::TransferReadOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::TransferReadOpInterface315523c145SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
329597b16aSMatthias Springer const AnalysisState &state) const {
335550c821STres Popp assert(isa<RankedTensorType>(opOperand.get().getType()) &&
345523c145SMatthias Springer "only tensor types expected");
355523c145SMatthias Springer return true;
365523c145SMatthias Springer }
375523c145SMatthias Springer
bufferizesToMemoryWritemlir::vector::__anonf19189740111::TransferReadOpInterface385523c145SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
399597b16aSMatthias Springer const AnalysisState &state) const {
405550c821STres Popp assert(isa<RankedTensorType>(opOperand.get().getType()) &&
415523c145SMatthias Springer "only tensor types expected");
425523c145SMatthias Springer return false;
435523c145SMatthias Springer }
445523c145SMatthias Springer
getAliasingValuesmlir::vector::__anonf19189740111::TransferReadOpInterface45a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
469597b16aSMatthias Springer const AnalysisState &state) const {
47585a8a32SMatthias Springer return {};
485523c145SMatthias Springer }
495523c145SMatthias Springer
bufferizemlir::vector::__anonf19189740111::TransferReadOpInterface505523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51b55d55ecSMatthias Springer const BufferizationOptions &options) const {
525523c145SMatthias Springer auto readOp = cast<vector::TransferReadOp>(op);
535550c821STres Popp assert(isa<TensorType>(readOp.getShapedType()) &&
545523c145SMatthias Springer "only tensor types expected");
555d50f51cSMatthias Springer FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
565d50f51cSMatthias Springer if (failed(buffer))
575d50f51cSMatthias Springer return failure();
585523c145SMatthias Springer replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
595d50f51cSMatthias Springer rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
607c38fd60SJacques Pienaar readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
617c38fd60SJacques Pienaar readOp.getInBoundsAttr());
625523c145SMatthias Springer return success();
635523c145SMatthias Springer }
645523c145SMatthias Springer };
655523c145SMatthias Springer
665523c145SMatthias Springer /// Bufferization of vector.transfer_write. Replace with a new
675523c145SMatthias Springer /// vector.transfer_write that operates on a memref.
68bf531f28SMatthias Springer ///
69bf531f28SMatthias Springer /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
70bf531f28SMatthias Springer /// implementations for DestinationStyle ops.
715523c145SMatthias Springer struct TransferWriteOpInterface
72bf531f28SMatthias Springer : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
735523c145SMatthias Springer vector::TransferWriteOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::TransferWriteOpInterface7480853a16SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
7580853a16SMatthias Springer const AnalysisState &state) const {
7680853a16SMatthias Springer auto writeOp = cast<vector::TransferWriteOp>(op);
7780853a16SMatthias Springer
7880853a16SMatthias Springer // Does not bufferize to a memory read if the vector completely overwrites
7980853a16SMatthias Springer // the buffer.
8080853a16SMatthias Springer
8180853a16SMatthias Springer // Destination must have static shape.
8280853a16SMatthias Springer if (!writeOp.getShapedType().hasStaticShape())
8380853a16SMatthias Springer return true;
8480853a16SMatthias Springer
8580853a16SMatthias Springer // All offsets must be 0.
8680853a16SMatthias Springer for (Value offset : writeOp.getIndices()) {
8780853a16SMatthias Springer if (getConstantIntValue(offset) != 0)
8880853a16SMatthias Springer return true;
8980853a16SMatthias Springer }
9080853a16SMatthias Springer
9180853a16SMatthias Springer // There is no mask.
9280853a16SMatthias Springer if (writeOp.isMasked())
9380853a16SMatthias Springer return true;
9480853a16SMatthias Springer
9580853a16SMatthias Springer // Must write at least the full dimension size.
9680853a16SMatthias Springer for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
9780853a16SMatthias Springer writeOp.getVectorType().getShape())) {
9880853a16SMatthias Springer if (d0 > d1)
9980853a16SMatthias Springer return true;
10080853a16SMatthias Springer }
10180853a16SMatthias Springer
10280853a16SMatthias Springer return false;
10380853a16SMatthias Springer }
10480853a16SMatthias Springer
bufferizemlir::vector::__anonf19189740111::TransferWriteOpInterface1055523c145SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
106b55d55ecSMatthias Springer const BufferizationOptions &options) const {
1075523c145SMatthias Springer auto writeOp = cast<vector::TransferWriteOp>(op);
1085550c821STres Popp assert(isa<TensorType>(writeOp.getShapedType()) &&
1095523c145SMatthias Springer "only tensor types expected");
1105523c145SMatthias Springer
1115523c145SMatthias Springer // Create a new transfer_write on buffer that doesn't have a return value.
1125d50f51cSMatthias Springer FailureOr<Value> resultBuffer =
1135d50f51cSMatthias Springer getBuffer(rewriter, writeOp.getSource(), options);
1145d50f51cSMatthias Springer if (failed(resultBuffer))
1155d50f51cSMatthias Springer return failure();
1165523c145SMatthias Springer rewriter.create<vector::TransferWriteOp>(
1175d50f51cSMatthias Springer writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
1187c38fd60SJacques Pienaar writeOp.getIndices(), writeOp.getPermutationMapAttr(),
119a28ce1a4SMatthias Springer writeOp.getMask(), writeOp.getInBoundsAttr());
1205d50f51cSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
1215523c145SMatthias Springer
1225523c145SMatthias Springer return success();
1235523c145SMatthias Springer }
1245523c145SMatthias Springer };
1255523c145SMatthias Springer
12666c2b768SJerry Wu /// Bufferization of vector.gather. Replaced with a new vector.gather that
12766c2b768SJerry Wu /// operates on a memref.
12866c2b768SJerry Wu struct GatherOpInterface
12966c2b768SJerry Wu : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
13066c2b768SJerry Wu vector::GatherOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::GatherOpInterface13166c2b768SJerry Wu bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
13266c2b768SJerry Wu const AnalysisState &state) const {
1335550c821STres Popp assert(isa<RankedTensorType>(opOperand.get().getType()) &&
13466c2b768SJerry Wu "only tensor types expected");
13566c2b768SJerry Wu return true;
13666c2b768SJerry Wu }
13766c2b768SJerry Wu
bufferizesToMemoryWritemlir::vector::__anonf19189740111::GatherOpInterface13866c2b768SJerry Wu bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
13966c2b768SJerry Wu const AnalysisState &state) const {
1405550c821STres Popp assert(isa<RankedTensorType>(opOperand.get().getType()) &&
14166c2b768SJerry Wu "only tensor types expected");
14266c2b768SJerry Wu return false;
14366c2b768SJerry Wu }
14466c2b768SJerry Wu
getAliasingValuesmlir::vector::__anonf19189740111::GatherOpInterface145a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
14666c2b768SJerry Wu const AnalysisState &state) const {
14766c2b768SJerry Wu return {};
14866c2b768SJerry Wu }
14966c2b768SJerry Wu
bufferizemlir::vector::__anonf19189740111::GatherOpInterface15066c2b768SJerry Wu LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
15166c2b768SJerry Wu const BufferizationOptions &options) const {
15266c2b768SJerry Wu auto gatherOp = cast<vector::GatherOp>(op);
1535550c821STres Popp assert(isa<TensorType>(gatherOp.getBaseType()) &&
15466c2b768SJerry Wu "only tensor types expected");
15566c2b768SJerry Wu FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
15666c2b768SJerry Wu if (failed(buffer))
15766c2b768SJerry Wu return failure();
15866c2b768SJerry Wu replaceOpWithNewBufferizedOp<vector::GatherOp>(
15966c2b768SJerry Wu rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
16066c2b768SJerry Wu gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
16166c2b768SJerry Wu gatherOp.getPassThru());
16266c2b768SJerry Wu return success();
16366c2b768SJerry Wu }
16466c2b768SJerry Wu };
16566c2b768SJerry Wu
166199f368eSMatthias Springer /// Bufferization of vector.mask. Replaced with a new vector.mask that
167199f368eSMatthias Springer /// operates on a memref.
168199f368eSMatthias Springer struct MaskOpInterface
169199f368eSMatthias Springer : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
170199f368eSMatthias Springer vector::MaskOp> {
1711ac248e4SMatthias Springer AliasingOpOperandList
getAliasingOpOperandsmlir::vector::__anonf19189740111::MaskOpInterface172a02ad6c1SMatthias Springer getAliasingOpOperands(Operation *op, Value value,
173199f368eSMatthias Springer const AnalysisState &state) const {
174199f368eSMatthias Springer // MaskOps do not have tensor OpOperands. The yielded values are the result
175199f368eSMatthias Springer // of the wrapped op.
176199f368eSMatthias Springer auto maskOp = cast<vector::MaskOp>(op);
177199f368eSMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(),
178a02ad6c1SMatthias Springer llvm::find(op->getOpResults(), value));
179199f368eSMatthias Springer auto yieldOp =
180199f368eSMatthias Springer cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
1819fa6b350SMatthias Springer return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
182199f368eSMatthias Springer }
183199f368eSMatthias Springer
resolveConflictsmlir::vector::__anonf19189740111::MaskOpInterface184199f368eSMatthias Springer LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
185199f368eSMatthias Springer const AnalysisState &state) const {
186199f368eSMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(op);
187199f368eSMatthias Springer if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
188199f368eSMatthias Springer return failure();
189199f368eSMatthias Springer
190199f368eSMatthias Springer // TODO: Remove this function when vector.mask bodies can bufferize
191199f368eSMatthias Springer // out-of-place. This is currently not supported because yielding allocs
192199f368eSMatthias Springer // from a block leads to a memory leak and because vector.mask supports only
193199f368eSMatthias Springer // a single op in its body.
194199f368eSMatthias Springer auto maskOp = cast<vector::MaskOp>(op);
195199f368eSMatthias Springer if (!maskOp.getMaskRegion()
196199f368eSMatthias Springer .front()
197199f368eSMatthias Springer .getOps<bufferization::AllocTensorOp>()
198199f368eSMatthias Springer .empty())
199199f368eSMatthias Springer return op->emitOpError("body must bufferize in-place");
200199f368eSMatthias Springer
201199f368eSMatthias Springer return success();
202199f368eSMatthias Springer }
203199f368eSMatthias Springer
bufferizemlir::vector::__anonf19189740111::MaskOpInterface204199f368eSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
205199f368eSMatthias Springer const BufferizationOptions &options) const {
206199f368eSMatthias Springer auto maskOp = cast<vector::MaskOp>(op);
207199f368eSMatthias Springer
208199f368eSMatthias Springer // Do not bufferize if the masked op is not bufferizable.
209199f368eSMatthias Springer Operation *maskedOp = maskOp.getMaskableOp();
210199f368eSMatthias Springer if (!options.dynCastBufferizableOp(maskedOp))
211199f368eSMatthias Springer return success();
212199f368eSMatthias Springer
213199f368eSMatthias Springer // Update the terminator: Drop all operands that are not results of the
214199f368eSMatthias Springer // masked op.
215199f368eSMatthias Springer auto yieldOp =
216199f368eSMatthias Springer cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
217199f368eSMatthias Springer SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
218199f368eSMatthias Springer SmallVector<Value> newYieldedValues;
219199f368eSMatthias Springer for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
2208e8bbbd4SKazu Hirata if (llvm::is_contained(maskedOp->getOpResults(), it.value())) {
221199f368eSMatthias Springer newYieldedValues.push_back(it.value());
222199f368eSMatthias Springer } else {
223199f368eSMatthias Springer // This used to be a tensor result of the masked op, but is now a memref
224199f368eSMatthias Springer // that is defined outside of the vector.mask op.
225199f368eSMatthias Springer newReturnValues[it.index()] = it.value();
226199f368eSMatthias Springer }
227199f368eSMatthias Springer }
228*5fcf907bSMatthias Springer rewriter.modifyOpInPlace(yieldOp, [&]() {
229199f368eSMatthias Springer yieldOp.getOperandsMutable().assign(newYieldedValues);
230199f368eSMatthias Springer });
231199f368eSMatthias Springer
232199f368eSMatthias Springer // Create a new vector.mask op.
233017be821SMatthias Springer ValueRange newYieldedValuesRange(newYieldedValues);
234017be821SMatthias Springer TypeRange newResultTypes(newYieldedValuesRange);
235199f368eSMatthias Springer auto newOp = rewriter.create<vector::MaskOp>(
236199f368eSMatthias Springer op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
237199f368eSMatthias Springer /*maskableOp=*/nullptr,
238199f368eSMatthias Springer /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
239199f368eSMatthias Springer newOp.getRegion().takeBody(maskOp.getMaskRegion());
240199f368eSMatthias Springer
241199f368eSMatthias Springer // Replace all uses of the old vector.mask op.
242199f368eSMatthias Springer int idx = 0;
243199f368eSMatthias Springer for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
244199f368eSMatthias Springer if (!newReturnValues[i])
245199f368eSMatthias Springer newReturnValues[i] = newOp->getResult(idx++);
246199f368eSMatthias Springer }
247199f368eSMatthias Springer replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
248199f368eSMatthias Springer return success();
249199f368eSMatthias Springer }
250199f368eSMatthias Springer };
251199f368eSMatthias Springer
252199f368eSMatthias Springer /// Bufferization of vector.yield. Replaced with a new vector.yield that
253199f368eSMatthias Springer /// operates on a memref.
254199f368eSMatthias Springer struct YieldOpInterface
255199f368eSMatthias Springer : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
256199f368eSMatthias Springer vector::YieldOp> {
bufferizesToMemoryReadmlir::vector::__anonf19189740111::YieldOpInterface257199f368eSMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
258199f368eSMatthias Springer const AnalysisState &state) const {
259199f368eSMatthias Springer return true;
260199f368eSMatthias Springer }
261199f368eSMatthias Springer
bufferizesToMemoryWritemlir::vector::__anonf19189740111::YieldOpInterface262199f368eSMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
263199f368eSMatthias Springer const AnalysisState &state) const {
264199f368eSMatthias Springer return false;
265199f368eSMatthias Springer }
266199f368eSMatthias Springer
getAliasingValuesmlir::vector::__anonf19189740111::YieldOpInterface267a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
268199f368eSMatthias Springer const AnalysisState &state) const {
2699fa6b350SMatthias Springer return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
2709fa6b350SMatthias Springer BufferRelation::Equivalent}};
271199f368eSMatthias Springer }
272199f368eSMatthias Springer
mustBufferizeInPlacemlir::vector::__anonf19189740111::YieldOpInterface273199f368eSMatthias Springer bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
274199f368eSMatthias Springer const AnalysisState &state) const {
275199f368eSMatthias Springer // Yield operands always bufferize inplace. Otherwise, an alloc + copy
276199f368eSMatthias Springer // may be generated inside the block. We should not return/yield allocations
277199f368eSMatthias Springer // when possible.
278199f368eSMatthias Springer return true;
279199f368eSMatthias Springer }
280199f368eSMatthias Springer
bufferizemlir::vector::__anonf19189740111::YieldOpInterface281199f368eSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
282199f368eSMatthias Springer const BufferizationOptions &options) const {
283199f368eSMatthias Springer auto yieldOp = cast<vector::YieldOp>(op);
284199f368eSMatthias Springer
285199f368eSMatthias Springer // Only supported as a vector.mask terminator.
286199f368eSMatthias Springer auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
287199f368eSMatthias Springer if (!maskOp)
288199f368eSMatthias Springer return yieldOp->emitError("unsupported vector::YieldOp parent");
289199f368eSMatthias Springer
290199f368eSMatthias Springer // Do not bufferize if the masked op is not bufferizable.
291199f368eSMatthias Springer Operation *maskedOp = &maskOp.getMaskRegion().front().front();
292199f368eSMatthias Springer if (!options.dynCastBufferizableOp(maskedOp))
293199f368eSMatthias Springer return success();
294199f368eSMatthias Springer
295199f368eSMatthias Springer // Create a new terminator with the same number of operands. Some of these
296199f368eSMatthias Springer // may get dropped during the bufferization of vector.mask.
297199f368eSMatthias Springer SmallVector<Value> newResults;
298199f368eSMatthias Springer for (Value value : yieldOp.getOperands()) {
2995550c821STres Popp if (isa<TensorType>(value.getType())) {
300199f368eSMatthias Springer FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
301199f368eSMatthias Springer if (failed(maybeBuffer))
302199f368eSMatthias Springer return failure();
303199f368eSMatthias Springer newResults.push_back(*maybeBuffer);
304199f368eSMatthias Springer } else {
305199f368eSMatthias Springer newResults.push_back(value);
306199f368eSMatthias Springer }
307199f368eSMatthias Springer }
308199f368eSMatthias Springer
309199f368eSMatthias Springer replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
310199f368eSMatthias Springer return success();
311199f368eSMatthias Springer }
312199f368eSMatthias Springer };
313199f368eSMatthias Springer
3145523c145SMatthias Springer } // namespace
3155523c145SMatthias Springer } // namespace vector
3165523c145SMatthias Springer } // namespace mlir
3175523c145SMatthias Springer
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)3185523c145SMatthias Springer void mlir::vector::registerBufferizableOpInterfaceExternalModels(
3195523c145SMatthias Springer DialectRegistry ®istry) {
32077eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
32177eee579SRiver Riddle TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
32277eee579SRiver Riddle TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
32366c2b768SJerry Wu GatherOp::attachInterface<GatherOpInterface>(*ctx);
324199f368eSMatthias Springer MaskOp::attachInterface<MaskOpInterface>(*ctx);
325199f368eSMatthias Springer YieldOp::attachInterface<YieldOpInterface>(*ctx);
32677eee579SRiver Riddle });
3275523c145SMatthias Springer }
328