1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Utils/StaticValueUtils.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/IR/PatternMatch.h" 23 24 using namespace mlir; 25 using namespace mlir::bufferization; 26 using namespace mlir::scf; 27 28 namespace mlir { 29 namespace scf { 30 namespace { 31 32 /// Helper function for loop bufferization. Cast the given buffer to the given 33 /// memref type. 34 static Value castBuffer(OpBuilder &b, Value buffer, Type type) { 35 assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType"); 36 assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType"); 37 // If the buffer already has the correct type, no cast is needed. 38 if (buffer.getType() == type) 39 return buffer; 40 // TODO: In case `type` has a layout map that is not the fully dynamic 41 // one, we may not be able to cast the buffer. In that case, the loop 42 // iter_arg's layout map must be changed (see uses of `castBuffer`). 43 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 44 "scf.while op bufferization: cast incompatible"); 45 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); 46 } 47 48 /// Helper function for loop bufferization. Return "true" if the given value 49 /// is guaranteed to not alias with an external tensor apart from values in 50 /// `exceptions`. A value is external if it is defined outside of the given 51 /// region or if it is an entry block argument of the region. 52 static bool doesNotAliasExternalValue(Value value, Region *region, 53 ValueRange exceptions, 54 const OneShotAnalysisState &state) { 55 assert(region->getBlocks().size() == 1 && 56 "expected region with single block"); 57 bool result = true; 58 state.applyOnAliases(value, [&](Value alias) { 59 if (llvm::is_contained(exceptions, alias)) 60 return; 61 Region *aliasRegion = alias.getParentRegion(); 62 if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion)) 63 result = false; 64 if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion)) 65 result = false; 66 }); 67 return result; 68 } 69 70 /// Bufferization of scf.condition. 71 struct ConditionOpInterface 72 : public BufferizableOpInterface::ExternalModel<ConditionOpInterface, 73 scf::ConditionOp> { 74 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 75 const AnalysisState &state) const { 76 return true; 77 } 78 79 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 80 const AnalysisState &state) const { 81 return false; 82 } 83 84 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 85 const AnalysisState &state) const { 86 return {}; 87 } 88 89 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 90 const AnalysisState &state) const { 91 // Condition operands always bufferize inplace. Otherwise, an alloc + copy 92 // may be generated inside the block. We should not return/yield allocations 93 // when possible. 94 return true; 95 } 96 97 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 98 const BufferizationOptions &options) const { 99 auto conditionOp = cast<scf::ConditionOp>(op); 100 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp()); 101 102 SmallVector<Value> newArgs; 103 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 104 Value value = it.value(); 105 if (isa<TensorType>(value.getType())) { 106 FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options); 107 if (failed(maybeBuffer)) 108 return failure(); 109 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( 110 whileOp.getAfterArguments()[it.index()], options); 111 if (failed(resultType)) 112 return failure(); 113 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType); 114 newArgs.push_back(buffer); 115 } else { 116 newArgs.push_back(value); 117 } 118 } 119 120 replaceOpWithNewBufferizedOp<scf::ConditionOp>( 121 rewriter, op, conditionOp.getCondition(), newArgs); 122 return success(); 123 } 124 }; 125 126 /// Return the unique scf.yield op. If there are multiple or no scf.yield ops, 127 /// return an empty op. 128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { 129 scf::YieldOp result; 130 for (Block &block : executeRegionOp.getRegion()) { 131 if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) { 132 if (result) 133 return {}; 134 result = yieldOp; 135 } 136 } 137 return result; 138 } 139 140 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 141 /// fully implemented at the moment. 142 struct ExecuteRegionOpInterface 143 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< 144 ExecuteRegionOpInterface, scf::ExecuteRegionOp> { 145 146 static bool supportsUnstructuredControlFlow() { return true; } 147 148 bool isWritable(Operation *op, Value value, 149 const AnalysisState &state) const { 150 return true; 151 } 152 153 LogicalResult verifyAnalysis(Operation *op, 154 const AnalysisState &state) const { 155 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 156 // TODO: scf.execute_region with multiple yields are not supported. 157 if (!getUniqueYieldOp(executeRegionOp)) 158 return op->emitOpError("op without unique scf.yield is not supported"); 159 return success(); 160 } 161 162 AliasingOpOperandList 163 getAliasingOpOperands(Operation *op, Value value, 164 const AnalysisState &state) const { 165 if (auto bbArg = dyn_cast<BlockArgument>(value)) 166 return getAliasingBranchOpOperands(op, bbArg, state); 167 168 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 169 // any SSA value that is in scope. To allow for use-def chain traversal 170 // through ExecuteRegionOps in the analysis, the corresponding yield value 171 // is considered to be aliasing with the result. 172 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 173 auto it = llvm::find(op->getOpResults(), value); 174 assert(it != op->getOpResults().end() && "invalid value"); 175 size_t resultNum = std::distance(op->getOpResults().begin(), it); 176 auto yieldOp = getUniqueYieldOp(executeRegionOp); 177 // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. 178 if (!yieldOp) 179 return {}; 180 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; 181 } 182 183 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 184 const BufferizationOptions &options) const { 185 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 186 auto yieldOp = getUniqueYieldOp(executeRegionOp); 187 TypeRange newResultTypes(yieldOp.getResults()); 188 189 // Create new op and move over region. 190 auto newOp = 191 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 192 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 193 194 // Bufferize every block. 195 for (Block &block : newOp.getRegion()) 196 if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, 197 options))) 198 return failure(); 199 200 // Update all uses of the old op. 201 rewriter.setInsertionPointAfter(newOp); 202 SmallVector<Value> newResults; 203 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 204 if (isa<TensorType>(it.value())) { 205 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 206 executeRegionOp.getLoc(), it.value(), 207 newOp->getResult(it.index()))); 208 } else { 209 newResults.push_back(newOp->getResult(it.index())); 210 } 211 } 212 213 // Replace old op. 214 rewriter.replaceOp(executeRegionOp, newResults); 215 216 return success(); 217 } 218 }; 219 220 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 221 struct IfOpInterface 222 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 223 AliasingOpOperandList 224 getAliasingOpOperands(Operation *op, Value value, 225 const AnalysisState &state) const { 226 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 227 // value that is in scope. To allow for use-def chain traversal through 228 // IfOps in the analysis, both corresponding yield values from the then/else 229 // branches are considered to be aliasing with the result. 230 auto ifOp = cast<scf::IfOp>(op); 231 size_t resultNum = std::distance(op->getOpResults().begin(), 232 llvm::find(op->getOpResults(), value)); 233 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum); 234 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum); 235 return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false}, 236 {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}}; 237 } 238 239 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 240 const BufferizationOptions &options) const { 241 OpBuilder::InsertionGuard g(rewriter); 242 auto ifOp = cast<scf::IfOp>(op); 243 244 // Compute bufferized result types. 245 SmallVector<Type> newTypes; 246 for (Value result : ifOp.getResults()) { 247 if (!isa<TensorType>(result.getType())) { 248 newTypes.push_back(result.getType()); 249 continue; 250 } 251 auto bufferType = bufferization::getBufferType(result, options); 252 if (failed(bufferType)) 253 return failure(); 254 newTypes.push_back(*bufferType); 255 } 256 257 // Create new op. 258 rewriter.setInsertionPoint(ifOp); 259 auto newIfOp = 260 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 261 /*withElseRegion=*/true); 262 263 // Move over then/else blocks. 264 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 265 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 266 267 // Replace op results. 268 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 269 270 return success(); 271 } 272 273 FailureOr<BaseMemRefType> 274 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 275 SmallVector<Value> &invocationStack) const { 276 auto ifOp = cast<scf::IfOp>(op); 277 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator()); 278 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator()); 279 assert(value.getDefiningOp() == op && "invalid valid"); 280 281 // Determine buffer types of the true/false branches. 282 auto opResult = cast<OpResult>(value); 283 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); 284 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); 285 BaseMemRefType thenBufferType, elseBufferType; 286 if (isa<BaseMemRefType>(thenValue.getType())) { 287 // True branch was already bufferized. 288 thenBufferType = cast<BaseMemRefType>(thenValue.getType()); 289 } else { 290 auto maybeBufferType = 291 bufferization::getBufferType(thenValue, options, invocationStack); 292 if (failed(maybeBufferType)) 293 return failure(); 294 thenBufferType = *maybeBufferType; 295 } 296 if (isa<BaseMemRefType>(elseValue.getType())) { 297 // False branch was already bufferized. 298 elseBufferType = cast<BaseMemRefType>(elseValue.getType()); 299 } else { 300 auto maybeBufferType = 301 bufferization::getBufferType(elseValue, options, invocationStack); 302 if (failed(maybeBufferType)) 303 return failure(); 304 elseBufferType = *maybeBufferType; 305 } 306 307 // Best case: Both branches have the exact same buffer type. 308 if (thenBufferType == elseBufferType) 309 return thenBufferType; 310 311 // Memory space mismatch. 312 if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) 313 return op->emitError("inconsistent memory space on then/else branches"); 314 315 // Layout maps are different: Promote to fully dynamic layout map. 316 return getMemRefTypeWithFullyDynamicLayout( 317 cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()); 318 } 319 }; 320 321 /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that 322 /// yields memrefs. 323 struct IndexSwitchOpInterface 324 : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface, 325 scf::IndexSwitchOp> { 326 AliasingOpOperandList 327 getAliasingOpOperands(Operation *op, Value value, 328 const AnalysisState &state) const { 329 // IndexSwitchOps do not have tensor OpOperands. The yielded value can be 330 // any SSA. This is similar to IfOps. 331 auto switchOp = cast<scf::IndexSwitchOp>(op); 332 int64_t resultNum = cast<OpResult>(value).getResultNumber(); 333 AliasingOpOperandList result; 334 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { 335 auto yieldOp = 336 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator()); 337 result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum), 338 BufferRelation::Equivalent, 339 /*isDefinite=*/false)); 340 } 341 auto defaultYieldOp = 342 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator()); 343 result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum), 344 BufferRelation::Equivalent, 345 /*isDefinite=*/false)); 346 return result; 347 } 348 349 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 350 const BufferizationOptions &options) const { 351 OpBuilder::InsertionGuard g(rewriter); 352 auto switchOp = cast<scf::IndexSwitchOp>(op); 353 354 // Compute bufferized result types. 355 SmallVector<Type> newTypes; 356 for (Value result : switchOp.getResults()) { 357 if (!isa<TensorType>(result.getType())) { 358 newTypes.push_back(result.getType()); 359 continue; 360 } 361 auto bufferType = bufferization::getBufferType(result, options); 362 if (failed(bufferType)) 363 return failure(); 364 newTypes.push_back(*bufferType); 365 } 366 367 // Create new op. 368 rewriter.setInsertionPoint(switchOp); 369 auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>( 370 switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(), 371 switchOp.getCases().size()); 372 373 // Move over blocks. 374 for (auto [src, dest] : 375 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions())) 376 rewriter.inlineRegionBefore(src, dest, dest.begin()); 377 rewriter.inlineRegionBefore(switchOp.getDefaultRegion(), 378 newSwitchOp.getDefaultRegion(), 379 newSwitchOp.getDefaultRegion().begin()); 380 381 // Replace op results. 382 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults()); 383 384 return success(); 385 } 386 387 FailureOr<BaseMemRefType> 388 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 389 SmallVector<Value> &invocationStack) const { 390 auto switchOp = cast<scf::IndexSwitchOp>(op); 391 assert(value.getDefiningOp() == op && "invalid value"); 392 int64_t resultNum = cast<OpResult>(value).getResultNumber(); 393 394 // Helper function to get buffer type of a case. 395 SmallVector<BaseMemRefType> yieldedTypes; 396 auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> { 397 auto yieldOp = cast<scf::YieldOp>(b.getTerminator()); 398 Value yieldedValue = yieldOp->getOperand(resultNum); 399 if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType())) 400 return bufferType; 401 auto maybeBufferType = 402 bufferization::getBufferType(yieldedValue, options, invocationStack); 403 if (failed(maybeBufferType)) 404 return failure(); 405 return maybeBufferType; 406 }; 407 408 // Compute buffer type of the default case. 409 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock()); 410 if (failed(maybeBufferType)) 411 return failure(); 412 BaseMemRefType bufferType = *maybeBufferType; 413 414 // Compute buffer types of all other cases. 415 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { 416 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i)); 417 if (failed(yieldedBufferType)) 418 return failure(); 419 420 // Best case: Both branches have the exact same buffer type. 421 if (bufferType == *yieldedBufferType) 422 continue; 423 424 // Memory space mismatch. 425 if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace()) 426 return op->emitError("inconsistent memory space on switch cases"); 427 428 // Layout maps are different: Promote to fully dynamic layout map. 429 bufferType = getMemRefTypeWithFullyDynamicLayout( 430 cast<TensorType>(value.getType()), bufferType.getMemorySpace()); 431 } 432 433 return bufferType; 434 } 435 }; 436 437 /// Helper function for loop bufferization. Return the indices of all values 438 /// that have a tensor type. 439 static DenseSet<int64_t> getTensorIndices(ValueRange values) { 440 DenseSet<int64_t> result; 441 for (const auto &it : llvm::enumerate(values)) 442 if (isa<TensorType>(it.value().getType())) 443 result.insert(it.index()); 444 return result; 445 } 446 447 /// Helper function for loop bufferization. Return the indices of all 448 /// bbArg/yielded value pairs who's buffer relation is "Equivalent". 449 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, 450 ValueRange yieldedValues, 451 const AnalysisState &state) { 452 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); 453 DenseSet<int64_t> result; 454 for (unsigned int i = 0; i < minSize; ++i) { 455 if (!isa<TensorType>(bbArgs[i].getType()) || 456 !isa<TensorType>(yieldedValues[i].getType())) 457 continue; 458 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) 459 result.insert(i); 460 } 461 return result; 462 } 463 464 /// Helper function for loop bufferization. Return the bufferized values of the 465 /// given OpOperands. If an operand is not a tensor, return the original value. 466 static FailureOr<SmallVector<Value>> 467 getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, 468 const BufferizationOptions &options) { 469 SmallVector<Value> result; 470 for (OpOperand &opOperand : operands) { 471 if (isa<TensorType>(opOperand.get().getType())) { 472 FailureOr<Value> resultBuffer = 473 getBuffer(rewriter, opOperand.get(), options); 474 if (failed(resultBuffer)) 475 return failure(); 476 result.push_back(*resultBuffer); 477 } else { 478 result.push_back(opOperand.get()); 479 } 480 } 481 return result; 482 } 483 484 /// Helper function for loop bufferization. Given a list of bbArgs of the new 485 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into 486 /// ToTensorOps, so that the block body can be moved over to the new op. 487 static SmallVector<Value> 488 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, 489 Block::BlockArgListType oldBbArgs, 490 const DenseSet<int64_t> &tensorIndices) { 491 SmallVector<Value> result; 492 for (const auto &it : llvm::enumerate(bbArgs)) { 493 size_t idx = it.index(); 494 Value val = it.value(); 495 if (tensorIndices.contains(idx)) { 496 result.push_back(rewriter 497 .create<bufferization::ToTensorOp>( 498 val.getLoc(), oldBbArgs[idx].getType(), val) 499 .getResult()); 500 } else { 501 result.push_back(val); 502 } 503 } 504 return result; 505 } 506 507 /// Compute the bufferized type of a loop iter_arg. This type must be equal to 508 /// the bufferized type of the corresponding init_arg and the bufferized type 509 /// of the corresponding yielded value. 510 /// 511 /// This function uses bufferization::getBufferType to compute the bufferized 512 /// type of the init_arg and of the yielded value. (The computation of the 513 /// bufferized yielded value type usually requires computing the bufferized type 514 /// of the iter_arg again; the implementation of getBufferType traces back the 515 /// use-def chain of the given value and computes a buffer type along the way.) 516 /// If both buffer types are equal, no casts are needed the computed buffer type 517 /// can be used directly. Otherwise, the buffer types can only differ in their 518 /// layout map and a cast must be inserted. 519 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType( 520 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, 521 const BufferizationOptions &options, SmallVector<Value> &invocationStack) { 522 // Determine the buffer type of the init_arg. 523 auto initArgBufferType = 524 bufferization::getBufferType(initArg, options, invocationStack); 525 if (failed(initArgBufferType)) 526 return failure(); 527 528 if (llvm::count(invocationStack, iterArg) >= 2) { 529 // If the iter_arg is already twice on the invocation stack, just take the 530 // type of the init_arg. This is to avoid infinite loops when calculating 531 // the buffer type. This will most likely result in computing a memref type 532 // with a fully dynamic layout map. 533 534 // Note: For more precise layout map computation, a fixpoint iteration could 535 // be done (i.e., re-computing the yielded buffer type until the bufferized 536 // iter_arg type no longer changes). This current implementation immediately 537 // switches to a fully dynamic layout map when a mismatch between bufferized 538 // init_arg type and bufferized yield value type is detected. 539 return *initArgBufferType; 540 } 541 542 // Compute the buffer type of the yielded value. 543 BaseMemRefType yieldedValueBufferType; 544 if (isa<BaseMemRefType>(yieldedValue.getType())) { 545 // scf.yield was already bufferized. 546 yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType()); 547 } else { 548 // Note: This typically triggers a recursive call for the buffer type of 549 // the iter_arg. 550 auto maybeBufferType = 551 bufferization::getBufferType(yieldedValue, options, invocationStack); 552 if (failed(maybeBufferType)) 553 return failure(); 554 yieldedValueBufferType = *maybeBufferType; 555 } 556 557 // If yielded type and init_arg type are the same, use that type directly. 558 if (*initArgBufferType == yieldedValueBufferType) 559 return yieldedValueBufferType; 560 561 // If there is a mismatch between the yielded buffer type and the init_arg 562 // buffer type, the buffer type must be promoted to a fully dynamic layout 563 // map. 564 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType); 565 auto iterTensorType = cast<TensorType>(iterArg.getType()); 566 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType); 567 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace()) 568 return loopOp->emitOpError( 569 "init_arg and yielded value bufferize to inconsistent memory spaces"); 570 #ifndef NDEBUG 571 if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) { 572 assert( 573 llvm::all_equal({yieldedRankedBufferType.getShape(), 574 cast<MemRefType>(initBufferType).getShape(), 575 cast<RankedTensorType>(iterTensorType).getShape()}) && 576 "expected same shape"); 577 } 578 #endif // NDEBUG 579 return getMemRefTypeWithFullyDynamicLayout( 580 iterTensorType, yieldedBufferType.getMemorySpace()); 581 } 582 583 /// Return `true` if the given loop may have 0 iterations. 584 bool mayHaveZeroIterations(scf::ForOp forOp) { 585 std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound()); 586 std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound()); 587 if (!lb.has_value() || !ub.has_value()) 588 return true; 589 return *ub <= *lb; 590 } 591 592 /// Bufferization of scf.for. Replace with a new scf.for that operates on 593 /// memrefs. 594 struct ForOpInterface 595 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 596 scf::ForOp> { 597 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 598 const AnalysisState &state) const { 599 auto forOp = cast<scf::ForOp>(op); 600 601 // If the loop has zero iterations, the results of the op are their 602 // corresponding init_args, meaning that the init_args bufferize to a read. 603 if (mayHaveZeroIterations(forOp)) 604 return true; 605 606 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 607 // its matching bbArg may. 608 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand)); 609 } 610 611 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 612 const AnalysisState &state) const { 613 // Tensor iter_args of scf::ForOps are always considered as a write. 614 return true; 615 } 616 617 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 618 const AnalysisState &state) const { 619 auto forOp = cast<scf::ForOp>(op); 620 OpResult opResult = forOp.getTiedLoopResult(&opOperand); 621 BufferRelation relation = bufferRelation(op, opResult, state); 622 return {{opResult, relation, 623 /*isDefinite=*/relation == BufferRelation::Equivalent}}; 624 } 625 626 BufferRelation bufferRelation(Operation *op, OpResult opResult, 627 const AnalysisState &state) const { 628 // ForOp results are equivalent to their corresponding init_args if the 629 // corresponding iter_args and yield values are equivalent. 630 auto forOp = cast<scf::ForOp>(op); 631 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); 632 bool equivalentYield = state.areEquivalentBufferizedValues( 633 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get()); 634 return equivalentYield ? BufferRelation::Equivalent 635 : BufferRelation::Unknown; 636 } 637 638 bool isWritable(Operation *op, Value value, 639 const AnalysisState &state) const { 640 // Interestingly, scf::ForOp's bbArg can **always** be viewed 641 // inplace from the perspective of ops nested under: 642 // 1. Either the matching iter operand is not bufferized inplace and an 643 // alloc + optional copy makes the bbArg itself inplaceable. 644 // 2. Or the matching iter operand is bufferized inplace and bbArg just 645 // bufferizes to that too. 646 return true; 647 } 648 649 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 650 const AnalysisState &state) const { 651 auto bufferizableOp = cast<BufferizableOpInterface>(op); 652 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 653 return failure(); 654 655 if (!state.getOptions().enforceAliasingInvariants || 656 state.getOptions().copyBeforeWrite) 657 return success(); 658 659 // According to the `getAliasing...` implementations, a bufferized OpResult 660 // may alias only with the corresponding bufferized init_arg (or with a 661 // newly allocated buffer) and not with other buffers defined outside of the 662 // loop. I.e., the i-th OpResult may alias with the i-th init_arg; 663 // but not with any other OpOperand. 664 auto forOp = cast<scf::ForOp>(op); 665 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); 666 OpBuilder::InsertionGuard g(rewriter); 667 rewriter.setInsertionPoint(yieldOp); 668 669 // Indices of all iter_args that have tensor type. These are the ones that 670 // are bufferized. 671 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 672 // For every yielded value, does it alias with something defined outside of 673 // the loop? 674 SmallVector<Value> yieldValues; 675 for (const auto it : llvm::enumerate(yieldOp.getResults())) { 676 // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this 677 // type cannot be used in the signature of `resolveConflicts` because the 678 // op interface is in the "IR" build unit and the `OneShotAnalysisState` 679 // is defined in the "Transforms" build unit. 680 if (!indices.contains(it.index()) || 681 doesNotAliasExternalValue( 682 it.value(), &forOp.getRegion(), 683 /*exceptions=*/forOp.getRegionIterArg(it.index()), 684 static_cast<const OneShotAnalysisState &>(state))) { 685 yieldValues.push_back(it.value()); 686 continue; 687 } 688 FailureOr<Value> alloc = allocateTensorForShapedValue( 689 rewriter, yieldOp.getLoc(), it.value(), state.getOptions()); 690 if (failed(alloc)) 691 return failure(); 692 yieldValues.push_back(*alloc); 693 } 694 695 rewriter.modifyOpInPlace( 696 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); 697 return success(); 698 } 699 700 FailureOr<BaseMemRefType> 701 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 702 SmallVector<Value> &invocationStack) const { 703 auto forOp = cast<scf::ForOp>(op); 704 assert(getOwnerOfValue(value) == op && "invalid value"); 705 assert(isa<TensorType>(value.getType()) && "expected tensor type"); 706 707 if (auto opResult = dyn_cast<OpResult>(value)) { 708 // The type of an OpResult must match the corresponding iter_arg type. 709 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); 710 return bufferization::getBufferType(bbArg, options, invocationStack); 711 } 712 713 // Compute result/argument number. 714 BlockArgument bbArg = cast<BlockArgument>(value); 715 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber(); 716 717 // Compute the bufferized type. 718 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); 719 Value yieldedValue = yieldOp.getOperand(resultNum); 720 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; 721 Value initArg = forOp.getInitArgs()[resultNum]; 722 return computeLoopRegionIterArgBufferType( 723 op, iterArg, initArg, yieldedValue, options, invocationStack); 724 } 725 726 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 727 const BufferizationOptions &options) const { 728 auto forOp = cast<scf::ForOp>(op); 729 Block *oldLoopBody = forOp.getBody(); 730 731 // Indices of all iter_args that have tensor type. These are the ones that 732 // are bufferized. 733 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 734 735 // The new memref init_args of the loop. 736 FailureOr<SmallVector<Value>> maybeInitArgs = 737 getBuffers(rewriter, forOp.getInitArgsMutable(), options); 738 if (failed(maybeInitArgs)) 739 return failure(); 740 SmallVector<Value> initArgs = *maybeInitArgs; 741 742 // Cast init_args if necessary. 743 SmallVector<Value> castedInitArgs; 744 for (const auto &it : llvm::enumerate(initArgs)) { 745 Value initArg = it.value(); 746 Value result = forOp->getResult(it.index()); 747 // If the type is not a tensor, bufferization doesn't need to touch it. 748 if (!isa<TensorType>(result.getType())) { 749 castedInitArgs.push_back(initArg); 750 continue; 751 } 752 auto targetType = bufferization::getBufferType(result, options); 753 if (failed(targetType)) 754 return failure(); 755 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); 756 } 757 758 // Construct a new scf.for op with memref instead of tensor values. 759 auto newForOp = rewriter.create<scf::ForOp>( 760 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 761 forOp.getStep(), castedInitArgs); 762 newForOp->setAttrs(forOp->getAttrs()); 763 Block *loopBody = newForOp.getBody(); 764 765 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 766 // iter_args of the new loop in ToTensorOps. 767 rewriter.setInsertionPointToStart(loopBody); 768 SmallVector<Value> iterArgs = 769 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), 770 forOp.getRegionIterArgs(), indices); 771 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 772 773 // Move loop body to new loop. 774 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 775 776 // Replace loop results. 777 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 778 779 return success(); 780 } 781 782 /// Assert that yielded values of an scf.for op are equivalent to their 783 /// corresponding bbArgs. In that case, the buffer relations of the 784 /// corresponding OpResults are "Equivalent". 785 /// 786 /// If this is not the case, an allocs+copies are inserted and yielded from 787 /// the loop. This could be a performance problem, so it must be explicitly 788 /// activated with `alloc-return-allocs`. 789 LogicalResult verifyAnalysis(Operation *op, 790 const AnalysisState &state) const { 791 const auto &options = 792 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 793 if (options.allowReturnAllocsFromLoops) 794 return success(); 795 796 auto forOp = cast<scf::ForOp>(op); 797 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); 798 for (OpResult opResult : op->getOpResults()) { 799 if (!isa<TensorType>(opResult.getType())) 800 continue; 801 802 // Note: This is overly strict. We should check for aliasing bufferized 803 // values. But we don't have a "must-alias" analysis yet. 804 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) 805 return yieldOp->emitError() 806 << "Yield operand #" << opResult.getResultNumber() 807 << " is not equivalent to the corresponding iter bbArg"; 808 } 809 810 return success(); 811 } 812 }; 813 814 /// Bufferization of scf.while. Replace with a new scf.while that operates on 815 /// memrefs. 816 struct WhileOpInterface 817 : public BufferizableOpInterface::ExternalModel<WhileOpInterface, 818 scf::WhileOp> { 819 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 820 const AnalysisState &state) const { 821 // Tensor iter_args of scf::WhileOps are always considered as a read. 822 return true; 823 } 824 825 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 826 const AnalysisState &state) const { 827 // Tensor iter_args of scf::WhileOps are always considered as a write. 828 return true; 829 } 830 831 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 832 const AnalysisState &state) const { 833 auto whileOp = cast<scf::WhileOp>(op); 834 unsigned int idx = opOperand.getOperandNumber(); 835 836 // The OpResults and OpOperands may not match. They may not even have the 837 // same type. The number of OpResults and OpOperands can also differ. 838 if (idx >= op->getNumResults() || 839 opOperand.get().getType() != op->getResult(idx).getType()) 840 return {}; 841 842 // The only aliasing OpResult may be the one at the same index. 843 OpResult opResult = whileOp->getResult(idx); 844 BufferRelation relation = bufferRelation(op, opResult, state); 845 return {{opResult, relation, 846 /*isDefinite=*/relation == BufferRelation::Equivalent}}; 847 } 848 849 BufferRelation bufferRelation(Operation *op, OpResult opResult, 850 const AnalysisState &state) const { 851 // WhileOp results are equivalent to their corresponding init_args if the 852 // corresponding iter_args and yield values are equivalent (for both the 853 // "before" and the "after" block). 854 unsigned int resultNumber = opResult.getResultNumber(); 855 auto whileOp = cast<scf::WhileOp>(op); 856 857 // The "before" region bbArgs and the OpResults may not match. 858 if (resultNumber >= whileOp.getBeforeArguments().size()) 859 return BufferRelation::Unknown; 860 if (opResult.getType() != 861 whileOp.getBeforeArguments()[resultNumber].getType()) 862 return BufferRelation::Unknown; 863 864 auto conditionOp = whileOp.getConditionOp(); 865 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; 866 Value conditionOperand = conditionOp.getArgs()[resultNumber]; 867 bool equivCondition = 868 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); 869 870 auto yieldOp = whileOp.getYieldOp(); 871 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; 872 Value yieldOperand = yieldOp.getOperand(resultNumber); 873 bool equivYield = 874 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); 875 876 return equivCondition && equivYield ? BufferRelation::Equivalent 877 : BufferRelation::Unknown; 878 } 879 880 bool isWritable(Operation *op, Value value, 881 const AnalysisState &state) const { 882 // Interestingly, scf::WhileOp's bbArg can **always** be viewed 883 // inplace from the perspective of ops nested under: 884 // 1. Either the matching iter operand is not bufferized inplace and an 885 // alloc + optional copy makes the bbArg itself inplaceable. 886 // 2. Or the matching iter operand is bufferized inplace and bbArg just 887 // bufferizes to that too. 888 return true; 889 } 890 891 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 892 const AnalysisState &state) const { 893 auto bufferizableOp = cast<BufferizableOpInterface>(op); 894 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 895 return failure(); 896 897 if (!state.getOptions().enforceAliasingInvariants || 898 state.getOptions().copyBeforeWrite) 899 return success(); 900 901 // According to the `getAliasing...` implementations, a bufferized OpResult 902 // may alias only with the corresponding bufferized init_arg and with no 903 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; 904 // but not with any other OpOperand. If a corresponding OpResult/init_arg 905 // pair bufferizes to equivalent buffers, this aliasing requirement is 906 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. 907 // (New buffer copies do not alias with any buffer.) 908 OpBuilder::InsertionGuard g(rewriter); 909 auto whileOp = cast<scf::WhileOp>(op); 910 auto conditionOp = whileOp.getConditionOp(); 911 912 // For every yielded value, is the value equivalent to its corresponding 913 // bbArg? 914 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 915 whileOp.getBeforeArguments(), conditionOp.getArgs(), state); 916 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 917 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); 918 919 // Update "before" region. 920 rewriter.setInsertionPoint(conditionOp); 921 SmallVector<Value> beforeYieldValues; 922 for (int64_t idx = 0; 923 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { 924 Value value = conditionOp.getArgs()[idx]; 925 if (!isa<TensorType>(value.getType()) || 926 (equivalentYieldsAfter.contains(idx) && 927 equivalentYieldsBefore.contains(idx))) { 928 beforeYieldValues.push_back(value); 929 continue; 930 } 931 FailureOr<Value> alloc = allocateTensorForShapedValue( 932 rewriter, conditionOp.getLoc(), value, state.getOptions()); 933 if (failed(alloc)) 934 return failure(); 935 beforeYieldValues.push_back(*alloc); 936 } 937 rewriter.modifyOpInPlace(conditionOp, [&]() { 938 conditionOp.getArgsMutable().assign(beforeYieldValues); 939 }); 940 941 return success(); 942 } 943 944 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 945 const BufferizationOptions &options) const { 946 auto whileOp = cast<scf::WhileOp>(op); 947 948 // Indices of all bbArgs that have tensor type. These are the ones that 949 // are bufferized. The "before" and "after" regions may have different args. 950 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); 951 DenseSet<int64_t> indicesAfter = 952 getTensorIndices(whileOp.getAfterArguments()); 953 954 // The new memref init_args of the loop. 955 FailureOr<SmallVector<Value>> maybeInitArgs = 956 getBuffers(rewriter, whileOp.getInitsMutable(), options); 957 if (failed(maybeInitArgs)) 958 return failure(); 959 SmallVector<Value> initArgs = *maybeInitArgs; 960 961 // Cast init_args if necessary. 962 SmallVector<Value> castedInitArgs; 963 for (const auto &it : llvm::enumerate(initArgs)) { 964 Value initArg = it.value(); 965 Value beforeArg = whileOp.getBeforeArguments()[it.index()]; 966 // If the type is not a tensor, bufferization doesn't need to touch it. 967 if (!isa<TensorType>(beforeArg.getType())) { 968 castedInitArgs.push_back(initArg); 969 continue; 970 } 971 auto targetType = bufferization::getBufferType(beforeArg, options); 972 if (failed(targetType)) 973 return failure(); 974 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); 975 } 976 977 // The result types of a WhileOp are the same as the "after" bbArg types. 978 SmallVector<Type> argsTypesAfter = llvm::to_vector( 979 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { 980 if (!isa<TensorType>(bbArg.getType())) 981 return bbArg.getType(); 982 // TODO: error handling 983 return llvm::cast<Type>( 984 *bufferization::getBufferType(bbArg, options)); 985 })); 986 987 // Construct a new scf.while op with memref instead of tensor values. 988 ValueRange argsRangeBefore(castedInitArgs); 989 TypeRange argsTypesBefore(argsRangeBefore); 990 auto newWhileOp = rewriter.create<scf::WhileOp>( 991 whileOp.getLoc(), argsTypesAfter, castedInitArgs); 992 993 // Add before/after regions to the new op. 994 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(), 995 whileOp.getLoc()); 996 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), 997 whileOp.getLoc()); 998 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); 999 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); 1000 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); 1001 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); 1002 1003 // Set up new iter_args and move the loop condition block to the new op. 1004 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 1005 // in ToTensorOps. 1006 rewriter.setInsertionPointToStart(newBeforeBody); 1007 SmallVector<Value> newBeforeArgs = 1008 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(), 1009 whileOp.getBeforeArguments(), indicesBefore); 1010 rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs); 1011 1012 // Set up new iter_args and move the loop body block to the new op. 1013 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 1014 // in ToTensorOps. 1015 rewriter.setInsertionPointToStart(newAfterBody); 1016 SmallVector<Value> newAfterArgs = 1017 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), 1018 whileOp.getAfterArguments(), indicesAfter); 1019 rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs); 1020 1021 // Replace loop results. 1022 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); 1023 1024 return success(); 1025 } 1026 1027 FailureOr<BaseMemRefType> 1028 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 1029 SmallVector<Value> &invocationStack) const { 1030 auto whileOp = cast<scf::WhileOp>(op); 1031 assert(getOwnerOfValue(value) == op && "invalid value"); 1032 assert(isa<TensorType>(value.getType()) && "expected tensor type"); 1033 1034 // Case 1: Block argument of the "before" region. 1035 if (auto bbArg = dyn_cast<BlockArgument>(value)) { 1036 if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { 1037 Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; 1038 auto yieldOp = whileOp.getYieldOp(); 1039 Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); 1040 return computeLoopRegionIterArgBufferType( 1041 op, bbArg, initArg, yieldedValue, options, invocationStack); 1042 } 1043 } 1044 1045 // Case 2: OpResult of the loop or block argument of the "after" region. 1046 // The bufferized "after" bbArg type can be directly computed from the 1047 // bufferized "before" bbArg type. 1048 unsigned resultNum; 1049 if (auto opResult = dyn_cast<OpResult>(value)) { 1050 resultNum = opResult.getResultNumber(); 1051 } else if (cast<BlockArgument>(value).getOwner()->getParent() == 1052 &whileOp.getAfter()) { 1053 resultNum = cast<BlockArgument>(value).getArgNumber(); 1054 } else { 1055 llvm_unreachable("invalid value"); 1056 } 1057 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; 1058 if (!isa<TensorType>(conditionYieldedVal.getType())) { 1059 // scf.condition was already bufferized. 1060 return cast<BaseMemRefType>(conditionYieldedVal.getType()); 1061 } 1062 return bufferization::getBufferType(conditionYieldedVal, options, 1063 invocationStack); 1064 } 1065 1066 /// Assert that yielded values of an scf.while op are equivalent to their 1067 /// corresponding bbArgs. In that case, the buffer relations of the 1068 /// corresponding OpResults are "Equivalent". 1069 /// 1070 /// If this is not the case, allocs+copies are inserted and yielded from 1071 /// the loop. This could be a performance problem, so it must be explicitly 1072 /// activated with `allow-return-allocs`. 1073 /// 1074 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the 1075 /// equivalence condition must be checked for both. 1076 LogicalResult verifyAnalysis(Operation *op, 1077 const AnalysisState &state) const { 1078 auto whileOp = cast<scf::WhileOp>(op); 1079 const auto &options = 1080 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 1081 if (options.allowReturnAllocsFromLoops) 1082 return success(); 1083 1084 auto conditionOp = whileOp.getConditionOp(); 1085 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 1086 Block *block = conditionOp->getBlock(); 1087 if (!isa<TensorType>(it.value().getType())) 1088 continue; 1089 if (it.index() >= block->getNumArguments() || 1090 !state.areEquivalentBufferizedValues(it.value(), 1091 block->getArgument(it.index()))) 1092 return conditionOp->emitError() 1093 << "Condition arg #" << it.index() 1094 << " is not equivalent to the corresponding iter bbArg"; 1095 } 1096 1097 auto yieldOp = whileOp.getYieldOp(); 1098 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 1099 Block *block = yieldOp->getBlock(); 1100 if (!isa<TensorType>(it.value().getType())) 1101 continue; 1102 if (it.index() >= block->getNumArguments() || 1103 !state.areEquivalentBufferizedValues(it.value(), 1104 block->getArgument(it.index()))) 1105 return yieldOp->emitError() 1106 << "Yield operand #" << it.index() 1107 << " is not equivalent to the corresponding iter bbArg"; 1108 } 1109 1110 return success(); 1111 } 1112 }; 1113 1114 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 1115 /// this is for analysis only. 1116 struct YieldOpInterface 1117 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 1118 scf::YieldOp> { 1119 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1120 const AnalysisState &state) const { 1121 return true; 1122 } 1123 1124 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1125 const AnalysisState &state) const { 1126 return false; 1127 } 1128 1129 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 1130 const AnalysisState &state) const { 1131 if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) { 1132 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), 1133 BufferRelation::Equivalent, /*isDefinite=*/false}}; 1134 } 1135 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 1136 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), 1137 BufferRelation::Equivalent}}; 1138 return {}; 1139 } 1140 1141 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 1142 const AnalysisState &state) const { 1143 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 1144 // may be generated inside the block. We should not return/yield allocations 1145 // when possible. 1146 return true; 1147 } 1148 1149 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1150 const BufferizationOptions &options) const { 1151 auto yieldOp = cast<scf::YieldOp>(op); 1152 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp, 1153 scf::WhileOp>(yieldOp->getParentOp())) 1154 return yieldOp->emitError("unsupported scf::YieldOp parent"); 1155 1156 SmallVector<Value> newResults; 1157 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 1158 Value value = it.value(); 1159 if (isa<TensorType>(value.getType())) { 1160 FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options); 1161 if (failed(maybeBuffer)) 1162 return failure(); 1163 Value buffer = *maybeBuffer; 1164 // We may have to cast the value before yielding it. 1165 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>( 1166 yieldOp->getParentOp())) { 1167 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( 1168 yieldOp->getParentOp()->getResult(it.index()), options); 1169 if (failed(resultType)) 1170 return failure(); 1171 buffer = castBuffer(rewriter, buffer, *resultType); 1172 } else if (auto whileOp = 1173 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) { 1174 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( 1175 whileOp.getBeforeArguments()[it.index()], options); 1176 if (failed(resultType)) 1177 return failure(); 1178 buffer = castBuffer(rewriter, buffer, *resultType); 1179 } 1180 newResults.push_back(buffer); 1181 } else { 1182 newResults.push_back(value); 1183 } 1184 } 1185 1186 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults); 1187 return success(); 1188 } 1189 }; 1190 1191 /// Return `true` if the given loop may have 0 iterations. 1192 bool mayHaveZeroIterations(scf::ForallOp forallOp) { 1193 for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(), 1194 forallOp.getMixedUpperBound())) { 1195 std::optional<int64_t> lbConst = getConstantIntValue(lb); 1196 std::optional<int64_t> ubConst = getConstantIntValue(ub); 1197 if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst) 1198 return true; 1199 } 1200 return false; 1201 } 1202 1203 /// Bufferization of ForallOp. This also bufferizes the terminator of the 1204 /// region. There are op interfaces for the terminators (InParallelOp 1205 /// and ParallelInsertSliceOp), but these are only used during analysis. Not 1206 /// for bufferization. 1207 struct ForallOpInterface 1208 : public BufferizableOpInterface::ExternalModel<ForallOpInterface, 1209 ForallOp> { 1210 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1211 const AnalysisState &state) const { 1212 auto forallOp = cast<ForallOp>(op); 1213 1214 // If the loop has zero iterations, the results of the op are their 1215 // corresponding shared_outs, meaning that the shared_outs bufferize to a 1216 // read. 1217 if (mayHaveZeroIterations(forallOp)) 1218 return true; 1219 1220 // scf::ForallOp alone doesn't bufferize to a memory read, one of the 1221 // uses of its matching bbArg may. 1222 return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand)); 1223 } 1224 1225 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1226 const AnalysisState &state) const { 1227 // Outputs of scf::ForallOps are always considered as a write. 1228 return true; 1229 } 1230 1231 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 1232 const AnalysisState &state) const { 1233 auto forallOp = cast<ForallOp>(op); 1234 return { 1235 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}}; 1236 } 1237 1238 bool isWritable(Operation *op, Value value, 1239 const AnalysisState &state) const { 1240 return true; 1241 } 1242 1243 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1244 const BufferizationOptions &options) const { 1245 OpBuilder::InsertionGuard guard(rewriter); 1246 auto forallOp = cast<ForallOp>(op); 1247 int64_t rank = forallOp.getRank(); 1248 1249 // Get buffers for all output operands. 1250 SmallVector<Value> buffers; 1251 for (Value out : forallOp.getOutputs()) { 1252 FailureOr<Value> buffer = getBuffer(rewriter, out, options); 1253 if (failed(buffer)) 1254 return failure(); 1255 buffers.push_back(*buffer); 1256 } 1257 1258 // Use buffers instead of block arguments. 1259 rewriter.setInsertionPointToStart(forallOp.getBody()); 1260 for (const auto &it : llvm::zip( 1261 forallOp.getBody()->getArguments().drop_front(rank), buffers)) { 1262 BlockArgument bbArg = std::get<0>(it); 1263 Value buffer = std::get<1>(it); 1264 Value bufferAsTensor = rewriter.create<ToTensorOp>( 1265 forallOp.getLoc(), bbArg.getType(), buffer); 1266 bbArg.replaceAllUsesWith(bufferAsTensor); 1267 } 1268 1269 // Create new ForallOp without any results and drop the automatically 1270 // introduced terminator. 1271 rewriter.setInsertionPoint(forallOp); 1272 ForallOp newForallOp; 1273 newForallOp = rewriter.create<ForallOp>( 1274 forallOp.getLoc(), forallOp.getMixedLowerBound(), 1275 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), 1276 /*outputs=*/ValueRange(), forallOp.getMapping()); 1277 1278 // Keep discardable attributes from the original op. 1279 newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); 1280 1281 rewriter.eraseOp(newForallOp.getBody()->getTerminator()); 1282 1283 // Move over block contents of the old op. 1284 SmallVector<Value> replacementBbArgs; 1285 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(), 1286 newForallOp.getBody()->getArguments().end()); 1287 replacementBbArgs.append(forallOp.getOutputs().size(), Value()); 1288 rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(), 1289 replacementBbArgs); 1290 1291 // Remove the old op and replace all of its uses. 1292 replaceOpWithBufferizedValues(rewriter, op, buffers); 1293 1294 return success(); 1295 } 1296 1297 FailureOr<BaseMemRefType> 1298 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 1299 SmallVector<Value> &invocationStack) const { 1300 auto forallOp = cast<ForallOp>(op); 1301 1302 if (auto bbArg = dyn_cast<BlockArgument>(value)) 1303 // A tensor block argument has the same bufferized type as the 1304 // corresponding output operand. 1305 return bufferization::getBufferType( 1306 forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack); 1307 1308 // The bufferized result type is the same as the bufferized type of the 1309 // corresponding output operand. 1310 return bufferization::getBufferType( 1311 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options, 1312 invocationStack); 1313 } 1314 1315 bool isRepetitiveRegion(Operation *op, unsigned index) const { 1316 auto forallOp = cast<ForallOp>(op); 1317 1318 // This op is repetitive if it has 1 or more steps. 1319 // If the control variables are dynamic, it is also considered so. 1320 for (auto [lb, ub, step] : 1321 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), 1322 forallOp.getMixedStep())) { 1323 std::optional<int64_t> lbConstant = getConstantIntValue(lb); 1324 if (!lbConstant) 1325 return true; 1326 1327 std::optional<int64_t> ubConstant = getConstantIntValue(ub); 1328 if (!ubConstant) 1329 return true; 1330 1331 std::optional<int64_t> stepConstant = getConstantIntValue(step); 1332 if (!stepConstant) 1333 return true; 1334 1335 if (*lbConstant + *stepConstant < *ubConstant) 1336 return true; 1337 } 1338 return false; 1339 } 1340 1341 bool isParallelRegion(Operation *op, unsigned index) const { 1342 return isRepetitiveRegion(op, index); 1343 } 1344 }; 1345 1346 /// Nothing to do for InParallelOp. 1347 struct InParallelOpInterface 1348 : public BufferizableOpInterface::ExternalModel<InParallelOpInterface, 1349 InParallelOp> { 1350 LogicalResult bufferize(Operation *op, RewriterBase &b, 1351 const BufferizationOptions &options) const { 1352 llvm_unreachable("op does not have any tensor OpOperands / OpResults"); 1353 return failure(); 1354 } 1355 }; 1356 1357 } // namespace 1358 } // namespace scf 1359 } // namespace mlir 1360 1361 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 1362 DialectRegistry ®istry) { 1363 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 1364 ConditionOp::attachInterface<ConditionOpInterface>(*ctx); 1365 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 1366 ForOp::attachInterface<ForOpInterface>(*ctx); 1367 IfOp::attachInterface<IfOpInterface>(*ctx); 1368 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx); 1369 ForallOp::attachInterface<ForallOpInterface>(*ctx); 1370 InParallelOp::attachInterface<InParallelOpInterface>(*ctx); 1371 WhileOp::attachInterface<WhileOpInterface>(*ctx); 1372 YieldOp::attachInterface<YieldOpInterface>(*ctx); 1373 }); 1374 } 1375