1 //===- BufferizableOpInterface.h - Bufferizable Ops -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ 10 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ 11 12 #include "mlir/IR/Operation.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Support/LLVM.h" 15 #include "llvm/ADT/DenseMapInfoVariant.h" 16 #include "llvm/ADT/SetVector.h" 17 #include <optional> 18 19 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc" 20 21 namespace mlir { 22 class OpBuilder; 23 namespace func { 24 class FuncOp; 25 } 26 27 namespace bufferization { 28 29 class AnalysisState; 30 class BufferizableOpInterface; 31 32 /// Specifies a fine-grain relationship between buffers to enable more analysis. 33 enum class BufferRelation { 34 Unknown, 35 // TODO: ResultContainsOperand, 36 // TODO: OperandContainsResult, 37 Equivalent 38 }; 39 40 /// A maybe aliasing OpOperand. If `isDefinite` is `true`, the OpOperand is 41 /// guaranteed to alias at runtime. 42 struct AliasingOpOperand { 43 AliasingOpOperand(OpOperand *opOperand, BufferRelation relation, 44 bool isDefinite = true) 45 : opOperand(opOperand), relation(relation), isDefinite(isDefinite) {} 46 47 OpOperand *opOperand; 48 BufferRelation relation; 49 bool isDefinite; 50 }; 51 52 /// A maybe aliasing Value. If `isDefinite` is `true`, the Value is guaranteed 53 /// to alias at runtime. 54 struct AliasingValue { 55 AliasingValue(Value value, BufferRelation relation, bool isDefinite = true) 56 : value(value), relation(relation), isDefinite(isDefinite) {} 57 58 Value value; 59 BufferRelation relation; 60 bool isDefinite; 61 }; 62 63 template <typename T> 64 class AliasList { 65 public: 66 /// Create an empty list of aliases. 67 AliasList() = default; 68 69 /// Create a list of aliases. 70 AliasList(std::initializer_list<T> elems) { 71 for (T alias : elems) 72 addAlias(alias); 73 } 74 75 /// Create a list of aliases. 76 AliasList(SmallVector<T> &&aliases) : aliases(std::move(aliases)) {} 77 78 ArrayRef<T> getAliases() const { return aliases; } 79 80 size_t getNumAliases() const { return aliases.size(); } 81 82 void addAlias(T alias) { aliases.push_back(alias); } 83 84 auto begin() const { return aliases.begin(); } 85 auto end() const { return aliases.end(); } 86 87 private: 88 /// The list of aliases. 89 SmallVector<T> aliases; 90 }; 91 92 /// A list of possible aliasing OpOperands. This list models the runtime 93 /// aliasing relationship for a Value. 94 using AliasingOpOperandList = AliasList<AliasingOpOperand>; 95 96 /// A list of possible aliasing Values. This list models the runtime aliasing 97 /// relationship for an OpOperand. 98 using AliasingValueList = AliasList<AliasingValue>; 99 100 class OpFilter { 101 public: 102 /// An op filter entry. Filters can be used to specify which ops should be 103 /// processed by the bufferization. 104 struct Entry { 105 /// If the filter function evaluates to `true`, the filter matches. 106 using FilterFn = std::function<bool(Operation *)>; 107 108 /// Filter type: A filter can either be a DENY filter or an ALLOW filter. 109 enum FilterType : int8_t { DENY = 0, ALLOW = 1 }; 110 111 FilterFn fn; 112 FilterType type; 113 }; 114 115 /// Return whether the op is allowed or not. 116 /// 117 /// If the filter does not have an ALLOW rule, ops are allowed by default, 118 /// unless they are explicitly marked as DENY. If the filter has at least one 119 /// ALLOW rule, ops are denied by default and only allowed if they match 120 /// an ALLOW rule and no DENY rule. 121 bool isOpAllowed(Operation *op) const; 122 123 /// Allow the given dialects. 124 /// 125 /// This function adds one or multiple ALLOW entries. 126 template <typename... DialectTs> 127 void allowDialect() { 128 // The following expands a call to allowDialectImpl for each dialect 129 // in 'DialectTs'. 130 (allowDialectImpl<DialectTs>(), ...); 131 } 132 133 /// Deny the given dialects. 134 /// 135 /// This function adds one or multiple DENY entries. 136 template <typename... DialectTs> 137 void denyDialect() { 138 (denyDialectImpl<DialectTs>(), ...); 139 } 140 141 /// Allow the given dialect. 142 /// 143 /// This function adds an ALLOW entry. 144 void allowDialect(StringRef dialectNamespace) { 145 Entry::FilterFn filterFn = [=](Operation *op) { 146 return op->getName().getDialectNamespace() == dialectNamespace; 147 }; 148 entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW}); 149 } 150 151 /// Deny the given dialect. 152 /// 153 /// This function adds a DENY entry. 154 void denyDialect(StringRef dialectNamespace) { 155 Entry::FilterFn filterFn = [=](Operation *op) { 156 return op->getName().getDialectNamespace() == dialectNamespace; 157 }; 158 entries.push_back(Entry{filterFn, Entry::FilterType::DENY}); 159 } 160 161 /// Allow the given ops. 162 /// 163 /// This function adds one or multiple ALLOW entries. 164 template <typename... OpTys> 165 void allowOperation() { 166 (allowOperationImpl<OpTys>(), ...); 167 } 168 169 /// Deny the given ops. 170 /// 171 /// This function adds one or multiple DENY entries. 172 template <typename... OpTys> 173 void denyOperation() { 174 (denyOperationImpl<OpTys>(), ...); 175 } 176 177 /// Allow the given op. 178 /// 179 /// This function adds an ALLOW entry. 180 void allowOperation(StringRef opName) { 181 Entry::FilterFn filterFn = [=](Operation *op) { 182 return op->getName().getStringRef() == opName; 183 }; 184 allowOperation(filterFn); 185 } 186 187 /// Deny the given op. 188 /// 189 /// This function adds a DENY entry. 190 void denyOperation(StringRef opName) { 191 Entry::FilterFn filterFn = [=](Operation *op) { 192 return op->getName().getStringRef() == opName; 193 }; 194 denyOperation(filterFn); 195 } 196 197 /// Allow ops that are matched by `fn`. 198 /// 199 /// This function adds an ALLOW entry. 200 void allowOperation(Entry::FilterFn fn) { 201 entries.push_back(Entry{fn, Entry::FilterType::ALLOW}); 202 } 203 204 /// Deny ops that are matched by `fn`. 205 /// 206 /// This function adds a DENY entry. 207 void denyOperation(Entry::FilterFn fn) { 208 entries.push_back(Entry{fn, Entry::FilterType::DENY}); 209 } 210 211 private: 212 /// Return `true` if the filter has at least one ALLOW rule. 213 bool hasAllowRule() const { 214 for (const Entry &e : entries) 215 if (e.type == Entry::FilterType::ALLOW) 216 return true; 217 return false; 218 } 219 220 /// Allow a dialect. 221 template <typename DialectT> 222 void allowDialectImpl() { 223 allowDialect(DialectT::getDialectNamespace()); 224 } 225 226 /// Deny a dialect. 227 template <typename DialectT> 228 void denyDialectImpl() { 229 denyDialect(DialectT::getDialectNamespace()); 230 } 231 232 /// Allow an op. 233 template <typename OpTy> 234 void allowOperationImpl() { 235 allowOperation(OpTy::getOperationName()); 236 } 237 238 /// Deny an op. 239 template <typename OpTy> 240 void denyOperationImpl() { 241 denyOperation(OpTy::getOperationName()); 242 } 243 244 /// A list of filter entries that determine whether an op should be allowed or 245 /// denied. If the filter has an ALLOW rule, only ops that are allowed and not 246 /// denied are allowed. If the filter does not have an ALLOW rule, only ops 247 /// that are not denied are allowed. 248 SmallVector<Entry> entries; 249 }; 250 251 /// Options for BufferizableOpInterface-based bufferization. 252 struct BufferizationOptions { 253 /// Allocator function: Generate a memref allocation with the given type, 254 /// dynamic extents and alignment. 255 using AllocationFn = std::function<FailureOr<Value>( 256 OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>; 257 /// Memcpy function: Generate a memcpy between two buffers. 258 using MemCpyFn = 259 std::function<LogicalResult(OpBuilder &, Location, Value, Value)>; 260 /// Initializer function for analysis state. 261 using AnalysisStateInitFn = std::function<void(AnalysisState &)>; 262 /// Tensor -> MemRef type converter. 263 /// Parameters: tensor type, memory space, func op, bufferization options 264 using FunctionArgTypeConverterFn = 265 std::function<BaseMemRefType(TensorType, Attribute memorySpace, 266 func::FuncOp, const BufferizationOptions &)>; 267 /// Tensor -> MemRef type converter. 268 /// Parameters: Value, memory space, bufferization options 269 using UnknownTypeConverterFn = std::function<BaseMemRefType( 270 Value, Attribute memorySpace, const BufferizationOptions &)>; 271 // Produce a MemorySpace attribute from a tensor type 272 using DefaultMemorySpaceFn = 273 std::function<std::optional<Attribute>(TensorType t)>; 274 275 BufferizationOptions(); 276 277 /// Try to cast the given op to BufferizableOpInterface if the op is allow 278 /// listed. 279 BufferizableOpInterface dynCastBufferizableOp(Operation *op) const; 280 281 /// Try to cast the given value to BufferizableOpInterface if the op is allow 282 /// listed. 283 BufferizableOpInterface dynCastBufferizableOp(Value value) const; 284 285 /// A filter that specifies which ops should be bufferized and which ops 286 /// should be ignored. 287 OpFilter opFilter; 288 289 /// Return `true` if the given op should be bufferized. 290 bool isOpAllowed(Operation *op) const; 291 292 /// Helper functions for allocation and memory copying. 293 std::optional<AllocationFn> allocationFn; 294 std::optional<MemCpyFn> memCpyFn; 295 296 /// Create a memref allocation with the given type and dynamic extents. 297 FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type, 298 ValueRange dynShape) const; 299 300 /// Creates a memcpy between two given buffers. 301 LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, 302 Value to) const; 303 304 /// Specifies whether not bufferizable ops are allowed in the input. If so, 305 /// bufferization.to_memref and bufferization.to_tensor ops are inserted at 306 /// the boundaries. 307 bool allowUnknownOps = false; 308 309 /// Specifies whether function boundaries (ops in the func dialect) should be 310 /// bufferized or not. 311 bool bufferizeFunctionBoundaries = false; 312 313 // Specifies whether to account for parallel regions in RaW analysis. If true, 314 // then writes inside of parallel regions that write to buffers defined 315 // outside of the parallel region will be given a new buffer. 316 bool checkParallelRegions = true; 317 318 /// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for). 319 /// If this flag is set to `false`, those invariants are no longer enforced 320 /// with buffer copies. 321 /// 322 /// Note: Deactivating this flag can lead to incorrect bufferization results 323 /// when used incorrectly. This flag is useful with 324 /// `AlwaysCopyAnalysisState` which bufferizes all writing tensor 325 /// OpOperands out-of-place. 326 bool enforceAliasingInvariants = true; 327 328 /// This function controls buffer types on function signatures. Sets 329 /// `functionArgTypeConverterFn` and `inferFunctionResultLayout` accordingly. 330 /// 331 /// * InferLayoutMap: All function parameter types have a fully dynamic layout 332 /// map, but function result types are inferred from the body of the 333 /// function. 334 /// * FullyDynamicLayoutMap: All function parameter types and result types 335 /// have a fully dynamic layout map. This option is most efficient because 336 /// any layout map can be casted to a fully dynamic one. 337 /// * IdentityLayoutMap: All function parameter types and result types have a 338 /// static identity layout (i.e., no layout map). This option may introduce 339 /// additional buffer allocs and copies because layout maps cannot be casted 340 /// away. 341 /// 342 /// Note: Inferred layout maps may not be desireable when interacting with 343 /// external functions, because the generated function signatures will be less 344 /// predictable. 345 void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); 346 347 /// Type converter from tensors to memrefs. This type converter is used to 348 /// determine bufferized function argument and result types. By default, a 349 /// type converter that returns a memref type with a fully dynamic layout map 350 /// is used. 351 /// 352 /// If `bufferizeFunctionBoundaries` is not set, this function isn't used. 353 FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr; 354 355 /// If true, function result types are inferred from the body of the function. 356 /// Otherwise, function result type is determined by 357 /// `functionArgTypeConverterFn`. 358 /// 359 /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. 360 bool inferFunctionResultLayout = true; 361 362 /// Type converter from tensors to memrefs. This type converter is used if no 363 /// memref type could be inferred during bufferization. By default, a type 364 /// converter that returns a memref type with a fully dynamic layout map is 365 /// used. 366 UnknownTypeConverterFn unknownTypeConverterFn = nullptr; 367 368 // Use during type conversion to determine the memory space for memref based 369 // on the original tensor type if the memory space cannot be inferred. 370 // Returning std::nullopt will cause bufferization to fail (useful to indicate 371 // failure to determine memory space for a tensor type). 372 DefaultMemorySpaceFn defaultMemorySpaceFn = 373 [](TensorType t) -> std::optional<Attribute> { return Attribute(); }; 374 375 /// If set to `true`, the analysis is skipped. A buffer is copied before every 376 /// write. This flag cannot be used together with `testAnalysisOnly = true`. 377 bool copyBeforeWrite = false; 378 379 /// If set to `true`, does not modify the IR apart from adding attributes (for 380 /// checking the results of the analysis) and post analysis steps. 381 bool testAnalysisOnly = false; 382 383 /// If set to `true`, the IR is annotated with details about RaW conflicts. 384 /// For debugging only. Should be used together with `testAnalysisOnly`. 385 bool printConflicts = false; 386 387 /// Buffer alignment for new memory allocations. 388 unsigned int bufferAlignment = 64; 389 390 /// Initializer functions for analysis state. These can be used to 391 /// initialize dialect-specific analysis state. 392 SmallVector<AnalysisStateInitFn> stateInitializers; 393 }; 394 395 /// Traversal parameters for `findValueInReverseUseDefChain`. 396 struct TraversalConfig { 397 /// Specifies if leaves (that do not have further OpOperands to follow) 398 /// should be returned even if they do not match the specified filter. 399 bool alwaysIncludeLeaves = true; 400 401 /// Specifies whether out-of-place/undecided OpOperands should be followed. 402 bool followInPlaceOnly = false; 403 404 /// Specifies whether non-equivalent OpOperands should be followed. 405 bool followEquivalentOnly = false; 406 407 /// Specifies whether unknown/non-bufferizable/ops not included in the 408 /// OpFilter of BufferizationOptions should be followed. 409 bool followUnknownOps = false; 410 411 /// Specifies whether OpOperands with a different type that are not the result 412 /// of a CastOpInterface op should be followed. 413 bool followSameTypeOrCastsOnly = false; 414 415 /// Specifies whether already visited values should be visited again. 416 /// (Note: This can result in infinite looping.) 417 bool revisitAlreadyVisitedValues = false; 418 }; 419 420 /// AnalysisState provides a variety of helper functions for dealing with 421 /// tensor values. 422 class AnalysisState { 423 public: 424 /// Determine which OpOperand* will alias with `value` if the op is 425 /// bufferized in place. Return all tensor OpOperand* if the op is not 426 /// bufferizable. 427 AliasingOpOperandList getAliasingOpOperands(Value value) const; 428 429 /// Determine which Value will alias with `opOperand` if the op is bufferized 430 /// in place. Return all tensor Values if the op is not bufferizable. 431 AliasingValueList getAliasingValues(OpOperand &opOperand) const; 432 433 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if 434 /// the op is not bufferizable. 435 bool bufferizesToMemoryRead(OpOperand &opOperand) const; 436 437 /// Return true if `opOperand` bufferizes to a memory write. Return true` if 438 /// the op is not bufferizable. 439 bool bufferizesToMemoryWrite(OpOperand &opOperand) const; 440 441 /// Return true if the given `value` bufferizes to a memory write. Return 442 /// true if the value is a block argument. Return `true` if the defining op is 443 /// not bufferizable. Otherwise, consult the BufferizableOpInterface. 444 bool bufferizesToMemoryWrite(Value value) const; 445 446 /// Return true if `opOperand` does neither read nor write but bufferizes to 447 /// an alias. Return false if the op is not bufferizable. 448 bool bufferizesToAliasOnly(OpOperand &opOperand) const; 449 450 /// Return true if a copy can always be avoided when allocating a new tensor 451 /// for the given OpOperand. 452 bool canOmitTensorCopy(OpOperand &opOperand) const; 453 454 /// Return true if the given value is read by an op that bufferizes to a 455 /// memory read. Also takes into account ops that create an alias but do not 456 /// read by themselves (e.g., ExtractSliceOp). 457 bool isValueRead(Value value) const; 458 459 /// Starting from `opOperand`, follow the use-def chain in reverse, always 460 /// selecting the aliasing OpOperands. Find and return Values for which 461 /// `condition` evaluates to true. OpOperands of such matching Values are not 462 /// traversed any further, the visited aliasing opOperands will be preserved 463 /// through `visitedOpOperands`. 464 /// 465 /// When reaching the end of a chain, also return the last Value of that 466 /// chain if `config.alwaysIncludeLeaves` is set. 467 /// 468 /// Example: 469 /// 470 /// 8 471 /// | 472 /// 6* 7* +-----+----+ 473 /// | | | | 474 /// 2* 3 4* 5 475 /// | | | | 476 /// +----------+----------+----------+ 477 /// | 478 /// 1 479 /// 480 /// In the above example, Values with a star satisfy the condition. When 481 /// starting the traversal from Value 1, the resulting SetVector is: 482 /// { 2, 7, 8, 5 } 483 /// 484 /// Additional stopping conditions for the traversal can be specified in 485 /// `config`. 486 SetVector<Value> findValueInReverseUseDefChain( 487 OpOperand *opOperand, llvm::function_ref<bool(Value)> condition, 488 TraversalConfig config = TraversalConfig(), 489 llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const; 490 491 /// Find the values that may define the contents of the given value at 492 /// runtime. A block argument is always a definition. An OpResult is a 493 /// definition if it bufferizes to memory write. If it does not bufferize to 494 /// a memory write but has aliasing operands, we continue the lookup on these 495 /// values. 496 /// 497 /// Example: %r = tensor.insert %f into %t[%c0] : tensor<?xf32> 498 /// findDefinitions(%r) = {%r} because %r bufferizes to memory write. 499 /// 500 /// Example: %r = tensor.empty() : tensor<10xf32> 501 /// findDefinitions(%r) = {} because tensor.empty does not the define the 502 /// contents of its result (i.e., it does not bufferize to a memory write) 503 /// and it has no aliasing OpOperands. 504 /// 505 /// Example: 506 /// %a = arith.constant ... : tensor<10xf32> 507 /// %b1 = tensor.insert %f into %t : tensor<50xf32> 508 /// %b2 = tensor.extract_slice %b1[0][10][1] : tensor<50xf32> tensor<10xf32> 509 /// %r = arith.select %cond, %a, %b : tensor<10xf32> 510 /// findDefinitions(%r) = {%a, %b1}. %r and %b2 are skipped (lookup continues 511 /// in the operands) because their defining ops do not define the contents of 512 /// the tensor. 513 /// 514 /// Example: 515 /// %a = tensor.empty() : tensor<10xf32> 516 /// %b = arith.constant ... : tensor<10xf32> 517 /// %r = arith.select %cond, %a, %b : tensor<10xf32> 518 /// findDefinitions(%r) = {%b}. %a is excluded because it does not define the 519 /// contents of the tensor. 520 /// 521 /// Note: OpResults of unknown ops are handled conservatively and assumed to 522 /// be definitions. 523 SetVector<Value> findDefinitions(OpOperand *opOperand) const; 524 525 /// Return `true` if the given OpResult has been decided to bufferize inplace. 526 virtual bool isInPlace(OpOperand &opOperand) const; 527 528 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 529 virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const; 530 531 /// Return true if `v1` and `v2` may bufferize to aliasing buffers. 532 virtual bool areAliasingBufferizedValues(Value v1, Value v2) const; 533 534 /// Return `true` if the given tensor has undefined contents. 535 virtual bool hasUndefinedContents(OpOperand *opOperand) const; 536 537 /// Return a reference to the BufferizationOptions. 538 const BufferizationOptions &getOptions() const { return options; } 539 540 AnalysisState(const BufferizationOptions &options); 541 542 // AnalysisState should be passed as a reference. 543 AnalysisState(const AnalysisState &) = delete; 544 545 virtual ~AnalysisState() = default; 546 547 static bool classof(const AnalysisState *base) { return true; } 548 549 TypeID getType() const { return type; } 550 551 /// Return the closest enclosing repetitive region around the given op. 552 Region *getEnclosingRepetitiveRegion(Operation *op, 553 const BufferizationOptions &options); 554 555 /// Return the closest enclosing repetitive region around the place where the 556 /// given value is defined. 557 Region *getEnclosingRepetitiveRegion(Value value, 558 const BufferizationOptions &options); 559 560 /// Return the closest enclosing repetitive region around the given block. 561 Region *getEnclosingRepetitiveRegion(Block *block, 562 const BufferizationOptions &options); 563 564 virtual void resetCache(); 565 566 protected: 567 AnalysisState(const BufferizationOptions &options, TypeID type); 568 569 private: 570 /// A reference to current bufferization options. 571 const BufferizationOptions &options; 572 573 /// The type of analysis. 574 TypeID type; 575 576 /// Cache containing closest ancestor repetitive Region. 577 DenseMap<std::variant<Operation *, Block *, Region *, Value>, Region *> 578 enclosingRepetitiveRegionCache; 579 }; 580 581 /// Create an AllocTensorOp for the given shaped value (memref or tensor). 582 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with 583 /// undefined contents is allocated. 584 FailureOr<Value> 585 allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, 586 const BufferizationOptions &options, 587 bool copy = true); 588 589 /// Lookup the buffer for the given value. If the value was not bufferized 590 /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, 591 /// from which the memref operand is returned. 592 FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value, 593 const BufferizationOptions &options); 594 595 /// Return the buffer type for a given Value (tensor) after bufferization 596 /// without bufferizing any IR. 597 /// 598 /// Note: It should be sufficient to call `getBuffer()->getType()` in most 599 /// cases. However, when a buffer type should be predicted without modifying any 600 /// IR, this function can be used. 601 /// 602 /// This function is a wrapper around BufferizableOpInterface::getBufferType. 603 FailureOr<BaseMemRefType> getBufferType(Value value, 604 const BufferizationOptions &options); 605 606 /// Return the buffer type for a given Value (tensor) after bufferization 607 /// without bufferizing any IR. This function (and not the other overload 608 /// without `invocationStack`) can be used from `getBufferType` implementations 609 /// of the `BufferizableOpInterface`. 610 /// 611 /// Note: It should be sufficient to call `getBuffer()->getType()` in most 612 /// cases. However, when a buffer type should be predicted without modifying any 613 /// IR, this function can be used. 614 /// 615 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`. 616 FailureOr<BaseMemRefType> getBufferType(Value value, 617 const BufferizationOptions &options, 618 SmallVector<Value> &invocationStack); 619 620 /// Return "true" if the given op has tensor semantics and should be bufferized. 621 /// If the op is bufferizable, the BufferizableOpInterface is queried. 622 /// Otherwise, an op has tensor semantics if it has tensor operands, tensor 623 /// op results and/or tensor block arguments. 624 bool hasTensorSemantics(Operation *op); 625 626 /// Replace an op with replacement values. The op is deleted. Tensor OpResults 627 /// must be replaced with memref values. 628 void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, 629 ValueRange values); 630 631 /// Replace an op with a new op. The new op must have the same number of 632 /// results as the replaced op. The new op may not return any tensor values. 633 template <typename OpTy, typename... Args> 634 OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op, 635 Args &&...args) { 636 auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...); 637 replaceOpWithBufferizedValues(rewriter, op, newOp->getResults()); 638 return newOp; 639 } 640 641 /// Return a MemRefType to which the type of the given value can be bufferized. 642 /// 643 /// If possible, op bufferization implementations should not use this function 644 /// and instead infer precise memref types for tensor results by themselves. 645 /// 646 /// Unless a layout map was specified, `options.unknownTypeConverterFn` 647 /// determines what kind of layout map will be used. For best composability 648 /// (without copies), the fully dynamic layout map is used by default. 649 /// 650 /// Note: Canonicalization patterns could clean up layout maps and infer more 651 /// precise layout maps after bufferization. However, many possible 652 /// canonicalizations are currently not implemented. 653 BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, 654 MemRefLayoutAttrInterface layout = {}, 655 Attribute memorySpace = nullptr); 656 657 /// Return a MemRef type with fully dynamic layout. If the given tensor type 658 /// is unranked, return an unranked MemRef type. 659 BaseMemRefType 660 getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 661 Attribute memorySpace = nullptr); 662 663 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 664 /// the given tensor type is unranked, return an unranked MemRef type. 665 BaseMemRefType 666 getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 667 Attribute memorySpace = nullptr); 668 669 /// Return the owner of the given value. In case of a BlockArgument that is the 670 /// owner of the block. In case of an OpResult that is the defining op. 671 Operation *getOwnerOfValue(Value value); 672 673 /// Assuming that the given region is repetitive, find the next enclosing 674 /// repetitive region. 675 Region *getNextEnclosingRepetitiveRegion(Region *region, 676 const BufferizationOptions &options); 677 678 /// If `region` is a parallel region, return `region`. Otherwise, find the first 679 /// enclosing parallel region of `region`. If there is no such region, return 680 /// "nullptr". 681 /// 682 /// Note: Whether a region is parallel or sequential is queried from the 683 /// `BufferizableOpInterface`. 684 Region *getParallelRegion(Region *region, const BufferizationOptions &options); 685 686 namespace detail { 687 /// This is the default implementation of 688 /// BufferizableOpInterface::getAliasingOpOperands. Should not be called from 689 /// other places. 690 AliasingOpOperandList defaultGetAliasingOpOperands(Value value, 691 const AnalysisState &state); 692 693 /// This is the default implementation of 694 /// BufferizableOpInterface::getBufferType. Should not be called from other 695 /// places. 696 FailureOr<BaseMemRefType> 697 defaultGetBufferType(Value value, const BufferizationOptions &options, 698 SmallVector<Value> &invocationStack); 699 700 /// This is the default implementation of 701 /// BufferizableOpInterface::resultBufferizesToMemoryWrite. Should not be called 702 /// from other places. 703 bool defaultResultBufferizesToMemoryWrite(OpResult opResult, 704 const AnalysisState &state); 705 706 /// This is the default implementation of 707 /// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other 708 /// places. 709 bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp, 710 unsigned index); 711 712 /// This is the default implementation of getAliasingOpOperands in case the 713 /// defining op does not implement the BufferizableOpInterface. 714 AliasingOpOperandList unknownGetAliasingOpOperands(Value value); 715 716 /// This is the default implementation of getAliasingValues in case the owner 717 /// op does not implement the BufferizableOpInterface. 718 AliasingValueList unknownGetAliasingValues(OpOperand &opOperand); 719 720 /// This is the default implementation of 721 /// BufferizableOpInterface::hasTensorSemantics 722 bool defaultHasTensorSemantics(Operation *op); 723 } // namespace detail 724 725 } // namespace bufferization 726 } // namespace mlir 727 728 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState) 729 730 //===----------------------------------------------------------------------===// 731 // Bufferization Interfaces 732 //===----------------------------------------------------------------------===// 733 734 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc" 735 736 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ 737