xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/BuiltinTypeInterfaces.h"
23 #include "mlir/IR/Dialect.h"
24 #include "mlir/IR/Operation.h"
25 
26 using namespace mlir;
27 using namespace mlir::bufferization;
28 using namespace mlir::tensor;
29 
30 namespace mlir {
31 namespace tensor {
32 namespace {
33 
34 struct CastOpInterface
35     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
36                                                     tensor::CastOp> {
37   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
38                               const AnalysisState &state) const {
39     return false;
40   }
41 
42   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
43                                const AnalysisState &state) const {
44     return false;
45   }
46 
47   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
48                                       const AnalysisState &state) const {
49     return {{op->getResult(0), BufferRelation::Equivalent}};
50   }
51 
52   FailureOr<BaseMemRefType>
53   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
54                 SmallVector<Value> &invocationStack) const {
55     auto castOp = cast<tensor::CastOp>(op);
56     auto maybeSrcBufferType = bufferization::getBufferType(
57         castOp.getSource(), options, invocationStack);
58     if (failed(maybeSrcBufferType))
59       return failure();
60     Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
61 
62     // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
63     // type in case the input is an unranked tensor type.
64 
65     // Case 1: Casting an unranked tensor
66     if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
67       // When casting to a ranked tensor, we cannot infer any static offset or
68       // strides from the source. Assume fully dynamic.
69       return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
70     }
71 
72     // Case 2: Casting to an unranked tensor type
73     if (isa<UnrankedTensorType>(castOp.getType())) {
74       return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
75     }
76 
77     // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
78     // change.
79     auto rankedResultType = cast<RankedTensorType>(castOp.getType());
80     return MemRefType::get(
81         rankedResultType.getShape(), rankedResultType.getElementType(),
82         llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
83   }
84 
85   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
86                           const BufferizationOptions &options) const {
87     auto castOp = cast<tensor::CastOp>(op);
88 
89     // The result buffer still has the old (pre-cast) type.
90     FailureOr<Value> resultBuffer =
91         getBuffer(rewriter, castOp.getSource(), options);
92     if (failed(resultBuffer))
93       return failure();
94 
95     // Compute the new type.
96     auto resultMemRefType =
97         bufferization::getBufferType(castOp.getResult(), options);
98     if (failed(resultMemRefType))
99       return failure();
100     if (resultBuffer->getType() == *resultMemRefType) {
101       // This cast is a no-op.
102       replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
103       return success();
104     }
105 
106     // Replace the op with a memref.cast.
107     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
108                                              *resultMemRefType) &&
109            "CallOp::bufferize: cast incompatible");
110     replaceOpWithNewBufferizedOp<memref::CastOp>(
111         rewriter, op, *resultMemRefType, *resultBuffer);
112 
113     return success();
114   }
115 };
116 
117 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
118 struct CollapseShapeOpInterface
119     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
120                                                     tensor::CollapseShapeOp> {
121   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
122                               const AnalysisState &state) const {
123     // tensor.collapse_shape may reallocate, at which point the source buffer is
124     // copied. I.e., there will be a memory read side effect on the bufferized
125     // source. This function conservatively returns "true" because whether a
126     // copy will be created or not is not known at this point.
127     return true;
128   }
129 
130   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
131                                const AnalysisState &state) const {
132     return false;
133   }
134 
135   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
136                                       const AnalysisState &state) const {
137     // TODO: CollapseShapeOp may allocate at runtime.
138     return {{op->getOpResult(0), BufferRelation::Equivalent}};
139   }
140 
141   FailureOr<BaseMemRefType>
142   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
143                 SmallVector<Value> &invocationStack) const {
144     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
145     auto maybeSrcBufferType = bufferization::getBufferType(
146         collapseShapeOp.getSrc(), options, invocationStack);
147     if (failed(maybeSrcBufferType))
148       return failure();
149     auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
150     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
151         srcBufferType, collapseShapeOp.getReassociationIndices());
152 
153     if (!canBeCollapsed) {
154       // If dims cannot be collapsed, this op bufferizes to a new allocation.
155       RankedTensorType tensorResultType = collapseShapeOp.getResultType();
156       return bufferization::getMemRefTypeWithStaticIdentityLayout(
157           tensorResultType, srcBufferType.getMemorySpace());
158     }
159 
160     return memref::CollapseShapeOp::computeCollapsedType(
161         srcBufferType, collapseShapeOp.getReassociationIndices());
162   }
163 
164   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
165                           const BufferizationOptions &options) const {
166     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
167     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
168     FailureOr<Value> maybeBuffer =
169         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
170     if (failed(maybeBuffer))
171       return failure();
172     Value buffer = *maybeBuffer;
173     auto bufferType = cast<MemRefType>(buffer.getType());
174 
175     if (tensorResultType.getRank() == 0) {
176       // 0-d collapses must go through a different op builder.
177       MemRefType resultType;
178 
179       if (bufferType.getLayout().isIdentity()) {
180         // Standard layout: result type has no offset.
181         MemRefLayoutAttrInterface layout;
182         resultType = MemRefType::get({}, tensorResultType.getElementType(),
183                                      layout, bufferType.getMemorySpace());
184       } else {
185         // Source memref has a layout map: result type has the same offset as
186         // the source type.
187         SmallVector<int64_t> strides;
188         int64_t offset;
189         if (failed(bufferType.getStridesAndOffset(strides, offset)))
190           return failure();
191         resultType = MemRefType::get(
192             {}, tensorResultType.getElementType(),
193             StridedLayoutAttr::get(op->getContext(), offset, {}),
194             bufferType.getMemorySpace());
195       }
196 
197       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
198           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
199       return success();
200     }
201 
202     // If the dims are not collapsible (due to an incompatible source layout
203     // map), force an out-of-place bufferization, i.e., a buffer copy. This
204     // newly allocated buffer will have no layout map and thus be collapsible.
205     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
206         bufferType, collapseShapeOp.getReassociationIndices());
207     if (!canBeCollapsed) {
208       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
209       AnalysisState analysisState(options);
210       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
211           rewriter, op->getLoc(), collapseShapeOp.getSrc(), options);
212       if (failed(tensorAlloc))
213         return failure();
214       auto memrefType =
215           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
216                           collapseShapeOp.getSrcType().getElementType(),
217                           AffineMap(), bufferType.getMemorySpace());
218       buffer = rewriter.create<bufferization::ToMemrefOp>(
219           op->getLoc(), memrefType, *tensorAlloc);
220     }
221 
222     // Result type is inferred by the builder.
223     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
224         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
225     return success();
226   }
227 };
228 
229 /// Bufferization of tensor.dim. Replace with memref.dim.
230 struct DimOpInterface
231     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
232                                                     tensor::DimOp> {
233   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
234                               const AnalysisState &state) const {
235     // The op reads the tensor's metadata but not its contents.
236     return false;
237   }
238 
239   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
240                                const AnalysisState &state) const {
241     return false;
242   }
243 
244   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
245                                       const AnalysisState &state) const {
246     return {};
247   }
248 
249   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
250                           const BufferizationOptions &options) const {
251     auto dimOp = cast<tensor::DimOp>(op);
252     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
253     if (failed(v))
254       return failure();
255     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
256                                                 dimOp.getIndex());
257     return success();
258   }
259 };
260 
261 /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
262 struct EmptyOpInterface
263     : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
264                                                     tensor::EmptyOp> {
265   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
266 
267   bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
268                                      const AnalysisState &state) const {
269     // The returned tensor does not have specified contents.
270     return false;
271   }
272 
273   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
274                           const BufferizationOptions &options) const {
275     auto emptyOp = cast<tensor::EmptyOp>(op);
276 
277     // Optimization: Fold away the op if it has no uses.
278     if (op->getUses().empty()) {
279       rewriter.eraseOp(op);
280       return success();
281     }
282 
283     // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
284     FailureOr<Value> allocTensor = allocateTensorForShapedValue(
285         rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false);
286     if (failed(allocTensor))
287       return failure();
288     rewriter.replaceOp(op, *allocTensor);
289     return success();
290   }
291 };
292 
293 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
294 struct ExpandShapeOpInterface
295     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
296                                                     tensor::ExpandShapeOp> {
297   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
298                               const AnalysisState &state) const {
299     // In contrast to tensor.collapse_shape, this op can always be bufferized
300     // without a copy.
301     return false;
302   }
303 
304   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
305                                const AnalysisState &state) const {
306     return false;
307   }
308 
309   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
310                                       const AnalysisState &state) const {
311     return {{op->getOpResult(0), BufferRelation::Equivalent}};
312   }
313 
314   FailureOr<BaseMemRefType>
315   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
316                 SmallVector<Value> &invocationStack) const {
317     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
318     auto maybeSrcBufferType = bufferization::getBufferType(
319         expandShapeOp.getSrc(), options, invocationStack);
320     if (failed(maybeSrcBufferType))
321       return failure();
322     auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
323     auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
324         srcBufferType, expandShapeOp.getResultType().getShape(),
325         expandShapeOp.getReassociationIndices());
326     if (failed(maybeResultType))
327       return failure();
328     return *maybeResultType;
329   }
330 
331   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
332                           const BufferizationOptions &options) const {
333     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
334     auto tensorResultType = expandShapeOp.getResultType();
335     FailureOr<Value> buffer =
336         getBuffer(rewriter, expandShapeOp.getSrc(), options);
337     if (failed(buffer))
338       return failure();
339 
340     // Memref result type is inferred by the builder based on reassociation
341     // indices and result shape.
342     // TODO: Instead of inferring the output shape argument of
343     // memref.expand_shape op, use output_shape argument of tensor.expand_shape
344     // op.
345     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
346         rewriter, op, tensorResultType.getShape(), *buffer,
347         expandShapeOp.getReassociationIndices());
348     return success();
349   }
350 };
351 
352 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
353 struct ExtractSliceOpInterface
354     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
355                                                     tensor::ExtractSliceOp> {
356   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
357                               const AnalysisState &state) const {
358     return false;
359   }
360 
361   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
362                                const AnalysisState &state) const {
363     return false;
364   }
365 
366   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
367                                       const AnalysisState &state) const {
368     return {{op->getOpResult(0), BufferRelation::Unknown}};
369   }
370 
371   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
372                           const BufferizationOptions &options) const {
373     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
374     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
375     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
376     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
377     Location loc = extractSliceOp.getLoc();
378 
379     // Get source buffer.
380     FailureOr<Value> srcMemref =
381         getBuffer(rewriter, extractSliceOp.getSource(), options);
382     if (failed(srcMemref))
383       return failure();
384 
385     // Take a subview of the source buffer.
386     auto resultMemrefType =
387         bufferization::getBufferType(extractSliceOp.getResult(), options);
388     if (failed(resultMemrefType))
389       return failure();
390     Value subView = rewriter.create<memref::SubViewOp>(
391         loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
392         mixedOffsets, mixedSizes, mixedStrides);
393 
394     replaceOpWithBufferizedValues(rewriter, op, subView);
395     return success();
396   }
397 
398   FailureOr<BaseMemRefType>
399   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
400                 SmallVector<Value> &invocationStack) const {
401     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
402     assert(value == extractSliceOp.getResult() && "invalid value");
403     auto srcMemrefType = bufferization::getBufferType(
404         extractSliceOp.getSource(), options, invocationStack);
405     if (failed(srcMemrefType))
406       return failure();
407     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
408     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
409     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
410     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
411         extractSliceOp.getType().getShape(),
412         llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
413         mixedStrides));
414   }
415 };
416 
417 /// Bufferization of tensor.extract. Replace with memref.load.
418 struct ExtractOpInterface
419     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
420                                                     tensor::ExtractOp> {
421   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
422                               const AnalysisState &state) const {
423     return true;
424   }
425 
426   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
427                                const AnalysisState &state) const {
428     return false;
429   }
430 
431   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
432                                       const AnalysisState &state) const {
433     return {};
434   }
435 
436   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
437                           const BufferizationOptions &options) const {
438     auto extractOp = cast<tensor::ExtractOp>(op);
439     FailureOr<Value> srcMemref =
440         getBuffer(rewriter, extractOp.getTensor(), options);
441     if (failed(srcMemref))
442       return failure();
443     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
444                                                  extractOp.getIndices());
445     return success();
446   }
447 };
448 
449 // Implements backtracking to traverse indices of the output buffer while
450 // iterating over op.elements().
451 static void createStores(RewriterBase &rewriter, Location loc, int dim,
452                          Value buffer, ArrayRef<int64_t> shape,
453                          ArrayRef<Value> constants,
454                          OperandRange::iterator &elementIt,
455                          SmallVectorImpl<Value> &indices) {
456   if (dim == static_cast<int>(shape.size()) - 1) {
457     for (int i = 0; i < shape.back(); ++i) {
458       indices.back() = constants[i];
459       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
460       ++elementIt;
461     }
462     return;
463   }
464   for (int i = 0; i < shape[dim]; ++i) {
465     indices[dim] = constants[i];
466     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
467                  indices);
468   }
469 }
470 
471 /// Bufferization of tensor.from_elements.
472 struct FromElementsOpInterface
473     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
474                                                     tensor::FromElementsOp> {
475 
476   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
477 
478   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
479                           const BufferizationOptions &options) const {
480     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
481     auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
482 
483     // Allocate a buffer for the result.
484     Location loc = op->getLoc();
485     auto shape = tensorType.getShape();
486     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
487     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
488         rewriter, loc, fromElementsOp.getResult(), options,
489         /*copy=*/false);
490     if (failed(tensorAlloc))
491       return failure();
492     FailureOr<BaseMemRefType> memrefType =
493         bufferization::getBufferType(*tensorAlloc, options);
494     if (failed(memrefType))
495       return failure();
496     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
497         op->getLoc(), *memrefType, *tensorAlloc);
498 
499     // Case: tensor<0xelem_type>.
500     if (fromElementsOp.getElements().empty()) {
501       replaceOpWithBufferizedValues(rewriter, op, buffer);
502       return success();
503     }
504 
505     // Case: tensor<elem_type>.
506     if (shape.empty()) {
507       rewriter.create<memref::StoreOp>(
508           loc, fromElementsOp.getElements().front(), buffer);
509       replaceOpWithBufferizedValues(rewriter, op, buffer);
510       return success();
511     }
512 
513     // Create constants for the range of possible indices [0, max{shape_i}).
514     auto maxDim = *llvm::max_element(shape);
515     SmallVector<Value, 2> constants;
516     constants.reserve(maxDim);
517     for (int i = 0; i < maxDim; ++i)
518       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
519 
520     // Traverse all `elements` and create `memref.store` ops.
521     auto elementIt = fromElementsOp.getElements().begin();
522     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
523     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
524                  indices);
525 
526     replaceOpWithBufferizedValues(rewriter, op, buffer);
527 
528     return success();
529   }
530 };
531 
532 /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
533 /// Such ops are lowered to linalg.map with the given tensor as a destination.
534 ///
535 /// Example:
536 /// ```
537 /// %r = tensor.generate %x, %y {
538 ///   ^bb0(%arg0: index, %arg1: index):
539 ///   %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
540 ///   tensor.yield %0 : index
541 /// } : tensor<?x?xindex>
542 /// ```
543 ///
544 /// Is lowered to:
545 /// ```
546 /// linalg.map ins() outs(%dest) {
547 ///   %d0 = linalg.index 0 : index
548 ///   %d1 = linalg.index 1 : index
549 ///   %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
550 ///   linalg.yield %0 : index
551 /// }
552 /// ```
553 static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
554                                      Value tensorDestination,
555                                      ValueRange dynamicSizes,
556                                      Region &generateBody) {
557   assert(generateBody.hasOneBlock() && "expected body with single block");
558   auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
559   assert(generateBody.getNumArguments() == tensorType.getRank() &&
560          "rank mismatch");
561 
562   // Create linalg::MapOp.
563   OpBuilder::InsertionGuard g(rewriter);
564   auto linalgOp =
565       rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
566                                      /*init=*/tensorDestination);
567   Block &linalgBody = linalgOp.getMapper().emplaceBlock();
568 
569   // Create linalg::IndexOps.
570   rewriter.setInsertionPointToStart(&linalgBody);
571   SmallVector<Value> indices;
572   for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
573     indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim));
574 
575   // Move over body.
576   rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices);
577   auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
578   rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
579 
580   return linalgOp.getResult()[0];
581 }
582 
583 /// Bufferization of tensor.generate.
584 struct GenerateOpInterface
585     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
586                                                     tensor::GenerateOp> {
587 
588   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
589 
590   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
591                           const BufferizationOptions &options) const {
592     auto generateOp = cast<tensor::GenerateOp>(op);
593 
594     auto type = generateOp.getResult().getType();
595 
596     // TODO: Implement memory space for this op.
597     if (options.defaultMemorySpaceFn(type) != Attribute())
598       return op->emitError("memory space not implemented yet");
599 
600     // Allocate memory.
601     Location loc = op->getLoc();
602     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
603         rewriter, loc, generateOp.getResult(), options,
604         /*copy=*/false);
605     if (failed(tensorAlloc))
606       return failure();
607 
608     Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
609                                            generateOp.getDynamicExtents(),
610                                            generateOp.getBody());
611     rewriter.replaceOp(generateOp, result);
612 
613     return success();
614   }
615 };
616 
617 /// Bufferization of tensor.insert. Replace with memref.store.
618 ///
619 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
620 /// implementations for DestinationStyle ops.
621 struct InsertOpInterface
622     : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
623                                                      tensor::InsertOp> {
624   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
625                           const BufferizationOptions &options) const {
626     auto insertOp = cast<tensor::InsertOp>(op);
627     FailureOr<Value> destMemref =
628         getBuffer(rewriter, insertOp.getDest(), options);
629     if (failed(destMemref))
630       return failure();
631     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
632                                      *destMemref, insertOp.getIndices());
633     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
634     return success();
635   }
636 };
637 
638 template <typename InsertOpTy>
639 static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
640                                       OpOperand &opOperand) {
641   // The source is always read.
642   if (opOperand == insertSliceOp.getSourceMutable())
643     return true;
644 
645   // For the destination, it depends...
646   assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
647 
648   // Dest is not read if it is entirely overwritten. E.g.:
649   // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
650   bool allOffsetsZero =
651       llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
652   RankedTensorType destType = insertSliceOp.getDestType();
653   bool sizesMatchDestSizes =
654       areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
655   bool allStridesOne =
656       areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
657   return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
658 }
659 
660 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
661 /// certain circumstances, this op can also be a no-op.
662 ///
663 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
664 /// implementations for DestinationStyle ops.
665 struct InsertSliceOpInterface
666     : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
667                                                      tensor::InsertSliceOp> {
668   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
669                               const AnalysisState &state) const {
670     return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
671                                      opOperand);
672   }
673 
674   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
675                           const BufferizationOptions &options) const {
676     // insert_slice ops arise from tiling and bufferizing them out-of-place is
677     // generally a deal breaker. When used with loops, this ends up cloning the
678     // whole tensor on every single iteration and is a symptom of a
679     // catastrophically bad scheduling decision.
680     // TODO: be very loud about it or even consider failing the pass.
681     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
682     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
683     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
684     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
685     Location loc = insertSliceOp.getLoc();
686 
687     // Get destination buffer.
688     FailureOr<Value> dstMemref =
689         getBuffer(rewriter, insertSliceOp.getDest(), options);
690     if (failed(dstMemref))
691       return failure();
692 
693     // Take a subview of the destination buffer.
694     auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
695     auto subviewMemRefType =
696         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
697             insertSliceOp.getSourceType().getShape(), dstMemrefType,
698             mixedOffsets, mixedSizes, mixedStrides));
699     Value subView = rewriter.create<memref::SubViewOp>(
700         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
701         mixedStrides);
702 
703     // Copy tensor. If this tensor.insert_slice has a matching
704     // tensor.extract_slice, the copy operation will eventually fold away.
705     FailureOr<Value> srcMemref =
706         getBuffer(rewriter, insertSliceOp.getSource(), options);
707     if (failed(srcMemref))
708       return failure();
709     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
710       return failure();
711 
712     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
713     return success();
714   }
715 };
716 
717 /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
718 /// linalg.map + insert_slice.
719 /// For best performance, vectorize before bufferization (better performance in
720 /// case of padding with a constant).
721 struct PadOpInterface
722     : public BufferizableOpInterface::ExternalModel<PadOpInterface,
723                                                     tensor::PadOp> {
724   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
725 
726   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
727                               const AnalysisState &state) const {
728     return true;
729   }
730 
731   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
732                                const AnalysisState &state) const {
733     return false;
734   }
735 
736   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
737                                       const AnalysisState &state) const {
738     return {};
739   }
740 
741   FailureOr<BaseMemRefType>
742   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
743                 SmallVector<Value> &invocationStack) const {
744     // Infer memory space from the source tensor.
745     auto padOp = cast<tensor::PadOp>(op);
746     auto maybeSrcBufferType = bufferization::getBufferType(
747         padOp.getSource(), options, invocationStack);
748     if (failed(maybeSrcBufferType))
749       return failure();
750     MemRefLayoutAttrInterface layout;
751     return MemRefType::get(padOp.getResultType().getShape(),
752                            padOp.getResultType().getElementType(), layout,
753                            maybeSrcBufferType->getMemorySpace());
754   }
755 
756   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
757                           const BufferizationOptions &options) const {
758     auto padOp = cast<tensor::PadOp>(op);
759     Location loc = padOp.getLoc();
760     RankedTensorType resultType = padOp.getResultType();
761     RankedTensorType srcType = padOp.getSourceType();
762 
763     auto toValue = [&](OpFoldResult ofr) {
764       if (auto value = dyn_cast<Value>(ofr))
765         return value;
766       return rewriter
767           .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr))
768           .getResult();
769     };
770 
771     // Compute dynamic result dimensions.
772     SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
773     SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
774     SmallVector<Value> dynamicSizes;
775     for (int64_t i = 0; i < resultType.getRank(); ++i) {
776       if (!resultType.isDynamicDim(i))
777         continue;
778       Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
779       Value lowPad = toValue(mixedLowPad[i]);
780       Value highPad = toValue(mixedHighPad[i]);
781       AffineExpr s0, s1, s2;
782       bindSymbols(op->getContext(), s0, s1, s2);
783       AffineExpr sumExpr = s0 + s1 + s2;
784       Value sum = rewriter.create<affine::AffineApplyOp>(
785           loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
786       dynamicSizes.push_back(sum);
787     }
788 
789     // Allocate a buffer for the padded result.
790     FailureOr<Value> tensorAlloc =
791         allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options,
792                                      /*copy=*/false);
793     if (failed(tensorAlloc))
794       return failure();
795 
796     // tensor::PadOp is like tensor::GenerateOp: The only difference is that
797     // only a part of the generated tensor is needed. For simplicity, we reuse
798     // the same functionality here.
799     Value filledBuffer = lowerGenerateLikeOpBody(
800         rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
801 
802     // Create tensor::InsertSliceOp.
803     SmallVector<OpFoldResult> sliceSizes =
804         getMixedSizes(rewriter, loc, padOp.getSource());
805     SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
806                                            rewriter.getIndexAttr(1));
807     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
808         padOp, padOp.getSource(), filledBuffer,
809         /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
810 
811     return success();
812   }
813 };
814 
815 /// Bufferization of tensor.rank. Replace with memref.rank.
816 struct RankOpInterface
817     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
818                                                     tensor::RankOp> {
819   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
820                               const AnalysisState &state) const {
821     // The op reads the tensor's metadata but not its contents.
822     return false;
823   }
824 
825   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
826                                const AnalysisState &state) const {
827     return false;
828   }
829 
830   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
831                                       const AnalysisState &state) const {
832     return {};
833   }
834 
835   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
836                           const BufferizationOptions &options) const {
837     auto rankOp = cast<tensor::RankOp>(op);
838     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
839     if (failed(v))
840       return failure();
841     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
842                                                  *v);
843     return success();
844   }
845 };
846 
847 /// Bufferization of tensor.reshape. Replace with memref.reshape.
848 struct ReshapeOpInterface
849     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
850                                                     tensor::ReshapeOp> {
851   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
852                               const AnalysisState &state) const {
853     // Depending on the layout map, the source buffer may have to be copied.
854     auto reshapeOp = cast<tensor::ReshapeOp>(op);
855     return opOperand == reshapeOp.getShapeMutable();
856   }
857 
858   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
859                                const AnalysisState &state) const {
860     return false;
861   }
862 
863   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
864                                       const AnalysisState &state) const {
865     return {{op->getOpResult(0), BufferRelation::Equivalent}};
866   }
867 
868   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
869                           const BufferizationOptions &options) const {
870     auto reshapeOp = cast<tensor::ReshapeOp>(op);
871     FailureOr<Value> srcBuffer =
872         getBuffer(rewriter, reshapeOp.getSource(), options);
873     FailureOr<Value> shapeBuffer =
874         getBuffer(rewriter, reshapeOp.getShape(), options);
875     if (failed(srcBuffer) || failed(shapeBuffer))
876       return failure();
877     auto maybeResultMemRefType =
878         bufferization::getBufferType(reshapeOp.getResult(), options);
879     if (failed(maybeResultMemRefType))
880       return failure();
881 
882     // memref.reshape requires the source buffer to have an identity layout.
883     // If the source memref does not have an identity layout, copy the source
884     // into a new buffer with an identity layout.
885     auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
886     if (srcType && !srcType.getLayout().isIdentity()) {
887       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
888           rewriter, op->getLoc(), reshapeOp.getSource(), options);
889       if (failed(tensorAlloc))
890         return failure();
891       auto memrefType = MemRefType::get(
892           srcType.getShape(), srcType.getElementType(), AffineMap(),
893           cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
894       srcBuffer = rewriter
895                       .create<bufferization::ToMemrefOp>(
896                           op->getLoc(), memrefType, *tensorAlloc)
897                       .getResult();
898     }
899 
900     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
901         rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
902     return success();
903   }
904 
905   FailureOr<BaseMemRefType>
906   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
907                 SmallVector<Value> &invocationStack) const {
908     auto reshapeOp = cast<tensor::ReshapeOp>(op);
909     assert(value == reshapeOp.getResult() && "unexpected value provided");
910     auto maybeSourceBufferType = bufferization::getBufferType(
911         reshapeOp.getSource(), options, invocationStack);
912     if (failed(maybeSourceBufferType))
913       return failure();
914     return getMemRefTypeWithStaticIdentityLayout(
915         reshapeOp.getResult().getType(),
916         cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
917   }
918 };
919 
920 /// Analysis of ParallelInsertSliceOp.
921 struct ParallelInsertSliceOpInterface
922     : public BufferizableOpInterface::ExternalModel<
923           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
924   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
925                                       const AnalysisState &state) const {
926     return {};
927   }
928 
929   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
930                               const AnalysisState &state) const {
931     return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
932                                      opOperand);
933   }
934 
935   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
936                                const AnalysisState &state) const {
937     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
938     return opOperand == parallelInsertSliceOp.getDestMutable();
939   }
940 
941   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
942                           const BufferizationOptions &options) const {
943     OpBuilder::InsertionGuard g(rewriter);
944     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
945     ParallelCombiningOpInterface parallelCombiningParent =
946         parallelInsertSliceOp.getParallelCombiningParent();
947 
948     // Bufferize the op outside of the parallel combining terminator.
949     rewriter.setInsertionPoint(parallelCombiningParent);
950 
951     // Get source and destination buffers.
952     FailureOr<Value> destBuffer =
953         getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
954     if (failed(destBuffer))
955       return failure();
956     FailureOr<Value> srcBuffer =
957         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
958     if (failed(srcBuffer))
959       return failure();
960 
961     // Take a subview of the destination buffer.
962     auto destBufferType = cast<MemRefType>(destBuffer->getType());
963     auto subviewMemRefType =
964         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
965             parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
966             parallelInsertSliceOp.getMixedOffsets(),
967             parallelInsertSliceOp.getMixedSizes(),
968             parallelInsertSliceOp.getMixedStrides()));
969     Value subview = rewriter.create<memref::SubViewOp>(
970         parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
971         parallelInsertSliceOp.getMixedOffsets(),
972         parallelInsertSliceOp.getMixedSizes(),
973         parallelInsertSliceOp.getMixedStrides());
974 
975     // This memcpy will fold away if everything bufferizes in-place.
976     if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
977                                     *srcBuffer, subview)))
978       return failure();
979 
980     // In case the source was allocated in the same block, make sure that the
981     // deallocation op (if any) appears after the memcpy. By default, deallocs
982     // are placed before the terminator, but this does not work for ForallOp
983     // because the terminator does more than just yielding a value.
984     //
985     // Note: This is not a problem for the destination buffer because these are
986     // assumed to always bufferize in-place.
987     for (Operation *user : srcBuffer->getUsers()) {
988       if (hasEffect<MemoryEffects::Free>(user)) {
989         if (user->getBlock() == parallelCombiningParent->getBlock())
990           rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
991         break;
992       }
993     }
994 
995     // Delete the op.
996     rewriter.eraseOp(op);
997     return success();
998   }
999 
1000   /// tensor.parallel_insert_slice op has implicit inplace behavior. We
1001   /// shouldn't create copy to resolve conflict.
1002   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
1003                                  const AnalysisState &state) const {
1004     return success();
1005   }
1006 };
1007 
1008 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1009 /// with a linalg.map. Similar to tensor.generate.
1010 struct SplatOpInterface
1011     : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1012                                                     tensor::SplatOp> {
1013 
1014   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1015 
1016   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1017                           const BufferizationOptions &options) const {
1018     OpBuilder::InsertionGuard g(rewriter);
1019     auto splatOp = cast<tensor::SplatOp>(op);
1020 
1021     // Allocate memory.
1022     Location loc = op->getLoc();
1023     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1024         rewriter, loc, splatOp.getResult(), options,
1025         /*copy=*/false);
1026     if (failed(tensorAlloc))
1027       return failure();
1028 
1029     // Create linalg::MapOp.
1030     auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1031 
1032     // TODO: Implement memory space for this op.
1033     if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1034       return op->emitError("memory space not implemented yet");
1035 
1036     auto linalgOp =
1037         rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
1038                                        /*init=*/*tensorAlloc);
1039     Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1040 
1041     // Create linalg::IndexOps.
1042     rewriter.setInsertionPointToStart(&linalgBody);
1043     rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
1044     rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
1045 
1046     return success();
1047   }
1048 };
1049 
1050 } // namespace
1051 } // namespace tensor
1052 } // namespace mlir
1053 
1054 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1055     DialectRegistry &registry) {
1056   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1057     CastOp::attachInterface<CastOpInterface>(*ctx);
1058     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1059     DimOp::attachInterface<DimOpInterface>(*ctx);
1060     EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1061     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1062     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1063     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1064     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1065     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1066     InsertOp::attachInterface<InsertOpInterface>(*ctx);
1067     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1068     PadOp::attachInterface<PadOpInterface>(*ctx);
1069     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1070         *ctx);
1071     RankOp::attachInterface<RankOpInterface>(*ctx);
1072     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1073     SplatOp::attachInterface<SplatOpInterface>(*ctx);
1074 
1075     // Load additional dialects of which ops may get created.
1076     ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
1077   });
1078 
1079   // Bufferization requires SubsetInsertionOpInterface models. Make sure that
1080   // they are registered.
1081   tensor::registerSubsetOpInterfaceExternalModels(registry);
1082 }
1083