xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
249e37000SMatthias Springer //
349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
649e37000SMatthias Springer //
749e37000SMatthias Springer //===----------------------------------------------------------------------===//
849e37000SMatthias Springer 
949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
108143307bSMatthias Springer 
11c37ed776SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
12abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1349e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
152d5edc64SMatthias Springer #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
16c1f0a15cSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
1749e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
1949e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
208143307bSMatthias Springer #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
219ee12f47SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
2298e838a8SMax191 #include "mlir/IR/BuiltinTypeInterfaces.h"
2349e37000SMatthias Springer #include "mlir/IR/Dialect.h"
2449e37000SMatthias Springer #include "mlir/IR/Operation.h"
2549e37000SMatthias Springer 
2649e37000SMatthias Springer using namespace mlir;
2749e37000SMatthias Springer using namespace mlir::bufferization;
2849e37000SMatthias Springer using namespace mlir::tensor;
2949e37000SMatthias Springer 
3049e37000SMatthias Springer namespace mlir {
3149e37000SMatthias Springer namespace tensor {
3249e37000SMatthias Springer namespace {
3349e37000SMatthias Springer 
3449e37000SMatthias Springer struct CastOpInterface
3549e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
3649e37000SMatthias Springer                                                     tensor::CastOp> {
3749e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
389597b16aSMatthias Springer                               const AnalysisState &state) const {
3949e37000SMatthias Springer     return false;
4049e37000SMatthias Springer   }
4149e37000SMatthias Springer 
4249e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
439597b16aSMatthias Springer                                const AnalysisState &state) const {
4449e37000SMatthias Springer     return false;
4549e37000SMatthias Springer   }
4649e37000SMatthias Springer 
47a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
489597b16aSMatthias Springer                                       const AnalysisState &state) const {
499fa6b350SMatthias Springer     return {{op->getResult(0), BufferRelation::Equivalent}};
5049e37000SMatthias Springer   }
5149e37000SMatthias Springer 
52b6ae3f88SMatthias Springer   FailureOr<BaseMemRefType>
53b6ae3f88SMatthias Springer   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
54878950b8SMatthias Springer                 SmallVector<Value> &invocationStack) const {
55b6ae3f88SMatthias Springer     auto castOp = cast<tensor::CastOp>(op);
56878950b8SMatthias Springer     auto maybeSrcBufferType = bufferization::getBufferType(
57878950b8SMatthias Springer         castOp.getSource(), options, invocationStack);
58b6ae3f88SMatthias Springer     if (failed(maybeSrcBufferType))
59b6ae3f88SMatthias Springer       return failure();
60b6ae3f88SMatthias Springer     Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
61b6ae3f88SMatthias Springer 
62b6ae3f88SMatthias Springer     // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
63b6ae3f88SMatthias Springer     // type in case the input is an unranked tensor type.
64b6ae3f88SMatthias Springer 
65b6ae3f88SMatthias Springer     // Case 1: Casting an unranked tensor
665550c821STres Popp     if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
67b6ae3f88SMatthias Springer       // When casting to a ranked tensor, we cannot infer any static offset or
68b6ae3f88SMatthias Springer       // strides from the source. Assume fully dynamic.
69b6ae3f88SMatthias Springer       return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
70b6ae3f88SMatthias Springer     }
71b6ae3f88SMatthias Springer 
72b6ae3f88SMatthias Springer     // Case 2: Casting to an unranked tensor type
735550c821STres Popp     if (isa<UnrankedTensorType>(castOp.getType())) {
74b6ae3f88SMatthias Springer       return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
75b6ae3f88SMatthias Springer     }
76b6ae3f88SMatthias Springer 
77b6ae3f88SMatthias Springer     // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
78b6ae3f88SMatthias Springer     // change.
795550c821STres Popp     auto rankedResultType = cast<RankedTensorType>(castOp.getType());
80b6ae3f88SMatthias Springer     return MemRefType::get(
81b6ae3f88SMatthias Springer         rankedResultType.getShape(), rankedResultType.getElementType(),
8268f58812STres Popp         llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
83b6ae3f88SMatthias Springer   }
84b6ae3f88SMatthias Springer 
8549e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
86b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
8749e37000SMatthias Springer     auto castOp = cast<tensor::CastOp>(op);
8849e37000SMatthias Springer 
8949e37000SMatthias Springer     // The result buffer still has the old (pre-cast) type.
905d50f51cSMatthias Springer     FailureOr<Value> resultBuffer =
915d50f51cSMatthias Springer         getBuffer(rewriter, castOp.getSource(), options);
925d50f51cSMatthias Springer     if (failed(resultBuffer))
935d50f51cSMatthias Springer       return failure();
9449e37000SMatthias Springer 
95b6ae3f88SMatthias Springer     // Compute the new type.
96b6ae3f88SMatthias Springer     auto resultMemRefType =
97b6ae3f88SMatthias Springer         bufferization::getBufferType(castOp.getResult(), options);
98b6ae3f88SMatthias Springer     if (failed(resultMemRefType))
99b6ae3f88SMatthias Springer       return failure();
1006cd7b655SKai Sasaki     if (resultBuffer->getType() == *resultMemRefType) {
1016cd7b655SKai Sasaki       // This cast is a no-op.
1026cd7b655SKai Sasaki       replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
1036cd7b655SKai Sasaki       return success();
1046cd7b655SKai Sasaki     }
10549e37000SMatthias Springer 
10649e37000SMatthias Springer     // Replace the op with a memref.cast.
1075d50f51cSMatthias Springer     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
108b6ae3f88SMatthias Springer                                              *resultMemRefType) &&
10949e37000SMatthias Springer            "CallOp::bufferize: cast incompatible");
110b6ae3f88SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CastOp>(
111b6ae3f88SMatthias Springer         rewriter, op, *resultMemRefType, *resultBuffer);
11249e37000SMatthias Springer 
11349e37000SMatthias Springer     return success();
11449e37000SMatthias Springer   }
11549e37000SMatthias Springer };
11649e37000SMatthias Springer 
117e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
118e6f69161SMatthias Springer struct CollapseShapeOpInterface
119e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
120e6f69161SMatthias Springer                                                     tensor::CollapseShapeOp> {
121e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1229597b16aSMatthias Springer                               const AnalysisState &state) const {
123ea71d2d0SMatthias Springer     // tensor.collapse_shape may reallocate, at which point the source buffer is
124ea71d2d0SMatthias Springer     // copied. I.e., there will be a memory read side effect on the bufferized
125ea71d2d0SMatthias Springer     // source. This function conservatively returns "true" because whether a
126ea71d2d0SMatthias Springer     // copy will be created or not is not known at this point.
127ea71d2d0SMatthias Springer     return true;
128e6f69161SMatthias Springer   }
129e6f69161SMatthias Springer 
130e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1319597b16aSMatthias Springer                                const AnalysisState &state) const {
132e6f69161SMatthias Springer     return false;
133e6f69161SMatthias Springer   }
134e6f69161SMatthias Springer 
135a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1369597b16aSMatthias Springer                                       const AnalysisState &state) const {
1371ac248e4SMatthias Springer     // TODO: CollapseShapeOp may allocate at runtime.
1389fa6b350SMatthias Springer     return {{op->getOpResult(0), BufferRelation::Equivalent}};
139e6f69161SMatthias Springer   }
140e6f69161SMatthias Springer 
14104ff6009SMatthias Springer   FailureOr<BaseMemRefType>
14204ff6009SMatthias Springer   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
143878950b8SMatthias Springer                 SmallVector<Value> &invocationStack) const {
14404ff6009SMatthias Springer     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
14504ff6009SMatthias Springer     auto maybeSrcBufferType = bufferization::getBufferType(
146878950b8SMatthias Springer         collapseShapeOp.getSrc(), options, invocationStack);
14704ff6009SMatthias Springer     if (failed(maybeSrcBufferType))
14804ff6009SMatthias Springer       return failure();
14968f58812STres Popp     auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
15004ff6009SMatthias Springer     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
15104ff6009SMatthias Springer         srcBufferType, collapseShapeOp.getReassociationIndices());
15204ff6009SMatthias Springer 
15304ff6009SMatthias Springer     if (!canBeCollapsed) {
15404ff6009SMatthias Springer       // If dims cannot be collapsed, this op bufferizes to a new allocation.
15504ff6009SMatthias Springer       RankedTensorType tensorResultType = collapseShapeOp.getResultType();
15604ff6009SMatthias Springer       return bufferization::getMemRefTypeWithStaticIdentityLayout(
1579bb63374SLei Zhang           tensorResultType, srcBufferType.getMemorySpace());
15804ff6009SMatthias Springer     }
15904ff6009SMatthias Springer 
16004ff6009SMatthias Springer     return memref::CollapseShapeOp::computeCollapsedType(
16104ff6009SMatthias Springer         srcBufferType, collapseShapeOp.getReassociationIndices());
16204ff6009SMatthias Springer   }
16304ff6009SMatthias Springer 
164e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
165b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
166e6f69161SMatthias Springer     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
16751df6238SMatthias Springer     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
1685d50f51cSMatthias Springer     FailureOr<Value> maybeBuffer =
1695d50f51cSMatthias Springer         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
1705d50f51cSMatthias Springer     if (failed(maybeBuffer))
1715d50f51cSMatthias Springer       return failure();
1725d50f51cSMatthias Springer     Value buffer = *maybeBuffer;
1735550c821STres Popp     auto bufferType = cast<MemRefType>(buffer.getType());
17451df6238SMatthias Springer 
17551df6238SMatthias Springer     if (tensorResultType.getRank() == 0) {
17651df6238SMatthias Springer       // 0-d collapses must go through a different op builder.
17773c0333dSMatthias Springer       MemRefType resultType;
17873c0333dSMatthias Springer 
17973c0333dSMatthias Springer       if (bufferType.getLayout().isIdentity()) {
18073c0333dSMatthias Springer         // Standard layout: result type has no offset.
18151df6238SMatthias Springer         MemRefLayoutAttrInterface layout;
18273c0333dSMatthias Springer         resultType = MemRefType::get({}, tensorResultType.getElementType(),
18351df6238SMatthias Springer                                      layout, bufferType.getMemorySpace());
18473c0333dSMatthias Springer       } else {
18573c0333dSMatthias Springer         // Source memref has a layout map: result type has the same offset as
18673c0333dSMatthias Springer         // the source type.
18773c0333dSMatthias Springer         SmallVector<int64_t> strides;
18873c0333dSMatthias Springer         int64_t offset;
189*6aaa8f25SMatthias Springer         if (failed(bufferType.getStridesAndOffset(strides, offset)))
19073c0333dSMatthias Springer           return failure();
19146b90a7bSAlex Zinenko         resultType = MemRefType::get(
19246b90a7bSAlex Zinenko             {}, tensorResultType.getElementType(),
19346b90a7bSAlex Zinenko             StridedLayoutAttr::get(op->getContext(), offset, {}),
19446b90a7bSAlex Zinenko             bufferType.getMemorySpace());
19573c0333dSMatthias Springer       }
19673c0333dSMatthias Springer 
197e6f69161SMatthias Springer       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
1988df54a6aSJacques Pienaar           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
199e6f69161SMatthias Springer       return success();
200e6f69161SMatthias Springer     }
20151df6238SMatthias Springer 
202d7a9bf91SMatthias Springer     // If the dims are not collapsible (due to an incompatible source layout
203d7a9bf91SMatthias Springer     // map), force an out-of-place bufferization, i.e., a buffer copy. This
204d7a9bf91SMatthias Springer     // newly allocated buffer will have no layout map and thus be collapsible.
205a74e5a89SAdrian Kuegel     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
206d7a9bf91SMatthias Springer         bufferType, collapseShapeOp.getReassociationIndices());
207b3ebe3beSMatthias Springer     if (!canBeCollapsed) {
208b3ebe3beSMatthias Springer       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
209b55d55ecSMatthias Springer       AnalysisState analysisState(options);
21045b995cdSMatthias Springer       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
2116bf043e7SMartin Erhart           rewriter, op->getLoc(), collapseShapeOp.getSrc(), options);
21245b995cdSMatthias Springer       if (failed(tensorAlloc))
21345b995cdSMatthias Springer         return failure();
214b3ebe3beSMatthias Springer       auto memrefType =
215b3ebe3beSMatthias Springer           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
216b3ebe3beSMatthias Springer                           collapseShapeOp.getSrcType().getElementType(),
2179bb63374SLei Zhang                           AffineMap(), bufferType.getMemorySpace());
218b3ebe3beSMatthias Springer       buffer = rewriter.create<bufferization::ToMemrefOp>(
21945b995cdSMatthias Springer           op->getLoc(), memrefType, *tensorAlloc);
220b3ebe3beSMatthias Springer     }
221d7a9bf91SMatthias Springer 
22251df6238SMatthias Springer     // Result type is inferred by the builder.
22351df6238SMatthias Springer     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
224b3ebe3beSMatthias Springer         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
22551df6238SMatthias Springer     return success();
22651df6238SMatthias Springer   }
227e6f69161SMatthias Springer };
228e6f69161SMatthias Springer 
22949e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim.
23049e37000SMatthias Springer struct DimOpInterface
23149e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
23249e37000SMatthias Springer                                                     tensor::DimOp> {
23349e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2349597b16aSMatthias Springer                               const AnalysisState &state) const {
235e5dc99e6SMatthias Springer     // The op reads the tensor's metadata but not its contents.
236e5dc99e6SMatthias Springer     return false;
23749e37000SMatthias Springer   }
23849e37000SMatthias Springer 
23949e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
2409597b16aSMatthias Springer                                const AnalysisState &state) const {
24149e37000SMatthias Springer     return false;
24249e37000SMatthias Springer   }
24349e37000SMatthias Springer 
244a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
2459597b16aSMatthias Springer                                       const AnalysisState &state) const {
246585a8a32SMatthias Springer     return {};
24749e37000SMatthias Springer   }
24849e37000SMatthias Springer 
24949e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
250b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
25149e37000SMatthias Springer     auto dimOp = cast<tensor::DimOp>(op);
2525d50f51cSMatthias Springer     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
2535d50f51cSMatthias Springer     if (failed(v))
2545d50f51cSMatthias Springer       return failure();
2555d50f51cSMatthias Springer     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
256136d746eSJacques Pienaar                                                 dimOp.getIndex());
25749e37000SMatthias Springer     return success();
25849e37000SMatthias Springer   }
25949e37000SMatthias Springer };
26049e37000SMatthias Springer 
261464dfebaSMatthias Springer /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
262be630f07SMatthias Springer struct EmptyOpInterface
263be630f07SMatthias Springer     : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
264be630f07SMatthias Springer                                                     tensor::EmptyOp> {
26558678d3bSMatthias Springer   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
26658678d3bSMatthias Springer 
267330372f2SMatthias Springer   bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
268330372f2SMatthias Springer                                      const AnalysisState &state) const {
269330372f2SMatthias Springer     // The returned tensor does not have specified contents.
270330372f2SMatthias Springer     return false;
271330372f2SMatthias Springer   }
272330372f2SMatthias Springer 
273be630f07SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
274be630f07SMatthias Springer                           const BufferizationOptions &options) const {
275464dfebaSMatthias Springer     auto emptyOp = cast<tensor::EmptyOp>(op);
276464dfebaSMatthias Springer 
277464dfebaSMatthias Springer     // Optimization: Fold away the op if it has no uses.
278ef4f5357SMatthias Springer     if (op->getUses().empty()) {
279ef4f5357SMatthias Springer       rewriter.eraseOp(op);
280ef4f5357SMatthias Springer       return success();
281ef4f5357SMatthias Springer     }
282ef4f5357SMatthias Springer 
283464dfebaSMatthias Springer     // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
284464dfebaSMatthias Springer     FailureOr<Value> allocTensor = allocateTensorForShapedValue(
285464dfebaSMatthias Springer         rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false);
286464dfebaSMatthias Springer     if (failed(allocTensor))
287464dfebaSMatthias Springer       return failure();
288464dfebaSMatthias Springer     rewriter.replaceOp(op, *allocTensor);
289464dfebaSMatthias Springer     return success();
290be630f07SMatthias Springer   }
291be630f07SMatthias Springer };
292be630f07SMatthias Springer 
293e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
294e6f69161SMatthias Springer struct ExpandShapeOpInterface
295e6f69161SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
296e6f69161SMatthias Springer                                                     tensor::ExpandShapeOp> {
297e6f69161SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
2989597b16aSMatthias Springer                               const AnalysisState &state) const {
299ea71d2d0SMatthias Springer     // In contrast to tensor.collapse_shape, this op can always be bufferized
300ea71d2d0SMatthias Springer     // without a copy.
301e6f69161SMatthias Springer     return false;
302e6f69161SMatthias Springer   }
303e6f69161SMatthias Springer 
304e6f69161SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3059597b16aSMatthias Springer                                const AnalysisState &state) const {
306e6f69161SMatthias Springer     return false;
307e6f69161SMatthias Springer   }
308e6f69161SMatthias Springer 
309a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
3109597b16aSMatthias Springer                                       const AnalysisState &state) const {
3119fa6b350SMatthias Springer     return {{op->getOpResult(0), BufferRelation::Equivalent}};
312e6f69161SMatthias Springer   }
313e6f69161SMatthias Springer 
31404ff6009SMatthias Springer   FailureOr<BaseMemRefType>
31504ff6009SMatthias Springer   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
316878950b8SMatthias Springer                 SmallVector<Value> &invocationStack) const {
31704ff6009SMatthias Springer     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
31804ff6009SMatthias Springer     auto maybeSrcBufferType = bufferization::getBufferType(
319878950b8SMatthias Springer         expandShapeOp.getSrc(), options, invocationStack);
32004ff6009SMatthias Springer     if (failed(maybeSrcBufferType))
32104ff6009SMatthias Springer       return failure();
32268f58812STres Popp     auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
32304ff6009SMatthias Springer     auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
32404ff6009SMatthias Springer         srcBufferType, expandShapeOp.getResultType().getShape(),
32504ff6009SMatthias Springer         expandShapeOp.getReassociationIndices());
32604ff6009SMatthias Springer     if (failed(maybeResultType))
32704ff6009SMatthias Springer       return failure();
32804ff6009SMatthias Springer     return *maybeResultType;
32904ff6009SMatthias Springer   }
33004ff6009SMatthias Springer 
331e6f69161SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
332b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
333e6f69161SMatthias Springer     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
33451df6238SMatthias Springer     auto tensorResultType = expandShapeOp.getResultType();
3355d50f51cSMatthias Springer     FailureOr<Value> buffer =
3365d50f51cSMatthias Springer         getBuffer(rewriter, expandShapeOp.getSrc(), options);
3375d50f51cSMatthias Springer     if (failed(buffer))
3385d50f51cSMatthias Springer       return failure();
33951df6238SMatthias Springer 
34051df6238SMatthias Springer     // Memref result type is inferred by the builder based on reassociation
34151df6238SMatthias Springer     // indices and result shape.
34297069a86SGaurav Shukla     // TODO: Instead of inferring the output shape argument of
34397069a86SGaurav Shukla     // memref.expand_shape op, use output_shape argument of tensor.expand_shape
34497069a86SGaurav Shukla     // op.
345e6f69161SMatthias Springer     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
3465d50f51cSMatthias Springer         rewriter, op, tensorResultType.getShape(), *buffer,
34751df6238SMatthias Springer         expandShapeOp.getReassociationIndices());
348e6f69161SMatthias Springer     return success();
349e6f69161SMatthias Springer   }
350e6f69161SMatthias Springer };
351e6f69161SMatthias Springer 
35249e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview.
35349e37000SMatthias Springer struct ExtractSliceOpInterface
35449e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
35549e37000SMatthias Springer                                                     tensor::ExtractSliceOp> {
35649e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
3579597b16aSMatthias Springer                               const AnalysisState &state) const {
35849e37000SMatthias Springer     return false;
35949e37000SMatthias Springer   }
36049e37000SMatthias Springer 
36149e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
3629597b16aSMatthias Springer                                const AnalysisState &state) const {
36349e37000SMatthias Springer     return false;
36449e37000SMatthias Springer   }
36549e37000SMatthias Springer 
366a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
3679597b16aSMatthias Springer                                       const AnalysisState &state) const {
3689fa6b350SMatthias Springer     return {{op->getOpResult(0), BufferRelation::Unknown}};
36949e37000SMatthias Springer   }
37049e37000SMatthias Springer 
37149e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
372b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
37349e37000SMatthias Springer     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
3746c3c5f80SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
3756c3c5f80SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
3766c3c5f80SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
37749e37000SMatthias Springer     Location loc = extractSliceOp.getLoc();
378d7a9bf91SMatthias Springer 
3796c3c5f80SMatthias Springer     // Get source buffer.
3805d50f51cSMatthias Springer     FailureOr<Value> srcMemref =
3815d50f51cSMatthias Springer         getBuffer(rewriter, extractSliceOp.getSource(), options);
3825d50f51cSMatthias Springer     if (failed(srcMemref))
3835d50f51cSMatthias Springer       return failure();
38449e37000SMatthias Springer 
3856c3c5f80SMatthias Springer     // Take a subview of the source buffer.
386111c9196SMatthias Springer     auto resultMemrefType =
387123c4b02SMatthias Springer         bufferization::getBufferType(extractSliceOp.getResult(), options);
388111c9196SMatthias Springer     if (failed(resultMemrefType))
389111c9196SMatthias Springer       return failure();
39049e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
391d69e9491Sdonald chen         loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
392d69e9491Sdonald chen         mixedOffsets, mixedSizes, mixedStrides);
39349e37000SMatthias Springer 
39449e37000SMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, subView);
39549e37000SMatthias Springer     return success();
39649e37000SMatthias Springer   }
397111c9196SMatthias Springer 
398111c9196SMatthias Springer   FailureOr<BaseMemRefType>
399123c4b02SMatthias Springer   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
400878950b8SMatthias Springer                 SmallVector<Value> &invocationStack) const {
401111c9196SMatthias Springer     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
402111c9196SMatthias Springer     assert(value == extractSliceOp.getResult() && "invalid value");
403123c4b02SMatthias Springer     auto srcMemrefType = bufferization::getBufferType(
404878950b8SMatthias Springer         extractSliceOp.getSource(), options, invocationStack);
405111c9196SMatthias Springer     if (failed(srcMemrefType))
406111c9196SMatthias Springer       return failure();
407111c9196SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
408111c9196SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
409111c9196SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
4105550c821STres Popp     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
411d69e9491Sdonald chen         extractSliceOp.getType().getShape(),
412d69e9491Sdonald chen         llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
413d69e9491Sdonald chen         mixedStrides));
414111c9196SMatthias Springer   }
41549e37000SMatthias Springer };
41649e37000SMatthias Springer 
41749e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load.
41849e37000SMatthias Springer struct ExtractOpInterface
41949e37000SMatthias Springer     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
42049e37000SMatthias Springer                                                     tensor::ExtractOp> {
42149e37000SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
4229597b16aSMatthias Springer                               const AnalysisState &state) const {
42349e37000SMatthias Springer     return true;
42449e37000SMatthias Springer   }
42549e37000SMatthias Springer 
42649e37000SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
4279597b16aSMatthias Springer                                const AnalysisState &state) const {
42849e37000SMatthias Springer     return false;
42949e37000SMatthias Springer   }
43049e37000SMatthias Springer 
431a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
4329597b16aSMatthias Springer                                       const AnalysisState &state) const {
433585a8a32SMatthias Springer     return {};
43449e37000SMatthias Springer   }
43549e37000SMatthias Springer 
43649e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
437b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
43849e37000SMatthias Springer     auto extractOp = cast<tensor::ExtractOp>(op);
4395d50f51cSMatthias Springer     FailureOr<Value> srcMemref =
4405d50f51cSMatthias Springer         getBuffer(rewriter, extractOp.getTensor(), options);
4415d50f51cSMatthias Springer     if (failed(srcMemref))
4425d50f51cSMatthias Springer       return failure();
4435d50f51cSMatthias Springer     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
444136d746eSJacques Pienaar                                                  extractOp.getIndices());
44549e37000SMatthias Springer     return success();
44649e37000SMatthias Springer   }
44749e37000SMatthias Springer };
44849e37000SMatthias Springer 
449d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while
450d581c94dSMatthias Springer // iterating over op.elements().
451d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim,
452d581c94dSMatthias Springer                          Value buffer, ArrayRef<int64_t> shape,
453d581c94dSMatthias Springer                          ArrayRef<Value> constants,
454d581c94dSMatthias Springer                          OperandRange::iterator &elementIt,
455d581c94dSMatthias Springer                          SmallVectorImpl<Value> &indices) {
456d581c94dSMatthias Springer   if (dim == static_cast<int>(shape.size()) - 1) {
457d581c94dSMatthias Springer     for (int i = 0; i < shape.back(); ++i) {
458d581c94dSMatthias Springer       indices.back() = constants[i];
459d581c94dSMatthias Springer       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
460d581c94dSMatthias Springer       ++elementIt;
461d581c94dSMatthias Springer     }
462d581c94dSMatthias Springer     return;
463d581c94dSMatthias Springer   }
464d581c94dSMatthias Springer   for (int i = 0; i < shape[dim]; ++i) {
465d581c94dSMatthias Springer     indices[dim] = constants[i];
466d581c94dSMatthias Springer     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
467d581c94dSMatthias Springer                  indices);
468d581c94dSMatthias Springer   }
469d581c94dSMatthias Springer }
470d581c94dSMatthias Springer 
471d581c94dSMatthias Springer /// Bufferization of tensor.from_elements.
472d581c94dSMatthias Springer struct FromElementsOpInterface
473d581c94dSMatthias Springer     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
474d581c94dSMatthias Springer                                                     tensor::FromElementsOp> {
475664ffa46SMatthias Springer 
476a02ad6c1SMatthias Springer   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
477664ffa46SMatthias Springer 
478d581c94dSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
479b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
480d581c94dSMatthias Springer     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
481067d2779Sian Bearman     auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
482d581c94dSMatthias Springer 
483d581c94dSMatthias Springer     // Allocate a buffer for the result.
484d581c94dSMatthias Springer     Location loc = op->getLoc();
485d581c94dSMatthias Springer     auto shape = tensorType.getShape();
486b3ebe3beSMatthias Springer     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
4876bf043e7SMartin Erhart     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
4886bf043e7SMartin Erhart         rewriter, loc, fromElementsOp.getResult(), options,
489b3ebe3beSMatthias Springer         /*copy=*/false);
49045b995cdSMatthias Springer     if (failed(tensorAlloc))
49145b995cdSMatthias Springer       return failure();
492ced2fc78SChristopher Bate     FailureOr<BaseMemRefType> memrefType =
493ced2fc78SChristopher Bate         bufferization::getBufferType(*tensorAlloc, options);
494ced2fc78SChristopher Bate     if (failed(memrefType))
495ced2fc78SChristopher Bate       return failure();
496b3ebe3beSMatthias Springer     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
497ced2fc78SChristopher Bate         op->getLoc(), *memrefType, *tensorAlloc);
498d581c94dSMatthias Springer 
499d581c94dSMatthias Springer     // Case: tensor<0xelem_type>.
5008df54a6aSJacques Pienaar     if (fromElementsOp.getElements().empty()) {
501d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
502d581c94dSMatthias Springer       return success();
503d581c94dSMatthias Springer     }
504d581c94dSMatthias Springer 
505d581c94dSMatthias Springer     // Case: tensor<elem_type>.
506d581c94dSMatthias Springer     if (shape.empty()) {
5078df54a6aSJacques Pienaar       rewriter.create<memref::StoreOp>(
5088df54a6aSJacques Pienaar           loc, fromElementsOp.getElements().front(), buffer);
509d581c94dSMatthias Springer       replaceOpWithBufferizedValues(rewriter, op, buffer);
510d581c94dSMatthias Springer       return success();
511d581c94dSMatthias Springer     }
512d581c94dSMatthias Springer 
513d581c94dSMatthias Springer     // Create constants for the range of possible indices [0, max{shape_i}).
514fab2bb8bSJustin Lebar     auto maxDim = *llvm::max_element(shape);
515d581c94dSMatthias Springer     SmallVector<Value, 2> constants;
516d581c94dSMatthias Springer     constants.reserve(maxDim);
517d581c94dSMatthias Springer     for (int i = 0; i < maxDim; ++i)
518d581c94dSMatthias Springer       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
519d581c94dSMatthias Springer 
520d581c94dSMatthias Springer     // Traverse all `elements` and create `memref.store` ops.
5218df54a6aSJacques Pienaar     auto elementIt = fromElementsOp.getElements().begin();
522d581c94dSMatthias Springer     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
523d581c94dSMatthias Springer     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
524d581c94dSMatthias Springer                  indices);
525d581c94dSMatthias Springer 
526d581c94dSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, buffer);
527664ffa46SMatthias Springer 
528d581c94dSMatthias Springer     return success();
529d581c94dSMatthias Springer   }
530d581c94dSMatthias Springer };
531d581c94dSMatthias Springer 
532c1f0a15cSMatthias Springer /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
533c1f0a15cSMatthias Springer /// Such ops are lowered to linalg.map with the given tensor as a destination.
534c1f0a15cSMatthias Springer ///
535c1f0a15cSMatthias Springer /// Example:
536c1f0a15cSMatthias Springer /// ```
537c1f0a15cSMatthias Springer /// %r = tensor.generate %x, %y {
538c1f0a15cSMatthias Springer ///   ^bb0(%arg0: index, %arg1: index):
539c1f0a15cSMatthias Springer ///   %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
540c1f0a15cSMatthias Springer ///   tensor.yield %0 : index
541c1f0a15cSMatthias Springer /// } : tensor<?x?xindex>
542c1f0a15cSMatthias Springer /// ```
543c1f0a15cSMatthias Springer ///
544c1f0a15cSMatthias Springer /// Is lowered to:
545c1f0a15cSMatthias Springer /// ```
546c1f0a15cSMatthias Springer /// linalg.map ins() outs(%dest) {
547c1f0a15cSMatthias Springer ///   %d0 = linalg.index 0 : index
548c1f0a15cSMatthias Springer ///   %d1 = linalg.index 1 : index
549c1f0a15cSMatthias Springer ///   %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
550c1f0a15cSMatthias Springer ///   linalg.yield %0 : index
551c1f0a15cSMatthias Springer /// }
552c1f0a15cSMatthias Springer /// ```
553c1f0a15cSMatthias Springer static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
554c1f0a15cSMatthias Springer                                      Value tensorDestination,
555c1f0a15cSMatthias Springer                                      ValueRange dynamicSizes,
556c1f0a15cSMatthias Springer                                      Region &generateBody) {
557c1f0a15cSMatthias Springer   assert(generateBody.hasOneBlock() && "expected body with single block");
5585550c821STres Popp   auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
559c1f0a15cSMatthias Springer   assert(generateBody.getNumArguments() == tensorType.getRank() &&
560c1f0a15cSMatthias Springer          "rank mismatch");
561c1f0a15cSMatthias Springer 
562c1f0a15cSMatthias Springer   // Create linalg::MapOp.
563c1f0a15cSMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
564c1f0a15cSMatthias Springer   auto linalgOp =
565c1f0a15cSMatthias Springer       rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
566c1f0a15cSMatthias Springer                                      /*init=*/tensorDestination);
567c1f0a15cSMatthias Springer   Block &linalgBody = linalgOp.getMapper().emplaceBlock();
568c1f0a15cSMatthias Springer 
569c1f0a15cSMatthias Springer   // Create linalg::IndexOps.
570c1f0a15cSMatthias Springer   rewriter.setInsertionPointToStart(&linalgBody);
571c1f0a15cSMatthias Springer   SmallVector<Value> indices;
572c1f0a15cSMatthias Springer   for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
573c1f0a15cSMatthias Springer     indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim));
574c1f0a15cSMatthias Springer 
575c1f0a15cSMatthias Springer   // Move over body.
576c1f0a15cSMatthias Springer   rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices);
577c1f0a15cSMatthias Springer   auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
578c1f0a15cSMatthias Springer   rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
579c1f0a15cSMatthias Springer 
580c1f0a15cSMatthias Springer   return linalgOp.getResult()[0];
581c1f0a15cSMatthias Springer }
582c1f0a15cSMatthias Springer 
58371bbb78bSMatthias Springer /// Bufferization of tensor.generate.
58471bbb78bSMatthias Springer struct GenerateOpInterface
58571bbb78bSMatthias Springer     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
58671bbb78bSMatthias Springer                                                     tensor::GenerateOp> {
587664ffa46SMatthias Springer 
588a02ad6c1SMatthias Springer   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
589664ffa46SMatthias Springer 
59071bbb78bSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
591b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
59271bbb78bSMatthias Springer     auto generateOp = cast<tensor::GenerateOp>(op);
593c0b0b6a0SMatthias Springer 
594067d2779Sian Bearman     auto type = generateOp.getResult().getType();
595067d2779Sian Bearman 
596c0b0b6a0SMatthias Springer     // TODO: Implement memory space for this op.
597067d2779Sian Bearman     if (options.defaultMemorySpaceFn(type) != Attribute())
598c0b0b6a0SMatthias Springer       return op->emitError("memory space not implemented yet");
599c0b0b6a0SMatthias Springer 
60071bbb78bSMatthias Springer     // Allocate memory.
60171bbb78bSMatthias Springer     Location loc = op->getLoc();
6026bf043e7SMartin Erhart     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
6036bf043e7SMartin Erhart         rewriter, loc, generateOp.getResult(), options,
604b3ebe3beSMatthias Springer         /*copy=*/false);
60545b995cdSMatthias Springer     if (failed(tensorAlloc))
60645b995cdSMatthias Springer       return failure();
60771bbb78bSMatthias Springer 
608c1f0a15cSMatthias Springer     Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
609c1f0a15cSMatthias Springer                                            generateOp.getDynamicExtents(),
610c1f0a15cSMatthias Springer                                            generateOp.getBody());
611c1f0a15cSMatthias Springer     rewriter.replaceOp(generateOp, result);
612664ffa46SMatthias Springer 
61371bbb78bSMatthias Springer     return success();
61471bbb78bSMatthias Springer   }
61571bbb78bSMatthias Springer };
61671bbb78bSMatthias Springer 
61749e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store.
6182d5edc64SMatthias Springer ///
6192d5edc64SMatthias Springer /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
6202d5edc64SMatthias Springer /// implementations for DestinationStyle ops.
62149e37000SMatthias Springer struct InsertOpInterface
6222d5edc64SMatthias Springer     : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
62349e37000SMatthias Springer                                                      tensor::InsertOp> {
62449e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
625b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
62649e37000SMatthias Springer     auto insertOp = cast<tensor::InsertOp>(op);
6275d50f51cSMatthias Springer     FailureOr<Value> destMemref =
6285d50f51cSMatthias Springer         getBuffer(rewriter, insertOp.getDest(), options);
6295d50f51cSMatthias Springer     if (failed(destMemref))
6305d50f51cSMatthias Springer       return failure();
6318df54a6aSJacques Pienaar     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
6325d50f51cSMatthias Springer                                      *destMemref, insertOp.getIndices());
6335d50f51cSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
63449e37000SMatthias Springer     return success();
63549e37000SMatthias Springer   }
63649e37000SMatthias Springer };
63749e37000SMatthias Springer 
63898e838a8SMax191 template <typename InsertOpTy>
63998e838a8SMax191 static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
64098e838a8SMax191                                       OpOperand &opOperand) {
64113593dc9SMatthias Springer   // The source is always read.
64255585043SMatthias Springer   if (opOperand == insertSliceOp.getSourceMutable())
64313593dc9SMatthias Springer     return true;
64413593dc9SMatthias Springer 
64513593dc9SMatthias Springer   // For the destination, it depends...
64655585043SMatthias Springer   assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
64713593dc9SMatthias Springer 
64813593dc9SMatthias Springer   // Dest is not read if it is entirely overwritten. E.g.:
64913593dc9SMatthias Springer   // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
65013593dc9SMatthias Springer   bool allOffsetsZero =
65198e838a8SMax191       llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
65298e838a8SMax191   RankedTensorType destType = insertSliceOp.getDestType();
65398e838a8SMax191   bool sizesMatchDestSizes =
65498e838a8SMax191       areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
65513593dc9SMatthias Springer   bool allStridesOne =
65698e838a8SMax191       areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
65713593dc9SMatthias Springer   return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
65813593dc9SMatthias Springer }
65913593dc9SMatthias Springer 
66098e838a8SMax191 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
66198e838a8SMax191 /// certain circumstances, this op can also be a no-op.
66298e838a8SMax191 ///
66398e838a8SMax191 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
66498e838a8SMax191 /// implementations for DestinationStyle ops.
66598e838a8SMax191 struct InsertSliceOpInterface
66698e838a8SMax191     : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
66798e838a8SMax191                                                      tensor::InsertSliceOp> {
66898e838a8SMax191   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
66998e838a8SMax191                               const AnalysisState &state) const {
67098e838a8SMax191     return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
67198e838a8SMax191                                      opOperand);
67298e838a8SMax191   }
67398e838a8SMax191 
67449e37000SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
675b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
67649e37000SMatthias Springer     // insert_slice ops arise from tiling and bufferizing them out-of-place is
67749e37000SMatthias Springer     // generally a deal breaker. When used with loops, this ends up cloning the
67849e37000SMatthias Springer     // whole tensor on every single iteration and is a symptom of a
67949e37000SMatthias Springer     // catastrophically bad scheduling decision.
68049e37000SMatthias Springer     // TODO: be very loud about it or even consider failing the pass.
68149e37000SMatthias Springer     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
6826c3c5f80SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
6836c3c5f80SMatthias Springer     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
6846c3c5f80SMatthias Springer     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
68549e37000SMatthias Springer     Location loc = insertSliceOp.getLoc();
6866c3c5f80SMatthias Springer 
6876c3c5f80SMatthias Springer     // Get destination buffer.
6885d50f51cSMatthias Springer     FailureOr<Value> dstMemref =
6895d50f51cSMatthias Springer         getBuffer(rewriter, insertSliceOp.getDest(), options);
6905d50f51cSMatthias Springer     if (failed(dstMemref))
6915d50f51cSMatthias Springer       return failure();
69249e37000SMatthias Springer 
6936c3c5f80SMatthias Springer     // Take a subview of the destination buffer.
6945550c821STres Popp     auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
69549e37000SMatthias Springer     auto subviewMemRefType =
6965550c821STres Popp         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
6976c3c5f80SMatthias Springer             insertSliceOp.getSourceType().getShape(), dstMemrefType,
6985550c821STres Popp             mixedOffsets, mixedSizes, mixedStrides));
69949e37000SMatthias Springer     Value subView = rewriter.create<memref::SubViewOp>(
7005d50f51cSMatthias Springer         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
70149e37000SMatthias Springer         mixedStrides);
70249e37000SMatthias Springer 
70349e37000SMatthias Springer     // Copy tensor. If this tensor.insert_slice has a matching
70449e37000SMatthias Springer     // tensor.extract_slice, the copy operation will eventually fold away.
7055d50f51cSMatthias Springer     FailureOr<Value> srcMemref =
7065d50f51cSMatthias Springer         getBuffer(rewriter, insertSliceOp.getSource(), options);
7075d50f51cSMatthias Springer     if (failed(srcMemref))
7085d50f51cSMatthias Springer       return failure();
7095d50f51cSMatthias Springer     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
71049e37000SMatthias Springer       return failure();
71149e37000SMatthias Springer 
7125d50f51cSMatthias Springer     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
71349e37000SMatthias Springer     return success();
71449e37000SMatthias Springer   }
71549e37000SMatthias Springer };
71649e37000SMatthias Springer 
71709dfb441SMatthias Springer /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
71809dfb441SMatthias Springer /// linalg.map + insert_slice.
7199ee12f47SMatthias Springer /// For best performance, vectorize before bufferization (better performance in
7209ee12f47SMatthias Springer /// case of padding with a constant).
7219ee12f47SMatthias Springer struct PadOpInterface
7229ee12f47SMatthias Springer     : public BufferizableOpInterface::ExternalModel<PadOpInterface,
7239ee12f47SMatthias Springer                                                     tensor::PadOp> {
724a02ad6c1SMatthias Springer   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
7259ee12f47SMatthias Springer 
7269ee12f47SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
7279ee12f47SMatthias Springer                               const AnalysisState &state) const {
7289ee12f47SMatthias Springer     return true;
7299ee12f47SMatthias Springer   }
7309ee12f47SMatthias Springer 
7319ee12f47SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
7329ee12f47SMatthias Springer                                const AnalysisState &state) const {
7339ee12f47SMatthias Springer     return false;
7349ee12f47SMatthias Springer   }
7359ee12f47SMatthias Springer 
736a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
7379ee12f47SMatthias Springer                                       const AnalysisState &state) const {
7389ee12f47SMatthias Springer     return {};
7399ee12f47SMatthias Springer   }
7409ee12f47SMatthias Springer 
74109dfb441SMatthias Springer   FailureOr<BaseMemRefType>
74209dfb441SMatthias Springer   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
743878950b8SMatthias Springer                 SmallVector<Value> &invocationStack) const {
74409dfb441SMatthias Springer     // Infer memory space from the source tensor.
74509dfb441SMatthias Springer     auto padOp = cast<tensor::PadOp>(op);
746878950b8SMatthias Springer     auto maybeSrcBufferType = bufferization::getBufferType(
747878950b8SMatthias Springer         padOp.getSource(), options, invocationStack);
74809dfb441SMatthias Springer     if (failed(maybeSrcBufferType))
74909dfb441SMatthias Springer       return failure();
75009dfb441SMatthias Springer     MemRefLayoutAttrInterface layout;
75109dfb441SMatthias Springer     return MemRefType::get(padOp.getResultType().getShape(),
75209dfb441SMatthias Springer                            padOp.getResultType().getElementType(), layout,
75309dfb441SMatthias Springer                            maybeSrcBufferType->getMemorySpace());
75409dfb441SMatthias Springer   }
75509dfb441SMatthias Springer 
7569ee12f47SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
7579ee12f47SMatthias Springer                           const BufferizationOptions &options) const {
7589ee12f47SMatthias Springer     auto padOp = cast<tensor::PadOp>(op);
7599ee12f47SMatthias Springer     Location loc = padOp.getLoc();
7609ee12f47SMatthias Springer     RankedTensorType resultType = padOp.getResultType();
7619ee12f47SMatthias Springer     RankedTensorType srcType = padOp.getSourceType();
7629ee12f47SMatthias Springer 
7639ee12f47SMatthias Springer     auto toValue = [&](OpFoldResult ofr) {
764fecf1397SKazu Hirata       if (auto value = dyn_cast<Value>(ofr))
765fecf1397SKazu Hirata         return value;
7669ee12f47SMatthias Springer       return rewriter
7679ee12f47SMatthias Springer           .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr))
7689ee12f47SMatthias Springer           .getResult();
7699ee12f47SMatthias Springer     };
7709ee12f47SMatthias Springer 
7719ee12f47SMatthias Springer     // Compute dynamic result dimensions.
7729ee12f47SMatthias Springer     SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
7739ee12f47SMatthias Springer     SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
7749ee12f47SMatthias Springer     SmallVector<Value> dynamicSizes;
7759ee12f47SMatthias Springer     for (int64_t i = 0; i < resultType.getRank(); ++i) {
7769ee12f47SMatthias Springer       if (!resultType.isDynamicDim(i))
7779ee12f47SMatthias Springer         continue;
7789ee12f47SMatthias Springer       Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
7799ee12f47SMatthias Springer       Value lowPad = toValue(mixedLowPad[i]);
7809ee12f47SMatthias Springer       Value highPad = toValue(mixedHighPad[i]);
781c37ed776SMatthias Springer       AffineExpr s0, s1, s2;
782c37ed776SMatthias Springer       bindSymbols(op->getContext(), s0, s1, s2);
783c37ed776SMatthias Springer       AffineExpr sumExpr = s0 + s1 + s2;
7844c48f016SMatthias Springer       Value sum = rewriter.create<affine::AffineApplyOp>(
785c37ed776SMatthias Springer           loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
786c37ed776SMatthias Springer       dynamicSizes.push_back(sum);
7879ee12f47SMatthias Springer     }
7889ee12f47SMatthias Springer 
78909dfb441SMatthias Springer     // Allocate a buffer for the padded result.
79009dfb441SMatthias Springer     FailureOr<Value> tensorAlloc =
7916bf043e7SMartin Erhart         allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options,
79209dfb441SMatthias Springer                                      /*copy=*/false);
79309dfb441SMatthias Springer     if (failed(tensorAlloc))
79409dfb441SMatthias Springer       return failure();
79509dfb441SMatthias Springer 
79609dfb441SMatthias Springer     // tensor::PadOp is like tensor::GenerateOp: The only difference is that
79709dfb441SMatthias Springer     // only a part of the generated tensor is needed. For simplicity, we reuse
79809dfb441SMatthias Springer     // the same functionality here.
79909dfb441SMatthias Springer     Value filledBuffer = lowerGenerateLikeOpBody(
80009dfb441SMatthias Springer         rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
8019ee12f47SMatthias Springer 
8029ee12f47SMatthias Springer     // Create tensor::InsertSliceOp.
803ba95bf76SMatthias Springer     SmallVector<OpFoldResult> sliceSizes =
804ba95bf76SMatthias Springer         getMixedSizes(rewriter, loc, padOp.getSource());
805ba95bf76SMatthias Springer     SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
806ba95bf76SMatthias Springer                                            rewriter.getIndexAttr(1));
8079ee12f47SMatthias Springer     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
80809dfb441SMatthias Springer         padOp, padOp.getSource(), filledBuffer,
8099ee12f47SMatthias Springer         /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
8109ee12f47SMatthias Springer 
8119ee12f47SMatthias Springer     return success();
8129ee12f47SMatthias Springer   }
8139ee12f47SMatthias Springer };
8149ee12f47SMatthias Springer 
815fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank.
816fc08d1c2SMatthias Springer struct RankOpInterface
817fc08d1c2SMatthias Springer     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
818fc08d1c2SMatthias Springer                                                     tensor::RankOp> {
819fc08d1c2SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
8209597b16aSMatthias Springer                               const AnalysisState &state) const {
821e5dc99e6SMatthias Springer     // The op reads the tensor's metadata but not its contents.
822e5dc99e6SMatthias Springer     return false;
823fc08d1c2SMatthias Springer   }
824fc08d1c2SMatthias Springer 
825fc08d1c2SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
8269597b16aSMatthias Springer                                const AnalysisState &state) const {
827fc08d1c2SMatthias Springer     return false;
828fc08d1c2SMatthias Springer   }
829fc08d1c2SMatthias Springer 
830a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
8319597b16aSMatthias Springer                                       const AnalysisState &state) const {
832585a8a32SMatthias Springer     return {};
833fc08d1c2SMatthias Springer   }
834fc08d1c2SMatthias Springer 
835fc08d1c2SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
836b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
837fc08d1c2SMatthias Springer     auto rankOp = cast<tensor::RankOp>(op);
8385d50f51cSMatthias Springer     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
8395d50f51cSMatthias Springer     if (failed(v))
8405d50f51cSMatthias Springer       return failure();
841fc08d1c2SMatthias Springer     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
8425d50f51cSMatthias Springer                                                  *v);
843fc08d1c2SMatthias Springer     return success();
844fc08d1c2SMatthias Springer   }
845fc08d1c2SMatthias Springer };
846fc08d1c2SMatthias Springer 
847e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape.
848e287d647SAshay Rane struct ReshapeOpInterface
849e287d647SAshay Rane     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
850e287d647SAshay Rane                                                     tensor::ReshapeOp> {
851e287d647SAshay Rane   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
852e287d647SAshay Rane                               const AnalysisState &state) const {
853ea71d2d0SMatthias Springer     // Depending on the layout map, the source buffer may have to be copied.
8540f952cfeSMatthias Springer     auto reshapeOp = cast<tensor::ReshapeOp>(op);
85555585043SMatthias Springer     return opOperand == reshapeOp.getShapeMutable();
856e287d647SAshay Rane   }
857e287d647SAshay Rane 
858e287d647SAshay Rane   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
859e287d647SAshay Rane                                const AnalysisState &state) const {
860e287d647SAshay Rane     return false;
861e287d647SAshay Rane   }
862e287d647SAshay Rane 
863a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
864e287d647SAshay Rane                                       const AnalysisState &state) const {
8659fa6b350SMatthias Springer     return {{op->getOpResult(0), BufferRelation::Equivalent}};
866e287d647SAshay Rane   }
867e287d647SAshay Rane 
868e287d647SAshay Rane   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
869b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
870e287d647SAshay Rane     auto reshapeOp = cast<tensor::ReshapeOp>(op);
8715d50f51cSMatthias Springer     FailureOr<Value> srcBuffer =
8725d50f51cSMatthias Springer         getBuffer(rewriter, reshapeOp.getSource(), options);
8735d50f51cSMatthias Springer     FailureOr<Value> shapeBuffer =
8745d50f51cSMatthias Springer         getBuffer(rewriter, reshapeOp.getShape(), options);
8755d50f51cSMatthias Springer     if (failed(srcBuffer) || failed(shapeBuffer))
8765d50f51cSMatthias Springer       return failure();
8779dbb8eefSIngo Müller     auto maybeResultMemRefType =
8789dbb8eefSIngo Müller         bufferization::getBufferType(reshapeOp.getResult(), options);
8799dbb8eefSIngo Müller     if (failed(maybeResultMemRefType))
8809dbb8eefSIngo Müller       return failure();
8810a0c7e89SSpenser Bauman 
8820a0c7e89SSpenser Bauman     // memref.reshape requires the source buffer to have an identity layout.
883ea71d2d0SMatthias Springer     // If the source memref does not have an identity layout, copy the source
8840a0c7e89SSpenser Bauman     // into a new buffer with an identity layout.
8850a0c7e89SSpenser Bauman     auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
8860a0c7e89SSpenser Bauman     if (srcType && !srcType.getLayout().isIdentity()) {
887ea71d2d0SMatthias Springer       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
888ea71d2d0SMatthias Springer           rewriter, op->getLoc(), reshapeOp.getSource(), options);
889ea71d2d0SMatthias Springer       if (failed(tensorAlloc))
890ea71d2d0SMatthias Springer         return failure();
891ea71d2d0SMatthias Springer       auto memrefType = MemRefType::get(
892ea71d2d0SMatthias Springer           srcType.getShape(), srcType.getElementType(), AffineMap(),
893ea71d2d0SMatthias Springer           cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
8940a0c7e89SSpenser Bauman       srcBuffer = rewriter
895ea71d2d0SMatthias Springer                       .create<bufferization::ToMemrefOp>(
896ea71d2d0SMatthias Springer                           op->getLoc(), memrefType, *tensorAlloc)
8970a0c7e89SSpenser Bauman                       .getResult();
8980a0c7e89SSpenser Bauman     }
8990a0c7e89SSpenser Bauman 
900e287d647SAshay Rane     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
9019dbb8eefSIngo Müller         rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
902e287d647SAshay Rane     return success();
903e287d647SAshay Rane   }
9049dbb8eefSIngo Müller 
9059dbb8eefSIngo Müller   FailureOr<BaseMemRefType>
9069dbb8eefSIngo Müller   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
907878950b8SMatthias Springer                 SmallVector<Value> &invocationStack) const {
9089dbb8eefSIngo Müller     auto reshapeOp = cast<tensor::ReshapeOp>(op);
9099dbb8eefSIngo Müller     assert(value == reshapeOp.getResult() && "unexpected value provided");
9109dbb8eefSIngo Müller     auto maybeSourceBufferType = bufferization::getBufferType(
911878950b8SMatthias Springer         reshapeOp.getSource(), options, invocationStack);
9129dbb8eefSIngo Müller     if (failed(maybeSourceBufferType))
9139dbb8eefSIngo Müller       return failure();
9149dbb8eefSIngo Müller     return getMemRefTypeWithStaticIdentityLayout(
9159dbb8eefSIngo Müller         reshapeOp.getResult().getType(),
9169dbb8eefSIngo Müller         cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
9179dbb8eefSIngo Müller   }
918e287d647SAshay Rane };
919e287d647SAshay Rane 
9207fbf55c9SNicolas Vasilache /// Analysis of ParallelInsertSliceOp.
9217fbf55c9SNicolas Vasilache struct ParallelInsertSliceOpInterface
9227fbf55c9SNicolas Vasilache     : public BufferizableOpInterface::ExternalModel<
9237fbf55c9SNicolas Vasilache           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
924a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
9257fbf55c9SNicolas Vasilache                                       const AnalysisState &state) const {
9267fbf55c9SNicolas Vasilache     return {};
9277fbf55c9SNicolas Vasilache   }
9287fbf55c9SNicolas Vasilache 
9297fbf55c9SNicolas Vasilache   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
9307fbf55c9SNicolas Vasilache                               const AnalysisState &state) const {
93198e838a8SMax191     return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
93298e838a8SMax191                                      opOperand);
9337fbf55c9SNicolas Vasilache   }
9347fbf55c9SNicolas Vasilache 
9357fbf55c9SNicolas Vasilache   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
9367fbf55c9SNicolas Vasilache                                const AnalysisState &state) const {
9370f952cfeSMatthias Springer     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
93855585043SMatthias Springer     return opOperand == parallelInsertSliceOp.getDestMutable();
9397fbf55c9SNicolas Vasilache   }
9407fbf55c9SNicolas Vasilache 
9417fbf55c9SNicolas Vasilache   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
9427fbf55c9SNicolas Vasilache                           const BufferizationOptions &options) const {
9437fbf55c9SNicolas Vasilache     OpBuilder::InsertionGuard g(rewriter);
9447fbf55c9SNicolas Vasilache     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
9457fbf55c9SNicolas Vasilache     ParallelCombiningOpInterface parallelCombiningParent =
9467fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getParallelCombiningParent();
9477fbf55c9SNicolas Vasilache 
9484cd73620SMatthias Springer     // Bufferize the op outside of the parallel combining terminator.
9494cd73620SMatthias Springer     rewriter.setInsertionPoint(parallelCombiningParent);
9504cd73620SMatthias Springer 
9514cd73620SMatthias Springer     // Get source and destination buffers.
9527fbf55c9SNicolas Vasilache     FailureOr<Value> destBuffer =
9537fbf55c9SNicolas Vasilache         getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
9547fbf55c9SNicolas Vasilache     if (failed(destBuffer))
9557fbf55c9SNicolas Vasilache       return failure();
9567fbf55c9SNicolas Vasilache     FailureOr<Value> srcBuffer =
9577fbf55c9SNicolas Vasilache         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
9587fbf55c9SNicolas Vasilache     if (failed(srcBuffer))
9597fbf55c9SNicolas Vasilache       return failure();
9606c3c5f80SMatthias Springer 
9616c3c5f80SMatthias Springer     // Take a subview of the destination buffer.
9625550c821STres Popp     auto destBufferType = cast<MemRefType>(destBuffer->getType());
9636c3c5f80SMatthias Springer     auto subviewMemRefType =
9645550c821STres Popp         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
9656c3c5f80SMatthias Springer             parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
9666c3c5f80SMatthias Springer             parallelInsertSliceOp.getMixedOffsets(),
9676c3c5f80SMatthias Springer             parallelInsertSliceOp.getMixedSizes(),
9685550c821STres Popp             parallelInsertSliceOp.getMixedStrides()));
9697fbf55c9SNicolas Vasilache     Value subview = rewriter.create<memref::SubViewOp>(
9706c3c5f80SMatthias Springer         parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
9717fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getMixedOffsets(),
9727fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getMixedSizes(),
9737fbf55c9SNicolas Vasilache         parallelInsertSliceOp.getMixedStrides());
9746c3c5f80SMatthias Springer 
9757fbf55c9SNicolas Vasilache     // This memcpy will fold away if everything bufferizes in-place.
9767fbf55c9SNicolas Vasilache     if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
9777fbf55c9SNicolas Vasilache                                     *srcBuffer, subview)))
9787fbf55c9SNicolas Vasilache       return failure();
9797fbf55c9SNicolas Vasilache 
9807c06f631SMatthias Springer     // In case the source was allocated in the same block, make sure that the
9817c06f631SMatthias Springer     // deallocation op (if any) appears after the memcpy. By default, deallocs
9827c06f631SMatthias Springer     // are placed before the terminator, but this does not work for ForallOp
9837c06f631SMatthias Springer     // because the terminator does more than just yielding a value.
9847c06f631SMatthias Springer     //
9857c06f631SMatthias Springer     // Note: This is not a problem for the destination buffer because these are
9867c06f631SMatthias Springer     // assumed to always bufferize in-place.
9877c06f631SMatthias Springer     for (Operation *user : srcBuffer->getUsers()) {
9887c06f631SMatthias Springer       if (hasEffect<MemoryEffects::Free>(user)) {
9897c06f631SMatthias Springer         if (user->getBlock() == parallelCombiningParent->getBlock())
9905cc0f76dSMatthias Springer           rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
9917c06f631SMatthias Springer         break;
9927c06f631SMatthias Springer       }
9937c06f631SMatthias Springer     }
9947c06f631SMatthias Springer 
9954cd73620SMatthias Springer     // Delete the op.
9967fbf55c9SNicolas Vasilache     rewriter.eraseOp(op);
9977fbf55c9SNicolas Vasilache     return success();
9987fbf55c9SNicolas Vasilache   }
999d69e9491Sdonald chen 
1000d69e9491Sdonald chen   /// tensor.parallel_insert_slice op has implicit inplace behavior. We
1001d69e9491Sdonald chen   /// shouldn't create copy to resolve conflict.
1002d69e9491Sdonald chen   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
1003d69e9491Sdonald chen                                  const AnalysisState &state) const {
1004d69e9491Sdonald chen     return success();
1005d69e9491Sdonald chen   }
10067fbf55c9SNicolas Vasilache };
10077fbf55c9SNicolas Vasilache 
1008481b254eSMatthias Springer /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1009481b254eSMatthias Springer /// with a linalg.map. Similar to tensor.generate.
1010481b254eSMatthias Springer struct SplatOpInterface
1011481b254eSMatthias Springer     : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1012481b254eSMatthias Springer                                                     tensor::SplatOp> {
1013481b254eSMatthias Springer 
1014a02ad6c1SMatthias Springer   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1015481b254eSMatthias Springer 
1016481b254eSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1017481b254eSMatthias Springer                           const BufferizationOptions &options) const {
1018481b254eSMatthias Springer     OpBuilder::InsertionGuard g(rewriter);
1019481b254eSMatthias Springer     auto splatOp = cast<tensor::SplatOp>(op);
1020481b254eSMatthias Springer 
1021481b254eSMatthias Springer     // Allocate memory.
1022481b254eSMatthias Springer     Location loc = op->getLoc();
10236bf043e7SMartin Erhart     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
10246bf043e7SMartin Erhart         rewriter, loc, splatOp.getResult(), options,
1025481b254eSMatthias Springer         /*copy=*/false);
1026481b254eSMatthias Springer     if (failed(tensorAlloc))
1027481b254eSMatthias Springer       return failure();
1028481b254eSMatthias Springer 
1029481b254eSMatthias Springer     // Create linalg::MapOp.
1030481b254eSMatthias Springer     auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1031067d2779Sian Bearman 
1032067d2779Sian Bearman     // TODO: Implement memory space for this op.
1033067d2779Sian Bearman     if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1034067d2779Sian Bearman       return op->emitError("memory space not implemented yet");
1035067d2779Sian Bearman 
1036481b254eSMatthias Springer     auto linalgOp =
1037481b254eSMatthias Springer         rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
1038481b254eSMatthias Springer                                        /*init=*/*tensorAlloc);
1039481b254eSMatthias Springer     Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1040481b254eSMatthias Springer 
1041481b254eSMatthias Springer     // Create linalg::IndexOps.
1042481b254eSMatthias Springer     rewriter.setInsertionPointToStart(&linalgBody);
1043481b254eSMatthias Springer     rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
1044481b254eSMatthias Springer     rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
1045481b254eSMatthias Springer 
1046481b254eSMatthias Springer     return success();
1047481b254eSMatthias Springer   }
1048481b254eSMatthias Springer };
1049481b254eSMatthias Springer 
105049e37000SMatthias Springer } // namespace
105149e37000SMatthias Springer } // namespace tensor
105249e37000SMatthias Springer } // namespace mlir
105349e37000SMatthias Springer 
105449e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
105549e37000SMatthias Springer     DialectRegistry &registry) {
105677eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
105777eee579SRiver Riddle     CastOp::attachInterface<CastOpInterface>(*ctx);
105877eee579SRiver Riddle     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
105977eee579SRiver Riddle     DimOp::attachInterface<DimOpInterface>(*ctx);
1060be630f07SMatthias Springer     EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
106177eee579SRiver Riddle     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
106277eee579SRiver Riddle     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
106377eee579SRiver Riddle     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
106477eee579SRiver Riddle     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
106577eee579SRiver Riddle     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
106677eee579SRiver Riddle     InsertOp::attachInterface<InsertOpInterface>(*ctx);
106777eee579SRiver Riddle     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
10689ee12f47SMatthias Springer     PadOp::attachInterface<PadOpInterface>(*ctx);
10697fbf55c9SNicolas Vasilache     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
10707fbf55c9SNicolas Vasilache         *ctx);
107177eee579SRiver Riddle     RankOp::attachInterface<RankOpInterface>(*ctx);
1072e287d647SAshay Rane     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1073481b254eSMatthias Springer     SplatOp::attachInterface<SplatOpInterface>(*ctx);
10745f5f71e7SMatthias Springer 
10755f5f71e7SMatthias Springer     // Load additional dialects of which ops may get created.
1076c1f0a15cSMatthias Springer     ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
107777eee579SRiver Riddle   });
10788143307bSMatthias Springer 
10798143307bSMatthias Springer   // Bufferization requires SubsetInsertionOpInterface models. Make sure that
10808143307bSMatthias Springer   // they are registered.
10811abd8d1aSMatthias Springer   tensor::registerSubsetOpInterfaceExternalModels(registry);
108249e37000SMatthias Springer }
1083