145ccff17SMatthias Springer //===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===// 245ccff17SMatthias Springer // 345ccff17SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 445ccff17SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 545ccff17SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 645ccff17SMatthias Springer // 745ccff17SMatthias Springer //===----------------------------------------------------------------------===// 845ccff17SMatthias Springer // 945ccff17SMatthias Springer // This file contains patterns to convert non-DPS ops to DPS ops. New 1045ccff17SMatthias Springer // tensor.empty ops are inserted as a destination. Such tensor.empty can be 1145ccff17SMatthias Springer // eliminated with "empty tensor elimination", allowing them to bufferize 1245ccff17SMatthias Springer // without an allocation (assuming there are no further conflicts). 1345ccff17SMatthias Springer // 1445ccff17SMatthias Springer //===----------------------------------------------------------------------===// 1545ccff17SMatthias Springer // 1645ccff17SMatthias Springer #include "mlir/Dialect/Arith/IR/Arith.h" 172a5b13e7SMatthias Springer #include "mlir/Dialect/Arith/Utils/Utils.h" 1801581e28SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1901581e28SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 2045ccff17SMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h" 2145ccff17SMatthias Springer #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 2245ccff17SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 2396179dffSNicolas Vasilache #include "mlir/Dialect/Utils/StaticValueUtils.h" 247b3c662dSMatthias Springer #include "mlir/IR/Matchers.h" 2545ccff17SMatthias Springer #include "mlir/IR/PatternMatch.h" 2696179dffSNicolas Vasilache #include "llvm/ADT/STLExtras.h" 2745ccff17SMatthias Springer #include "llvm/Support/Debug.h" 2845ccff17SMatthias Springer 2945ccff17SMatthias Springer using namespace mlir; 3045ccff17SMatthias Springer using namespace mlir::tensor; 3145ccff17SMatthias Springer 32124fce09SMatthias Springer // Implements backtracking to traverse indices of the output buffer while 33124fce09SMatthias Springer // iterating over op.elements(). 34124fce09SMatthias Springer static Value createInserts(RewriterBase &rewriter, Location loc, int dim, 35124fce09SMatthias Springer Value destination, ArrayRef<int64_t> shape, 36124fce09SMatthias Springer ArrayRef<Value> constants, 37124fce09SMatthias Springer OperandRange::iterator &elementIt, 38124fce09SMatthias Springer SmallVectorImpl<Value> &indices) { 39124fce09SMatthias Springer if (dim == static_cast<int>(shape.size()) - 1) { 40124fce09SMatthias Springer for (int i = 0; i < shape.back(); ++i) { 41124fce09SMatthias Springer indices.back() = constants[i]; 42124fce09SMatthias Springer destination = rewriter.create<tensor::InsertOp>(loc, *elementIt, 43124fce09SMatthias Springer destination, indices); 44124fce09SMatthias Springer ++elementIt; 45124fce09SMatthias Springer } 46124fce09SMatthias Springer return destination; 47124fce09SMatthias Springer } 48124fce09SMatthias Springer for (int i = 0; i < shape[dim]; ++i) { 49124fce09SMatthias Springer indices[dim] = constants[i]; 50124fce09SMatthias Springer destination = createInserts(rewriter, loc, dim + 1, destination, shape, 51124fce09SMatthias Springer constants, elementIt, indices); 52124fce09SMatthias Springer } 53124fce09SMatthias Springer return destination; 54124fce09SMatthias Springer } 55124fce09SMatthias Springer 56579bca12SMatthias Springer /// Create a memcpy from the given source tensor to the given destination 57579bca12SMatthias Springer /// memref. The copy op type can be specified in the `options`. 58579bca12SMatthias Springer static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, 59579bca12SMatthias Springer Value memrefDest, 60579bca12SMatthias Springer const linalg::BufferizeToAllocationOptions &options) { 61579bca12SMatthias Springer auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType()); 62579bca12SMatthias Springer assert(tensorType && "expected ranked tensor"); 63a5757c5bSChristian Sigg assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref"); 64579bca12SMatthias Springer 65579bca12SMatthias Springer switch (options.memcpyOp) { 66437c6217SMatthias Springer case linalg::BufferizeToAllocationOptions::MemcpyOp:: 67437c6217SMatthias Springer MaterializeInDestination: { 68579bca12SMatthias Springer // Note: This is the preferred way of memcpy'ing because no layout map 69579bca12SMatthias Springer // and/or memory space must be specified for the source. 70437c6217SMatthias Springer auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>( 71437c6217SMatthias Springer loc, tensorSource, memrefDest); 72437c6217SMatthias Springer materializeOp.setWritable(true); 73437c6217SMatthias Springer } break; 74579bca12SMatthias Springer case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: { 75579bca12SMatthias Springer // TODO: Support custom memory space on source. 76579bca12SMatthias Springer // We do not know the layout map of the source yet, so use a fully dynamic 77579bca12SMatthias Springer // layout for best compatibility. 78579bca12SMatthias Springer Value toMemref = b.create<bufferization::ToMemrefOp>( 79579bca12SMatthias Springer loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), 80579bca12SMatthias Springer tensorSource, /*readOnly=*/true); 81579bca12SMatthias Springer b.create<memref::CopyOp>(loc, toMemref, memrefDest); 82579bca12SMatthias Springer } break; 83579bca12SMatthias Springer case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: { 84579bca12SMatthias Springer // TODO: Support custom memory space on source. 85579bca12SMatthias Springer // We do not know the layout map of the source yet, so use a fully dynamic 86579bca12SMatthias Springer // layout for best compatibility. 87579bca12SMatthias Springer Value toMemref = b.create<bufferization::ToMemrefOp>( 88579bca12SMatthias Springer loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), 89579bca12SMatthias Springer tensorSource, /*readOnly=*/true); 90579bca12SMatthias Springer b.create<linalg::CopyOp>(loc, toMemref, memrefDest); 91579bca12SMatthias Springer } break; 92579bca12SMatthias Springer }; 93579bca12SMatthias Springer } 94579bca12SMatthias Springer 9501581e28SMatthias Springer static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, 9601581e28SMatthias Springer Location loc, PadOp padOp, 9701581e28SMatthias Springer Value dest) { 9801581e28SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 9901581e28SMatthias Springer RankedTensorType resultType = padOp.getResultType(); 10001581e28SMatthias Springer 10101581e28SMatthias Springer // Examine the yielded value to decide if a linalg.generic is neede or a 10201581e28SMatthias Springer // linalg.fill is sufficient. 10301581e28SMatthias Springer Value yieldedValue = 10401581e28SMatthias Springer cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue(); 10501581e28SMatthias Springer Attribute constYieldedValue; 10601581e28SMatthias Springer // Is the yielded value a bbArg defined outside of the PadOp? 10701581e28SMatthias Springer bool outsideBbArg = 1085550c821STres Popp isa<BlockArgument>(yieldedValue) && 1095550c821STres Popp cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() != 11001581e28SMatthias Springer padOp.getOperation(); 11101581e28SMatthias Springer // Is the yielded value an OpResult defined outside of the PadOp? 11201581e28SMatthias Springer bool outsideOpResult = 1135550c821STres Popp isa<OpResult>(yieldedValue) && 11401581e28SMatthias Springer yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); 11501581e28SMatthias Springer bool invariantYieldedValue = outsideBbArg || outsideOpResult; 11601581e28SMatthias Springer if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { 11701581e28SMatthias Springer // Padding with a constant: Create linalg.fill. 11801581e28SMatthias Springer Dialect *arithDialect = 11901581e28SMatthias Springer rewriter.getContext()->getLoadedDialect<arith::ArithDialect>(); 12001581e28SMatthias Springer Value fillValue = 12101581e28SMatthias Springer arithDialect 12201581e28SMatthias Springer ->materializeConstant(rewriter, constYieldedValue, 12301581e28SMatthias Springer yieldedValue.getType(), yieldedValue.getLoc()) 12401581e28SMatthias Springer ->getResult(0); 12501581e28SMatthias Springer auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue), 12601581e28SMatthias Springer ValueRange(dest)); 12701581e28SMatthias Springer return fillOp; 12801581e28SMatthias Springer } 12901581e28SMatthias Springer 13001581e28SMatthias Springer if (invariantYieldedValue) { 13101581e28SMatthias Springer // Padding with an invariant value. 13201581e28SMatthias Springer auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue), 13301581e28SMatthias Springer ValueRange(dest)); 13401581e28SMatthias Springer return fillOp; 13501581e28SMatthias Springer } 13601581e28SMatthias Springer 13701581e28SMatthias Springer // Create linalg.generic. 13801581e28SMatthias Springer SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(), 13901581e28SMatthias Springer utils::IteratorType::parallel); 14001581e28SMatthias Springer SmallVector<AffineMap> indexingMaps( 14101581e28SMatthias Springer 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); 14201581e28SMatthias Springer auto genericOp = rewriter.create<linalg::GenericOp>( 14301581e28SMatthias Springer loc, resultType, /*inputs=*/ValueRange(), 14401581e28SMatthias Springer /*outputs=*/ValueRange{dest}, /*indexingMaps=*/ 14501581e28SMatthias Springer indexingMaps, iteratorTypes); 14601581e28SMatthias Springer Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, 14701581e28SMatthias Springer resultType.getElementType(), loc); 14801581e28SMatthias Springer rewriter.setInsertionPointToStart(body); 14901581e28SMatthias Springer SmallVector<Value> bbArgReplacements; 15001581e28SMatthias Springer for (int64_t i = 0; i < resultType.getRank(); ++i) 15101581e28SMatthias Springer bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i)); 15201581e28SMatthias Springer rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); 15301581e28SMatthias Springer 15401581e28SMatthias Springer // Update terminator. 15501581e28SMatthias Springer auto yieldOp = cast<tensor::YieldOp>(body->getTerminator()); 15601581e28SMatthias Springer rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue()); 15701581e28SMatthias Springer return genericOp; 15801581e28SMatthias Springer } 15901581e28SMatthias Springer 16001581e28SMatthias Springer static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b, 16101581e28SMatthias Springer Value value) { 1625550c821STres Popp auto tensorType = cast<RankedTensorType>(value.getType()); 16301581e28SMatthias Springer if (tensorType.hasStaticShape()) 16401581e28SMatthias Springer return {}; 16501581e28SMatthias Springer 16601581e28SMatthias Springer // Try to reify dynamic sizes. 16701581e28SMatthias Springer ReifiedRankedShapedTypeDims reifiedShape; 1685550c821STres Popp if (isa<OpResult>(value) && 169758329dcSMatthias Springer succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) { 17001581e28SMatthias Springer SmallVector<Value> dynSizes; 17101581e28SMatthias Springer for (int64_t i = 0; i < tensorType.getRank(); ++i) { 17201581e28SMatthias Springer if (tensorType.isDynamicDim(i)) 1734f279a57SKazu Hirata dynSizes.push_back(cast<Value>( 1744f279a57SKazu Hirata reifiedShape[cast<OpResult>(value).getResultNumber()][i])); 17501581e28SMatthias Springer } 17601581e28SMatthias Springer return dynSizes; 17701581e28SMatthias Springer } 17801581e28SMatthias Springer 17901581e28SMatthias Springer // Create tensor.dim ops. 18001581e28SMatthias Springer SmallVector<Value> dynSizes; 18101581e28SMatthias Springer for (int64_t i = 0; i < tensorType.getRank(); ++i) { 18201581e28SMatthias Springer if (tensorType.isDynamicDim(i)) 18301581e28SMatthias Springer dynSizes.push_back( 18401581e28SMatthias Springer b.create<DimOp>(value.getLoc(), value, 18501581e28SMatthias Springer b.create<arith::ConstantIndexOp>(value.getLoc(), i))); 18601581e28SMatthias Springer } 18701581e28SMatthias Springer return dynSizes; 18801581e28SMatthias Springer } 18901581e28SMatthias Springer 1901a5aa77fSMatthias Springer static Value 1911a5aa77fSMatthias Springer createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, 1921a5aa77fSMatthias Springer const linalg::BufferizeToAllocationOptions &options, 19301581e28SMatthias Springer Attribute memorySpace = {}) { 19401581e28SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 1955550c821STres Popp auto tensorType = cast<RankedTensorType>(value.getType()); 19601581e28SMatthias Springer 19701581e28SMatthias Springer // Create buffer allocation. 1985550c821STres Popp auto memrefType = 1995550c821STres Popp cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout( 2005550c821STres Popp tensorType, memorySpace)); 20101581e28SMatthias Springer SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value); 20201581e28SMatthias Springer 2031a5aa77fSMatthias Springer Value alloc; 2041a5aa77fSMatthias Springer if (options.allocOp == 2051a5aa77fSMatthias Springer linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) { 2061a5aa77fSMatthias Springer alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes); 207412c2fd2SMartin Erhart if (options.emitDealloc) { 20801581e28SMatthias Springer // Place deallocation at the end of the block. 20901581e28SMatthias Springer rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); 21001581e28SMatthias Springer rewriter.create<memref::DeallocOp>(loc, alloc); 211412c2fd2SMartin Erhart } 2121a5aa77fSMatthias Springer } else if (options.allocOp == 2131a5aa77fSMatthias Springer linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) { 2141a5aa77fSMatthias Springer alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes); 2151a5aa77fSMatthias Springer // No dealloc is needed. 2161a5aa77fSMatthias Springer } 21701581e28SMatthias Springer 21801581e28SMatthias Springer return alloc; 21901581e28SMatthias Springer } 22001581e28SMatthias Springer 221579bca12SMatthias Springer Value linalg::bufferizeToAllocation( 222579bca12SMatthias Springer RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, 223579bca12SMatthias Springer PadOp padOp, Attribute memorySpace, Operation *insertionPoint) { 224a5bba98aSMatthias Springer // tensor.pad does not have a destination operand. 225a5bba98aSMatthias Springer assert(!options.bufferizeDestinationOnly && "invalid options"); 226a5bba98aSMatthias Springer 22701581e28SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 228eb74eff9SMatthias Springer rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp); 22901581e28SMatthias Springer Location loc = padOp.getLoc(); 23001581e28SMatthias Springer 23101581e28SMatthias Springer // Create buffer allocation. 2321a5aa77fSMatthias Springer Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(), 2331a5aa77fSMatthias Springer options, memorySpace); 234eb74eff9SMatthias Springer rewriter.setInsertionPoint(padOp); 23501581e28SMatthias Springer 236bb566b65SMatthias Springer if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) { 237bb566b65SMatthias Springer // Create linalg.fill or linalg.generic. Not needed if there is no padding. 238bb566b65SMatthias Springer Operation *fillOp = 239bb566b65SMatthias Springer movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc); 24001581e28SMatthias Springer rewriter.setInsertionPointAfter(fillOp); 241bb566b65SMatthias Springer } 24201581e28SMatthias Springer 243437c6217SMatthias Springer // Create memcpy. 24401581e28SMatthias Springer SmallVector<OpFoldResult> sizes = 24501581e28SMatthias Springer getMixedSizes(rewriter, loc, padOp.getSource()); 24601581e28SMatthias Springer SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(), 24701581e28SMatthias Springer rewriter.getIndexAttr(1)); 24801581e28SMatthias Springer Value subview = rewriter.create<memref::SubViewOp>( 24901581e28SMatthias Springer loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides); 250579bca12SMatthias Springer createMemcpy(rewriter, loc, padOp.getSource(), subview, options); 25101581e28SMatthias Springer 25201581e28SMatthias Springer // Create bufferization.to_tensor with "restrict" and "writable". The returned 25301581e28SMatthias Springer // tensor is a new buffer allocation, so it does not alias with any buffer. 25401581e28SMatthias Springer Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( 25501581e28SMatthias Springer loc, alloc, /*restrict=*/true, /*writable=*/true); 25601581e28SMatthias Springer rewriter.replaceOp(padOp, toTensorOp); 25703301be0SMatthias Springer return alloc; 25801581e28SMatthias Springer } 25901581e28SMatthias Springer 260579bca12SMatthias Springer Value linalg::bufferizeToAllocation( 261579bca12SMatthias Springer RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, 262579bca12SMatthias Springer vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) { 263eb74eff9SMatthias Springer assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 && 264eb74eff9SMatthias Springer "expected single masked op"); 265eb74eff9SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 266579bca12SMatthias Springer bufferization::BufferizationOptions bufferizationOptions; 267eb74eff9SMatthias Springer Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator(); 268eb74eff9SMatthias Springer assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator"); 269eb74eff9SMatthias Springer 270eb74eff9SMatthias Springer // Bufferize maskable op. By default, place the buffer allocation right before 271eb74eff9SMatthias Springer // the mask op. 272eb74eff9SMatthias Springer Value alloc = bufferizeToAllocation( 273579bca12SMatthias Springer rewriter, options, maskOp.getMaskableOp(), memorySpace, 274eb74eff9SMatthias Springer /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp); 275eb74eff9SMatthias Springer 276a5bba98aSMatthias Springer if (options.bufferizeDestinationOnly) 277a5bba98aSMatthias Springer return alloc; 278a5bba98aSMatthias Springer 279eb74eff9SMatthias Springer // Bufferize terminator. 280eb74eff9SMatthias Springer rewriter.setInsertionPoint(yieldOp); 281eb74eff9SMatthias Springer if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize( 282579bca12SMatthias Springer rewriter, bufferizationOptions))) 283eb74eff9SMatthias Springer return nullptr; 284eb74eff9SMatthias Springer 285eb74eff9SMatthias Springer // Erase dead to_tensor ops inside of the mask op. This is necessary because 286eb74eff9SMatthias Springer // there only be one op (apart from the terminator) inside the mask op. 287eb74eff9SMatthias Springer // TODO: Remove dead to_tensor ops more aggressively during bufferization. 288eb74eff9SMatthias Springer SmallVector<Operation *> toTensorOps; 289eb74eff9SMatthias Springer maskOp.walk([&](bufferization::ToTensorOp toTensorOp) { 290eb74eff9SMatthias Springer if (toTensorOp->getUses().empty()) 291eb74eff9SMatthias Springer toTensorOps.push_back(toTensorOp.getOperation()); 292eb74eff9SMatthias Springer }); 293eb74eff9SMatthias Springer for (Operation *op : toTensorOps) 294eb74eff9SMatthias Springer rewriter.eraseOp(op); 295eb74eff9SMatthias Springer 296eb74eff9SMatthias Springer // Bufferize mask op. 297eb74eff9SMatthias Springer SmallVector<OpOperand *> resultUses; 298eb74eff9SMatthias Springer for (Value result : maskOp.getResults()) 299eb74eff9SMatthias Springer if (isa<TensorType>(result.getType())) 300eb74eff9SMatthias Springer for (OpOperand &use : result.getUses()) 301eb74eff9SMatthias Springer resultUses.push_back(&use); 302eb74eff9SMatthias Springer rewriter.setInsertionPoint(maskOp); 303eb74eff9SMatthias Springer if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation()) 304579bca12SMatthias Springer .bufferize(rewriter, bufferizationOptions))) 305eb74eff9SMatthias Springer return nullptr; 306eb74eff9SMatthias Springer 307eb74eff9SMatthias Springer // Set "restrict" attribute, indicating that no other tensor aliases with 308eb74eff9SMatthias Springer // this tensor. That is because we just allocated a new buffer for the tensor. 309eb74eff9SMatthias Springer for (OpOperand *resultUse : resultUses) { 310eb74eff9SMatthias Springer auto toTensorOp = 311eb74eff9SMatthias Springer resultUse->get().getDefiningOp<bufferization::ToTensorOp>(); 312eb74eff9SMatthias Springer assert(toTensorOp && "expected to_tensor op"); 3135fcf907bSMatthias Springer rewriter.modifyOpInPlace(toTensorOp, [&]() { 314eb74eff9SMatthias Springer toTensorOp.setRestrict(true); 315eb74eff9SMatthias Springer toTensorOp.setWritable(true); 316eb74eff9SMatthias Springer }); 317eb74eff9SMatthias Springer } 318eb74eff9SMatthias Springer 319eb74eff9SMatthias Springer return alloc; 320eb74eff9SMatthias Springer } 321eb74eff9SMatthias Springer 3223a223f44SNicolas Vasilache Value linalg::bufferizeToAllocation( 3233a223f44SNicolas Vasilache RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, 3243a223f44SNicolas Vasilache bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace, 3253a223f44SNicolas Vasilache Operation *insertionPoint) { 3263a223f44SNicolas Vasilache Location loc = allocTensorOp.getLoc(); 3273a223f44SNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 3283a223f44SNicolas Vasilache rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp); 3293a223f44SNicolas Vasilache bufferization::BufferizationOptions bufferizationOptions; 3303a223f44SNicolas Vasilache 3313a223f44SNicolas Vasilache // Create buffer allocation. 3323a223f44SNicolas Vasilache Value alloc = createAllocationForTensor( 3333a223f44SNicolas Vasilache rewriter, loc, allocTensorOp.getResult(), options, memorySpace); 3343a223f44SNicolas Vasilache 3353a223f44SNicolas Vasilache // Create bufferization.to_tensor with "restrict" and "writable". The returned 3363a223f44SNicolas Vasilache // tensor is a new buffer allocation, so it does not alias with any buffer. 3373a223f44SNicolas Vasilache Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( 3383a223f44SNicolas Vasilache loc, alloc, /*restrict=*/true, /*writable=*/true); 3393a223f44SNicolas Vasilache rewriter.replaceOp(allocTensorOp, toTensorOp); 3403a223f44SNicolas Vasilache return alloc; 3413a223f44SNicolas Vasilache } 3423a223f44SNicolas Vasilache 34396179dffSNicolas Vasilache /// Lower tensor.from_elements to a sequence of chained tensor.insert. 34496179dffSNicolas Vasilache FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle( 34596179dffSNicolas Vasilache RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { 34696179dffSNicolas Vasilache Location loc = fromElementsOp.getLoc(); 34796179dffSNicolas Vasilache RankedTensorType tensorType = 3485550c821STres Popp cast<RankedTensorType>(fromElementsOp.getType()); 34996179dffSNicolas Vasilache auto shape = tensorType.getShape(); 3507b3c662dSMatthias Springer 35196179dffSNicolas Vasilache // Create tensor.empty. 35296179dffSNicolas Vasilache auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange()); 35396179dffSNicolas Vasilache 35496179dffSNicolas Vasilache // Case: tensor<elem_type>. 35596179dffSNicolas Vasilache if (shape.empty()) { 35696179dffSNicolas Vasilache Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>( 35796179dffSNicolas Vasilache fromElementsOp, fromElementsOp.getElements().front(), 35896179dffSNicolas Vasilache emptyOp.getResult(), ValueRange()); 35996179dffSNicolas Vasilache return res; 36096179dffSNicolas Vasilache } 36196179dffSNicolas Vasilache 36296179dffSNicolas Vasilache // Create constants for the range of possible indices [0, max{shape_i}). 363fab2bb8bSJustin Lebar auto maxDim = *llvm::max_element(shape); 36496179dffSNicolas Vasilache SmallVector<Value, 2> constants; 36596179dffSNicolas Vasilache constants.reserve(maxDim); 36696179dffSNicolas Vasilache for (int i = 0; i < maxDim; ++i) 36796179dffSNicolas Vasilache constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 36896179dffSNicolas Vasilache 36996179dffSNicolas Vasilache // Traverse all elements and create tensor.insert ops. 37096179dffSNicolas Vasilache auto elementIt = fromElementsOp.getElements().begin(); 37196179dffSNicolas Vasilache SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 37296179dffSNicolas Vasilache Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), 37396179dffSNicolas Vasilache shape, constants, elementIt, indices); 37496179dffSNicolas Vasilache 37596179dffSNicolas Vasilache // Replace tensor.from_elements. 37696179dffSNicolas Vasilache rewriter.replaceOp(fromElementsOp, result); 37796179dffSNicolas Vasilache return result.getDefiningOp(); 37896179dffSNicolas Vasilache } 37996179dffSNicolas Vasilache 38096179dffSNicolas Vasilache /// Lower tensor.generate to linalg.generic. 38196179dffSNicolas Vasilache FailureOr<Operation *> 38296179dffSNicolas Vasilache mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, 38396179dffSNicolas Vasilache tensor::GenerateOp generateOp) { 38496179dffSNicolas Vasilache // Only ops with exactly one block are supported. 38596179dffSNicolas Vasilache if (!generateOp.getBody().hasOneBlock()) 38696179dffSNicolas Vasilache return failure(); 38796179dffSNicolas Vasilache 38896179dffSNicolas Vasilache Location loc = generateOp.getLoc(); 3895550c821STres Popp RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType()); 39096179dffSNicolas Vasilache 39196179dffSNicolas Vasilache // Create tensor.empty. 39296179dffSNicolas Vasilache auto emptyOp = 39396179dffSNicolas Vasilache rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents()); 39496179dffSNicolas Vasilache 39596179dffSNicolas Vasilache // Create linalg.generic. 39696179dffSNicolas Vasilache SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(), 39796179dffSNicolas Vasilache utils::IteratorType::parallel); 39896179dffSNicolas Vasilache SmallVector<AffineMap> indexingMaps( 39996179dffSNicolas Vasilache 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); 40096179dffSNicolas Vasilache auto genericOp = rewriter.create<linalg::GenericOp>( 40196179dffSNicolas Vasilache loc, tensorType, /*inputs=*/ValueRange(), 40296179dffSNicolas Vasilache /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ 40396179dffSNicolas Vasilache indexingMaps, iteratorTypes); 40496179dffSNicolas Vasilache Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, 40596179dffSNicolas Vasilache tensorType.getElementType(), loc); 40696179dffSNicolas Vasilache rewriter.setInsertionPointToStart(body); 40796179dffSNicolas Vasilache SmallVector<Value> bbArgReplacements; 40896179dffSNicolas Vasilache for (int64_t i = 0; i < tensorType.getRank(); ++i) 40996179dffSNicolas Vasilache bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i)); 41096179dffSNicolas Vasilache rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements); 41196179dffSNicolas Vasilache 41296179dffSNicolas Vasilache // Update terminator. 41396179dffSNicolas Vasilache auto yieldOp = cast<tensor::YieldOp>(body->getTerminator()); 41496179dffSNicolas Vasilache rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue()); 41596179dffSNicolas Vasilache 41696179dffSNicolas Vasilache // Replace tensor.generate. 41796179dffSNicolas Vasilache rewriter.replaceOp(generateOp, genericOp->getResult(0)); 41896179dffSNicolas Vasilache return genericOp.getOperation(); 41996179dffSNicolas Vasilache } 42096179dffSNicolas Vasilache 42196179dffSNicolas Vasilache /// Lower tensor.pad to linalg.generic + tensor.insert_slice. 42296179dffSNicolas Vasilache FailureOr<Operation *> 42396179dffSNicolas Vasilache mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, 42496179dffSNicolas Vasilache tensor::PadOp padOp) { 4257b3c662dSMatthias Springer // Only ops with exactly one block are supported. 4267b3c662dSMatthias Springer if (!padOp.getBodyRegion().hasOneBlock()) 4277b3c662dSMatthias Springer return failure(); 4287b3c662dSMatthias Springer 4297b3c662dSMatthias Springer // Create tensor.empty. 4307b3c662dSMatthias Springer Location loc = padOp.getLoc(); 4317b3c662dSMatthias Springer RankedTensorType resultType = padOp.getResultType(); 4327b3c662dSMatthias Springer ReifiedRankedShapedTypeDims reifiedShape; 433758329dcSMatthias Springer if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) 4347b3c662dSMatthias Springer return rewriter.notifyMatchFailure( 4357b3c662dSMatthias Springer padOp, "failed to reify tensor.pad op result shape"); 4367b3c662dSMatthias Springer SmallVector<Value> dynamicSizes; 4377b3c662dSMatthias Springer for (int64_t i = 0; i < resultType.getRank(); ++i) 4387b3c662dSMatthias Springer if (resultType.isDynamicDim(i)) 4394f279a57SKazu Hirata dynamicSizes.push_back(cast<Value>(reifiedShape[0][i])); 4407b3c662dSMatthias Springer 44196179dffSNicolas Vasilache // If the `padOp` has a nofold attribute and all paddings are known to be 0, 44296179dffSNicolas Vasilache // explicitly insert a `linalg.copy`. 44396179dffSNicolas Vasilache if (padOp.getNofoldAttr() && 44496179dffSNicolas Vasilache llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) && 44596179dffSNicolas Vasilache llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) { 44696179dffSNicolas Vasilache using bufferization::AllocTensorOp; 44796179dffSNicolas Vasilache Value allocated = 44896179dffSNicolas Vasilache rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes); 44996179dffSNicolas Vasilache auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>( 45096179dffSNicolas Vasilache padOp, padOp.getSource(), allocated); 45196179dffSNicolas Vasilache return copyOp.getOperation(); 45296179dffSNicolas Vasilache } 45396179dffSNicolas Vasilache 45496179dffSNicolas Vasilache Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes); 45501581e28SMatthias Springer // Create linalg.fill or linalg.generic. 45696179dffSNicolas Vasilache Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty); 4577b3c662dSMatthias Springer rewriter.setInsertionPointAfter(fillOp); 4587b3c662dSMatthias Springer 4597b3c662dSMatthias Springer // Create tensor::InsertSliceOp. 4607b3c662dSMatthias Springer SmallVector<OpFoldResult> sliceSizes = 4617b3c662dSMatthias Springer getMixedSizes(rewriter, loc, padOp.getSource()); 4627b3c662dSMatthias Springer SmallVector<OpFoldResult> sliceStrides(resultType.getRank(), 4637b3c662dSMatthias Springer rewriter.getIndexAttr(1)); 46496179dffSNicolas Vasilache auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 46501581e28SMatthias Springer padOp, padOp.getSource(), fillOp->getResult(0), 4667b3c662dSMatthias Springer /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); 46796179dffSNicolas Vasilache return insertSliceOp.getOperation(); 4687b3c662dSMatthias Springer } 46945ccff17SMatthias Springer 470579bca12SMatthias Springer Value linalg::bufferizeToAllocation( 471579bca12SMatthias Springer RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, 472579bca12SMatthias Springer Operation *op, Attribute memorySpace, Operation *insertionPoint) { 4736badbd6fSMatthias Springer using namespace bufferization; 4746badbd6fSMatthias Springer 47501581e28SMatthias Springer // Call specialized overload for certain ops. 476335ada60SMatthias Springer if (auto padOp = dyn_cast<tensor::PadOp>(op)) 477579bca12SMatthias Springer return bufferizeToAllocation(rewriter, options, padOp, memorySpace); 478eb74eff9SMatthias Springer if (auto maskOp = dyn_cast<vector::MaskOp>(op)) 479579bca12SMatthias Springer return bufferizeToAllocation(rewriter, options, maskOp, memorySpace); 4803a223f44SNicolas Vasilache if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op)) 4813a223f44SNicolas Vasilache return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace); 48201581e28SMatthias Springer 4836badbd6fSMatthias Springer // Only bufferizable ops are supported. 4846badbd6fSMatthias Springer auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 4856badbd6fSMatthias Springer if (!bufferizableOp) 486335ada60SMatthias Springer return nullptr; 487579bca12SMatthias Springer BufferizationOptions bufferizationOptions; 488579bca12SMatthias Springer AnalysisState state(bufferizationOptions); 4896badbd6fSMatthias Springer 490a5bba98aSMatthias Springer #ifndef NDEBUG 491b76a180dSMatthias Springer if (!options.bufferizeDestinationOnly) { 492a5bba98aSMatthias Springer // Ops with nested tensor ops are not supported yet. At the moment, this 493a5bba98aSMatthias Springer // function just bufferizes the given op itself, but not its body. 494a5bba98aSMatthias Springer op->walk([&](Operation *nestedOp) { 495a5bba98aSMatthias Springer if (op == nestedOp) 496a5bba98aSMatthias Springer return; 497a5bba98aSMatthias Springer if (llvm::any_of(nestedOp->getOperands(), 498a5757c5bSChristian Sigg [](Value v) { return isa<TensorType>(v.getType()); })) 499a5bba98aSMatthias Springer llvm_unreachable("ops with nested tensor ops are not supported yet"); 500a5bba98aSMatthias Springer if (llvm::any_of(nestedOp->getResults(), 501a5757c5bSChristian Sigg [](Value v) { return isa<TensorType>(v.getType()); })) 502a5bba98aSMatthias Springer llvm_unreachable("ops with nested tensor ops are not supported yet"); 503a5bba98aSMatthias Springer }); 504b76a180dSMatthias Springer } 505a5bba98aSMatthias Springer #endif // NDEBUG 506a5bba98aSMatthias Springer 5076badbd6fSMatthias Springer // Gather tensor results. 5086badbd6fSMatthias Springer SmallVector<OpResult> tensorResults; 5096badbd6fSMatthias Springer for (OpResult result : op->getResults()) { 510a5757c5bSChristian Sigg if (!isa<TensorType>(result.getType())) 5116badbd6fSMatthias Springer continue; 5126badbd6fSMatthias Springer // Unranked tensors are not supported 5136badbd6fSMatthias Springer if (!isa<RankedTensorType>(result.getType())) 5146badbd6fSMatthias Springer return nullptr; 5156badbd6fSMatthias Springer // Ops that bufferize to an allocation are not supported. 5166badbd6fSMatthias Springer if (bufferizableOp.bufferizesToAllocation(result)) 5176badbd6fSMatthias Springer return nullptr; 5186badbd6fSMatthias Springer tensorResults.push_back(result); 5196badbd6fSMatthias Springer } 5206badbd6fSMatthias Springer 5216badbd6fSMatthias Springer // Gather all operands that should bufferize to a new allocation. I.e., 5226badbd6fSMatthias Springer // bufferize out-of-place. 5236badbd6fSMatthias Springer SmallVector<OpOperand *> outOfPlaceOperands, resultUses; 5246badbd6fSMatthias Springer auto addOutOfPlaceOperand = [&](OpOperand *operand) { 5258e8bbbd4SKazu Hirata if (!llvm::is_contained(outOfPlaceOperands, operand)) 5266badbd6fSMatthias Springer outOfPlaceOperands.push_back(operand); 5276badbd6fSMatthias Springer }; 5286badbd6fSMatthias Springer for (OpResult result : tensorResults) { 5296badbd6fSMatthias Springer AliasingOpOperandList aliasingOperands = 5306badbd6fSMatthias Springer state.getAliasingOpOperands(result); 5316badbd6fSMatthias Springer for (const AliasingOpOperand &operand : aliasingOperands) { 5326badbd6fSMatthias Springer addOutOfPlaceOperand(operand.opOperand); 5336badbd6fSMatthias Springer for (OpOperand &resultUse : result.getUses()) 5346badbd6fSMatthias Springer resultUses.push_back(&resultUse); 5356badbd6fSMatthias Springer } 5366badbd6fSMatthias Springer } 5376badbd6fSMatthias Springer for (OpOperand &operand : op->getOpOperands()) { 5386badbd6fSMatthias Springer if (!state.bufferizesToMemoryWrite(operand)) 5396badbd6fSMatthias Springer continue; 5406badbd6fSMatthias Springer if (!isa<RankedTensorType>(operand.get().getType())) 541b76a180dSMatthias Springer continue; 5426badbd6fSMatthias Springer addOutOfPlaceOperand(&operand); 5436badbd6fSMatthias Springer } 5446badbd6fSMatthias Springer // TODO: Support multiple buffers. 5456badbd6fSMatthias Springer if (outOfPlaceOperands.size() != 1) 5466badbd6fSMatthias Springer return nullptr; 5476badbd6fSMatthias Springer 5486badbd6fSMatthias Springer // Allocate buffers. 5496badbd6fSMatthias Springer OpBuilder::InsertionGuard g(rewriter); 550eb74eff9SMatthias Springer rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op); 5516badbd6fSMatthias Springer SmallVector<Value> allocs; 5526badbd6fSMatthias Springer for (OpOperand *operand : outOfPlaceOperands) { 5531a5aa77fSMatthias Springer Value alloc = createAllocationForTensor( 5541a5aa77fSMatthias Springer rewriter, op->getLoc(), operand->get(), options, memorySpace); 5556badbd6fSMatthias Springer allocs.push_back(alloc); 556*d9111f19SAmir Bishara if (!state.findDefinitions(operand).empty()) { 557ef4f5357SMatthias Springer // Initialize buffer with a copy of the operand data. Not needed if the 558ef4f5357SMatthias Springer // tensor is uninitialized. 559579bca12SMatthias Springer createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); 560ef4f5357SMatthias Springer } 5615fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [&]() { 562a5bba98aSMatthias Springer auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc); 563a5bba98aSMatthias Springer operand->set(toTensorOp); 564a5bba98aSMatthias Springer if (options.bufferizeDestinationOnly) { 5655fcf907bSMatthias Springer rewriter.modifyOpInPlace(toTensorOp, [&]() { 566a5bba98aSMatthias Springer toTensorOp.setRestrict(true); 567a5bba98aSMatthias Springer toTensorOp.setWritable(true); 5686badbd6fSMatthias Springer }); 5696badbd6fSMatthias Springer } 570a5bba98aSMatthias Springer }); 571a5bba98aSMatthias Springer } 572a5bba98aSMatthias Springer 573a5bba98aSMatthias Springer if (options.bufferizeDestinationOnly) 574a5bba98aSMatthias Springer return allocs.front(); 5756badbd6fSMatthias Springer 5766badbd6fSMatthias Springer // Bufferize the op. 577eb74eff9SMatthias Springer rewriter.setInsertionPoint(op); 578579bca12SMatthias Springer if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions))) 5796badbd6fSMatthias Springer return nullptr; 5806badbd6fSMatthias Springer 5816badbd6fSMatthias Springer // Set "restrict" attribute, indicating that no other tensor aliases with 5826badbd6fSMatthias Springer // this tensor. That is because we just allocated a new buffer for the tensor. 5836badbd6fSMatthias Springer for (OpOperand *resultUse : resultUses) { 5846badbd6fSMatthias Springer auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>(); 5856badbd6fSMatthias Springer assert(toTensorOp && "expected to_tensor op"); 5865fcf907bSMatthias Springer rewriter.modifyOpInPlace(toTensorOp, [&]() { 5876badbd6fSMatthias Springer toTensorOp.setRestrict(true); 5886badbd6fSMatthias Springer toTensorOp.setWritable(true); 5896badbd6fSMatthias Springer }); 5906badbd6fSMatthias Springer } 5916badbd6fSMatthias Springer return allocs.front(); 59201581e28SMatthias Springer } 59301581e28SMatthias Springer 59496179dffSNicolas Vasilache namespace { 59596179dffSNicolas Vasilache 5966f3baf43SAlexander Belyaev template <typename OpTy> 5976f3baf43SAlexander Belyaev LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, 5986f3baf43SAlexander Belyaev PatternRewriter &rewriter) { 5996f3baf43SAlexander Belyaev return linalg::rewriteInDestinationPassingStyle(rewriter, op); 60096179dffSNicolas Vasilache } 60196179dffSNicolas Vasilache 60296179dffSNicolas Vasilache } // namespace 60396179dffSNicolas Vasilache 60445ccff17SMatthias Springer void linalg::populateConvertToDestinationStylePatterns( 60545ccff17SMatthias Springer RewritePatternSet &patterns) { 6066f3baf43SAlexander Belyaev patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>); 6076f3baf43SAlexander Belyaev patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>); 6086f3baf43SAlexander Belyaev patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>); 60945ccff17SMatthias Springer } 610