1 //===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains patterns to convert non-DPS ops to DPS ops. New 10 // tensor.empty ops are inserted as a destination. Such tensor.empty can be 11 // eliminated with "empty tensor elimination", allowing them to bufferize 12 // without an allocation (assuming there are no further conflicts). 13 // 14 //===----------------------------------------------------------------------===// 15 // 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/Support/Debug.h" 28 29 using namespace mlir; 30 using namespace mlir::tensor; 31 32 // Implements backtracking to traverse indices of the output buffer while 33 // iterating over op.elements(). 34 static Value createInserts(RewriterBase &rewriter, Location loc, int dim, 35 Value destination, ArrayRef<int64_t> shape, 36 ArrayRef<Value> constants, 37 OperandRange::iterator &elementIt, 38 SmallVectorImpl<Value> &indices) { 39 if (dim == static_cast<int>(shape.size()) - 1) { 40 for (int i = 0; i < shape.back(); ++i) { 41 indices.back() = constants[i]; 42 destination = rewriter.create<tensor::InsertOp>(loc, *elementIt, 43 destination, indices); 44 ++elementIt; 45 } 46 return destination; 47 } 48 for (int i = 0; i < shape[dim]; ++i) { 49 indices[dim] = constants[i]; 50 destination = createInserts(rewriter, loc, dim + 1, destination, shape, 51 constants, elementIt, indices); 52 } 53 return destination; 54 } 55 56 /// Create a memcpy from the given source tensor to the given destination 57 /// memref. The copy op type can be specified in the `options`. 58 static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, 59 Value memrefDest, 60 const linalg::BufferizeToAllocationOptions &options) { 61 auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType()); 62 assert(tensorType && "expected ranked tensor"); 63 assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref"); 64 65 switch (options.memcpyOp) { 66 case linalg::BufferizeToAllocationOptions::MemcpyOp:: 67 MaterializeInDestination: { 68 // Note: This is the preferred way of memcpy'ing because no layout map 69 // and/or memory space must be specified for the source. 70 auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>( 71 loc, tensorSource, memrefDest); 72 materializeOp.setWritable(true); 73 } break; 74 case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: { 75 // TODO: Support custom memory space on source. 76 // We do not know the layout map of the source yet, so use a fully dynamic 77 // layout for best compatibility. 78 Value toMemref = b.create<bufferization::ToMemrefOp>( 79 loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), 80 tensorSource, /*readOnly=*/true); 81 b.create<memref::CopyOp>(loc, toMemref, memrefDest); 82 } break; 83 case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: { 84 // TODO: Support custom memory space on source. 85 // We do not know the layout map of the source yet, so use a fully dynamic 86 // layout for best compatibility. 87 Value toMemref = b.create<bufferization::ToMemrefOp>( 88 loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), 89 tensorSource, /*readOnly=*/true); 90 b.create<linalg::CopyOp>(loc, toMemref, memrefDest); 91 } break; 92 }; 93 } 94 95 static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, 96 Location loc, PadOp padOp, 97 Value dest) { 98 OpBuilder::InsertionGuard g(rewriter); 99 RankedTensorType resultType = padOp.getResultType(); 100 101 // Examine the yielded value to decide if a linalg.generic is neede or a 102 // linalg.fill is sufficient. 103 Value yieldedValue = 104 cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue(); 105 Attribute constYieldedValue; 106 // Is the yielded value a bbArg defined outside of the PadOp? 107 bool outsideBbArg = 108 isa<BlockArgument>(yieldedValue) && 109 cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() != 110 padOp.getOperation(); 111 // Is the yielded value an OpResult defined outside of the PadOp? 112 bool outsideOpResult = 113 isa<OpResult>(yieldedValue) && 114 yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); 115 bool invariantYieldedValue = outsideBbArg || outsideOpResult; 116 if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { 117 // Padding with a constant: Create linalg.fill. 118 Dialect *arithDialect = 119 rewriter.getContext()->getLoadedDialect<arith::ArithDialect>(); 120 Value fillValue = 121 arithDialect 122 ->materializeConstant(rewriter, constYieldedValue, 123 yieldedValue.getType(), yieldedValue.getLoc()) 124 ->getResult(0); 125 auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue), 126 ValueRange(dest)); 127 return fillOp; 128 } 129 130 if (invariantYieldedValue) { 131 // Padding with an invariant value. 132 auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue), 133 ValueRange(dest)); 134 return fillOp; 135 } 136 137 // Create linalg.generic. 138 SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(), 139 utils::IteratorType::parallel); 140 SmallVector<AffineMap> indexingMaps( 141 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); 142 auto genericOp = rewriter.create<linalg::GenericOp>( 143 loc, resultType, /*inputs=*/ValueRange(), 144 /*outputs=*/ValueRange{dest}, /*indexingMaps=*/ 145 indexingMaps, iteratorTypes); 146 Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, 147 resultType.getElementType(), loc); 148 rewriter.setInsertionPointToStart(body); 149 SmallVector<Value> bbArgReplacements; 150 for (int64_t i = 0; i < resultType.getRank(); ++i) 151 bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i)); 152 rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); 153 154 // Update terminator. 155 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator()); 156 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue()); 157 return genericOp; 158 } 159 160 static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b, 161 Value value) { 162 auto tensorType = cast<RankedTensorType>(value.getType()); 163 if (tensorType.hasStaticShape()) 164 return {}; 165 166 // Try to reify dynamic sizes. 167 ReifiedRankedShapedTypeDims reifiedShape; 168 if (isa<OpResult>(value) && 169 succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) { 170 SmallVector<Value> dynSizes; 171 for (int64_t i = 0; i < tensorType.getRank(); ++i) { 172 if (tensorType.isDynamicDim(i)) 173 dynSizes.push_back(cast<Value>( 174 reifiedShape[cast<OpResult>(value).getResultNumber()][i])); 175 } 176 return dynSizes; 177 } 178 179 // Create tensor.dim ops. 180 SmallVector<Value> dynSizes; 181 for (int64_t i = 0; i < tensorType.getRank(); ++i) { 182 if (tensorType.isDynamicDim(i)) 183 dynSizes.push_back( 184 b.create<DimOp>(value.getLoc(), value, 185 b.create<arith::ConstantIndexOp>(value.getLoc(), i))); 186 } 187 return dynSizes; 188 } 189 190 static Value 191 createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, 192 const linalg::BufferizeToAllocationOptions &options, 193 Attribute memorySpace = {}) { 194 OpBuilder::InsertionGuard g(rewriter); 195 auto tensorType = cast<RankedTensorType>(value.getType()); 196 197 // Create buffer allocation. 198 auto memrefType = 199 cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout( 200 tensorType, memorySpace)); 201 SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value); 202 203 Value alloc; 204 if (options.allocOp == 205 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) { 206 alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes); 207 if (options.emitDealloc) { 208 // Place deallocation at the end of the block. 209 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); 210 rewriter.create<memref::DeallocOp>(loc, alloc); 211 } 212 } else if (options.allocOp == 213 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) { 214 alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes); 215 // No dealloc is needed. 216 } 217 218 return alloc; 219 } 220 221 Value linalg::bufferizeToAllocation( 222 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, 223 PadOp padOp, Attribute memorySpace, Operation *insertionPoint) { 224 // tensor.pad does not have a destination operand. 225 assert(!options.bufferizeDestinationOnly && "invalid options"); 226 227 OpBuilder::InsertionGuard g(rewriter); 228 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp); 229 Location loc = padOp.getLoc(); 230 231 // Create buffer allocation. 232 Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(), 233 options, memorySpace); 234 rewriter.setInsertionPoint(padOp); 235 236 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) { 237 // Create linalg.fill or linalg.generic. Not needed if there is no padding. 238 Operation *fillOp = 239 movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc); 240 rewriter.setInsertionPointAfter(fillOp); 241 } 242 243 // Create memcpy. 244 SmallVector<OpFoldResult> sizes = 245 getMixedSizes(rewriter, loc, padOp.getSource()); 246 SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(), 247 rewriter.getIndexAttr(1)); 248 Value subview = rewriter.create<memref::SubViewOp>( 249 loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides); 250 createMemcpy(rewriter, loc, padOp.getSource(), subview, options); 251 252 // Create bufferization.to_tensor with "restrict" and "writable". The returned 253 // tensor is a new buffer allocation, so it does not alias with any buffer. 254 Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( 255 loc, alloc, /*restrict=*/true, /*writable=*/true); 256 rewriter.replaceOp(padOp, toTensorOp); 257 return alloc; 258 } 259 260 Value linalg::bufferizeToAllocation( 261 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, 262 vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) { 263 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 && 264 "expected single masked op"); 265 OpBuilder::InsertionGuard g(rewriter); 266 bufferization::BufferizationOptions bufferizationOptions; 267 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator(); 268 assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator"); 269 270 // Bufferize maskable op. By default, place the buffer allocation right before 271 // the mask op. 272 Value alloc = bufferizeToAllocation( 273 rewriter, options, maskOp.getMaskableOp(), memorySpace, 274 /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp); 275 276 if (options.bufferizeDestinationOnly) 277 return alloc; 278 279 // Bufferize terminator. 280 rewriter.setInsertionPoint(yieldOp); 281 if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize( 282 rewriter, bufferizationOptions))) 283 return nullptr; 284 285 // Erase dead to_tensor ops inside of the mask op. This is necessary because 286 // there only be one op (apart from the terminator) inside the mask op. 287 // TODO: Remove dead to_tensor ops more aggressively during bufferization. 288 SmallVector<Operation *> toTensorOps; 289 maskOp.walk([&](bufferization::ToTensorOp toTensorOp) { 290 if (toTensorOp->getUses().empty()) 291 toTensorOps.push_back(toTensorOp.getOperation()); 292 }); 293 for (Operation *op : toTensorOps) 294 rewriter.eraseOp(op); 295 296 // Bufferize mask op. 297 SmallVector<OpOperand *> resultUses; 298 for (Value result : maskOp.getResults()) 299 if (isa<TensorType>(result.getType())) 300 for (OpOperand &use : result.getUses()) 301 resultUses.push_back(&use); 302 rewriter.setInsertionPoint(maskOp); 303 if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation()) 304 .bufferize(rewriter, bufferizationOptions))) 305 return nullptr; 306 307 // Set "restrict" attribute, indicating that no other tensor aliases with 308 // this tensor. That is because we just allocated a new buffer for the tensor. 309 for (OpOperand *resultUse : resultUses) { 310 auto toTensorOp = 311 resultUse->get().getDefiningOp<bufferization::ToTensorOp>(); 312 assert(toTensorOp && "expected to_tensor op"); 313 rewriter.modifyOpInPlace(toTensorOp, [&]() { 314 toTensorOp.setRestrict(true); 315 toTensorOp.setWritable(true); 316 }); 317 } 318 319 return alloc; 320 } 321 322 Value linalg::bufferizeToAllocation( 323 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, 324 bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace, 325 Operation *insertionPoint) { 326 Location loc = allocTensorOp.getLoc(); 327 OpBuilder::InsertionGuard g(rewriter); 328 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp); 329 bufferization::BufferizationOptions bufferizationOptions; 330 331 // Create buffer allocation. 332 Value alloc = createAllocationForTensor( 333 rewriter, loc, allocTensorOp.getResult(), options, memorySpace); 334 335 // Create bufferization.to_tensor with "restrict" and "writable". The returned 336 // tensor is a new buffer allocation, so it does not alias with any buffer. 337 Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( 338 loc, alloc, /*restrict=*/true, /*writable=*/true); 339 rewriter.replaceOp(allocTensorOp, toTensorOp); 340 return alloc; 341 } 342 343 /// Lower tensor.from_elements to a sequence of chained tensor.insert. 344 FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle( 345 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { 346 Location loc = fromElementsOp.getLoc(); 347 RankedTensorType tensorType = 348 cast<RankedTensorType>(fromElementsOp.getType()); 349 auto shape = tensorType.getShape(); 350 351 // Create tensor.empty. 352 auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange()); 353 354 // Case: tensor<elem_type>. 355 if (shape.empty()) { 356 Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>( 357 fromElementsOp, fromElementsOp.getElements().front(), 358 emptyOp.getResult(), ValueRange()); 359 return res; 360 } 361 362 // Create constants for the range of possible indices [0, max{shape_i}). 363 auto maxDim = *llvm::max_element(shape); 364 SmallVector<Value, 2> constants; 365 constants.reserve(maxDim); 366 for (int i = 0; i < maxDim; ++i) 367 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 368 369 // Traverse all elements and create tensor.insert ops. 370 auto elementIt = fromElementsOp.getElements().begin(); 371 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 372 Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(), 373 shape, constants, elementIt, indices); 374 375 // Replace tensor.from_elements. 376 rewriter.replaceOp(fromElementsOp, result); 377 return result.getDefiningOp(); 378 } 379 380 /// Lower tensor.generate to linalg.generic. 381 FailureOr<Operation *> 382 mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, 383 tensor::GenerateOp generateOp) { 384 // Only ops with exactly one block are supported. 385 if (!generateOp.getBody().hasOneBlock()) 386 return failure(); 387 388 Location loc = generateOp.getLoc(); 389 RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType()); 390 391 // Create tensor.empty. 392 auto emptyOp = 393 rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents()); 394 395 // Create linalg.generic. 396 SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(), 397 utils::IteratorType::parallel); 398 SmallVector<AffineMap> indexingMaps( 399 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); 400 auto genericOp = rewriter.create<linalg::GenericOp>( 401 loc, tensorType, /*inputs=*/ValueRange(), 402 /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ 403 indexingMaps, iteratorTypes); 404 Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, 405 tensorType.getElementType(), loc); 406 rewriter.setInsertionPointToStart(body); 407 SmallVector<Value> bbArgReplacements; 408 for (int64_t i = 0; i < tensorType.getRank(); ++i) 409 bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i)); 410 rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements); 411 412 // Update terminator. 413 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator()); 414 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue()); 415 416 // Replace tensor.generate. 417 rewriter.replaceOp(generateOp, genericOp->getResult(0)); 418 return genericOp.getOperation(); 419 } 420 421 /// Lower tensor.pad to linalg.generic + tensor.insert_slice. 422 FailureOr<Operation *> 423 mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, 424 tensor::PadOp padOp) { 425 // Only ops with exactly one block are supported. 426 if (!padOp.getBodyRegion().hasOneBlock()) 427 return failure(); 428 429 // Create tensor.empty. 430 Location loc = padOp.getLoc(); 431 RankedTensorType resultType = padOp.getResultType(); 432 ReifiedRankedShapedTypeDims reifiedShape; 433 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) 434 return rewriter.notifyMatchFailure( 435 padOp, "failed to reify tensor.pad op result shape"); 436 SmallVector<Value> dynamicSizes; 437 for (int64_t i = 0; i < resultType.getRank(); ++i) 438 if (resultType.isDynamicDim(i)) 439 dynamicSizes.push_back(cast<Value>(reifiedShape[0][i])); 440 441 // If the `padOp` has a nofold attribute and all paddings are known to be 0, 442 // explicitly insert a `linalg.copy`. 443 if (padOp.getNofoldAttr() && 444 llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) && 445 llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) { 446 using bufferization::AllocTensorOp; 447 Value allocated = 448 rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes); 449 auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>( 450 padOp, padOp.getSource(), allocated); 451 return copyOp.getOperation(); 452 } 453 454 Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes); 455 // Create linalg.fill or linalg.generic. 456 Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty); 457 rewriter.setInsertionPointAfter(fillOp); 458 459 // Create tensor::InsertSliceOp. 460 SmallVector<OpFoldResult> sliceSizes = 461 getMixedSizes(rewriter, loc, padOp.getSource()); 462 SmallVector<OpFoldResult> sliceStrides(resultType.getRank(), 463 rewriter.getIndexAttr(1)); 464 auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 465 padOp, padOp.getSource(), fillOp->getResult(0), 466 /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); 467 return insertSliceOp.getOperation(); 468 } 469 470 Value linalg::bufferizeToAllocation( 471 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, 472 Operation *op, Attribute memorySpace, Operation *insertionPoint) { 473 using namespace bufferization; 474 475 // Call specialized overload for certain ops. 476 if (auto padOp = dyn_cast<tensor::PadOp>(op)) 477 return bufferizeToAllocation(rewriter, options, padOp, memorySpace); 478 if (auto maskOp = dyn_cast<vector::MaskOp>(op)) 479 return bufferizeToAllocation(rewriter, options, maskOp, memorySpace); 480 if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op)) 481 return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace); 482 483 // Only bufferizable ops are supported. 484 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 485 if (!bufferizableOp) 486 return nullptr; 487 BufferizationOptions bufferizationOptions; 488 AnalysisState state(bufferizationOptions); 489 490 #ifndef NDEBUG 491 if (!options.bufferizeDestinationOnly) { 492 // Ops with nested tensor ops are not supported yet. At the moment, this 493 // function just bufferizes the given op itself, but not its body. 494 op->walk([&](Operation *nestedOp) { 495 if (op == nestedOp) 496 return; 497 if (llvm::any_of(nestedOp->getOperands(), 498 [](Value v) { return isa<TensorType>(v.getType()); })) 499 llvm_unreachable("ops with nested tensor ops are not supported yet"); 500 if (llvm::any_of(nestedOp->getResults(), 501 [](Value v) { return isa<TensorType>(v.getType()); })) 502 llvm_unreachable("ops with nested tensor ops are not supported yet"); 503 }); 504 } 505 #endif // NDEBUG 506 507 // Gather tensor results. 508 SmallVector<OpResult> tensorResults; 509 for (OpResult result : op->getResults()) { 510 if (!isa<TensorType>(result.getType())) 511 continue; 512 // Unranked tensors are not supported 513 if (!isa<RankedTensorType>(result.getType())) 514 return nullptr; 515 // Ops that bufferize to an allocation are not supported. 516 if (bufferizableOp.bufferizesToAllocation(result)) 517 return nullptr; 518 tensorResults.push_back(result); 519 } 520 521 // Gather all operands that should bufferize to a new allocation. I.e., 522 // bufferize out-of-place. 523 SmallVector<OpOperand *> outOfPlaceOperands, resultUses; 524 auto addOutOfPlaceOperand = [&](OpOperand *operand) { 525 if (!llvm::is_contained(outOfPlaceOperands, operand)) 526 outOfPlaceOperands.push_back(operand); 527 }; 528 for (OpResult result : tensorResults) { 529 AliasingOpOperandList aliasingOperands = 530 state.getAliasingOpOperands(result); 531 for (const AliasingOpOperand &operand : aliasingOperands) { 532 addOutOfPlaceOperand(operand.opOperand); 533 for (OpOperand &resultUse : result.getUses()) 534 resultUses.push_back(&resultUse); 535 } 536 } 537 for (OpOperand &operand : op->getOpOperands()) { 538 if (!state.bufferizesToMemoryWrite(operand)) 539 continue; 540 if (!isa<RankedTensorType>(operand.get().getType())) 541 continue; 542 addOutOfPlaceOperand(&operand); 543 } 544 // TODO: Support multiple buffers. 545 if (outOfPlaceOperands.size() != 1) 546 return nullptr; 547 548 // Allocate buffers. 549 OpBuilder::InsertionGuard g(rewriter); 550 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op); 551 SmallVector<Value> allocs; 552 for (OpOperand *operand : outOfPlaceOperands) { 553 Value alloc = createAllocationForTensor( 554 rewriter, op->getLoc(), operand->get(), options, memorySpace); 555 allocs.push_back(alloc); 556 if (!state.findDefinitions(operand).empty()) { 557 // Initialize buffer with a copy of the operand data. Not needed if the 558 // tensor is uninitialized. 559 createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); 560 } 561 rewriter.modifyOpInPlace(op, [&]() { 562 auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc); 563 operand->set(toTensorOp); 564 if (options.bufferizeDestinationOnly) { 565 rewriter.modifyOpInPlace(toTensorOp, [&]() { 566 toTensorOp.setRestrict(true); 567 toTensorOp.setWritable(true); 568 }); 569 } 570 }); 571 } 572 573 if (options.bufferizeDestinationOnly) 574 return allocs.front(); 575 576 // Bufferize the op. 577 rewriter.setInsertionPoint(op); 578 if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions))) 579 return nullptr; 580 581 // Set "restrict" attribute, indicating that no other tensor aliases with 582 // this tensor. That is because we just allocated a new buffer for the tensor. 583 for (OpOperand *resultUse : resultUses) { 584 auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>(); 585 assert(toTensorOp && "expected to_tensor op"); 586 rewriter.modifyOpInPlace(toTensorOp, [&]() { 587 toTensorOp.setRestrict(true); 588 toTensorOp.setWritable(true); 589 }); 590 } 591 return allocs.front(); 592 } 593 594 namespace { 595 596 template <typename OpTy> 597 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, 598 PatternRewriter &rewriter) { 599 return linalg::rewriteInDestinationPassingStyle(rewriter, op); 600 } 601 602 } // namespace 603 604 void linalg::populateConvertToDestinationStylePatterns( 605 RewritePatternSet &patterns) { 606 patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>); 607 patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>); 608 patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>); 609 } 610