149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 249e37000SMatthias Springer // 349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 649e37000SMatthias Springer // 749e37000SMatthias Springer //===----------------------------------------------------------------------===// 849e37000SMatthias Springer 949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 108143307bSMatthias Springer 11c37ed776SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 12abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1349e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 14b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 152d5edc64SMatthias Springer #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" 16c1f0a15cSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h" 1749e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 1949e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 208143307bSMatthias Springer #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" 219ee12f47SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h" 2298e838a8SMax191 #include "mlir/IR/BuiltinTypeInterfaces.h" 2349e37000SMatthias Springer #include "mlir/IR/Dialect.h" 2449e37000SMatthias Springer #include "mlir/IR/Operation.h" 2549e37000SMatthias Springer 2649e37000SMatthias Springer using namespace mlir; 2749e37000SMatthias Springer using namespace mlir::bufferization; 2849e37000SMatthias Springer using namespace mlir::tensor; 2949e37000SMatthias Springer 3049e37000SMatthias Springer namespace mlir { 3149e37000SMatthias Springer namespace tensor { 3249e37000SMatthias Springer namespace { 3349e37000SMatthias Springer 3449e37000SMatthias Springer struct CastOpInterface 3549e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<CastOpInterface, 3649e37000SMatthias Springer tensor::CastOp> { 3749e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 389597b16aSMatthias Springer const AnalysisState &state) const { 3949e37000SMatthias Springer return false; 4049e37000SMatthias Springer } 4149e37000SMatthias Springer 4249e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 439597b16aSMatthias Springer const AnalysisState &state) const { 4449e37000SMatthias Springer return false; 4549e37000SMatthias Springer } 4649e37000SMatthias Springer 47a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 489597b16aSMatthias Springer const AnalysisState &state) const { 499fa6b350SMatthias Springer return {{op->getResult(0), BufferRelation::Equivalent}}; 5049e37000SMatthias Springer } 5149e37000SMatthias Springer 52b6ae3f88SMatthias Springer FailureOr<BaseMemRefType> 53b6ae3f88SMatthias Springer getBufferType(Operation *op, Value value, const BufferizationOptions &options, 54878950b8SMatthias Springer SmallVector<Value> &invocationStack) const { 55b6ae3f88SMatthias Springer auto castOp = cast<tensor::CastOp>(op); 56878950b8SMatthias Springer auto maybeSrcBufferType = bufferization::getBufferType( 57878950b8SMatthias Springer castOp.getSource(), options, invocationStack); 58b6ae3f88SMatthias Springer if (failed(maybeSrcBufferType)) 59b6ae3f88SMatthias Springer return failure(); 60b6ae3f88SMatthias Springer Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); 61b6ae3f88SMatthias Springer 62b6ae3f88SMatthias Springer // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref 63b6ae3f88SMatthias Springer // type in case the input is an unranked tensor type. 64b6ae3f88SMatthias Springer 65b6ae3f88SMatthias Springer // Case 1: Casting an unranked tensor 665550c821STres Popp if (isa<UnrankedTensorType>(castOp.getSource().getType())) { 67b6ae3f88SMatthias Springer // When casting to a ranked tensor, we cannot infer any static offset or 68b6ae3f88SMatthias Springer // strides from the source. Assume fully dynamic. 69b6ae3f88SMatthias Springer return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); 70b6ae3f88SMatthias Springer } 71b6ae3f88SMatthias Springer 72b6ae3f88SMatthias Springer // Case 2: Casting to an unranked tensor type 735550c821STres Popp if (isa<UnrankedTensorType>(castOp.getType())) { 74b6ae3f88SMatthias Springer return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); 75b6ae3f88SMatthias Springer } 76b6ae3f88SMatthias Springer 77b6ae3f88SMatthias Springer // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not 78b6ae3f88SMatthias Springer // change. 795550c821STres Popp auto rankedResultType = cast<RankedTensorType>(castOp.getType()); 80b6ae3f88SMatthias Springer return MemRefType::get( 81b6ae3f88SMatthias Springer rankedResultType.getShape(), rankedResultType.getElementType(), 8268f58812STres Popp llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace); 83b6ae3f88SMatthias Springer } 84b6ae3f88SMatthias Springer 8549e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 86b55d55ecSMatthias Springer const BufferizationOptions &options) const { 8749e37000SMatthias Springer auto castOp = cast<tensor::CastOp>(op); 8849e37000SMatthias Springer 8949e37000SMatthias Springer // The result buffer still has the old (pre-cast) type. 905d50f51cSMatthias Springer FailureOr<Value> resultBuffer = 915d50f51cSMatthias Springer getBuffer(rewriter, castOp.getSource(), options); 925d50f51cSMatthias Springer if (failed(resultBuffer)) 935d50f51cSMatthias Springer return failure(); 9449e37000SMatthias Springer 95b6ae3f88SMatthias Springer // Compute the new type. 96b6ae3f88SMatthias Springer auto resultMemRefType = 97b6ae3f88SMatthias Springer bufferization::getBufferType(castOp.getResult(), options); 98b6ae3f88SMatthias Springer if (failed(resultMemRefType)) 99b6ae3f88SMatthias Springer return failure(); 1006cd7b655SKai Sasaki if (resultBuffer->getType() == *resultMemRefType) { 1016cd7b655SKai Sasaki // This cast is a no-op. 1026cd7b655SKai Sasaki replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); 1036cd7b655SKai Sasaki return success(); 1046cd7b655SKai Sasaki } 10549e37000SMatthias Springer 10649e37000SMatthias Springer // Replace the op with a memref.cast. 1075d50f51cSMatthias Springer assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 108b6ae3f88SMatthias Springer *resultMemRefType) && 10949e37000SMatthias Springer "CallOp::bufferize: cast incompatible"); 110b6ae3f88SMatthias Springer replaceOpWithNewBufferizedOp<memref::CastOp>( 111b6ae3f88SMatthias Springer rewriter, op, *resultMemRefType, *resultBuffer); 11249e37000SMatthias Springer 11349e37000SMatthias Springer return success(); 11449e37000SMatthias Springer } 11549e37000SMatthias Springer }; 11649e37000SMatthias Springer 117e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 118e6f69161SMatthias Springer struct CollapseShapeOpInterface 119e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 120e6f69161SMatthias Springer tensor::CollapseShapeOp> { 121e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1229597b16aSMatthias Springer const AnalysisState &state) const { 123ea71d2d0SMatthias Springer // tensor.collapse_shape may reallocate, at which point the source buffer is 124ea71d2d0SMatthias Springer // copied. I.e., there will be a memory read side effect on the bufferized 125ea71d2d0SMatthias Springer // source. This function conservatively returns "true" because whether a 126ea71d2d0SMatthias Springer // copy will be created or not is not known at this point. 127ea71d2d0SMatthias Springer return true; 128e6f69161SMatthias Springer } 129e6f69161SMatthias Springer 130e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1319597b16aSMatthias Springer const AnalysisState &state) const { 132e6f69161SMatthias Springer return false; 133e6f69161SMatthias Springer } 134e6f69161SMatthias Springer 135a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 1369597b16aSMatthias Springer const AnalysisState &state) const { 1371ac248e4SMatthias Springer // TODO: CollapseShapeOp may allocate at runtime. 1389fa6b350SMatthias Springer return {{op->getOpResult(0), BufferRelation::Equivalent}}; 139e6f69161SMatthias Springer } 140e6f69161SMatthias Springer 14104ff6009SMatthias Springer FailureOr<BaseMemRefType> 14204ff6009SMatthias Springer getBufferType(Operation *op, Value value, const BufferizationOptions &options, 143878950b8SMatthias Springer SmallVector<Value> &invocationStack) const { 14404ff6009SMatthias Springer auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 14504ff6009SMatthias Springer auto maybeSrcBufferType = bufferization::getBufferType( 146878950b8SMatthias Springer collapseShapeOp.getSrc(), options, invocationStack); 14704ff6009SMatthias Springer if (failed(maybeSrcBufferType)) 14804ff6009SMatthias Springer return failure(); 14968f58812STres Popp auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType); 15004ff6009SMatthias Springer bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 15104ff6009SMatthias Springer srcBufferType, collapseShapeOp.getReassociationIndices()); 15204ff6009SMatthias Springer 15304ff6009SMatthias Springer if (!canBeCollapsed) { 15404ff6009SMatthias Springer // If dims cannot be collapsed, this op bufferizes to a new allocation. 15504ff6009SMatthias Springer RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 15604ff6009SMatthias Springer return bufferization::getMemRefTypeWithStaticIdentityLayout( 1579bb63374SLei Zhang tensorResultType, srcBufferType.getMemorySpace()); 15804ff6009SMatthias Springer } 15904ff6009SMatthias Springer 16004ff6009SMatthias Springer return memref::CollapseShapeOp::computeCollapsedType( 16104ff6009SMatthias Springer srcBufferType, collapseShapeOp.getReassociationIndices()); 16204ff6009SMatthias Springer } 16304ff6009SMatthias Springer 164e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 165b55d55ecSMatthias Springer const BufferizationOptions &options) const { 166e6f69161SMatthias Springer auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 16751df6238SMatthias Springer RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 1685d50f51cSMatthias Springer FailureOr<Value> maybeBuffer = 1695d50f51cSMatthias Springer getBuffer(rewriter, collapseShapeOp.getSrc(), options); 1705d50f51cSMatthias Springer if (failed(maybeBuffer)) 1715d50f51cSMatthias Springer return failure(); 1725d50f51cSMatthias Springer Value buffer = *maybeBuffer; 1735550c821STres Popp auto bufferType = cast<MemRefType>(buffer.getType()); 17451df6238SMatthias Springer 17551df6238SMatthias Springer if (tensorResultType.getRank() == 0) { 17651df6238SMatthias Springer // 0-d collapses must go through a different op builder. 17773c0333dSMatthias Springer MemRefType resultType; 17873c0333dSMatthias Springer 17973c0333dSMatthias Springer if (bufferType.getLayout().isIdentity()) { 18073c0333dSMatthias Springer // Standard layout: result type has no offset. 18151df6238SMatthias Springer MemRefLayoutAttrInterface layout; 18273c0333dSMatthias Springer resultType = MemRefType::get({}, tensorResultType.getElementType(), 18351df6238SMatthias Springer layout, bufferType.getMemorySpace()); 18473c0333dSMatthias Springer } else { 18573c0333dSMatthias Springer // Source memref has a layout map: result type has the same offset as 18673c0333dSMatthias Springer // the source type. 18773c0333dSMatthias Springer SmallVector<int64_t> strides; 18873c0333dSMatthias Springer int64_t offset; 189*6aaa8f25SMatthias Springer if (failed(bufferType.getStridesAndOffset(strides, offset))) 19073c0333dSMatthias Springer return failure(); 19146b90a7bSAlex Zinenko resultType = MemRefType::get( 19246b90a7bSAlex Zinenko {}, tensorResultType.getElementType(), 19346b90a7bSAlex Zinenko StridedLayoutAttr::get(op->getContext(), offset, {}), 19446b90a7bSAlex Zinenko bufferType.getMemorySpace()); 19573c0333dSMatthias Springer } 19673c0333dSMatthias Springer 197e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 1988df54a6aSJacques Pienaar rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); 199e6f69161SMatthias Springer return success(); 200e6f69161SMatthias Springer } 20151df6238SMatthias Springer 202d7a9bf91SMatthias Springer // If the dims are not collapsible (due to an incompatible source layout 203d7a9bf91SMatthias Springer // map), force an out-of-place bufferization, i.e., a buffer copy. This 204d7a9bf91SMatthias Springer // newly allocated buffer will have no layout map and thus be collapsible. 205a74e5a89SAdrian Kuegel bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 206d7a9bf91SMatthias Springer bufferType, collapseShapeOp.getReassociationIndices()); 207b3ebe3beSMatthias Springer if (!canBeCollapsed) { 208b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion. 209b55d55ecSMatthias Springer AnalysisState analysisState(options); 21045b995cdSMatthias Springer FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 2116bf043e7SMartin Erhart rewriter, op->getLoc(), collapseShapeOp.getSrc(), options); 21245b995cdSMatthias Springer if (failed(tensorAlloc)) 21345b995cdSMatthias Springer return failure(); 214b3ebe3beSMatthias Springer auto memrefType = 215b3ebe3beSMatthias Springer MemRefType::get(collapseShapeOp.getSrcType().getShape(), 216b3ebe3beSMatthias Springer collapseShapeOp.getSrcType().getElementType(), 2179bb63374SLei Zhang AffineMap(), bufferType.getMemorySpace()); 218b3ebe3beSMatthias Springer buffer = rewriter.create<bufferization::ToMemrefOp>( 21945b995cdSMatthias Springer op->getLoc(), memrefType, *tensorAlloc); 220b3ebe3beSMatthias Springer } 221d7a9bf91SMatthias Springer 22251df6238SMatthias Springer // Result type is inferred by the builder. 22351df6238SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 224b3ebe3beSMatthias Springer rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 22551df6238SMatthias Springer return success(); 22651df6238SMatthias Springer } 227e6f69161SMatthias Springer }; 228e6f69161SMatthias Springer 22949e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim. 23049e37000SMatthias Springer struct DimOpInterface 23149e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<DimOpInterface, 23249e37000SMatthias Springer tensor::DimOp> { 23349e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2349597b16aSMatthias Springer const AnalysisState &state) const { 235e5dc99e6SMatthias Springer // The op reads the tensor's metadata but not its contents. 236e5dc99e6SMatthias Springer return false; 23749e37000SMatthias Springer } 23849e37000SMatthias Springer 23949e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 2409597b16aSMatthias Springer const AnalysisState &state) const { 24149e37000SMatthias Springer return false; 24249e37000SMatthias Springer } 24349e37000SMatthias Springer 244a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 2459597b16aSMatthias Springer const AnalysisState &state) const { 246585a8a32SMatthias Springer return {}; 24749e37000SMatthias Springer } 24849e37000SMatthias Springer 24949e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 250b55d55ecSMatthias Springer const BufferizationOptions &options) const { 25149e37000SMatthias Springer auto dimOp = cast<tensor::DimOp>(op); 2525d50f51cSMatthias Springer FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options); 2535d50f51cSMatthias Springer if (failed(v)) 2545d50f51cSMatthias Springer return failure(); 2555d50f51cSMatthias Springer replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, 256136d746eSJacques Pienaar dimOp.getIndex()); 25749e37000SMatthias Springer return success(); 25849e37000SMatthias Springer } 25949e37000SMatthias Springer }; 26049e37000SMatthias Springer 261464dfebaSMatthias Springer /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor". 262be630f07SMatthias Springer struct EmptyOpInterface 263be630f07SMatthias Springer : public BufferizableOpInterface::ExternalModel<EmptyOpInterface, 264be630f07SMatthias Springer tensor::EmptyOp> { 26558678d3bSMatthias Springer bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 26658678d3bSMatthias Springer 267330372f2SMatthias Springer bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult, 268330372f2SMatthias Springer const AnalysisState &state) const { 269330372f2SMatthias Springer // The returned tensor does not have specified contents. 270330372f2SMatthias Springer return false; 271330372f2SMatthias Springer } 272330372f2SMatthias Springer 273be630f07SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 274be630f07SMatthias Springer const BufferizationOptions &options) const { 275464dfebaSMatthias Springer auto emptyOp = cast<tensor::EmptyOp>(op); 276464dfebaSMatthias Springer 277464dfebaSMatthias Springer // Optimization: Fold away the op if it has no uses. 278ef4f5357SMatthias Springer if (op->getUses().empty()) { 279ef4f5357SMatthias Springer rewriter.eraseOp(op); 280ef4f5357SMatthias Springer return success(); 281ef4f5357SMatthias Springer } 282ef4f5357SMatthias Springer 283464dfebaSMatthias Springer // Allocate a tensor. This emits a "bufferization.alloc_tensor" op. 284464dfebaSMatthias Springer FailureOr<Value> allocTensor = allocateTensorForShapedValue( 285464dfebaSMatthias Springer rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false); 286464dfebaSMatthias Springer if (failed(allocTensor)) 287464dfebaSMatthias Springer return failure(); 288464dfebaSMatthias Springer rewriter.replaceOp(op, *allocTensor); 289464dfebaSMatthias Springer return success(); 290be630f07SMatthias Springer } 291be630f07SMatthias Springer }; 292be630f07SMatthias Springer 293e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 294e6f69161SMatthias Springer struct ExpandShapeOpInterface 295e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 296e6f69161SMatthias Springer tensor::ExpandShapeOp> { 297e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2989597b16aSMatthias Springer const AnalysisState &state) const { 299ea71d2d0SMatthias Springer // In contrast to tensor.collapse_shape, this op can always be bufferized 300ea71d2d0SMatthias Springer // without a copy. 301e6f69161SMatthias Springer return false; 302e6f69161SMatthias Springer } 303e6f69161SMatthias Springer 304e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 3059597b16aSMatthias Springer const AnalysisState &state) const { 306e6f69161SMatthias Springer return false; 307e6f69161SMatthias Springer } 308e6f69161SMatthias Springer 309a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 3109597b16aSMatthias Springer const AnalysisState &state) const { 3119fa6b350SMatthias Springer return {{op->getOpResult(0), BufferRelation::Equivalent}}; 312e6f69161SMatthias Springer } 313e6f69161SMatthias Springer 31404ff6009SMatthias Springer FailureOr<BaseMemRefType> 31504ff6009SMatthias Springer getBufferType(Operation *op, Value value, const BufferizationOptions &options, 316878950b8SMatthias Springer SmallVector<Value> &invocationStack) const { 31704ff6009SMatthias Springer auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 31804ff6009SMatthias Springer auto maybeSrcBufferType = bufferization::getBufferType( 319878950b8SMatthias Springer expandShapeOp.getSrc(), options, invocationStack); 32004ff6009SMatthias Springer if (failed(maybeSrcBufferType)) 32104ff6009SMatthias Springer return failure(); 32268f58812STres Popp auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType); 32304ff6009SMatthias Springer auto maybeResultType = memref::ExpandShapeOp::computeExpandedType( 32404ff6009SMatthias Springer srcBufferType, expandShapeOp.getResultType().getShape(), 32504ff6009SMatthias Springer expandShapeOp.getReassociationIndices()); 32604ff6009SMatthias Springer if (failed(maybeResultType)) 32704ff6009SMatthias Springer return failure(); 32804ff6009SMatthias Springer return *maybeResultType; 32904ff6009SMatthias Springer } 33004ff6009SMatthias Springer 331e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 332b55d55ecSMatthias Springer const BufferizationOptions &options) const { 333e6f69161SMatthias Springer auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 33451df6238SMatthias Springer auto tensorResultType = expandShapeOp.getResultType(); 3355d50f51cSMatthias Springer FailureOr<Value> buffer = 3365d50f51cSMatthias Springer getBuffer(rewriter, expandShapeOp.getSrc(), options); 3375d50f51cSMatthias Springer if (failed(buffer)) 3385d50f51cSMatthias Springer return failure(); 33951df6238SMatthias Springer 34051df6238SMatthias Springer // Memref result type is inferred by the builder based on reassociation 34151df6238SMatthias Springer // indices and result shape. 34297069a86SGaurav Shukla // TODO: Instead of inferring the output shape argument of 34397069a86SGaurav Shukla // memref.expand_shape op, use output_shape argument of tensor.expand_shape 34497069a86SGaurav Shukla // op. 345e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 3465d50f51cSMatthias Springer rewriter, op, tensorResultType.getShape(), *buffer, 34751df6238SMatthias Springer expandShapeOp.getReassociationIndices()); 348e6f69161SMatthias Springer return success(); 349e6f69161SMatthias Springer } 350e6f69161SMatthias Springer }; 351e6f69161SMatthias Springer 35249e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview. 35349e37000SMatthias Springer struct ExtractSliceOpInterface 35449e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 35549e37000SMatthias Springer tensor::ExtractSliceOp> { 35649e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 3579597b16aSMatthias Springer const AnalysisState &state) const { 35849e37000SMatthias Springer return false; 35949e37000SMatthias Springer } 36049e37000SMatthias Springer 36149e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 3629597b16aSMatthias Springer const AnalysisState &state) const { 36349e37000SMatthias Springer return false; 36449e37000SMatthias Springer } 36549e37000SMatthias Springer 366a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 3679597b16aSMatthias Springer const AnalysisState &state) const { 3689fa6b350SMatthias Springer return {{op->getOpResult(0), BufferRelation::Unknown}}; 36949e37000SMatthias Springer } 37049e37000SMatthias Springer 37149e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 372b55d55ecSMatthias Springer const BufferizationOptions &options) const { 37349e37000SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 3746c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 3756c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 3766c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 37749e37000SMatthias Springer Location loc = extractSliceOp.getLoc(); 378d7a9bf91SMatthias Springer 3796c3c5f80SMatthias Springer // Get source buffer. 3805d50f51cSMatthias Springer FailureOr<Value> srcMemref = 3815d50f51cSMatthias Springer getBuffer(rewriter, extractSliceOp.getSource(), options); 3825d50f51cSMatthias Springer if (failed(srcMemref)) 3835d50f51cSMatthias Springer return failure(); 38449e37000SMatthias Springer 3856c3c5f80SMatthias Springer // Take a subview of the source buffer. 386111c9196SMatthias Springer auto resultMemrefType = 387123c4b02SMatthias Springer bufferization::getBufferType(extractSliceOp.getResult(), options); 388111c9196SMatthias Springer if (failed(resultMemrefType)) 389111c9196SMatthias Springer return failure(); 39049e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 391d69e9491Sdonald chen loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, 392d69e9491Sdonald chen mixedOffsets, mixedSizes, mixedStrides); 39349e37000SMatthias Springer 39449e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, subView); 39549e37000SMatthias Springer return success(); 39649e37000SMatthias Springer } 397111c9196SMatthias Springer 398111c9196SMatthias Springer FailureOr<BaseMemRefType> 399123c4b02SMatthias Springer getBufferType(Operation *op, Value value, const BufferizationOptions &options, 400878950b8SMatthias Springer SmallVector<Value> &invocationStack) const { 401111c9196SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 402111c9196SMatthias Springer assert(value == extractSliceOp.getResult() && "invalid value"); 403123c4b02SMatthias Springer auto srcMemrefType = bufferization::getBufferType( 404878950b8SMatthias Springer extractSliceOp.getSource(), options, invocationStack); 405111c9196SMatthias Springer if (failed(srcMemrefType)) 406111c9196SMatthias Springer return failure(); 407111c9196SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 408111c9196SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 409111c9196SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 4105550c821STres Popp return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType( 411d69e9491Sdonald chen extractSliceOp.getType().getShape(), 412d69e9491Sdonald chen llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes, 413d69e9491Sdonald chen mixedStrides)); 414111c9196SMatthias Springer } 41549e37000SMatthias Springer }; 41649e37000SMatthias Springer 41749e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load. 41849e37000SMatthias Springer struct ExtractOpInterface 41949e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 42049e37000SMatthias Springer tensor::ExtractOp> { 42149e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 4229597b16aSMatthias Springer const AnalysisState &state) const { 42349e37000SMatthias Springer return true; 42449e37000SMatthias Springer } 42549e37000SMatthias Springer 42649e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 4279597b16aSMatthias Springer const AnalysisState &state) const { 42849e37000SMatthias Springer return false; 42949e37000SMatthias Springer } 43049e37000SMatthias Springer 431a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 4329597b16aSMatthias Springer const AnalysisState &state) const { 433585a8a32SMatthias Springer return {}; 43449e37000SMatthias Springer } 43549e37000SMatthias Springer 43649e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 437b55d55ecSMatthias Springer const BufferizationOptions &options) const { 43849e37000SMatthias Springer auto extractOp = cast<tensor::ExtractOp>(op); 4395d50f51cSMatthias Springer FailureOr<Value> srcMemref = 4405d50f51cSMatthias Springer getBuffer(rewriter, extractOp.getTensor(), options); 4415d50f51cSMatthias Springer if (failed(srcMemref)) 4425d50f51cSMatthias Springer return failure(); 4435d50f51cSMatthias Springer replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, 444136d746eSJacques Pienaar extractOp.getIndices()); 44549e37000SMatthias Springer return success(); 44649e37000SMatthias Springer } 44749e37000SMatthias Springer }; 44849e37000SMatthias Springer 449d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while 450d581c94dSMatthias Springer // iterating over op.elements(). 451d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim, 452d581c94dSMatthias Springer Value buffer, ArrayRef<int64_t> shape, 453d581c94dSMatthias Springer ArrayRef<Value> constants, 454d581c94dSMatthias Springer OperandRange::iterator &elementIt, 455d581c94dSMatthias Springer SmallVectorImpl<Value> &indices) { 456d581c94dSMatthias Springer if (dim == static_cast<int>(shape.size()) - 1) { 457d581c94dSMatthias Springer for (int i = 0; i < shape.back(); ++i) { 458d581c94dSMatthias Springer indices.back() = constants[i]; 459d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 460d581c94dSMatthias Springer ++elementIt; 461d581c94dSMatthias Springer } 462d581c94dSMatthias Springer return; 463d581c94dSMatthias Springer } 464d581c94dSMatthias Springer for (int i = 0; i < shape[dim]; ++i) { 465d581c94dSMatthias Springer indices[dim] = constants[i]; 466d581c94dSMatthias Springer createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 467d581c94dSMatthias Springer indices); 468d581c94dSMatthias Springer } 469d581c94dSMatthias Springer } 470d581c94dSMatthias Springer 471d581c94dSMatthias Springer /// Bufferization of tensor.from_elements. 472d581c94dSMatthias Springer struct FromElementsOpInterface 473d581c94dSMatthias Springer : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 474d581c94dSMatthias Springer tensor::FromElementsOp> { 475664ffa46SMatthias Springer 476a02ad6c1SMatthias Springer bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 477664ffa46SMatthias Springer 478d581c94dSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 479b55d55ecSMatthias Springer const BufferizationOptions &options) const { 480d581c94dSMatthias Springer auto fromElementsOp = cast<tensor::FromElementsOp>(op); 481067d2779Sian Bearman auto tensorType = cast<RankedTensorType>(fromElementsOp.getType()); 482d581c94dSMatthias Springer 483d581c94dSMatthias Springer // Allocate a buffer for the result. 484d581c94dSMatthias Springer Location loc = op->getLoc(); 485d581c94dSMatthias Springer auto shape = tensorType.getShape(); 486b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion. 4876bf043e7SMartin Erhart FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 4886bf043e7SMartin Erhart rewriter, loc, fromElementsOp.getResult(), options, 489b3ebe3beSMatthias Springer /*copy=*/false); 49045b995cdSMatthias Springer if (failed(tensorAlloc)) 49145b995cdSMatthias Springer return failure(); 492ced2fc78SChristopher Bate FailureOr<BaseMemRefType> memrefType = 493ced2fc78SChristopher Bate bufferization::getBufferType(*tensorAlloc, options); 494ced2fc78SChristopher Bate if (failed(memrefType)) 495ced2fc78SChristopher Bate return failure(); 496b3ebe3beSMatthias Springer Value buffer = rewriter.create<bufferization::ToMemrefOp>( 497ced2fc78SChristopher Bate op->getLoc(), *memrefType, *tensorAlloc); 498d581c94dSMatthias Springer 499d581c94dSMatthias Springer // Case: tensor<0xelem_type>. 5008df54a6aSJacques Pienaar if (fromElementsOp.getElements().empty()) { 501d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 502d581c94dSMatthias Springer return success(); 503d581c94dSMatthias Springer } 504d581c94dSMatthias Springer 505d581c94dSMatthias Springer // Case: tensor<elem_type>. 506d581c94dSMatthias Springer if (shape.empty()) { 5078df54a6aSJacques Pienaar rewriter.create<memref::StoreOp>( 5088df54a6aSJacques Pienaar loc, fromElementsOp.getElements().front(), buffer); 509d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 510d581c94dSMatthias Springer return success(); 511d581c94dSMatthias Springer } 512d581c94dSMatthias Springer 513d581c94dSMatthias Springer // Create constants for the range of possible indices [0, max{shape_i}). 514fab2bb8bSJustin Lebar auto maxDim = *llvm::max_element(shape); 515d581c94dSMatthias Springer SmallVector<Value, 2> constants; 516d581c94dSMatthias Springer constants.reserve(maxDim); 517d581c94dSMatthias Springer for (int i = 0; i < maxDim; ++i) 518d581c94dSMatthias Springer constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 519d581c94dSMatthias Springer 520d581c94dSMatthias Springer // Traverse all `elements` and create `memref.store` ops. 5218df54a6aSJacques Pienaar auto elementIt = fromElementsOp.getElements().begin(); 522d581c94dSMatthias Springer SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 523d581c94dSMatthias Springer createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 524d581c94dSMatthias Springer indices); 525d581c94dSMatthias Springer 526d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 527664ffa46SMatthias Springer 528d581c94dSMatthias Springer return success(); 529d581c94dSMatthias Springer } 530d581c94dSMatthias Springer }; 531d581c94dSMatthias Springer 532c1f0a15cSMatthias Springer /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim). 533c1f0a15cSMatthias Springer /// Such ops are lowered to linalg.map with the given tensor as a destination. 534c1f0a15cSMatthias Springer /// 535c1f0a15cSMatthias Springer /// Example: 536c1f0a15cSMatthias Springer /// ``` 537c1f0a15cSMatthias Springer /// %r = tensor.generate %x, %y { 538c1f0a15cSMatthias Springer /// ^bb0(%arg0: index, %arg1: index): 539c1f0a15cSMatthias Springer /// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index) 540c1f0a15cSMatthias Springer /// tensor.yield %0 : index 541c1f0a15cSMatthias Springer /// } : tensor<?x?xindex> 542c1f0a15cSMatthias Springer /// ``` 543c1f0a15cSMatthias Springer /// 544c1f0a15cSMatthias Springer /// Is lowered to: 545c1f0a15cSMatthias Springer /// ``` 546c1f0a15cSMatthias Springer /// linalg.map ins() outs(%dest) { 547c1f0a15cSMatthias Springer /// %d0 = linalg.index 0 : index 548c1f0a15cSMatthias Springer /// %d1 = linalg.index 1 : index 549c1f0a15cSMatthias Springer /// %0 = "some_op"(%d0, %d1) : (index, index) -> (index) 550c1f0a15cSMatthias Springer /// linalg.yield %0 : index 551c1f0a15cSMatthias Springer /// } 552c1f0a15cSMatthias Springer /// ``` 553c1f0a15cSMatthias Springer static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, 554c1f0a15cSMatthias Springer Value tensorDestination, 555c1f0a15cSMatthias Springer ValueRange dynamicSizes, 556c1f0a15cSMatthias Springer Region &generateBody) { 557c1f0a15cSMatthias Springer assert(generateBody.hasOneBlock() && "expected body with single block"); 5585550c821STres Popp auto tensorType = cast<RankedTensorType>(tensorDestination.getType()); 559c1f0a15cSMatthias Springer assert(generateBody.getNumArguments() == tensorType.getRank() && 560c1f0a15cSMatthias Springer "rank mismatch"); 561c1f0a15cSMatthias Springer 562c1f0a15cSMatthias Springer // Create linalg::MapOp. 563c1f0a15cSMatthias Springer OpBuilder::InsertionGuard g(rewriter); 564c1f0a15cSMatthias Springer auto linalgOp = 565c1f0a15cSMatthias Springer rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), 566c1f0a15cSMatthias Springer /*init=*/tensorDestination); 567c1f0a15cSMatthias Springer Block &linalgBody = linalgOp.getMapper().emplaceBlock(); 568c1f0a15cSMatthias Springer 569c1f0a15cSMatthias Springer // Create linalg::IndexOps. 570c1f0a15cSMatthias Springer rewriter.setInsertionPointToStart(&linalgBody); 571c1f0a15cSMatthias Springer SmallVector<Value> indices; 572c1f0a15cSMatthias Springer for (int64_t dim = 0; dim < tensorType.getRank(); ++dim) 573c1f0a15cSMatthias Springer indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim)); 574c1f0a15cSMatthias Springer 575c1f0a15cSMatthias Springer // Move over body. 576c1f0a15cSMatthias Springer rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices); 577c1f0a15cSMatthias Springer auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator()); 578c1f0a15cSMatthias Springer rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue()); 579c1f0a15cSMatthias Springer 580c1f0a15cSMatthias Springer return linalgOp.getResult()[0]; 581c1f0a15cSMatthias Springer } 582c1f0a15cSMatthias Springer 58371bbb78bSMatthias Springer /// Bufferization of tensor.generate. 58471bbb78bSMatthias Springer struct GenerateOpInterface 58571bbb78bSMatthias Springer : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 58671bbb78bSMatthias Springer tensor::GenerateOp> { 587664ffa46SMatthias Springer 588a02ad6c1SMatthias Springer bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 589664ffa46SMatthias Springer 59071bbb78bSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 591b55d55ecSMatthias Springer const BufferizationOptions &options) const { 59271bbb78bSMatthias Springer auto generateOp = cast<tensor::GenerateOp>(op); 593c0b0b6a0SMatthias Springer 594067d2779Sian Bearman auto type = generateOp.getResult().getType(); 595067d2779Sian Bearman 596c0b0b6a0SMatthias Springer // TODO: Implement memory space for this op. 597067d2779Sian Bearman if (options.defaultMemorySpaceFn(type) != Attribute()) 598c0b0b6a0SMatthias Springer return op->emitError("memory space not implemented yet"); 599c0b0b6a0SMatthias Springer 60071bbb78bSMatthias Springer // Allocate memory. 60171bbb78bSMatthias Springer Location loc = op->getLoc(); 6026bf043e7SMartin Erhart FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 6036bf043e7SMartin Erhart rewriter, loc, generateOp.getResult(), options, 604b3ebe3beSMatthias Springer /*copy=*/false); 60545b995cdSMatthias Springer if (failed(tensorAlloc)) 60645b995cdSMatthias Springer return failure(); 60771bbb78bSMatthias Springer 608c1f0a15cSMatthias Springer Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc, 609c1f0a15cSMatthias Springer generateOp.getDynamicExtents(), 610c1f0a15cSMatthias Springer generateOp.getBody()); 611c1f0a15cSMatthias Springer rewriter.replaceOp(generateOp, result); 612664ffa46SMatthias Springer 61371bbb78bSMatthias Springer return success(); 61471bbb78bSMatthias Springer } 61571bbb78bSMatthias Springer }; 61671bbb78bSMatthias Springer 61749e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store. 6182d5edc64SMatthias Springer /// 6192d5edc64SMatthias Springer /// Note: DstBufferizableOpInterfaceExternalModel provides many default method 6202d5edc64SMatthias Springer /// implementations for DestinationStyle ops. 62149e37000SMatthias Springer struct InsertOpInterface 6222d5edc64SMatthias Springer : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface, 62349e37000SMatthias Springer tensor::InsertOp> { 62449e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 625b55d55ecSMatthias Springer const BufferizationOptions &options) const { 62649e37000SMatthias Springer auto insertOp = cast<tensor::InsertOp>(op); 6275d50f51cSMatthias Springer FailureOr<Value> destMemref = 6285d50f51cSMatthias Springer getBuffer(rewriter, insertOp.getDest(), options); 6295d50f51cSMatthias Springer if (failed(destMemref)) 6305d50f51cSMatthias Springer return failure(); 6318df54a6aSJacques Pienaar rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), 6325d50f51cSMatthias Springer *destMemref, insertOp.getIndices()); 6335d50f51cSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *destMemref); 63449e37000SMatthias Springer return success(); 63549e37000SMatthias Springer } 63649e37000SMatthias Springer }; 63749e37000SMatthias Springer 63898e838a8SMax191 template <typename InsertOpTy> 63998e838a8SMax191 static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp, 64098e838a8SMax191 OpOperand &opOperand) { 64113593dc9SMatthias Springer // The source is always read. 64255585043SMatthias Springer if (opOperand == insertSliceOp.getSourceMutable()) 64313593dc9SMatthias Springer return true; 64413593dc9SMatthias Springer 64513593dc9SMatthias Springer // For the destination, it depends... 64655585043SMatthias Springer assert(opOperand == insertSliceOp.getDestMutable() && "expected dest"); 64713593dc9SMatthias Springer 64813593dc9SMatthias Springer // Dest is not read if it is entirely overwritten. E.g.: 64913593dc9SMatthias Springer // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> 65013593dc9SMatthias Springer bool allOffsetsZero = 65198e838a8SMax191 llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex); 65298e838a8SMax191 RankedTensorType destType = insertSliceOp.getDestType(); 65398e838a8SMax191 bool sizesMatchDestSizes = 65498e838a8SMax191 areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape()); 65513593dc9SMatthias Springer bool allStridesOne = 65698e838a8SMax191 areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1); 65713593dc9SMatthias Springer return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); 65813593dc9SMatthias Springer } 65913593dc9SMatthias Springer 66098e838a8SMax191 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 66198e838a8SMax191 /// certain circumstances, this op can also be a no-op. 66298e838a8SMax191 /// 66398e838a8SMax191 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method 66498e838a8SMax191 /// implementations for DestinationStyle ops. 66598e838a8SMax191 struct InsertSliceOpInterface 66698e838a8SMax191 : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface, 66798e838a8SMax191 tensor::InsertSliceOp> { 66898e838a8SMax191 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 66998e838a8SMax191 const AnalysisState &state) const { 67098e838a8SMax191 return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op), 67198e838a8SMax191 opOperand); 67298e838a8SMax191 } 67398e838a8SMax191 67449e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 675b55d55ecSMatthias Springer const BufferizationOptions &options) const { 67649e37000SMatthias Springer // insert_slice ops arise from tiling and bufferizing them out-of-place is 67749e37000SMatthias Springer // generally a deal breaker. When used with loops, this ends up cloning the 67849e37000SMatthias Springer // whole tensor on every single iteration and is a symptom of a 67949e37000SMatthias Springer // catastrophically bad scheduling decision. 68049e37000SMatthias Springer // TODO: be very loud about it or even consider failing the pass. 68149e37000SMatthias Springer auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 6826c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 6836c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 6846c3c5f80SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 68549e37000SMatthias Springer Location loc = insertSliceOp.getLoc(); 6866c3c5f80SMatthias Springer 6876c3c5f80SMatthias Springer // Get destination buffer. 6885d50f51cSMatthias Springer FailureOr<Value> dstMemref = 6895d50f51cSMatthias Springer getBuffer(rewriter, insertSliceOp.getDest(), options); 6905d50f51cSMatthias Springer if (failed(dstMemref)) 6915d50f51cSMatthias Springer return failure(); 69249e37000SMatthias Springer 6936c3c5f80SMatthias Springer // Take a subview of the destination buffer. 6945550c821STres Popp auto dstMemrefType = cast<MemRefType>(dstMemref->getType()); 69549e37000SMatthias Springer auto subviewMemRefType = 6965550c821STres Popp cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 6976c3c5f80SMatthias Springer insertSliceOp.getSourceType().getShape(), dstMemrefType, 6985550c821STres Popp mixedOffsets, mixedSizes, mixedStrides)); 69949e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 7005d50f51cSMatthias Springer loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 70149e37000SMatthias Springer mixedStrides); 70249e37000SMatthias Springer 70349e37000SMatthias Springer // Copy tensor. If this tensor.insert_slice has a matching 70449e37000SMatthias Springer // tensor.extract_slice, the copy operation will eventually fold away. 7055d50f51cSMatthias Springer FailureOr<Value> srcMemref = 7065d50f51cSMatthias Springer getBuffer(rewriter, insertSliceOp.getSource(), options); 7075d50f51cSMatthias Springer if (failed(srcMemref)) 7085d50f51cSMatthias Springer return failure(); 7095d50f51cSMatthias Springer if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) 71049e37000SMatthias Springer return failure(); 71149e37000SMatthias Springer 7125d50f51cSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 71349e37000SMatthias Springer return success(); 71449e37000SMatthias Springer } 71549e37000SMatthias Springer }; 71649e37000SMatthias Springer 71709dfb441SMatthias Springer /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor + 71809dfb441SMatthias Springer /// linalg.map + insert_slice. 7199ee12f47SMatthias Springer /// For best performance, vectorize before bufferization (better performance in 7209ee12f47SMatthias Springer /// case of padding with a constant). 7219ee12f47SMatthias Springer struct PadOpInterface 7229ee12f47SMatthias Springer : public BufferizableOpInterface::ExternalModel<PadOpInterface, 7239ee12f47SMatthias Springer tensor::PadOp> { 724a02ad6c1SMatthias Springer bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 7259ee12f47SMatthias Springer 7269ee12f47SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 7279ee12f47SMatthias Springer const AnalysisState &state) const { 7289ee12f47SMatthias Springer return true; 7299ee12f47SMatthias Springer } 7309ee12f47SMatthias Springer 7319ee12f47SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 7329ee12f47SMatthias Springer const AnalysisState &state) const { 7339ee12f47SMatthias Springer return false; 7349ee12f47SMatthias Springer } 7359ee12f47SMatthias Springer 736a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 7379ee12f47SMatthias Springer const AnalysisState &state) const { 7389ee12f47SMatthias Springer return {}; 7399ee12f47SMatthias Springer } 7409ee12f47SMatthias Springer 74109dfb441SMatthias Springer FailureOr<BaseMemRefType> 74209dfb441SMatthias Springer getBufferType(Operation *op, Value value, const BufferizationOptions &options, 743878950b8SMatthias Springer SmallVector<Value> &invocationStack) const { 74409dfb441SMatthias Springer // Infer memory space from the source tensor. 74509dfb441SMatthias Springer auto padOp = cast<tensor::PadOp>(op); 746878950b8SMatthias Springer auto maybeSrcBufferType = bufferization::getBufferType( 747878950b8SMatthias Springer padOp.getSource(), options, invocationStack); 74809dfb441SMatthias Springer if (failed(maybeSrcBufferType)) 74909dfb441SMatthias Springer return failure(); 75009dfb441SMatthias Springer MemRefLayoutAttrInterface layout; 75109dfb441SMatthias Springer return MemRefType::get(padOp.getResultType().getShape(), 75209dfb441SMatthias Springer padOp.getResultType().getElementType(), layout, 75309dfb441SMatthias Springer maybeSrcBufferType->getMemorySpace()); 75409dfb441SMatthias Springer } 75509dfb441SMatthias Springer 7569ee12f47SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 7579ee12f47SMatthias Springer const BufferizationOptions &options) const { 7589ee12f47SMatthias Springer auto padOp = cast<tensor::PadOp>(op); 7599ee12f47SMatthias Springer Location loc = padOp.getLoc(); 7609ee12f47SMatthias Springer RankedTensorType resultType = padOp.getResultType(); 7619ee12f47SMatthias Springer RankedTensorType srcType = padOp.getSourceType(); 7629ee12f47SMatthias Springer 7639ee12f47SMatthias Springer auto toValue = [&](OpFoldResult ofr) { 764fecf1397SKazu Hirata if (auto value = dyn_cast<Value>(ofr)) 765fecf1397SKazu Hirata return value; 7669ee12f47SMatthias Springer return rewriter 7679ee12f47SMatthias Springer .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr)) 7689ee12f47SMatthias Springer .getResult(); 7699ee12f47SMatthias Springer }; 7709ee12f47SMatthias Springer 7719ee12f47SMatthias Springer // Compute dynamic result dimensions. 7729ee12f47SMatthias Springer SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad(); 7739ee12f47SMatthias Springer SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad(); 7749ee12f47SMatthias Springer SmallVector<Value> dynamicSizes; 7759ee12f47SMatthias Springer for (int64_t i = 0; i < resultType.getRank(); ++i) { 7769ee12f47SMatthias Springer if (!resultType.isDynamicDim(i)) 7779ee12f47SMatthias Springer continue; 7789ee12f47SMatthias Springer Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i); 7799ee12f47SMatthias Springer Value lowPad = toValue(mixedLowPad[i]); 7809ee12f47SMatthias Springer Value highPad = toValue(mixedHighPad[i]); 781c37ed776SMatthias Springer AffineExpr s0, s1, s2; 782c37ed776SMatthias Springer bindSymbols(op->getContext(), s0, s1, s2); 783c37ed776SMatthias Springer AffineExpr sumExpr = s0 + s1 + s2; 7844c48f016SMatthias Springer Value sum = rewriter.create<affine::AffineApplyOp>( 785c37ed776SMatthias Springer loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); 786c37ed776SMatthias Springer dynamicSizes.push_back(sum); 7879ee12f47SMatthias Springer } 7889ee12f47SMatthias Springer 78909dfb441SMatthias Springer // Allocate a buffer for the padded result. 79009dfb441SMatthias Springer FailureOr<Value> tensorAlloc = 7916bf043e7SMartin Erhart allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options, 79209dfb441SMatthias Springer /*copy=*/false); 79309dfb441SMatthias Springer if (failed(tensorAlloc)) 79409dfb441SMatthias Springer return failure(); 79509dfb441SMatthias Springer 79609dfb441SMatthias Springer // tensor::PadOp is like tensor::GenerateOp: The only difference is that 79709dfb441SMatthias Springer // only a part of the generated tensor is needed. For simplicity, we reuse 79809dfb441SMatthias Springer // the same functionality here. 79909dfb441SMatthias Springer Value filledBuffer = lowerGenerateLikeOpBody( 80009dfb441SMatthias Springer rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion()); 8019ee12f47SMatthias Springer 8029ee12f47SMatthias Springer // Create tensor::InsertSliceOp. 803ba95bf76SMatthias Springer SmallVector<OpFoldResult> sliceSizes = 804ba95bf76SMatthias Springer getMixedSizes(rewriter, loc, padOp.getSource()); 805ba95bf76SMatthias Springer SmallVector<OpFoldResult> sliceStrides(srcType.getRank(), 806ba95bf76SMatthias Springer rewriter.getIndexAttr(1)); 8079ee12f47SMatthias Springer rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 80809dfb441SMatthias Springer padOp, padOp.getSource(), filledBuffer, 8099ee12f47SMatthias Springer /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); 8109ee12f47SMatthias Springer 8119ee12f47SMatthias Springer return success(); 8129ee12f47SMatthias Springer } 8139ee12f47SMatthias Springer }; 8149ee12f47SMatthias Springer 815fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank. 816fc08d1c2SMatthias Springer struct RankOpInterface 817fc08d1c2SMatthias Springer : public BufferizableOpInterface::ExternalModel<RankOpInterface, 818fc08d1c2SMatthias Springer tensor::RankOp> { 819fc08d1c2SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 8209597b16aSMatthias Springer const AnalysisState &state) const { 821e5dc99e6SMatthias Springer // The op reads the tensor's metadata but not its contents. 822e5dc99e6SMatthias Springer return false; 823fc08d1c2SMatthias Springer } 824fc08d1c2SMatthias Springer 825fc08d1c2SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 8269597b16aSMatthias Springer const AnalysisState &state) const { 827fc08d1c2SMatthias Springer return false; 828fc08d1c2SMatthias Springer } 829fc08d1c2SMatthias Springer 830a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 8319597b16aSMatthias Springer const AnalysisState &state) const { 832585a8a32SMatthias Springer return {}; 833fc08d1c2SMatthias Springer } 834fc08d1c2SMatthias Springer 835fc08d1c2SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 836b55d55ecSMatthias Springer const BufferizationOptions &options) const { 837fc08d1c2SMatthias Springer auto rankOp = cast<tensor::RankOp>(op); 8385d50f51cSMatthias Springer FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options); 8395d50f51cSMatthias Springer if (failed(v)) 8405d50f51cSMatthias Springer return failure(); 841fc08d1c2SMatthias Springer replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 8425d50f51cSMatthias Springer *v); 843fc08d1c2SMatthias Springer return success(); 844fc08d1c2SMatthias Springer } 845fc08d1c2SMatthias Springer }; 846fc08d1c2SMatthias Springer 847e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape. 848e287d647SAshay Rane struct ReshapeOpInterface 849e287d647SAshay Rane : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 850e287d647SAshay Rane tensor::ReshapeOp> { 851e287d647SAshay Rane bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 852e287d647SAshay Rane const AnalysisState &state) const { 853ea71d2d0SMatthias Springer // Depending on the layout map, the source buffer may have to be copied. 8540f952cfeSMatthias Springer auto reshapeOp = cast<tensor::ReshapeOp>(op); 85555585043SMatthias Springer return opOperand == reshapeOp.getShapeMutable(); 856e287d647SAshay Rane } 857e287d647SAshay Rane 858e287d647SAshay Rane bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 859e287d647SAshay Rane const AnalysisState &state) const { 860e287d647SAshay Rane return false; 861e287d647SAshay Rane } 862e287d647SAshay Rane 863a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 864e287d647SAshay Rane const AnalysisState &state) const { 8659fa6b350SMatthias Springer return {{op->getOpResult(0), BufferRelation::Equivalent}}; 866e287d647SAshay Rane } 867e287d647SAshay Rane 868e287d647SAshay Rane LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 869b55d55ecSMatthias Springer const BufferizationOptions &options) const { 870e287d647SAshay Rane auto reshapeOp = cast<tensor::ReshapeOp>(op); 8715d50f51cSMatthias Springer FailureOr<Value> srcBuffer = 8725d50f51cSMatthias Springer getBuffer(rewriter, reshapeOp.getSource(), options); 8735d50f51cSMatthias Springer FailureOr<Value> shapeBuffer = 8745d50f51cSMatthias Springer getBuffer(rewriter, reshapeOp.getShape(), options); 8755d50f51cSMatthias Springer if (failed(srcBuffer) || failed(shapeBuffer)) 8765d50f51cSMatthias Springer return failure(); 8779dbb8eefSIngo Müller auto maybeResultMemRefType = 8789dbb8eefSIngo Müller bufferization::getBufferType(reshapeOp.getResult(), options); 8799dbb8eefSIngo Müller if (failed(maybeResultMemRefType)) 8809dbb8eefSIngo Müller return failure(); 8810a0c7e89SSpenser Bauman 8820a0c7e89SSpenser Bauman // memref.reshape requires the source buffer to have an identity layout. 883ea71d2d0SMatthias Springer // If the source memref does not have an identity layout, copy the source 8840a0c7e89SSpenser Bauman // into a new buffer with an identity layout. 8850a0c7e89SSpenser Bauman auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType()); 8860a0c7e89SSpenser Bauman if (srcType && !srcType.getLayout().isIdentity()) { 887ea71d2d0SMatthias Springer FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 888ea71d2d0SMatthias Springer rewriter, op->getLoc(), reshapeOp.getSource(), options); 889ea71d2d0SMatthias Springer if (failed(tensorAlloc)) 890ea71d2d0SMatthias Springer return failure(); 891ea71d2d0SMatthias Springer auto memrefType = MemRefType::get( 892ea71d2d0SMatthias Springer srcType.getShape(), srcType.getElementType(), AffineMap(), 893ea71d2d0SMatthias Springer cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace()); 8940a0c7e89SSpenser Bauman srcBuffer = rewriter 895ea71d2d0SMatthias Springer .create<bufferization::ToMemrefOp>( 896ea71d2d0SMatthias Springer op->getLoc(), memrefType, *tensorAlloc) 8970a0c7e89SSpenser Bauman .getResult(); 8980a0c7e89SSpenser Bauman } 8990a0c7e89SSpenser Bauman 900e287d647SAshay Rane replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 9019dbb8eefSIngo Müller rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer); 902e287d647SAshay Rane return success(); 903e287d647SAshay Rane } 9049dbb8eefSIngo Müller 9059dbb8eefSIngo Müller FailureOr<BaseMemRefType> 9069dbb8eefSIngo Müller getBufferType(Operation *op, Value value, const BufferizationOptions &options, 907878950b8SMatthias Springer SmallVector<Value> &invocationStack) const { 9089dbb8eefSIngo Müller auto reshapeOp = cast<tensor::ReshapeOp>(op); 9099dbb8eefSIngo Müller assert(value == reshapeOp.getResult() && "unexpected value provided"); 9109dbb8eefSIngo Müller auto maybeSourceBufferType = bufferization::getBufferType( 911878950b8SMatthias Springer reshapeOp.getSource(), options, invocationStack); 9129dbb8eefSIngo Müller if (failed(maybeSourceBufferType)) 9139dbb8eefSIngo Müller return failure(); 9149dbb8eefSIngo Müller return getMemRefTypeWithStaticIdentityLayout( 9159dbb8eefSIngo Müller reshapeOp.getResult().getType(), 9169dbb8eefSIngo Müller cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()); 9179dbb8eefSIngo Müller } 918e287d647SAshay Rane }; 919e287d647SAshay Rane 9207fbf55c9SNicolas Vasilache /// Analysis of ParallelInsertSliceOp. 9217fbf55c9SNicolas Vasilache struct ParallelInsertSliceOpInterface 9227fbf55c9SNicolas Vasilache : public BufferizableOpInterface::ExternalModel< 9237fbf55c9SNicolas Vasilache ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { 924a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 9257fbf55c9SNicolas Vasilache const AnalysisState &state) const { 9267fbf55c9SNicolas Vasilache return {}; 9277fbf55c9SNicolas Vasilache } 9287fbf55c9SNicolas Vasilache 9297fbf55c9SNicolas Vasilache bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 9307fbf55c9SNicolas Vasilache const AnalysisState &state) const { 93198e838a8SMax191 return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op), 93298e838a8SMax191 opOperand); 9337fbf55c9SNicolas Vasilache } 9347fbf55c9SNicolas Vasilache 9357fbf55c9SNicolas Vasilache bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 9367fbf55c9SNicolas Vasilache const AnalysisState &state) const { 9370f952cfeSMatthias Springer auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 93855585043SMatthias Springer return opOperand == parallelInsertSliceOp.getDestMutable(); 9397fbf55c9SNicolas Vasilache } 9407fbf55c9SNicolas Vasilache 9417fbf55c9SNicolas Vasilache LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 9427fbf55c9SNicolas Vasilache const BufferizationOptions &options) const { 9437fbf55c9SNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 9447fbf55c9SNicolas Vasilache auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 9457fbf55c9SNicolas Vasilache ParallelCombiningOpInterface parallelCombiningParent = 9467fbf55c9SNicolas Vasilache parallelInsertSliceOp.getParallelCombiningParent(); 9477fbf55c9SNicolas Vasilache 9484cd73620SMatthias Springer // Bufferize the op outside of the parallel combining terminator. 9494cd73620SMatthias Springer rewriter.setInsertionPoint(parallelCombiningParent); 9504cd73620SMatthias Springer 9514cd73620SMatthias Springer // Get source and destination buffers. 9527fbf55c9SNicolas Vasilache FailureOr<Value> destBuffer = 9537fbf55c9SNicolas Vasilache getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); 9547fbf55c9SNicolas Vasilache if (failed(destBuffer)) 9557fbf55c9SNicolas Vasilache return failure(); 9567fbf55c9SNicolas Vasilache FailureOr<Value> srcBuffer = 9577fbf55c9SNicolas Vasilache getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); 9587fbf55c9SNicolas Vasilache if (failed(srcBuffer)) 9597fbf55c9SNicolas Vasilache return failure(); 9606c3c5f80SMatthias Springer 9616c3c5f80SMatthias Springer // Take a subview of the destination buffer. 9625550c821STres Popp auto destBufferType = cast<MemRefType>(destBuffer->getType()); 9636c3c5f80SMatthias Springer auto subviewMemRefType = 9645550c821STres Popp cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 9656c3c5f80SMatthias Springer parallelInsertSliceOp.getSourceType().getShape(), destBufferType, 9666c3c5f80SMatthias Springer parallelInsertSliceOp.getMixedOffsets(), 9676c3c5f80SMatthias Springer parallelInsertSliceOp.getMixedSizes(), 9685550c821STres Popp parallelInsertSliceOp.getMixedStrides())); 9697fbf55c9SNicolas Vasilache Value subview = rewriter.create<memref::SubViewOp>( 9706c3c5f80SMatthias Springer parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, 9717fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedOffsets(), 9727fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedSizes(), 9737fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedStrides()); 9746c3c5f80SMatthias Springer 9757fbf55c9SNicolas Vasilache // This memcpy will fold away if everything bufferizes in-place. 9767fbf55c9SNicolas Vasilache if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), 9777fbf55c9SNicolas Vasilache *srcBuffer, subview))) 9787fbf55c9SNicolas Vasilache return failure(); 9797fbf55c9SNicolas Vasilache 9807c06f631SMatthias Springer // In case the source was allocated in the same block, make sure that the 9817c06f631SMatthias Springer // deallocation op (if any) appears after the memcpy. By default, deallocs 9827c06f631SMatthias Springer // are placed before the terminator, but this does not work for ForallOp 9837c06f631SMatthias Springer // because the terminator does more than just yielding a value. 9847c06f631SMatthias Springer // 9857c06f631SMatthias Springer // Note: This is not a problem for the destination buffer because these are 9867c06f631SMatthias Springer // assumed to always bufferize in-place. 9877c06f631SMatthias Springer for (Operation *user : srcBuffer->getUsers()) { 9887c06f631SMatthias Springer if (hasEffect<MemoryEffects::Free>(user)) { 9897c06f631SMatthias Springer if (user->getBlock() == parallelCombiningParent->getBlock()) 9905cc0f76dSMatthias Springer rewriter.moveOpBefore(user, user->getBlock()->getTerminator()); 9917c06f631SMatthias Springer break; 9927c06f631SMatthias Springer } 9937c06f631SMatthias Springer } 9947c06f631SMatthias Springer 9954cd73620SMatthias Springer // Delete the op. 9967fbf55c9SNicolas Vasilache rewriter.eraseOp(op); 9977fbf55c9SNicolas Vasilache return success(); 9987fbf55c9SNicolas Vasilache } 999d69e9491Sdonald chen 1000d69e9491Sdonald chen /// tensor.parallel_insert_slice op has implicit inplace behavior. We 1001d69e9491Sdonald chen /// shouldn't create copy to resolve conflict. 1002d69e9491Sdonald chen LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 1003d69e9491Sdonald chen const AnalysisState &state) const { 1004d69e9491Sdonald chen return success(); 1005d69e9491Sdonald chen } 10067fbf55c9SNicolas Vasilache }; 10077fbf55c9SNicolas Vasilache 1008481b254eSMatthias Springer /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled 1009481b254eSMatthias Springer /// with a linalg.map. Similar to tensor.generate. 1010481b254eSMatthias Springer struct SplatOpInterface 1011481b254eSMatthias Springer : public BufferizableOpInterface::ExternalModel<SplatOpInterface, 1012481b254eSMatthias Springer tensor::SplatOp> { 1013481b254eSMatthias Springer 1014a02ad6c1SMatthias Springer bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 1015481b254eSMatthias Springer 1016481b254eSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1017481b254eSMatthias Springer const BufferizationOptions &options) const { 1018481b254eSMatthias Springer OpBuilder::InsertionGuard g(rewriter); 1019481b254eSMatthias Springer auto splatOp = cast<tensor::SplatOp>(op); 1020481b254eSMatthias Springer 1021481b254eSMatthias Springer // Allocate memory. 1022481b254eSMatthias Springer Location loc = op->getLoc(); 10236bf043e7SMartin Erhart FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 10246bf043e7SMartin Erhart rewriter, loc, splatOp.getResult(), options, 1025481b254eSMatthias Springer /*copy=*/false); 1026481b254eSMatthias Springer if (failed(tensorAlloc)) 1027481b254eSMatthias Springer return failure(); 1028481b254eSMatthias Springer 1029481b254eSMatthias Springer // Create linalg::MapOp. 1030481b254eSMatthias Springer auto tensorType = cast<RankedTensorType>(tensorAlloc->getType()); 1031067d2779Sian Bearman 1032067d2779Sian Bearman // TODO: Implement memory space for this op. 1033067d2779Sian Bearman if (options.defaultMemorySpaceFn(tensorType) != Attribute()) 1034067d2779Sian Bearman return op->emitError("memory space not implemented yet"); 1035067d2779Sian Bearman 1036481b254eSMatthias Springer auto linalgOp = 1037481b254eSMatthias Springer rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), 1038481b254eSMatthias Springer /*init=*/*tensorAlloc); 1039481b254eSMatthias Springer Block &linalgBody = linalgOp.getMapper().emplaceBlock(); 1040481b254eSMatthias Springer 1041481b254eSMatthias Springer // Create linalg::IndexOps. 1042481b254eSMatthias Springer rewriter.setInsertionPointToStart(&linalgBody); 1043481b254eSMatthias Springer rewriter.create<linalg::YieldOp>(loc, splatOp.getInput()); 1044481b254eSMatthias Springer rewriter.replaceOp(splatOp, linalgOp.getResult()[0]); 1045481b254eSMatthias Springer 1046481b254eSMatthias Springer return success(); 1047481b254eSMatthias Springer } 1048481b254eSMatthias Springer }; 1049481b254eSMatthias Springer 105049e37000SMatthias Springer } // namespace 105149e37000SMatthias Springer } // namespace tensor 105249e37000SMatthias Springer } // namespace mlir 105349e37000SMatthias Springer 105449e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 105549e37000SMatthias Springer DialectRegistry ®istry) { 105677eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 105777eee579SRiver Riddle CastOp::attachInterface<CastOpInterface>(*ctx); 105877eee579SRiver Riddle CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 105977eee579SRiver Riddle DimOp::attachInterface<DimOpInterface>(*ctx); 1060be630f07SMatthias Springer EmptyOp::attachInterface<EmptyOpInterface>(*ctx); 106177eee579SRiver Riddle ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 106277eee579SRiver Riddle ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 106377eee579SRiver Riddle ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 106477eee579SRiver Riddle FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 106577eee579SRiver Riddle GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 106677eee579SRiver Riddle InsertOp::attachInterface<InsertOpInterface>(*ctx); 106777eee579SRiver Riddle InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 10689ee12f47SMatthias Springer PadOp::attachInterface<PadOpInterface>(*ctx); 10697fbf55c9SNicolas Vasilache ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 10707fbf55c9SNicolas Vasilache *ctx); 107177eee579SRiver Riddle RankOp::attachInterface<RankOpInterface>(*ctx); 1072e287d647SAshay Rane ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 1073481b254eSMatthias Springer SplatOp::attachInterface<SplatOpInterface>(*ctx); 10745f5f71e7SMatthias Springer 10755f5f71e7SMatthias Springer // Load additional dialects of which ops may get created. 1076c1f0a15cSMatthias Springer ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>(); 107777eee579SRiver Riddle }); 10788143307bSMatthias Springer 10798143307bSMatthias Springer // Bufferization requires SubsetInsertionOpInterface models. Make sure that 10808143307bSMatthias Springer // they are registered. 10811abd8d1aSMatthias Springer tensor::registerSubsetOpInterfaceExternalModels(registry); 108249e37000SMatthias Springer } 1083