xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
145ccff17SMatthias Springer //===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===//
245ccff17SMatthias Springer //
345ccff17SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
445ccff17SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
545ccff17SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
645ccff17SMatthias Springer //
745ccff17SMatthias Springer //===----------------------------------------------------------------------===//
845ccff17SMatthias Springer //
945ccff17SMatthias Springer // This file contains patterns to convert non-DPS ops to DPS ops. New
1045ccff17SMatthias Springer // tensor.empty ops are inserted as a destination. Such tensor.empty can be
1145ccff17SMatthias Springer // eliminated with "empty tensor elimination", allowing them to bufferize
1245ccff17SMatthias Springer // without an allocation (assuming there are no further conflicts).
1345ccff17SMatthias Springer //
1445ccff17SMatthias Springer //===----------------------------------------------------------------------===//
1545ccff17SMatthias Springer //
1645ccff17SMatthias Springer #include "mlir/Dialect/Arith/IR/Arith.h"
172a5b13e7SMatthias Springer #include "mlir/Dialect/Arith/Utils/Utils.h"
1801581e28SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1901581e28SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2045ccff17SMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
2145ccff17SMatthias Springer #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2245ccff17SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
2396179dffSNicolas Vasilache #include "mlir/Dialect/Utils/StaticValueUtils.h"
247b3c662dSMatthias Springer #include "mlir/IR/Matchers.h"
2545ccff17SMatthias Springer #include "mlir/IR/PatternMatch.h"
2696179dffSNicolas Vasilache #include "llvm/ADT/STLExtras.h"
2745ccff17SMatthias Springer #include "llvm/Support/Debug.h"
2845ccff17SMatthias Springer 
2945ccff17SMatthias Springer using namespace mlir;
3045ccff17SMatthias Springer using namespace mlir::tensor;
3145ccff17SMatthias Springer 
32124fce09SMatthias Springer // Implements backtracking to traverse indices of the output buffer while
33124fce09SMatthias Springer // iterating over op.elements().
34124fce09SMatthias Springer static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
35124fce09SMatthias Springer                            Value destination, ArrayRef<int64_t> shape,
36124fce09SMatthias Springer                            ArrayRef<Value> constants,
37124fce09SMatthias Springer                            OperandRange::iterator &elementIt,
38124fce09SMatthias Springer                            SmallVectorImpl<Value> &indices) {
39124fce09SMatthias Springer   if (dim == static_cast<int>(shape.size()) - 1) {
40124fce09SMatthias Springer     for (int i = 0; i < shape.back(); ++i) {
41124fce09SMatthias Springer       indices.back() = constants[i];
42124fce09SMatthias Springer       destination = rewriter.create<tensor::InsertOp>(loc, *elementIt,
43124fce09SMatthias Springer                                                       destination, indices);
44124fce09SMatthias Springer       ++elementIt;
45124fce09SMatthias Springer     }
46124fce09SMatthias Springer     return destination;
47124fce09SMatthias Springer   }
48124fce09SMatthias Springer   for (int i = 0; i < shape[dim]; ++i) {
49124fce09SMatthias Springer     indices[dim] = constants[i];
50124fce09SMatthias Springer     destination = createInserts(rewriter, loc, dim + 1, destination, shape,
51124fce09SMatthias Springer                                 constants, elementIt, indices);
52124fce09SMatthias Springer   }
53124fce09SMatthias Springer   return destination;
54124fce09SMatthias Springer }
55124fce09SMatthias Springer 
56579bca12SMatthias Springer /// Create a memcpy from the given source tensor to the given destination
57579bca12SMatthias Springer /// memref. The copy op type can be specified in the `options`.
58579bca12SMatthias Springer static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
59579bca12SMatthias Springer                          Value memrefDest,
60579bca12SMatthias Springer                          const linalg::BufferizeToAllocationOptions &options) {
61579bca12SMatthias Springer   auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType());
62579bca12SMatthias Springer   assert(tensorType && "expected ranked tensor");
63a5757c5bSChristian Sigg   assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
64579bca12SMatthias Springer 
65579bca12SMatthias Springer   switch (options.memcpyOp) {
66437c6217SMatthias Springer   case linalg::BufferizeToAllocationOptions::MemcpyOp::
67437c6217SMatthias Springer       MaterializeInDestination: {
68579bca12SMatthias Springer     // Note: This is the preferred way of memcpy'ing because no layout map
69579bca12SMatthias Springer     // and/or memory space must be specified for the source.
70437c6217SMatthias Springer     auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>(
71437c6217SMatthias Springer         loc, tensorSource, memrefDest);
72437c6217SMatthias Springer     materializeOp.setWritable(true);
73437c6217SMatthias Springer   } break;
74579bca12SMatthias Springer   case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: {
75579bca12SMatthias Springer     // TODO: Support custom memory space on source.
76579bca12SMatthias Springer     // We do not know the layout map of the source yet, so use a fully dynamic
77579bca12SMatthias Springer     // layout for best compatibility.
78579bca12SMatthias Springer     Value toMemref = b.create<bufferization::ToMemrefOp>(
79579bca12SMatthias Springer         loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
80579bca12SMatthias Springer         tensorSource, /*readOnly=*/true);
81579bca12SMatthias Springer     b.create<memref::CopyOp>(loc, toMemref, memrefDest);
82579bca12SMatthias Springer   } break;
83579bca12SMatthias Springer   case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: {
84579bca12SMatthias Springer     // TODO: Support custom memory space on source.
85579bca12SMatthias Springer     // We do not know the layout map of the source yet, so use a fully dynamic
86579bca12SMatthias Springer     // layout for best compatibility.
87579bca12SMatthias Springer     Value toMemref = b.create<bufferization::ToMemrefOp>(
88579bca12SMatthias Springer         loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
89579bca12SMatthias Springer         tensorSource, /*readOnly=*/true);
90579bca12SMatthias Springer     b.create<linalg::CopyOp>(loc, toMemref, memrefDest);
91579bca12SMatthias Springer   } break;
92579bca12SMatthias Springer   };
93579bca12SMatthias Springer }
94579bca12SMatthias Springer 
9501581e28SMatthias Springer static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
9601581e28SMatthias Springer                                                Location loc, PadOp padOp,
9701581e28SMatthias Springer                                                Value dest) {
9801581e28SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
9901581e28SMatthias Springer   RankedTensorType resultType = padOp.getResultType();
10001581e28SMatthias Springer 
10101581e28SMatthias Springer   // Examine the yielded value to decide if a linalg.generic is neede or a
10201581e28SMatthias Springer   // linalg.fill is sufficient.
10301581e28SMatthias Springer   Value yieldedValue =
10401581e28SMatthias Springer       cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
10501581e28SMatthias Springer   Attribute constYieldedValue;
10601581e28SMatthias Springer   // Is the yielded value a bbArg defined outside of the PadOp?
10701581e28SMatthias Springer   bool outsideBbArg =
1085550c821STres Popp       isa<BlockArgument>(yieldedValue) &&
1095550c821STres Popp       cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
11001581e28SMatthias Springer           padOp.getOperation();
11101581e28SMatthias Springer   // Is the yielded value an OpResult defined outside of the PadOp?
11201581e28SMatthias Springer   bool outsideOpResult =
1135550c821STres Popp       isa<OpResult>(yieldedValue) &&
11401581e28SMatthias Springer       yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
11501581e28SMatthias Springer   bool invariantYieldedValue = outsideBbArg || outsideOpResult;
11601581e28SMatthias Springer   if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
11701581e28SMatthias Springer     // Padding with a constant: Create linalg.fill.
11801581e28SMatthias Springer     Dialect *arithDialect =
11901581e28SMatthias Springer         rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
12001581e28SMatthias Springer     Value fillValue =
12101581e28SMatthias Springer         arithDialect
12201581e28SMatthias Springer             ->materializeConstant(rewriter, constYieldedValue,
12301581e28SMatthias Springer                                   yieldedValue.getType(), yieldedValue.getLoc())
12401581e28SMatthias Springer             ->getResult(0);
12501581e28SMatthias Springer     auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue),
12601581e28SMatthias Springer                                                   ValueRange(dest));
12701581e28SMatthias Springer     return fillOp;
12801581e28SMatthias Springer   }
12901581e28SMatthias Springer 
13001581e28SMatthias Springer   if (invariantYieldedValue) {
13101581e28SMatthias Springer     // Padding with an invariant value.
13201581e28SMatthias Springer     auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue),
13301581e28SMatthias Springer                                                   ValueRange(dest));
13401581e28SMatthias Springer     return fillOp;
13501581e28SMatthias Springer   }
13601581e28SMatthias Springer 
13701581e28SMatthias Springer   // Create linalg.generic.
13801581e28SMatthias Springer   SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(),
13901581e28SMatthias Springer                                                  utils::IteratorType::parallel);
14001581e28SMatthias Springer   SmallVector<AffineMap> indexingMaps(
14101581e28SMatthias Springer       1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
14201581e28SMatthias Springer   auto genericOp = rewriter.create<linalg::GenericOp>(
14301581e28SMatthias Springer       loc, resultType, /*inputs=*/ValueRange(),
14401581e28SMatthias Springer       /*outputs=*/ValueRange{dest}, /*indexingMaps=*/
14501581e28SMatthias Springer       indexingMaps, iteratorTypes);
14601581e28SMatthias Springer   Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
14701581e28SMatthias Springer                                      resultType.getElementType(), loc);
14801581e28SMatthias Springer   rewriter.setInsertionPointToStart(body);
14901581e28SMatthias Springer   SmallVector<Value> bbArgReplacements;
15001581e28SMatthias Springer   for (int64_t i = 0; i < resultType.getRank(); ++i)
15101581e28SMatthias Springer     bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
15201581e28SMatthias Springer   rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
15301581e28SMatthias Springer 
15401581e28SMatthias Springer   // Update terminator.
15501581e28SMatthias Springer   auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
15601581e28SMatthias Springer   rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
15701581e28SMatthias Springer   return genericOp;
15801581e28SMatthias Springer }
15901581e28SMatthias Springer 
16001581e28SMatthias Springer static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
16101581e28SMatthias Springer                                                      Value value) {
1625550c821STres Popp   auto tensorType = cast<RankedTensorType>(value.getType());
16301581e28SMatthias Springer   if (tensorType.hasStaticShape())
16401581e28SMatthias Springer     return {};
16501581e28SMatthias Springer 
16601581e28SMatthias Springer   // Try to reify dynamic sizes.
16701581e28SMatthias Springer   ReifiedRankedShapedTypeDims reifiedShape;
1685550c821STres Popp   if (isa<OpResult>(value) &&
169758329dcSMatthias Springer       succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) {
17001581e28SMatthias Springer     SmallVector<Value> dynSizes;
17101581e28SMatthias Springer     for (int64_t i = 0; i < tensorType.getRank(); ++i) {
17201581e28SMatthias Springer       if (tensorType.isDynamicDim(i))
1734f279a57SKazu Hirata         dynSizes.push_back(cast<Value>(
1744f279a57SKazu Hirata             reifiedShape[cast<OpResult>(value).getResultNumber()][i]));
17501581e28SMatthias Springer     }
17601581e28SMatthias Springer     return dynSizes;
17701581e28SMatthias Springer   }
17801581e28SMatthias Springer 
17901581e28SMatthias Springer   // Create tensor.dim ops.
18001581e28SMatthias Springer   SmallVector<Value> dynSizes;
18101581e28SMatthias Springer   for (int64_t i = 0; i < tensorType.getRank(); ++i) {
18201581e28SMatthias Springer     if (tensorType.isDynamicDim(i))
18301581e28SMatthias Springer       dynSizes.push_back(
18401581e28SMatthias Springer           b.create<DimOp>(value.getLoc(), value,
18501581e28SMatthias Springer                           b.create<arith::ConstantIndexOp>(value.getLoc(), i)));
18601581e28SMatthias Springer   }
18701581e28SMatthias Springer   return dynSizes;
18801581e28SMatthias Springer }
18901581e28SMatthias Springer 
1901a5aa77fSMatthias Springer static Value
1911a5aa77fSMatthias Springer createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value,
1921a5aa77fSMatthias Springer                           const linalg::BufferizeToAllocationOptions &options,
19301581e28SMatthias Springer                           Attribute memorySpace = {}) {
19401581e28SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
1955550c821STres Popp   auto tensorType = cast<RankedTensorType>(value.getType());
19601581e28SMatthias Springer 
19701581e28SMatthias Springer   // Create buffer allocation.
1985550c821STres Popp   auto memrefType =
1995550c821STres Popp       cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout(
2005550c821STres Popp           tensorType, memorySpace));
20101581e28SMatthias Springer   SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
20201581e28SMatthias Springer 
2031a5aa77fSMatthias Springer   Value alloc;
2041a5aa77fSMatthias Springer   if (options.allocOp ==
2051a5aa77fSMatthias Springer       linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) {
2061a5aa77fSMatthias Springer     alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
207412c2fd2SMartin Erhart     if (options.emitDealloc) {
20801581e28SMatthias Springer       // Place deallocation at the end of the block.
20901581e28SMatthias Springer       rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
21001581e28SMatthias Springer       rewriter.create<memref::DeallocOp>(loc, alloc);
211412c2fd2SMartin Erhart     }
2121a5aa77fSMatthias Springer   } else if (options.allocOp ==
2131a5aa77fSMatthias Springer              linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) {
2141a5aa77fSMatthias Springer     alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes);
2151a5aa77fSMatthias Springer     // No dealloc is needed.
2161a5aa77fSMatthias Springer   }
21701581e28SMatthias Springer 
21801581e28SMatthias Springer   return alloc;
21901581e28SMatthias Springer }
22001581e28SMatthias Springer 
221579bca12SMatthias Springer Value linalg::bufferizeToAllocation(
222579bca12SMatthias Springer     RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
223579bca12SMatthias Springer     PadOp padOp, Attribute memorySpace, Operation *insertionPoint) {
224a5bba98aSMatthias Springer   // tensor.pad does not have a destination operand.
225a5bba98aSMatthias Springer   assert(!options.bufferizeDestinationOnly && "invalid options");
226a5bba98aSMatthias Springer 
22701581e28SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
228eb74eff9SMatthias Springer   rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);
22901581e28SMatthias Springer   Location loc = padOp.getLoc();
23001581e28SMatthias Springer 
23101581e28SMatthias Springer   // Create buffer allocation.
2321a5aa77fSMatthias Springer   Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(),
2331a5aa77fSMatthias Springer                                           options, memorySpace);
234eb74eff9SMatthias Springer   rewriter.setInsertionPoint(padOp);
23501581e28SMatthias Springer 
236bb566b65SMatthias Springer   if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
237bb566b65SMatthias Springer     // Create linalg.fill or linalg.generic. Not needed if there is no padding.
238bb566b65SMatthias Springer     Operation *fillOp =
239bb566b65SMatthias Springer         movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc);
24001581e28SMatthias Springer     rewriter.setInsertionPointAfter(fillOp);
241bb566b65SMatthias Springer   }
24201581e28SMatthias Springer 
243437c6217SMatthias Springer   // Create memcpy.
24401581e28SMatthias Springer   SmallVector<OpFoldResult> sizes =
24501581e28SMatthias Springer       getMixedSizes(rewriter, loc, padOp.getSource());
24601581e28SMatthias Springer   SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
24701581e28SMatthias Springer                                     rewriter.getIndexAttr(1));
24801581e28SMatthias Springer   Value subview = rewriter.create<memref::SubViewOp>(
24901581e28SMatthias Springer       loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
250579bca12SMatthias Springer   createMemcpy(rewriter, loc, padOp.getSource(), subview, options);
25101581e28SMatthias Springer 
25201581e28SMatthias Springer   // Create bufferization.to_tensor with "restrict" and "writable". The returned
25301581e28SMatthias Springer   // tensor is a new buffer allocation, so it does not alias with any buffer.
25401581e28SMatthias Springer   Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
25501581e28SMatthias Springer       loc, alloc, /*restrict=*/true, /*writable=*/true);
25601581e28SMatthias Springer   rewriter.replaceOp(padOp, toTensorOp);
25703301be0SMatthias Springer   return alloc;
25801581e28SMatthias Springer }
25901581e28SMatthias Springer 
260579bca12SMatthias Springer Value linalg::bufferizeToAllocation(
261579bca12SMatthias Springer     RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
262579bca12SMatthias Springer     vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) {
263eb74eff9SMatthias Springer   assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
264eb74eff9SMatthias Springer          "expected single masked op");
265eb74eff9SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
266579bca12SMatthias Springer   bufferization::BufferizationOptions bufferizationOptions;
267eb74eff9SMatthias Springer   Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
268eb74eff9SMatthias Springer   assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
269eb74eff9SMatthias Springer 
270eb74eff9SMatthias Springer   // Bufferize maskable op. By default, place the buffer allocation right before
271eb74eff9SMatthias Springer   // the mask op.
272eb74eff9SMatthias Springer   Value alloc = bufferizeToAllocation(
273579bca12SMatthias Springer       rewriter, options, maskOp.getMaskableOp(), memorySpace,
274eb74eff9SMatthias Springer       /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp);
275eb74eff9SMatthias Springer 
276a5bba98aSMatthias Springer   if (options.bufferizeDestinationOnly)
277a5bba98aSMatthias Springer     return alloc;
278a5bba98aSMatthias Springer 
279eb74eff9SMatthias Springer   // Bufferize terminator.
280eb74eff9SMatthias Springer   rewriter.setInsertionPoint(yieldOp);
281eb74eff9SMatthias Springer   if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
282579bca12SMatthias Springer           rewriter, bufferizationOptions)))
283eb74eff9SMatthias Springer     return nullptr;
284eb74eff9SMatthias Springer 
285eb74eff9SMatthias Springer   // Erase dead to_tensor ops inside of the mask op. This is necessary because
286eb74eff9SMatthias Springer   // there only be one op (apart from the terminator) inside the mask op.
287eb74eff9SMatthias Springer   // TODO: Remove dead to_tensor ops more aggressively during bufferization.
288eb74eff9SMatthias Springer   SmallVector<Operation *> toTensorOps;
289eb74eff9SMatthias Springer   maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
290eb74eff9SMatthias Springer     if (toTensorOp->getUses().empty())
291eb74eff9SMatthias Springer       toTensorOps.push_back(toTensorOp.getOperation());
292eb74eff9SMatthias Springer   });
293eb74eff9SMatthias Springer   for (Operation *op : toTensorOps)
294eb74eff9SMatthias Springer     rewriter.eraseOp(op);
295eb74eff9SMatthias Springer 
296eb74eff9SMatthias Springer   // Bufferize mask op.
297eb74eff9SMatthias Springer   SmallVector<OpOperand *> resultUses;
298eb74eff9SMatthias Springer   for (Value result : maskOp.getResults())
299eb74eff9SMatthias Springer     if (isa<TensorType>(result.getType()))
300eb74eff9SMatthias Springer       for (OpOperand &use : result.getUses())
301eb74eff9SMatthias Springer         resultUses.push_back(&use);
302eb74eff9SMatthias Springer   rewriter.setInsertionPoint(maskOp);
303eb74eff9SMatthias Springer   if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
304579bca12SMatthias Springer                  .bufferize(rewriter, bufferizationOptions)))
305eb74eff9SMatthias Springer     return nullptr;
306eb74eff9SMatthias Springer 
307eb74eff9SMatthias Springer   // Set "restrict" attribute, indicating that no other tensor aliases with
308eb74eff9SMatthias Springer   // this tensor. That is because we just allocated a new buffer for the tensor.
309eb74eff9SMatthias Springer   for (OpOperand *resultUse : resultUses) {
310eb74eff9SMatthias Springer     auto toTensorOp =
311eb74eff9SMatthias Springer         resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
312eb74eff9SMatthias Springer     assert(toTensorOp && "expected to_tensor op");
3135fcf907bSMatthias Springer     rewriter.modifyOpInPlace(toTensorOp, [&]() {
314eb74eff9SMatthias Springer       toTensorOp.setRestrict(true);
315eb74eff9SMatthias Springer       toTensorOp.setWritable(true);
316eb74eff9SMatthias Springer     });
317eb74eff9SMatthias Springer   }
318eb74eff9SMatthias Springer 
319eb74eff9SMatthias Springer   return alloc;
320eb74eff9SMatthias Springer }
321eb74eff9SMatthias Springer 
3223a223f44SNicolas Vasilache Value linalg::bufferizeToAllocation(
3233a223f44SNicolas Vasilache     RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
3243a223f44SNicolas Vasilache     bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
3253a223f44SNicolas Vasilache     Operation *insertionPoint) {
3263a223f44SNicolas Vasilache   Location loc = allocTensorOp.getLoc();
3273a223f44SNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
3283a223f44SNicolas Vasilache   rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
3293a223f44SNicolas Vasilache   bufferization::BufferizationOptions bufferizationOptions;
3303a223f44SNicolas Vasilache 
3313a223f44SNicolas Vasilache   // Create buffer allocation.
3323a223f44SNicolas Vasilache   Value alloc = createAllocationForTensor(
3333a223f44SNicolas Vasilache       rewriter, loc, allocTensorOp.getResult(), options, memorySpace);
3343a223f44SNicolas Vasilache 
3353a223f44SNicolas Vasilache   // Create bufferization.to_tensor with "restrict" and "writable". The returned
3363a223f44SNicolas Vasilache   // tensor is a new buffer allocation, so it does not alias with any buffer.
3373a223f44SNicolas Vasilache   Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
3383a223f44SNicolas Vasilache       loc, alloc, /*restrict=*/true, /*writable=*/true);
3393a223f44SNicolas Vasilache   rewriter.replaceOp(allocTensorOp, toTensorOp);
3403a223f44SNicolas Vasilache   return alloc;
3413a223f44SNicolas Vasilache }
3423a223f44SNicolas Vasilache 
34396179dffSNicolas Vasilache /// Lower tensor.from_elements to a sequence of chained tensor.insert.
34496179dffSNicolas Vasilache FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
34596179dffSNicolas Vasilache     RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
34696179dffSNicolas Vasilache   Location loc = fromElementsOp.getLoc();
34796179dffSNicolas Vasilache   RankedTensorType tensorType =
3485550c821STres Popp       cast<RankedTensorType>(fromElementsOp.getType());
34996179dffSNicolas Vasilache   auto shape = tensorType.getShape();
3507b3c662dSMatthias Springer 
35196179dffSNicolas Vasilache   // Create tensor.empty.
35296179dffSNicolas Vasilache   auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
35396179dffSNicolas Vasilache 
35496179dffSNicolas Vasilache   // Case: tensor<elem_type>.
35596179dffSNicolas Vasilache   if (shape.empty()) {
35696179dffSNicolas Vasilache     Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>(
35796179dffSNicolas Vasilache         fromElementsOp, fromElementsOp.getElements().front(),
35896179dffSNicolas Vasilache         emptyOp.getResult(), ValueRange());
35996179dffSNicolas Vasilache     return res;
36096179dffSNicolas Vasilache   }
36196179dffSNicolas Vasilache 
36296179dffSNicolas Vasilache   // Create constants for the range of possible indices [0, max{shape_i}).
363fab2bb8bSJustin Lebar   auto maxDim = *llvm::max_element(shape);
36496179dffSNicolas Vasilache   SmallVector<Value, 2> constants;
36596179dffSNicolas Vasilache   constants.reserve(maxDim);
36696179dffSNicolas Vasilache   for (int i = 0; i < maxDim; ++i)
36796179dffSNicolas Vasilache     constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
36896179dffSNicolas Vasilache 
36996179dffSNicolas Vasilache   // Traverse all elements and create tensor.insert ops.
37096179dffSNicolas Vasilache   auto elementIt = fromElementsOp.getElements().begin();
37196179dffSNicolas Vasilache   SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
37296179dffSNicolas Vasilache   Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
37396179dffSNicolas Vasilache                                shape, constants, elementIt, indices);
37496179dffSNicolas Vasilache 
37596179dffSNicolas Vasilache   // Replace tensor.from_elements.
37696179dffSNicolas Vasilache   rewriter.replaceOp(fromElementsOp, result);
37796179dffSNicolas Vasilache   return result.getDefiningOp();
37896179dffSNicolas Vasilache }
37996179dffSNicolas Vasilache 
38096179dffSNicolas Vasilache /// Lower tensor.generate to linalg.generic.
38196179dffSNicolas Vasilache FailureOr<Operation *>
38296179dffSNicolas Vasilache mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
38396179dffSNicolas Vasilache                                                tensor::GenerateOp generateOp) {
38496179dffSNicolas Vasilache   // Only ops with exactly one block are supported.
38596179dffSNicolas Vasilache   if (!generateOp.getBody().hasOneBlock())
38696179dffSNicolas Vasilache     return failure();
38796179dffSNicolas Vasilache 
38896179dffSNicolas Vasilache   Location loc = generateOp.getLoc();
3895550c821STres Popp   RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
39096179dffSNicolas Vasilache 
39196179dffSNicolas Vasilache   // Create tensor.empty.
39296179dffSNicolas Vasilache   auto emptyOp =
39396179dffSNicolas Vasilache       rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
39496179dffSNicolas Vasilache 
39596179dffSNicolas Vasilache   // Create linalg.generic.
39696179dffSNicolas Vasilache   SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
39796179dffSNicolas Vasilache                                                  utils::IteratorType::parallel);
39896179dffSNicolas Vasilache   SmallVector<AffineMap> indexingMaps(
39996179dffSNicolas Vasilache       1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
40096179dffSNicolas Vasilache   auto genericOp = rewriter.create<linalg::GenericOp>(
40196179dffSNicolas Vasilache       loc, tensorType, /*inputs=*/ValueRange(),
40296179dffSNicolas Vasilache       /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
40396179dffSNicolas Vasilache       indexingMaps, iteratorTypes);
40496179dffSNicolas Vasilache   Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
40596179dffSNicolas Vasilache                                      tensorType.getElementType(), loc);
40696179dffSNicolas Vasilache   rewriter.setInsertionPointToStart(body);
40796179dffSNicolas Vasilache   SmallVector<Value> bbArgReplacements;
40896179dffSNicolas Vasilache   for (int64_t i = 0; i < tensorType.getRank(); ++i)
40996179dffSNicolas Vasilache     bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
41096179dffSNicolas Vasilache   rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
41196179dffSNicolas Vasilache 
41296179dffSNicolas Vasilache   // Update terminator.
41396179dffSNicolas Vasilache   auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
41496179dffSNicolas Vasilache   rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
41596179dffSNicolas Vasilache 
41696179dffSNicolas Vasilache   // Replace tensor.generate.
41796179dffSNicolas Vasilache   rewriter.replaceOp(generateOp, genericOp->getResult(0));
41896179dffSNicolas Vasilache   return genericOp.getOperation();
41996179dffSNicolas Vasilache }
42096179dffSNicolas Vasilache 
42196179dffSNicolas Vasilache /// Lower tensor.pad to linalg.generic + tensor.insert_slice.
42296179dffSNicolas Vasilache FailureOr<Operation *>
42396179dffSNicolas Vasilache mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
42496179dffSNicolas Vasilache                                                tensor::PadOp padOp) {
4257b3c662dSMatthias Springer   // Only ops with exactly one block are supported.
4267b3c662dSMatthias Springer   if (!padOp.getBodyRegion().hasOneBlock())
4277b3c662dSMatthias Springer     return failure();
4287b3c662dSMatthias Springer 
4297b3c662dSMatthias Springer   // Create tensor.empty.
4307b3c662dSMatthias Springer   Location loc = padOp.getLoc();
4317b3c662dSMatthias Springer   RankedTensorType resultType = padOp.getResultType();
4327b3c662dSMatthias Springer   ReifiedRankedShapedTypeDims reifiedShape;
433758329dcSMatthias Springer   if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
4347b3c662dSMatthias Springer     return rewriter.notifyMatchFailure(
4357b3c662dSMatthias Springer         padOp, "failed to reify tensor.pad op result shape");
4367b3c662dSMatthias Springer   SmallVector<Value> dynamicSizes;
4377b3c662dSMatthias Springer   for (int64_t i = 0; i < resultType.getRank(); ++i)
4387b3c662dSMatthias Springer     if (resultType.isDynamicDim(i))
4394f279a57SKazu Hirata       dynamicSizes.push_back(cast<Value>(reifiedShape[0][i]));
4407b3c662dSMatthias Springer 
44196179dffSNicolas Vasilache   // If the `padOp` has a nofold attribute and all paddings are known to be 0,
44296179dffSNicolas Vasilache   // explicitly insert a `linalg.copy`.
44396179dffSNicolas Vasilache   if (padOp.getNofoldAttr() &&
44496179dffSNicolas Vasilache       llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
44596179dffSNicolas Vasilache       llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
44696179dffSNicolas Vasilache     using bufferization::AllocTensorOp;
44796179dffSNicolas Vasilache     Value allocated =
44896179dffSNicolas Vasilache         rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);
44996179dffSNicolas Vasilache     auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
45096179dffSNicolas Vasilache         padOp, padOp.getSource(), allocated);
45196179dffSNicolas Vasilache     return copyOp.getOperation();
45296179dffSNicolas Vasilache   }
45396179dffSNicolas Vasilache 
45496179dffSNicolas Vasilache   Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
45501581e28SMatthias Springer   // Create linalg.fill or linalg.generic.
45696179dffSNicolas Vasilache   Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty);
4577b3c662dSMatthias Springer   rewriter.setInsertionPointAfter(fillOp);
4587b3c662dSMatthias Springer 
4597b3c662dSMatthias Springer   // Create tensor::InsertSliceOp.
4607b3c662dSMatthias Springer   SmallVector<OpFoldResult> sliceSizes =
4617b3c662dSMatthias Springer       getMixedSizes(rewriter, loc, padOp.getSource());
4627b3c662dSMatthias Springer   SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
4637b3c662dSMatthias Springer                                          rewriter.getIndexAttr(1));
46496179dffSNicolas Vasilache   auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
46501581e28SMatthias Springer       padOp, padOp.getSource(), fillOp->getResult(0),
4667b3c662dSMatthias Springer       /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
46796179dffSNicolas Vasilache   return insertSliceOp.getOperation();
4687b3c662dSMatthias Springer }
46945ccff17SMatthias Springer 
470579bca12SMatthias Springer Value linalg::bufferizeToAllocation(
471579bca12SMatthias Springer     RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
472579bca12SMatthias Springer     Operation *op, Attribute memorySpace, Operation *insertionPoint) {
4736badbd6fSMatthias Springer   using namespace bufferization;
4746badbd6fSMatthias Springer 
47501581e28SMatthias Springer   // Call specialized overload for certain ops.
476335ada60SMatthias Springer   if (auto padOp = dyn_cast<tensor::PadOp>(op))
477579bca12SMatthias Springer     return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
478eb74eff9SMatthias Springer   if (auto maskOp = dyn_cast<vector::MaskOp>(op))
479579bca12SMatthias Springer     return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
4803a223f44SNicolas Vasilache   if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
4813a223f44SNicolas Vasilache     return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);
48201581e28SMatthias Springer 
4836badbd6fSMatthias Springer   // Only bufferizable ops are supported.
4846badbd6fSMatthias Springer   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
4856badbd6fSMatthias Springer   if (!bufferizableOp)
486335ada60SMatthias Springer     return nullptr;
487579bca12SMatthias Springer   BufferizationOptions bufferizationOptions;
488579bca12SMatthias Springer   AnalysisState state(bufferizationOptions);
4896badbd6fSMatthias Springer 
490a5bba98aSMatthias Springer #ifndef NDEBUG
491b76a180dSMatthias Springer   if (!options.bufferizeDestinationOnly) {
492a5bba98aSMatthias Springer     // Ops with nested tensor ops are not supported yet. At the moment, this
493a5bba98aSMatthias Springer     // function just bufferizes the given op itself, but not its body.
494a5bba98aSMatthias Springer     op->walk([&](Operation *nestedOp) {
495a5bba98aSMatthias Springer       if (op == nestedOp)
496a5bba98aSMatthias Springer         return;
497a5bba98aSMatthias Springer       if (llvm::any_of(nestedOp->getOperands(),
498a5757c5bSChristian Sigg                        [](Value v) { return isa<TensorType>(v.getType()); }))
499a5bba98aSMatthias Springer         llvm_unreachable("ops with nested tensor ops are not supported yet");
500a5bba98aSMatthias Springer       if (llvm::any_of(nestedOp->getResults(),
501a5757c5bSChristian Sigg                        [](Value v) { return isa<TensorType>(v.getType()); }))
502a5bba98aSMatthias Springer         llvm_unreachable("ops with nested tensor ops are not supported yet");
503a5bba98aSMatthias Springer     });
504b76a180dSMatthias Springer   }
505a5bba98aSMatthias Springer #endif // NDEBUG
506a5bba98aSMatthias Springer 
5076badbd6fSMatthias Springer   // Gather tensor results.
5086badbd6fSMatthias Springer   SmallVector<OpResult> tensorResults;
5096badbd6fSMatthias Springer   for (OpResult result : op->getResults()) {
510a5757c5bSChristian Sigg     if (!isa<TensorType>(result.getType()))
5116badbd6fSMatthias Springer       continue;
5126badbd6fSMatthias Springer     // Unranked tensors are not supported
5136badbd6fSMatthias Springer     if (!isa<RankedTensorType>(result.getType()))
5146badbd6fSMatthias Springer       return nullptr;
5156badbd6fSMatthias Springer     // Ops that bufferize to an allocation are not supported.
5166badbd6fSMatthias Springer     if (bufferizableOp.bufferizesToAllocation(result))
5176badbd6fSMatthias Springer       return nullptr;
5186badbd6fSMatthias Springer     tensorResults.push_back(result);
5196badbd6fSMatthias Springer   }
5206badbd6fSMatthias Springer 
5216badbd6fSMatthias Springer   // Gather all operands that should bufferize to a new allocation. I.e.,
5226badbd6fSMatthias Springer   // bufferize out-of-place.
5236badbd6fSMatthias Springer   SmallVector<OpOperand *> outOfPlaceOperands, resultUses;
5246badbd6fSMatthias Springer   auto addOutOfPlaceOperand = [&](OpOperand *operand) {
5258e8bbbd4SKazu Hirata     if (!llvm::is_contained(outOfPlaceOperands, operand))
5266badbd6fSMatthias Springer       outOfPlaceOperands.push_back(operand);
5276badbd6fSMatthias Springer   };
5286badbd6fSMatthias Springer   for (OpResult result : tensorResults) {
5296badbd6fSMatthias Springer     AliasingOpOperandList aliasingOperands =
5306badbd6fSMatthias Springer         state.getAliasingOpOperands(result);
5316badbd6fSMatthias Springer     for (const AliasingOpOperand &operand : aliasingOperands) {
5326badbd6fSMatthias Springer       addOutOfPlaceOperand(operand.opOperand);
5336badbd6fSMatthias Springer       for (OpOperand &resultUse : result.getUses())
5346badbd6fSMatthias Springer         resultUses.push_back(&resultUse);
5356badbd6fSMatthias Springer     }
5366badbd6fSMatthias Springer   }
5376badbd6fSMatthias Springer   for (OpOperand &operand : op->getOpOperands()) {
5386badbd6fSMatthias Springer     if (!state.bufferizesToMemoryWrite(operand))
5396badbd6fSMatthias Springer       continue;
5406badbd6fSMatthias Springer     if (!isa<RankedTensorType>(operand.get().getType()))
541b76a180dSMatthias Springer       continue;
5426badbd6fSMatthias Springer     addOutOfPlaceOperand(&operand);
5436badbd6fSMatthias Springer   }
5446badbd6fSMatthias Springer   // TODO: Support multiple buffers.
5456badbd6fSMatthias Springer   if (outOfPlaceOperands.size() != 1)
5466badbd6fSMatthias Springer     return nullptr;
5476badbd6fSMatthias Springer 
5486badbd6fSMatthias Springer   // Allocate buffers.
5496badbd6fSMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
550eb74eff9SMatthias Springer   rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);
5516badbd6fSMatthias Springer   SmallVector<Value> allocs;
5526badbd6fSMatthias Springer   for (OpOperand *operand : outOfPlaceOperands) {
5531a5aa77fSMatthias Springer     Value alloc = createAllocationForTensor(
5541a5aa77fSMatthias Springer         rewriter, op->getLoc(), operand->get(), options, memorySpace);
5556badbd6fSMatthias Springer     allocs.push_back(alloc);
556*d9111f19SAmir Bishara     if (!state.findDefinitions(operand).empty()) {
557ef4f5357SMatthias Springer       // Initialize buffer with a copy of the operand data. Not needed if the
558ef4f5357SMatthias Springer       // tensor is uninitialized.
559579bca12SMatthias Springer       createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
560ef4f5357SMatthias Springer     }
5615fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op, [&]() {
562a5bba98aSMatthias Springer       auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
563a5bba98aSMatthias Springer       operand->set(toTensorOp);
564a5bba98aSMatthias Springer       if (options.bufferizeDestinationOnly) {
5655fcf907bSMatthias Springer         rewriter.modifyOpInPlace(toTensorOp, [&]() {
566a5bba98aSMatthias Springer           toTensorOp.setRestrict(true);
567a5bba98aSMatthias Springer           toTensorOp.setWritable(true);
5686badbd6fSMatthias Springer         });
5696badbd6fSMatthias Springer       }
570a5bba98aSMatthias Springer     });
571a5bba98aSMatthias Springer   }
572a5bba98aSMatthias Springer 
573a5bba98aSMatthias Springer   if (options.bufferizeDestinationOnly)
574a5bba98aSMatthias Springer     return allocs.front();
5756badbd6fSMatthias Springer 
5766badbd6fSMatthias Springer   // Bufferize the op.
577eb74eff9SMatthias Springer   rewriter.setInsertionPoint(op);
578579bca12SMatthias Springer   if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
5796badbd6fSMatthias Springer     return nullptr;
5806badbd6fSMatthias Springer 
5816badbd6fSMatthias Springer   // Set "restrict" attribute, indicating that no other tensor aliases with
5826badbd6fSMatthias Springer   // this tensor. That is because we just allocated a new buffer for the tensor.
5836badbd6fSMatthias Springer   for (OpOperand *resultUse : resultUses) {
5846badbd6fSMatthias Springer     auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
5856badbd6fSMatthias Springer     assert(toTensorOp && "expected to_tensor op");
5865fcf907bSMatthias Springer     rewriter.modifyOpInPlace(toTensorOp, [&]() {
5876badbd6fSMatthias Springer       toTensorOp.setRestrict(true);
5886badbd6fSMatthias Springer       toTensorOp.setWritable(true);
5896badbd6fSMatthias Springer     });
5906badbd6fSMatthias Springer   }
5916badbd6fSMatthias Springer   return allocs.front();
59201581e28SMatthias Springer }
59301581e28SMatthias Springer 
59496179dffSNicolas Vasilache namespace {
59596179dffSNicolas Vasilache 
5966f3baf43SAlexander Belyaev template <typename OpTy>
5976f3baf43SAlexander Belyaev LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
5986f3baf43SAlexander Belyaev                                                  PatternRewriter &rewriter) {
5996f3baf43SAlexander Belyaev   return linalg::rewriteInDestinationPassingStyle(rewriter, op);
60096179dffSNicolas Vasilache }
60196179dffSNicolas Vasilache 
60296179dffSNicolas Vasilache } // namespace
60396179dffSNicolas Vasilache 
60445ccff17SMatthias Springer void linalg::populateConvertToDestinationStylePatterns(
60545ccff17SMatthias Springer     RewritePatternSet &patterns) {
6066f3baf43SAlexander Belyaev   patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
6076f3baf43SAlexander Belyaev   patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
6086f3baf43SAlexander Belyaev   patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
60945ccff17SMatthias Springer }
610