1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 15 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/IR/SCF.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/BuiltinTypeInterfaces.h" 23 #include "mlir/IR/Dialect.h" 24 #include "mlir/IR/Operation.h" 25 26 using namespace mlir; 27 using namespace mlir::bufferization; 28 using namespace mlir::tensor; 29 30 namespace mlir { 31 namespace tensor { 32 namespace { 33 34 struct CastOpInterface 35 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 36 tensor::CastOp> { 37 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 38 const AnalysisState &state) const { 39 return false; 40 } 41 42 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 43 const AnalysisState &state) const { 44 return false; 45 } 46 47 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 48 const AnalysisState &state) const { 49 return {{op->getResult(0), BufferRelation::Equivalent}}; 50 } 51 52 FailureOr<BaseMemRefType> 53 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 54 SmallVector<Value> &invocationStack) const { 55 auto castOp = cast<tensor::CastOp>(op); 56 auto maybeSrcBufferType = bufferization::getBufferType( 57 castOp.getSource(), options, invocationStack); 58 if (failed(maybeSrcBufferType)) 59 return failure(); 60 Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); 61 62 // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref 63 // type in case the input is an unranked tensor type. 64 65 // Case 1: Casting an unranked tensor 66 if (isa<UnrankedTensorType>(castOp.getSource().getType())) { 67 // When casting to a ranked tensor, we cannot infer any static offset or 68 // strides from the source. Assume fully dynamic. 69 return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); 70 } 71 72 // Case 2: Casting to an unranked tensor type 73 if (isa<UnrankedTensorType>(castOp.getType())) { 74 return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); 75 } 76 77 // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not 78 // change. 79 auto rankedResultType = cast<RankedTensorType>(castOp.getType()); 80 return MemRefType::get( 81 rankedResultType.getShape(), rankedResultType.getElementType(), 82 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace); 83 } 84 85 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 86 const BufferizationOptions &options) const { 87 auto castOp = cast<tensor::CastOp>(op); 88 89 // The result buffer still has the old (pre-cast) type. 90 FailureOr<Value> resultBuffer = 91 getBuffer(rewriter, castOp.getSource(), options); 92 if (failed(resultBuffer)) 93 return failure(); 94 95 // Compute the new type. 96 auto resultMemRefType = 97 bufferization::getBufferType(castOp.getResult(), options); 98 if (failed(resultMemRefType)) 99 return failure(); 100 if (resultBuffer->getType() == *resultMemRefType) { 101 // This cast is a no-op. 102 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); 103 return success(); 104 } 105 106 // Replace the op with a memref.cast. 107 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 108 *resultMemRefType) && 109 "CallOp::bufferize: cast incompatible"); 110 replaceOpWithNewBufferizedOp<memref::CastOp>( 111 rewriter, op, *resultMemRefType, *resultBuffer); 112 113 return success(); 114 } 115 }; 116 117 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 118 struct CollapseShapeOpInterface 119 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 120 tensor::CollapseShapeOp> { 121 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 122 const AnalysisState &state) const { 123 // tensor.collapse_shape may reallocate, at which point the source buffer is 124 // copied. I.e., there will be a memory read side effect on the bufferized 125 // source. This function conservatively returns "true" because whether a 126 // copy will be created or not is not known at this point. 127 return true; 128 } 129 130 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 131 const AnalysisState &state) const { 132 return false; 133 } 134 135 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 136 const AnalysisState &state) const { 137 // TODO: CollapseShapeOp may allocate at runtime. 138 return {{op->getOpResult(0), BufferRelation::Equivalent}}; 139 } 140 141 FailureOr<BaseMemRefType> 142 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 143 SmallVector<Value> &invocationStack) const { 144 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 145 auto maybeSrcBufferType = bufferization::getBufferType( 146 collapseShapeOp.getSrc(), options, invocationStack); 147 if (failed(maybeSrcBufferType)) 148 return failure(); 149 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType); 150 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 151 srcBufferType, collapseShapeOp.getReassociationIndices()); 152 153 if (!canBeCollapsed) { 154 // If dims cannot be collapsed, this op bufferizes to a new allocation. 155 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 156 return bufferization::getMemRefTypeWithStaticIdentityLayout( 157 tensorResultType, srcBufferType.getMemorySpace()); 158 } 159 160 return memref::CollapseShapeOp::computeCollapsedType( 161 srcBufferType, collapseShapeOp.getReassociationIndices()); 162 } 163 164 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 165 const BufferizationOptions &options) const { 166 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 167 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 168 FailureOr<Value> maybeBuffer = 169 getBuffer(rewriter, collapseShapeOp.getSrc(), options); 170 if (failed(maybeBuffer)) 171 return failure(); 172 Value buffer = *maybeBuffer; 173 auto bufferType = cast<MemRefType>(buffer.getType()); 174 175 if (tensorResultType.getRank() == 0) { 176 // 0-d collapses must go through a different op builder. 177 MemRefType resultType; 178 179 if (bufferType.getLayout().isIdentity()) { 180 // Standard layout: result type has no offset. 181 MemRefLayoutAttrInterface layout; 182 resultType = MemRefType::get({}, tensorResultType.getElementType(), 183 layout, bufferType.getMemorySpace()); 184 } else { 185 // Source memref has a layout map: result type has the same offset as 186 // the source type. 187 SmallVector<int64_t> strides; 188 int64_t offset; 189 if (failed(bufferType.getStridesAndOffset(strides, offset))) 190 return failure(); 191 resultType = MemRefType::get( 192 {}, tensorResultType.getElementType(), 193 StridedLayoutAttr::get(op->getContext(), offset, {}), 194 bufferType.getMemorySpace()); 195 } 196 197 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 198 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); 199 return success(); 200 } 201 202 // If the dims are not collapsible (due to an incompatible source layout 203 // map), force an out-of-place bufferization, i.e., a buffer copy. This 204 // newly allocated buffer will have no layout map and thus be collapsible. 205 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 206 bufferType, collapseShapeOp.getReassociationIndices()); 207 if (!canBeCollapsed) { 208 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 209 AnalysisState analysisState(options); 210 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 211 rewriter, op->getLoc(), collapseShapeOp.getSrc(), options); 212 if (failed(tensorAlloc)) 213 return failure(); 214 auto memrefType = 215 MemRefType::get(collapseShapeOp.getSrcType().getShape(), 216 collapseShapeOp.getSrcType().getElementType(), 217 AffineMap(), bufferType.getMemorySpace()); 218 buffer = rewriter.create<bufferization::ToMemrefOp>( 219 op->getLoc(), memrefType, *tensorAlloc); 220 } 221 222 // Result type is inferred by the builder. 223 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 224 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 225 return success(); 226 } 227 }; 228 229 /// Bufferization of tensor.dim. Replace with memref.dim. 230 struct DimOpInterface 231 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 232 tensor::DimOp> { 233 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 234 const AnalysisState &state) const { 235 // The op reads the tensor's metadata but not its contents. 236 return false; 237 } 238 239 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 240 const AnalysisState &state) const { 241 return false; 242 } 243 244 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 245 const AnalysisState &state) const { 246 return {}; 247 } 248 249 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 250 const BufferizationOptions &options) const { 251 auto dimOp = cast<tensor::DimOp>(op); 252 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options); 253 if (failed(v)) 254 return failure(); 255 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, 256 dimOp.getIndex()); 257 return success(); 258 } 259 }; 260 261 /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor". 262 struct EmptyOpInterface 263 : public BufferizableOpInterface::ExternalModel<EmptyOpInterface, 264 tensor::EmptyOp> { 265 bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 266 267 bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult, 268 const AnalysisState &state) const { 269 // The returned tensor does not have specified contents. 270 return false; 271 } 272 273 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 274 const BufferizationOptions &options) const { 275 auto emptyOp = cast<tensor::EmptyOp>(op); 276 277 // Optimization: Fold away the op if it has no uses. 278 if (op->getUses().empty()) { 279 rewriter.eraseOp(op); 280 return success(); 281 } 282 283 // Allocate a tensor. This emits a "bufferization.alloc_tensor" op. 284 FailureOr<Value> allocTensor = allocateTensorForShapedValue( 285 rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false); 286 if (failed(allocTensor)) 287 return failure(); 288 rewriter.replaceOp(op, *allocTensor); 289 return success(); 290 } 291 }; 292 293 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 294 struct ExpandShapeOpInterface 295 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 296 tensor::ExpandShapeOp> { 297 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 298 const AnalysisState &state) const { 299 // In contrast to tensor.collapse_shape, this op can always be bufferized 300 // without a copy. 301 return false; 302 } 303 304 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 305 const AnalysisState &state) const { 306 return false; 307 } 308 309 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 310 const AnalysisState &state) const { 311 return {{op->getOpResult(0), BufferRelation::Equivalent}}; 312 } 313 314 FailureOr<BaseMemRefType> 315 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 316 SmallVector<Value> &invocationStack) const { 317 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 318 auto maybeSrcBufferType = bufferization::getBufferType( 319 expandShapeOp.getSrc(), options, invocationStack); 320 if (failed(maybeSrcBufferType)) 321 return failure(); 322 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType); 323 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType( 324 srcBufferType, expandShapeOp.getResultType().getShape(), 325 expandShapeOp.getReassociationIndices()); 326 if (failed(maybeResultType)) 327 return failure(); 328 return *maybeResultType; 329 } 330 331 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 332 const BufferizationOptions &options) const { 333 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 334 auto tensorResultType = expandShapeOp.getResultType(); 335 FailureOr<Value> buffer = 336 getBuffer(rewriter, expandShapeOp.getSrc(), options); 337 if (failed(buffer)) 338 return failure(); 339 340 // Memref result type is inferred by the builder based on reassociation 341 // indices and result shape. 342 // TODO: Instead of inferring the output shape argument of 343 // memref.expand_shape op, use output_shape argument of tensor.expand_shape 344 // op. 345 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 346 rewriter, op, tensorResultType.getShape(), *buffer, 347 expandShapeOp.getReassociationIndices()); 348 return success(); 349 } 350 }; 351 352 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 353 struct ExtractSliceOpInterface 354 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 355 tensor::ExtractSliceOp> { 356 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 357 const AnalysisState &state) const { 358 return false; 359 } 360 361 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 362 const AnalysisState &state) const { 363 return false; 364 } 365 366 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 367 const AnalysisState &state) const { 368 return {{op->getOpResult(0), BufferRelation::Unknown}}; 369 } 370 371 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 372 const BufferizationOptions &options) const { 373 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 374 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 375 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 376 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 377 Location loc = extractSliceOp.getLoc(); 378 379 // Get source buffer. 380 FailureOr<Value> srcMemref = 381 getBuffer(rewriter, extractSliceOp.getSource(), options); 382 if (failed(srcMemref)) 383 return failure(); 384 385 // Take a subview of the source buffer. 386 auto resultMemrefType = 387 bufferization::getBufferType(extractSliceOp.getResult(), options); 388 if (failed(resultMemrefType)) 389 return failure(); 390 Value subView = rewriter.create<memref::SubViewOp>( 391 loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, 392 mixedOffsets, mixedSizes, mixedStrides); 393 394 replaceOpWithBufferizedValues(rewriter, op, subView); 395 return success(); 396 } 397 398 FailureOr<BaseMemRefType> 399 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 400 SmallVector<Value> &invocationStack) const { 401 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 402 assert(value == extractSliceOp.getResult() && "invalid value"); 403 auto srcMemrefType = bufferization::getBufferType( 404 extractSliceOp.getSource(), options, invocationStack); 405 if (failed(srcMemrefType)) 406 return failure(); 407 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 408 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 409 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 410 return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType( 411 extractSliceOp.getType().getShape(), 412 llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes, 413 mixedStrides)); 414 } 415 }; 416 417 /// Bufferization of tensor.extract. Replace with memref.load. 418 struct ExtractOpInterface 419 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 420 tensor::ExtractOp> { 421 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 422 const AnalysisState &state) const { 423 return true; 424 } 425 426 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 427 const AnalysisState &state) const { 428 return false; 429 } 430 431 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 432 const AnalysisState &state) const { 433 return {}; 434 } 435 436 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 437 const BufferizationOptions &options) const { 438 auto extractOp = cast<tensor::ExtractOp>(op); 439 FailureOr<Value> srcMemref = 440 getBuffer(rewriter, extractOp.getTensor(), options); 441 if (failed(srcMemref)) 442 return failure(); 443 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, 444 extractOp.getIndices()); 445 return success(); 446 } 447 }; 448 449 // Implements backtracking to traverse indices of the output buffer while 450 // iterating over op.elements(). 451 static void createStores(RewriterBase &rewriter, Location loc, int dim, 452 Value buffer, ArrayRef<int64_t> shape, 453 ArrayRef<Value> constants, 454 OperandRange::iterator &elementIt, 455 SmallVectorImpl<Value> &indices) { 456 if (dim == static_cast<int>(shape.size()) - 1) { 457 for (int i = 0; i < shape.back(); ++i) { 458 indices.back() = constants[i]; 459 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 460 ++elementIt; 461 } 462 return; 463 } 464 for (int i = 0; i < shape[dim]; ++i) { 465 indices[dim] = constants[i]; 466 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 467 indices); 468 } 469 } 470 471 /// Bufferization of tensor.from_elements. 472 struct FromElementsOpInterface 473 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 474 tensor::FromElementsOp> { 475 476 bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 477 478 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 479 const BufferizationOptions &options) const { 480 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 481 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType()); 482 483 // Allocate a buffer for the result. 484 Location loc = op->getLoc(); 485 auto shape = tensorType.getShape(); 486 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 487 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 488 rewriter, loc, fromElementsOp.getResult(), options, 489 /*copy=*/false); 490 if (failed(tensorAlloc)) 491 return failure(); 492 FailureOr<BaseMemRefType> memrefType = 493 bufferization::getBufferType(*tensorAlloc, options); 494 if (failed(memrefType)) 495 return failure(); 496 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 497 op->getLoc(), *memrefType, *tensorAlloc); 498 499 // Case: tensor<0xelem_type>. 500 if (fromElementsOp.getElements().empty()) { 501 replaceOpWithBufferizedValues(rewriter, op, buffer); 502 return success(); 503 } 504 505 // Case: tensor<elem_type>. 506 if (shape.empty()) { 507 rewriter.create<memref::StoreOp>( 508 loc, fromElementsOp.getElements().front(), buffer); 509 replaceOpWithBufferizedValues(rewriter, op, buffer); 510 return success(); 511 } 512 513 // Create constants for the range of possible indices [0, max{shape_i}). 514 auto maxDim = *llvm::max_element(shape); 515 SmallVector<Value, 2> constants; 516 constants.reserve(maxDim); 517 for (int i = 0; i < maxDim; ++i) 518 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 519 520 // Traverse all `elements` and create `memref.store` ops. 521 auto elementIt = fromElementsOp.getElements().begin(); 522 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 523 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 524 indices); 525 526 replaceOpWithBufferizedValues(rewriter, op, buffer); 527 528 return success(); 529 } 530 }; 531 532 /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim). 533 /// Such ops are lowered to linalg.map with the given tensor as a destination. 534 /// 535 /// Example: 536 /// ``` 537 /// %r = tensor.generate %x, %y { 538 /// ^bb0(%arg0: index, %arg1: index): 539 /// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index) 540 /// tensor.yield %0 : index 541 /// } : tensor<?x?xindex> 542 /// ``` 543 /// 544 /// Is lowered to: 545 /// ``` 546 /// linalg.map ins() outs(%dest) { 547 /// %d0 = linalg.index 0 : index 548 /// %d1 = linalg.index 1 : index 549 /// %0 = "some_op"(%d0, %d1) : (index, index) -> (index) 550 /// linalg.yield %0 : index 551 /// } 552 /// ``` 553 static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, 554 Value tensorDestination, 555 ValueRange dynamicSizes, 556 Region &generateBody) { 557 assert(generateBody.hasOneBlock() && "expected body with single block"); 558 auto tensorType = cast<RankedTensorType>(tensorDestination.getType()); 559 assert(generateBody.getNumArguments() == tensorType.getRank() && 560 "rank mismatch"); 561 562 // Create linalg::MapOp. 563 OpBuilder::InsertionGuard g(rewriter); 564 auto linalgOp = 565 rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), 566 /*init=*/tensorDestination); 567 Block &linalgBody = linalgOp.getMapper().emplaceBlock(); 568 569 // Create linalg::IndexOps. 570 rewriter.setInsertionPointToStart(&linalgBody); 571 SmallVector<Value> indices; 572 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim) 573 indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim)); 574 575 // Move over body. 576 rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices); 577 auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator()); 578 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue()); 579 580 return linalgOp.getResult()[0]; 581 } 582 583 /// Bufferization of tensor.generate. 584 struct GenerateOpInterface 585 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 586 tensor::GenerateOp> { 587 588 bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 589 590 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 591 const BufferizationOptions &options) const { 592 auto generateOp = cast<tensor::GenerateOp>(op); 593 594 auto type = generateOp.getResult().getType(); 595 596 // TODO: Implement memory space for this op. 597 if (options.defaultMemorySpaceFn(type) != Attribute()) 598 return op->emitError("memory space not implemented yet"); 599 600 // Allocate memory. 601 Location loc = op->getLoc(); 602 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 603 rewriter, loc, generateOp.getResult(), options, 604 /*copy=*/false); 605 if (failed(tensorAlloc)) 606 return failure(); 607 608 Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc, 609 generateOp.getDynamicExtents(), 610 generateOp.getBody()); 611 rewriter.replaceOp(generateOp, result); 612 613 return success(); 614 } 615 }; 616 617 /// Bufferization of tensor.insert. Replace with memref.store. 618 /// 619 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method 620 /// implementations for DestinationStyle ops. 621 struct InsertOpInterface 622 : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface, 623 tensor::InsertOp> { 624 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 625 const BufferizationOptions &options) const { 626 auto insertOp = cast<tensor::InsertOp>(op); 627 FailureOr<Value> destMemref = 628 getBuffer(rewriter, insertOp.getDest(), options); 629 if (failed(destMemref)) 630 return failure(); 631 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), 632 *destMemref, insertOp.getIndices()); 633 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 634 return success(); 635 } 636 }; 637 638 template <typename InsertOpTy> 639 static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp, 640 OpOperand &opOperand) { 641 // The source is always read. 642 if (opOperand == insertSliceOp.getSourceMutable()) 643 return true; 644 645 // For the destination, it depends... 646 assert(opOperand == insertSliceOp.getDestMutable() && "expected dest"); 647 648 // Dest is not read if it is entirely overwritten. E.g.: 649 // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> 650 bool allOffsetsZero = 651 llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex); 652 RankedTensorType destType = insertSliceOp.getDestType(); 653 bool sizesMatchDestSizes = 654 areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape()); 655 bool allStridesOne = 656 areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1); 657 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); 658 } 659 660 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 661 /// certain circumstances, this op can also be a no-op. 662 /// 663 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method 664 /// implementations for DestinationStyle ops. 665 struct InsertSliceOpInterface 666 : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface, 667 tensor::InsertSliceOp> { 668 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 669 const AnalysisState &state) const { 670 return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op), 671 opOperand); 672 } 673 674 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 675 const BufferizationOptions &options) const { 676 // insert_slice ops arise from tiling and bufferizing them out-of-place is 677 // generally a deal breaker. When used with loops, this ends up cloning the 678 // whole tensor on every single iteration and is a symptom of a 679 // catastrophically bad scheduling decision. 680 // TODO: be very loud about it or even consider failing the pass. 681 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 682 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 683 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 684 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 685 Location loc = insertSliceOp.getLoc(); 686 687 // Get destination buffer. 688 FailureOr<Value> dstMemref = 689 getBuffer(rewriter, insertSliceOp.getDest(), options); 690 if (failed(dstMemref)) 691 return failure(); 692 693 // Take a subview of the destination buffer. 694 auto dstMemrefType = cast<MemRefType>(dstMemref->getType()); 695 auto subviewMemRefType = 696 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 697 insertSliceOp.getSourceType().getShape(), dstMemrefType, 698 mixedOffsets, mixedSizes, mixedStrides)); 699 Value subView = rewriter.create<memref::SubViewOp>( 700 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 701 mixedStrides); 702 703 // Copy tensor. If this tensor.insert_slice has a matching 704 // tensor.extract_slice, the copy operation will eventually fold away. 705 FailureOr<Value> srcMemref = 706 getBuffer(rewriter, insertSliceOp.getSource(), options); 707 if (failed(srcMemref)) 708 return failure(); 709 if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) 710 return failure(); 711 712 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 713 return success(); 714 } 715 }; 716 717 /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor + 718 /// linalg.map + insert_slice. 719 /// For best performance, vectorize before bufferization (better performance in 720 /// case of padding with a constant). 721 struct PadOpInterface 722 : public BufferizableOpInterface::ExternalModel<PadOpInterface, 723 tensor::PadOp> { 724 bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 725 726 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 727 const AnalysisState &state) const { 728 return true; 729 } 730 731 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 732 const AnalysisState &state) const { 733 return false; 734 } 735 736 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 737 const AnalysisState &state) const { 738 return {}; 739 } 740 741 FailureOr<BaseMemRefType> 742 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 743 SmallVector<Value> &invocationStack) const { 744 // Infer memory space from the source tensor. 745 auto padOp = cast<tensor::PadOp>(op); 746 auto maybeSrcBufferType = bufferization::getBufferType( 747 padOp.getSource(), options, invocationStack); 748 if (failed(maybeSrcBufferType)) 749 return failure(); 750 MemRefLayoutAttrInterface layout; 751 return MemRefType::get(padOp.getResultType().getShape(), 752 padOp.getResultType().getElementType(), layout, 753 maybeSrcBufferType->getMemorySpace()); 754 } 755 756 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 757 const BufferizationOptions &options) const { 758 auto padOp = cast<tensor::PadOp>(op); 759 Location loc = padOp.getLoc(); 760 RankedTensorType resultType = padOp.getResultType(); 761 RankedTensorType srcType = padOp.getSourceType(); 762 763 auto toValue = [&](OpFoldResult ofr) { 764 if (auto value = dyn_cast<Value>(ofr)) 765 return value; 766 return rewriter 767 .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr)) 768 .getResult(); 769 }; 770 771 // Compute dynamic result dimensions. 772 SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad(); 773 SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad(); 774 SmallVector<Value> dynamicSizes; 775 for (int64_t i = 0; i < resultType.getRank(); ++i) { 776 if (!resultType.isDynamicDim(i)) 777 continue; 778 Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i); 779 Value lowPad = toValue(mixedLowPad[i]); 780 Value highPad = toValue(mixedHighPad[i]); 781 AffineExpr s0, s1, s2; 782 bindSymbols(op->getContext(), s0, s1, s2); 783 AffineExpr sumExpr = s0 + s1 + s2; 784 Value sum = rewriter.create<affine::AffineApplyOp>( 785 loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); 786 dynamicSizes.push_back(sum); 787 } 788 789 // Allocate a buffer for the padded result. 790 FailureOr<Value> tensorAlloc = 791 allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options, 792 /*copy=*/false); 793 if (failed(tensorAlloc)) 794 return failure(); 795 796 // tensor::PadOp is like tensor::GenerateOp: The only difference is that 797 // only a part of the generated tensor is needed. For simplicity, we reuse 798 // the same functionality here. 799 Value filledBuffer = lowerGenerateLikeOpBody( 800 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion()); 801 802 // Create tensor::InsertSliceOp. 803 SmallVector<OpFoldResult> sliceSizes = 804 getMixedSizes(rewriter, loc, padOp.getSource()); 805 SmallVector<OpFoldResult> sliceStrides(srcType.getRank(), 806 rewriter.getIndexAttr(1)); 807 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 808 padOp, padOp.getSource(), filledBuffer, 809 /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); 810 811 return success(); 812 } 813 }; 814 815 /// Bufferization of tensor.rank. Replace with memref.rank. 816 struct RankOpInterface 817 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 818 tensor::RankOp> { 819 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 820 const AnalysisState &state) const { 821 // The op reads the tensor's metadata but not its contents. 822 return false; 823 } 824 825 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 826 const AnalysisState &state) const { 827 return false; 828 } 829 830 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 831 const AnalysisState &state) const { 832 return {}; 833 } 834 835 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 836 const BufferizationOptions &options) const { 837 auto rankOp = cast<tensor::RankOp>(op); 838 FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options); 839 if (failed(v)) 840 return failure(); 841 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 842 *v); 843 return success(); 844 } 845 }; 846 847 /// Bufferization of tensor.reshape. Replace with memref.reshape. 848 struct ReshapeOpInterface 849 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 850 tensor::ReshapeOp> { 851 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 852 const AnalysisState &state) const { 853 // Depending on the layout map, the source buffer may have to be copied. 854 auto reshapeOp = cast<tensor::ReshapeOp>(op); 855 return opOperand == reshapeOp.getShapeMutable(); 856 } 857 858 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 859 const AnalysisState &state) const { 860 return false; 861 } 862 863 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 864 const AnalysisState &state) const { 865 return {{op->getOpResult(0), BufferRelation::Equivalent}}; 866 } 867 868 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 869 const BufferizationOptions &options) const { 870 auto reshapeOp = cast<tensor::ReshapeOp>(op); 871 FailureOr<Value> srcBuffer = 872 getBuffer(rewriter, reshapeOp.getSource(), options); 873 FailureOr<Value> shapeBuffer = 874 getBuffer(rewriter, reshapeOp.getShape(), options); 875 if (failed(srcBuffer) || failed(shapeBuffer)) 876 return failure(); 877 auto maybeResultMemRefType = 878 bufferization::getBufferType(reshapeOp.getResult(), options); 879 if (failed(maybeResultMemRefType)) 880 return failure(); 881 882 // memref.reshape requires the source buffer to have an identity layout. 883 // If the source memref does not have an identity layout, copy the source 884 // into a new buffer with an identity layout. 885 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType()); 886 if (srcType && !srcType.getLayout().isIdentity()) { 887 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 888 rewriter, op->getLoc(), reshapeOp.getSource(), options); 889 if (failed(tensorAlloc)) 890 return failure(); 891 auto memrefType = MemRefType::get( 892 srcType.getShape(), srcType.getElementType(), AffineMap(), 893 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace()); 894 srcBuffer = rewriter 895 .create<bufferization::ToMemrefOp>( 896 op->getLoc(), memrefType, *tensorAlloc) 897 .getResult(); 898 } 899 900 replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 901 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer); 902 return success(); 903 } 904 905 FailureOr<BaseMemRefType> 906 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 907 SmallVector<Value> &invocationStack) const { 908 auto reshapeOp = cast<tensor::ReshapeOp>(op); 909 assert(value == reshapeOp.getResult() && "unexpected value provided"); 910 auto maybeSourceBufferType = bufferization::getBufferType( 911 reshapeOp.getSource(), options, invocationStack); 912 if (failed(maybeSourceBufferType)) 913 return failure(); 914 return getMemRefTypeWithStaticIdentityLayout( 915 reshapeOp.getResult().getType(), 916 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()); 917 } 918 }; 919 920 /// Analysis of ParallelInsertSliceOp. 921 struct ParallelInsertSliceOpInterface 922 : public BufferizableOpInterface::ExternalModel< 923 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { 924 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 925 const AnalysisState &state) const { 926 return {}; 927 } 928 929 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 930 const AnalysisState &state) const { 931 return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op), 932 opOperand); 933 } 934 935 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 936 const AnalysisState &state) const { 937 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 938 return opOperand == parallelInsertSliceOp.getDestMutable(); 939 } 940 941 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 942 const BufferizationOptions &options) const { 943 OpBuilder::InsertionGuard g(rewriter); 944 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 945 ParallelCombiningOpInterface parallelCombiningParent = 946 parallelInsertSliceOp.getParallelCombiningParent(); 947 948 // Bufferize the op outside of the parallel combining terminator. 949 rewriter.setInsertionPoint(parallelCombiningParent); 950 951 // Get source and destination buffers. 952 FailureOr<Value> destBuffer = 953 getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); 954 if (failed(destBuffer)) 955 return failure(); 956 FailureOr<Value> srcBuffer = 957 getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); 958 if (failed(srcBuffer)) 959 return failure(); 960 961 // Take a subview of the destination buffer. 962 auto destBufferType = cast<MemRefType>(destBuffer->getType()); 963 auto subviewMemRefType = 964 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 965 parallelInsertSliceOp.getSourceType().getShape(), destBufferType, 966 parallelInsertSliceOp.getMixedOffsets(), 967 parallelInsertSliceOp.getMixedSizes(), 968 parallelInsertSliceOp.getMixedStrides())); 969 Value subview = rewriter.create<memref::SubViewOp>( 970 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, 971 parallelInsertSliceOp.getMixedOffsets(), 972 parallelInsertSliceOp.getMixedSizes(), 973 parallelInsertSliceOp.getMixedStrides()); 974 975 // This memcpy will fold away if everything bufferizes in-place. 976 if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), 977 *srcBuffer, subview))) 978 return failure(); 979 980 // In case the source was allocated in the same block, make sure that the 981 // deallocation op (if any) appears after the memcpy. By default, deallocs 982 // are placed before the terminator, but this does not work for ForallOp 983 // because the terminator does more than just yielding a value. 984 // 985 // Note: This is not a problem for the destination buffer because these are 986 // assumed to always bufferize in-place. 987 for (Operation *user : srcBuffer->getUsers()) { 988 if (hasEffect<MemoryEffects::Free>(user)) { 989 if (user->getBlock() == parallelCombiningParent->getBlock()) 990 rewriter.moveOpBefore(user, user->getBlock()->getTerminator()); 991 break; 992 } 993 } 994 995 // Delete the op. 996 rewriter.eraseOp(op); 997 return success(); 998 } 999 1000 /// tensor.parallel_insert_slice op has implicit inplace behavior. We 1001 /// shouldn't create copy to resolve conflict. 1002 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 1003 const AnalysisState &state) const { 1004 return success(); 1005 } 1006 }; 1007 1008 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled 1009 /// with a linalg.map. Similar to tensor.generate. 1010 struct SplatOpInterface 1011 : public BufferizableOpInterface::ExternalModel<SplatOpInterface, 1012 tensor::SplatOp> { 1013 1014 bool bufferizesToAllocation(Operation *op, Value value) const { return true; } 1015 1016 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1017 const BufferizationOptions &options) const { 1018 OpBuilder::InsertionGuard g(rewriter); 1019 auto splatOp = cast<tensor::SplatOp>(op); 1020 1021 // Allocate memory. 1022 Location loc = op->getLoc(); 1023 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 1024 rewriter, loc, splatOp.getResult(), options, 1025 /*copy=*/false); 1026 if (failed(tensorAlloc)) 1027 return failure(); 1028 1029 // Create linalg::MapOp. 1030 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType()); 1031 1032 // TODO: Implement memory space for this op. 1033 if (options.defaultMemorySpaceFn(tensorType) != Attribute()) 1034 return op->emitError("memory space not implemented yet"); 1035 1036 auto linalgOp = 1037 rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), 1038 /*init=*/*tensorAlloc); 1039 Block &linalgBody = linalgOp.getMapper().emplaceBlock(); 1040 1041 // Create linalg::IndexOps. 1042 rewriter.setInsertionPointToStart(&linalgBody); 1043 rewriter.create<linalg::YieldOp>(loc, splatOp.getInput()); 1044 rewriter.replaceOp(splatOp, linalgOp.getResult()[0]); 1045 1046 return success(); 1047 } 1048 }; 1049 1050 } // namespace 1051 } // namespace tensor 1052 } // namespace mlir 1053 1054 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 1055 DialectRegistry ®istry) { 1056 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 1057 CastOp::attachInterface<CastOpInterface>(*ctx); 1058 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 1059 DimOp::attachInterface<DimOpInterface>(*ctx); 1060 EmptyOp::attachInterface<EmptyOpInterface>(*ctx); 1061 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 1062 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 1063 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 1064 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 1065 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 1066 InsertOp::attachInterface<InsertOpInterface>(*ctx); 1067 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 1068 PadOp::attachInterface<PadOpInterface>(*ctx); 1069 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 1070 *ctx); 1071 RankOp::attachInterface<RankOpInterface>(*ctx); 1072 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 1073 SplatOp::attachInterface<SplatOpInterface>(*ctx); 1074 1075 // Load additional dialects of which ops may get created. 1076 ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>(); 1077 }); 1078 1079 // Bufferization requires SubsetInsertionOpInterface models. Make sure that 1080 // they are registered. 1081 tensor::registerSubsetOpInterfaceExternalModels(registry); 1082 } 1083