1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/MLIRContext.h" 10 #include "AffineExprDetail.h" 11 #include "AffineMapDetail.h" 12 #include "AttributeDetail.h" 13 #include "IntegerSetDetail.h" 14 #include "TypeDetail.h" 15 #include "mlir/IR/Action.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/BuiltinAttributes.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/Diagnostics.h" 22 #include "mlir/IR/Dialect.h" 23 #include "mlir/IR/ExtensibleDialect.h" 24 #include "mlir/IR/IntegerSet.h" 25 #include "mlir/IR/Location.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/OperationSupport.h" 28 #include "mlir/IR/Types.h" 29 #include "llvm/ADT/DenseMap.h" 30 #include "llvm/ADT/DenseSet.h" 31 #include "llvm/ADT/SmallString.h" 32 #include "llvm/ADT/StringSet.h" 33 #include "llvm/ADT/Twine.h" 34 #include "llvm/Support/Allocator.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Support/Compiler.h" 37 #include "llvm/Support/Debug.h" 38 #include "llvm/Support/Mutex.h" 39 #include "llvm/Support/RWMutex.h" 40 #include "llvm/Support/ThreadPool.h" 41 #include "llvm/Support/raw_ostream.h" 42 #include <memory> 43 #include <optional> 44 45 #define DEBUG_TYPE "mlircontext" 46 47 using namespace mlir; 48 using namespace mlir::detail; 49 50 //===----------------------------------------------------------------------===// 51 // MLIRContext CommandLine Options 52 //===----------------------------------------------------------------------===// 53 54 namespace { 55 /// This struct contains command line options that can be used to initialize 56 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need 57 /// for global command line options. 58 struct MLIRContextOptions { 59 llvm::cl::opt<bool> disableThreading{ 60 "mlir-disable-threading", 61 llvm::cl::desc("Disable multi-threading within MLIR, overrides any " 62 "further call to MLIRContext::enableMultiThreading()")}; 63 64 llvm::cl::opt<bool> printOpOnDiagnostic{ 65 "mlir-print-op-on-diagnostic", 66 llvm::cl::desc("When a diagnostic is emitted on an operation, also print " 67 "the operation as an attached note"), 68 llvm::cl::init(true)}; 69 70 llvm::cl::opt<bool> printStackTraceOnDiagnostic{ 71 "mlir-print-stacktrace-on-diagnostic", 72 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " 73 "as an attached note")}; 74 }; 75 } // namespace 76 77 static llvm::ManagedStatic<MLIRContextOptions> clOptions; 78 79 static bool isThreadingGloballyDisabled() { 80 #if LLVM_ENABLE_THREADS != 0 81 return clOptions.isConstructed() && clOptions->disableThreading; 82 #else 83 return true; 84 #endif 85 } 86 87 /// Register a set of useful command-line options that can be used to configure 88 /// various flags within the MLIRContext. These flags are used when constructing 89 /// an MLIR context for initialization. 90 void mlir::registerMLIRContextCLOptions() { 91 // Make sure that the options struct has been initialized. 92 *clOptions; 93 } 94 95 //===----------------------------------------------------------------------===// 96 // Locking Utilities 97 //===----------------------------------------------------------------------===// 98 99 namespace { 100 /// Utility writer lock that takes a runtime flag that specifies if we really 101 /// need to lock. 102 struct ScopedWriterLock { 103 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 104 : mutex(shouldLock ? &mutexParam : nullptr) { 105 if (mutex) 106 mutex->lock(); 107 } 108 ~ScopedWriterLock() { 109 if (mutex) 110 mutex->unlock(); 111 } 112 llvm::sys::SmartRWMutex<true> *mutex; 113 }; 114 } // namespace 115 116 //===----------------------------------------------------------------------===// 117 // MLIRContextImpl 118 //===----------------------------------------------------------------------===// 119 120 namespace mlir { 121 /// This is the implementation of the MLIRContext class, using the pImpl idiom. 122 /// This class is completely private to this file, so everything is public. 123 class MLIRContextImpl { 124 public: 125 //===--------------------------------------------------------------------===// 126 // Debugging 127 //===--------------------------------------------------------------------===// 128 129 /// An action handler for handling actions that are dispatched through this 130 /// context. 131 std::function<void(function_ref<void()>, const tracing::Action &)> 132 actionHandler; 133 134 //===--------------------------------------------------------------------===// 135 // Diagnostics 136 //===--------------------------------------------------------------------===// 137 DiagnosticEngine diagEngine; 138 139 //===--------------------------------------------------------------------===// 140 // Options 141 //===--------------------------------------------------------------------===// 142 143 /// In most cases, creating operation in unregistered dialect is not desired 144 /// and indicate a misconfiguration of the compiler. This option enables to 145 /// detect such use cases 146 bool allowUnregisteredDialects = false; 147 148 /// Enable support for multi-threading within MLIR. 149 bool threadingIsEnabled = true; 150 151 /// Track if we are currently executing in a threaded execution environment 152 /// (like the pass-manager): this is only a debugging feature to help reducing 153 /// the chances of data races one some context APIs. 154 #ifndef NDEBUG 155 std::atomic<int> multiThreadedExecutionContext{0}; 156 #endif 157 158 /// If the operation should be attached to diagnostics printed via the 159 /// Operation::emit methods. 160 bool printOpOnDiagnostic = true; 161 162 /// If the current stack trace should be attached when emitting diagnostics. 163 bool printStackTraceOnDiagnostic = false; 164 165 //===--------------------------------------------------------------------===// 166 // Other 167 //===--------------------------------------------------------------------===// 168 169 /// This points to the ThreadPool used when processing MLIR tasks in parallel. 170 /// It can't be nullptr when multi-threading is enabled. Otherwise if 171 /// multi-threading is disabled, and the threadpool wasn't externally provided 172 /// using `setThreadPool`, this will be nullptr. 173 llvm::ThreadPoolInterface *threadPool = nullptr; 174 175 /// In case where the thread pool is owned by the context, this ensures 176 /// destruction with the context. 177 std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool; 178 179 /// An allocator used for AbstractAttribute and AbstractType objects. 180 llvm::BumpPtrAllocator abstractDialectSymbolAllocator; 181 182 /// This is a mapping from operation name to the operation info describing it. 183 llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations; 184 185 /// A vector of operation info specifically for registered operations. 186 llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations; 187 llvm::StringMap<RegisteredOperationName> registeredOperationsByName; 188 189 /// This is a sorted container of registered operations for a deterministic 190 /// and efficient `getRegisteredOperations` implementation. 191 SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations; 192 193 /// This is a list of dialects that are created referring to this context. 194 /// The MLIRContext owns the objects. These need to be declared after the 195 /// registered operations to ensure correct destruction order. 196 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; 197 DialectRegistry dialectsRegistry; 198 199 /// A mutex used when accessing operation information. 200 llvm::sys::SmartRWMutex<true> operationInfoMutex; 201 202 //===--------------------------------------------------------------------===// 203 // Affine uniquing 204 //===--------------------------------------------------------------------===// 205 206 // Affine expression, map and integer set uniquing. 207 StorageUniquer affineUniquer; 208 209 //===--------------------------------------------------------------------===// 210 // Type uniquing 211 //===--------------------------------------------------------------------===// 212 213 DenseMap<TypeID, AbstractType *> registeredTypes; 214 StorageUniquer typeUniquer; 215 216 /// This is a mapping from type name to the abstract type describing it. 217 /// It is used by `AbstractType::lookup` to get an `AbstractType` from a name. 218 /// As this map needs to be populated before `StringAttr` is loaded, we 219 /// cannot use `StringAttr` as the key. The context does not take ownership 220 /// of the key, so the `StringRef` must outlive the context. 221 llvm::DenseMap<StringRef, AbstractType *> nameToType; 222 223 /// Cached Type Instances. 224 BFloat16Type bf16Ty; 225 Float16Type f16Ty; 226 FloatTF32Type tf32Ty; 227 Float32Type f32Ty; 228 Float64Type f64Ty; 229 Float80Type f80Ty; 230 Float128Type f128Ty; 231 IndexType indexTy; 232 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 233 NoneType noneType; 234 235 //===--------------------------------------------------------------------===// 236 // Attribute uniquing 237 //===--------------------------------------------------------------------===// 238 239 DenseMap<TypeID, AbstractAttribute *> registeredAttributes; 240 StorageUniquer attributeUniquer; 241 242 /// This is a mapping from attribute name to the abstract attribute describing 243 /// it. It is used by `AbstractType::lookup` to get an `AbstractType` from a 244 /// name. 245 /// As this map needs to be populated before `StringAttr` is loaded, we 246 /// cannot use `StringAttr` as the key. The context does not take ownership 247 /// of the key, so the `StringRef` must outlive the context. 248 llvm::DenseMap<StringRef, AbstractAttribute *> nameToAttribute; 249 250 /// Cached Attribute Instances. 251 BoolAttr falseAttr, trueAttr; 252 UnitAttr unitAttr; 253 UnknownLoc unknownLocAttr; 254 DictionaryAttr emptyDictionaryAttr; 255 StringAttr emptyStringAttr; 256 257 /// Map of string attributes that may reference a dialect, that are awaiting 258 /// that dialect to be loaded. 259 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex; 260 DenseMap<StringRef, SmallVector<StringAttrStorage *>> 261 dialectReferencingStrAttrs; 262 263 /// A distinct attribute allocator that allocates every time since the 264 /// address of the distinct attribute storage serves as unique identifier. The 265 /// allocator is thread safe and frees the allocated storage after its 266 /// destruction. 267 DistinctAttributeAllocator distinctAttributeAllocator; 268 269 public: 270 MLIRContextImpl(bool threadingIsEnabled) 271 : threadingIsEnabled(threadingIsEnabled) { 272 if (threadingIsEnabled) { 273 ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>(); 274 threadPool = ownedThreadPool.get(); 275 } 276 } 277 ~MLIRContextImpl() { 278 for (auto typeMapping : registeredTypes) 279 typeMapping.second->~AbstractType(); 280 for (auto attrMapping : registeredAttributes) 281 attrMapping.second->~AbstractAttribute(); 282 } 283 }; 284 } // namespace mlir 285 286 MLIRContext::MLIRContext(Threading setting) 287 : MLIRContext(DialectRegistry(), setting) {} 288 289 MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) 290 : impl(new MLIRContextImpl(setting == Threading::ENABLED && 291 !isThreadingGloballyDisabled())) { 292 // Initialize values based on the command line flags if they were provided. 293 if (clOptions.isConstructed()) { 294 printOpOnDiagnostic(clOptions->printOpOnDiagnostic); 295 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); 296 } 297 298 // Pre-populate the registry. 299 registry.appendTo(impl->dialectsRegistry); 300 301 // Ensure the builtin dialect is always pre-loaded. 302 getOrLoadDialect<BuiltinDialect>(); 303 304 // Initialize several common attributes and types to avoid the need to lock 305 // the context when accessing them. 306 307 //// Types. 308 /// Floating-point Types. 309 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this); 310 impl->f16Ty = TypeUniquer::get<Float16Type>(this); 311 impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this); 312 impl->f32Ty = TypeUniquer::get<Float32Type>(this); 313 impl->f64Ty = TypeUniquer::get<Float64Type>(this); 314 impl->f80Ty = TypeUniquer::get<Float80Type>(this); 315 impl->f128Ty = TypeUniquer::get<Float128Type>(this); 316 /// Index Type. 317 impl->indexTy = TypeUniquer::get<IndexType>(this); 318 /// Integer Types. 319 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless); 320 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless); 321 impl->int16Ty = 322 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless); 323 impl->int32Ty = 324 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless); 325 impl->int64Ty = 326 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless); 327 impl->int128Ty = 328 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless); 329 /// None Type. 330 impl->noneType = TypeUniquer::get<NoneType>(this); 331 332 //// Attributes. 333 //// Note: These must be registered after the types as they may generate one 334 //// of the above types internally. 335 /// Unknown Location Attribute. 336 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this); 337 /// Bool Attributes. 338 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); 339 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); 340 /// Unit Attribute. 341 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this); 342 /// The empty dictionary attribute. 343 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); 344 /// The empty string attribute. 345 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); 346 347 // Register the affine storage objects with the uniquer. 348 impl->affineUniquer 349 .registerParametricStorageType<AffineBinaryOpExprStorage>(); 350 impl->affineUniquer 351 .registerParametricStorageType<AffineConstantExprStorage>(); 352 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>(); 353 impl->affineUniquer.registerParametricStorageType<AffineMapStorage>(); 354 impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>(); 355 } 356 357 MLIRContext::~MLIRContext() = default; 358 359 /// Copy the specified array of elements into memory managed by the provided 360 /// bump pointer allocator. This assumes the elements are all PODs. 361 template <typename T> 362 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator, 363 ArrayRef<T> elements) { 364 auto result = allocator.Allocate<T>(elements.size()); 365 std::uninitialized_copy(elements.begin(), elements.end(), result); 366 return ArrayRef<T>(result, elements.size()); 367 } 368 369 //===----------------------------------------------------------------------===// 370 // Action Handling 371 //===----------------------------------------------------------------------===// 372 373 void MLIRContext::registerActionHandler(HandlerTy handler) { 374 getImpl().actionHandler = std::move(handler); 375 } 376 377 /// Dispatch the provided action to the handler if any, or just execute it. 378 void MLIRContext::executeActionInternal(function_ref<void()> actionFn, 379 const tracing::Action &action) { 380 assert(getImpl().actionHandler); 381 getImpl().actionHandler(actionFn, action); 382 } 383 384 bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; } 385 386 //===----------------------------------------------------------------------===// 387 // Diagnostic Handlers 388 //===----------------------------------------------------------------------===// 389 390 /// Returns the diagnostic engine for this context. 391 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } 392 393 //===----------------------------------------------------------------------===// 394 // Dialect and Operation Registration 395 //===----------------------------------------------------------------------===// 396 397 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { 398 if (registry.isSubsetOf(impl->dialectsRegistry)) 399 return; 400 401 assert(impl->multiThreadedExecutionContext == 0 && 402 "appending to the MLIRContext dialect registry while in a " 403 "multi-threaded execution context"); 404 registry.appendTo(impl->dialectsRegistry); 405 406 // For the already loaded dialects, apply any possible extensions immediately. 407 registry.applyExtensions(this); 408 } 409 410 const DialectRegistry &MLIRContext::getDialectRegistry() { 411 return impl->dialectsRegistry; 412 } 413 414 /// Return information about all registered IR dialects. 415 std::vector<Dialect *> MLIRContext::getLoadedDialects() { 416 std::vector<Dialect *> result; 417 result.reserve(impl->loadedDialects.size()); 418 for (auto &dialect : impl->loadedDialects) 419 result.push_back(dialect.second.get()); 420 llvm::array_pod_sort(result.begin(), result.end(), 421 [](Dialect *const *lhs, Dialect *const *rhs) -> int { 422 return (*lhs)->getNamespace() < (*rhs)->getNamespace(); 423 }); 424 return result; 425 } 426 std::vector<StringRef> MLIRContext::getAvailableDialects() { 427 std::vector<StringRef> result; 428 for (auto dialect : impl->dialectsRegistry.getDialectNames()) 429 result.push_back(dialect); 430 return result; 431 } 432 433 /// Get a registered IR dialect with the given namespace. If none is found, 434 /// then return nullptr. 435 Dialect *MLIRContext::getLoadedDialect(StringRef name) { 436 // Dialects are sorted by name, so we can use binary search for lookup. 437 auto it = impl->loadedDialects.find(name); 438 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; 439 } 440 441 Dialect *MLIRContext::getOrLoadDialect(StringRef name) { 442 Dialect *dialect = getLoadedDialect(name); 443 if (dialect) 444 return dialect; 445 DialectAllocatorFunctionRef allocator = 446 impl->dialectsRegistry.getDialectAllocator(name); 447 return allocator ? allocator(this) : nullptr; 448 } 449 450 /// Get a dialect for the provided namespace and TypeID: abort the program if a 451 /// dialect exist for this namespace with different TypeID. Returns a pointer to 452 /// the dialect owned by the context. 453 Dialect * 454 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 455 function_ref<std::unique_ptr<Dialect>()> ctor) { 456 auto &impl = getImpl(); 457 // Get the correct insertion position sorted by namespace. 458 auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr); 459 460 if (dialectIt.second) { 461 LLVM_DEBUG(llvm::dbgs() 462 << "Load new dialect in Context " << dialectNamespace << "\n"); 463 #ifndef NDEBUG 464 if (impl.multiThreadedExecutionContext != 0) 465 llvm::report_fatal_error( 466 "Loading a dialect (" + dialectNamespace + 467 ") while in a multi-threaded execution context (maybe " 468 "the PassManager): this can indicate a " 469 "missing `dependentDialects` in a pass for example."); 470 #endif // NDEBUG 471 // loadedDialects entry is initialized to nullptr, indicating that the 472 // dialect is currently being loaded. Re-lookup the address in 473 // loadedDialects because the table might have been rehashed by recursive 474 // dialect loading in ctor(). 475 std::unique_ptr<Dialect> &dialectOwned = 476 impl.loadedDialects[dialectNamespace] = ctor(); 477 Dialect *dialect = dialectOwned.get(); 478 assert(dialect && "dialect ctor failed"); 479 480 // Refresh all the identifiers dialect field, this catches cases where a 481 // dialect may be loaded after identifier prefixed with this dialect name 482 // were already created. 483 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace); 484 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) { 485 for (StringAttrStorage *storage : stringAttrsIt->second) 486 storage->referencedDialect = dialect; 487 impl.dialectReferencingStrAttrs.erase(stringAttrsIt); 488 } 489 490 // Apply any extensions to this newly loaded dialect. 491 impl.dialectsRegistry.applyExtensions(dialect); 492 return dialect; 493 } 494 495 #ifndef NDEBUG 496 if (dialectIt.first->second == nullptr) 497 llvm::report_fatal_error( 498 "Loading (and getting) a dialect (" + dialectNamespace + 499 ") while the same dialect is still loading: use loadDialect instead " 500 "of getOrLoadDialect."); 501 #endif // NDEBUG 502 503 // Abort if dialect with namespace has already been registered. 504 std::unique_ptr<Dialect> &dialect = dialectIt.first->second; 505 if (dialect->getTypeID() != dialectID) 506 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 507 "' has already been registered"); 508 509 return dialect.get(); 510 } 511 512 bool MLIRContext::isDialectLoading(StringRef dialectNamespace) { 513 auto it = getImpl().loadedDialects.find(dialectNamespace); 514 // nullptr indicates that the dialect is currently being loaded. 515 return it != getImpl().loadedDialects.end() && it->second == nullptr; 516 } 517 518 DynamicDialect *MLIRContext::getOrLoadDynamicDialect( 519 StringRef dialectNamespace, function_ref<void(DynamicDialect *)> ctor) { 520 auto &impl = getImpl(); 521 // Get the correct insertion position sorted by namespace. 522 auto dialectIt = impl.loadedDialects.find(dialectNamespace); 523 524 if (dialectIt != impl.loadedDialects.end()) { 525 if (auto *dynDialect = dyn_cast<DynamicDialect>(dialectIt->second.get())) 526 return dynDialect; 527 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 528 "' has already been registered"); 529 } 530 531 LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context " 532 << dialectNamespace << "\n"); 533 #ifndef NDEBUG 534 if (impl.multiThreadedExecutionContext != 0) 535 llvm::report_fatal_error( 536 "Loading a dynamic dialect (" + dialectNamespace + 537 ") while in a multi-threaded execution context (maybe " 538 "the PassManager): this can indicate a " 539 "missing `dependentDialects` in a pass for example."); 540 #endif 541 542 auto name = StringAttr::get(this, dialectNamespace); 543 auto *dialect = new DynamicDialect(name, this); 544 (void)getOrLoadDialect(name, dialect->getTypeID(), [dialect, ctor]() { 545 ctor(dialect); 546 return std::unique_ptr<DynamicDialect>(dialect); 547 }); 548 // This is the same result as `getOrLoadDialect` (if it didn't failed), 549 // since it has the same TypeID, and TypeIDs are unique. 550 return dialect; 551 } 552 553 void MLIRContext::loadAllAvailableDialects() { 554 for (StringRef name : getAvailableDialects()) 555 getOrLoadDialect(name); 556 } 557 558 llvm::hash_code MLIRContext::getRegistryHash() { 559 llvm::hash_code hash(0); 560 // Factor in number of loaded dialects, attributes, operations, types. 561 hash = llvm::hash_combine(hash, impl->loadedDialects.size()); 562 hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); 563 hash = llvm::hash_combine(hash, impl->registeredOperations.size()); 564 hash = llvm::hash_combine(hash, impl->registeredTypes.size()); 565 return hash; 566 } 567 568 bool MLIRContext::allowsUnregisteredDialects() { 569 return impl->allowUnregisteredDialects; 570 } 571 572 void MLIRContext::allowUnregisteredDialects(bool allowing) { 573 assert(impl->multiThreadedExecutionContext == 0 && 574 "changing MLIRContext `allow-unregistered-dialects` configuration " 575 "while in a multi-threaded execution context"); 576 impl->allowUnregisteredDialects = allowing; 577 } 578 579 /// Return true if multi-threading is enabled by the context. 580 bool MLIRContext::isMultithreadingEnabled() { 581 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); 582 } 583 584 /// Set the flag specifying if multi-threading is disabled by the context. 585 void MLIRContext::disableMultithreading(bool disable) { 586 // This API can be overridden by the global debugging flag 587 // --mlir-disable-threading 588 if (isThreadingGloballyDisabled()) 589 return; 590 assert(impl->multiThreadedExecutionContext == 0 && 591 "changing MLIRContext `disable-threading` configuration while " 592 "in a multi-threaded execution context"); 593 594 impl->threadingIsEnabled = !disable; 595 596 // Update the threading mode for each of the uniquers. 597 impl->affineUniquer.disableMultithreading(disable); 598 impl->attributeUniquer.disableMultithreading(disable); 599 impl->typeUniquer.disableMultithreading(disable); 600 601 // Destroy thread pool (stop all threads) if it is no longer needed, or create 602 // a new one if multithreading was re-enabled. 603 if (disable) { 604 // If the thread pool is owned, explicitly set it to nullptr to avoid 605 // keeping a dangling pointer around. If the thread pool is externally 606 // owned, we don't do anything. 607 if (impl->ownedThreadPool) { 608 assert(impl->threadPool); 609 impl->threadPool = nullptr; 610 impl->ownedThreadPool.reset(); 611 } 612 } else if (!impl->threadPool) { 613 // The thread pool isn't externally provided. 614 assert(!impl->ownedThreadPool); 615 impl->ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>(); 616 impl->threadPool = impl->ownedThreadPool.get(); 617 } 618 } 619 620 void MLIRContext::setThreadPool(llvm::ThreadPoolInterface &pool) { 621 assert(!isMultithreadingEnabled() && 622 "expected multi-threading to be disabled when setting a ThreadPool"); 623 impl->threadPool = &pool; 624 impl->ownedThreadPool.reset(); 625 enableMultithreading(); 626 } 627 628 unsigned MLIRContext::getNumThreads() { 629 if (isMultithreadingEnabled()) { 630 assert(impl->threadPool && 631 "multi-threading is enabled but threadpool not set"); 632 return impl->threadPool->getMaxConcurrency(); 633 } 634 // No multithreading or active thread pool. Return 1 thread. 635 return 1; 636 } 637 638 llvm::ThreadPoolInterface &MLIRContext::getThreadPool() { 639 assert(isMultithreadingEnabled() && 640 "expected multi-threading to be enabled within the context"); 641 assert(impl->threadPool && 642 "multi-threading is enabled but threadpool not set"); 643 return *impl->threadPool; 644 } 645 646 void MLIRContext::enterMultiThreadedExecution() { 647 #ifndef NDEBUG 648 ++impl->multiThreadedExecutionContext; 649 #endif 650 } 651 void MLIRContext::exitMultiThreadedExecution() { 652 #ifndef NDEBUG 653 --impl->multiThreadedExecutionContext; 654 #endif 655 } 656 657 /// Return true if we should attach the operation to diagnostics emitted via 658 /// Operation::emit. 659 bool MLIRContext::shouldPrintOpOnDiagnostic() { 660 return impl->printOpOnDiagnostic; 661 } 662 663 /// Set the flag specifying if we should attach the operation to diagnostics 664 /// emitted via Operation::emit. 665 void MLIRContext::printOpOnDiagnostic(bool enable) { 666 assert(impl->multiThreadedExecutionContext == 0 && 667 "changing MLIRContext `print-op-on-diagnostic` configuration while in " 668 "a multi-threaded execution context"); 669 impl->printOpOnDiagnostic = enable; 670 } 671 672 /// Return true if we should attach the current stacktrace to diagnostics when 673 /// emitted. 674 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { 675 return impl->printStackTraceOnDiagnostic; 676 } 677 678 /// Set the flag specifying if we should attach the current stacktrace when 679 /// emitting diagnostics. 680 void MLIRContext::printStackTraceOnDiagnostic(bool enable) { 681 assert(impl->multiThreadedExecutionContext == 0 && 682 "changing MLIRContext `print-stacktrace-on-diagnostic` configuration " 683 "while in a multi-threaded execution context"); 684 impl->printStackTraceOnDiagnostic = enable; 685 } 686 687 /// Return information about all registered operations. 688 ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() { 689 return impl->sortedRegisteredOperations; 690 } 691 692 /// Return information for registered operations by dialect. 693 ArrayRef<RegisteredOperationName> 694 MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) { 695 auto lowerBound = 696 std::lower_bound(impl->sortedRegisteredOperations.begin(), 697 impl->sortedRegisteredOperations.end(), dialectName, 698 [](auto &lhs, auto &rhs) { 699 return lhs.getDialect().getNamespace().compare(rhs); 700 }); 701 702 if (lowerBound == impl->sortedRegisteredOperations.end() || 703 lowerBound->getDialect().getNamespace() != dialectName) 704 return ArrayRef<RegisteredOperationName>(); 705 706 auto upperBound = 707 std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(), 708 dialectName, [](auto &lhs, auto &rhs) { 709 return lhs.compare(rhs.getDialect().getNamespace()); 710 }); 711 712 size_t count = std::distance(lowerBound, upperBound); 713 return ArrayRef(&*lowerBound, count); 714 } 715 716 bool MLIRContext::isOperationRegistered(StringRef name) { 717 return RegisteredOperationName::lookup(name, this).has_value(); 718 } 719 720 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { 721 auto &impl = context->getImpl(); 722 assert(impl.multiThreadedExecutionContext == 0 && 723 "Registering a new type kind while in a multi-threaded execution " 724 "context"); 725 auto *newInfo = 726 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>()) 727 AbstractType(std::move(typeInfo)); 728 if (!impl.registeredTypes.insert({typeID, newInfo}).second) 729 llvm::report_fatal_error("Dialect Type already registered."); 730 if (!impl.nameToType.insert({newInfo->getName(), newInfo}).second) 731 llvm::report_fatal_error("Dialect Type with name " + newInfo->getName() + 732 " is already registered."); 733 } 734 735 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { 736 auto &impl = context->getImpl(); 737 assert(impl.multiThreadedExecutionContext == 0 && 738 "Registering a new attribute kind while in a multi-threaded execution " 739 "context"); 740 auto *newInfo = 741 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>()) 742 AbstractAttribute(std::move(attrInfo)); 743 if (!impl.registeredAttributes.insert({typeID, newInfo}).second) 744 llvm::report_fatal_error("Dialect Attribute already registered."); 745 if (!impl.nameToAttribute.insert({newInfo->getName(), newInfo}).second) 746 llvm::report_fatal_error("Dialect Attribute with name " + 747 newInfo->getName() + " is already registered."); 748 } 749 750 //===----------------------------------------------------------------------===// 751 // AbstractAttribute 752 //===----------------------------------------------------------------------===// 753 754 /// Get the dialect that registered the attribute with the provided typeid. 755 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, 756 MLIRContext *context) { 757 const AbstractAttribute *abstract = lookupMutable(typeID, context); 758 if (!abstract) 759 llvm::report_fatal_error("Trying to create an Attribute that was not " 760 "registered in this MLIRContext."); 761 return *abstract; 762 } 763 764 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, 765 MLIRContext *context) { 766 auto &impl = context->getImpl(); 767 return impl.registeredAttributes.lookup(typeID); 768 } 769 770 std::optional<std::reference_wrapper<const AbstractAttribute>> 771 AbstractAttribute::lookup(StringRef name, MLIRContext *context) { 772 MLIRContextImpl &impl = context->getImpl(); 773 const AbstractAttribute *type = impl.nameToAttribute.lookup(name); 774 775 if (!type) 776 return std::nullopt; 777 return {*type}; 778 } 779 780 //===----------------------------------------------------------------------===// 781 // OperationName 782 //===----------------------------------------------------------------------===// 783 784 OperationName::Impl::Impl(StringRef name, Dialect *dialect, TypeID typeID, 785 detail::InterfaceMap interfaceMap) 786 : Impl(StringAttr::get(dialect->getContext(), name), dialect, typeID, 787 std::move(interfaceMap)) {} 788 789 OperationName::OperationName(StringRef name, MLIRContext *context) { 790 MLIRContextImpl &ctxImpl = context->getImpl(); 791 792 // Check for an existing name in read-only mode. 793 bool isMultithreadingEnabled = context->isMultithreadingEnabled(); 794 if (isMultithreadingEnabled) { 795 // Check the registered info map first. In the overwhelmingly common case, 796 // the entry will be in here and it also removes the need to acquire any 797 // locks. 798 auto registeredIt = ctxImpl.registeredOperationsByName.find(name); 799 if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) { 800 impl = registeredIt->second.impl; 801 return; 802 } 803 804 llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex); 805 auto it = ctxImpl.operations.find(name); 806 if (it != ctxImpl.operations.end()) { 807 impl = it->second.get(); 808 return; 809 } 810 } 811 812 // Acquire a writer-lock so that we can safely create the new instance. 813 ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled); 814 815 auto it = ctxImpl.operations.insert({name, nullptr}); 816 if (it.second) { 817 auto nameAttr = StringAttr::get(context, name); 818 it.first->second = std::make_unique<UnregisteredOpModel>( 819 nameAttr, nameAttr.getReferencedDialect(), TypeID::get<void>(), 820 detail::InterfaceMap()); 821 } 822 impl = it.first->second.get(); 823 } 824 825 StringRef OperationName::getDialectNamespace() const { 826 if (Dialect *dialect = getDialect()) 827 return dialect->getNamespace(); 828 return getStringRef().split('.').first; 829 } 830 831 LogicalResult 832 OperationName::UnregisteredOpModel::foldHook(Operation *, ArrayRef<Attribute>, 833 SmallVectorImpl<OpFoldResult> &) { 834 return failure(); 835 } 836 void OperationName::UnregisteredOpModel::getCanonicalizationPatterns( 837 RewritePatternSet &, MLIRContext *) {} 838 bool OperationName::UnregisteredOpModel::hasTrait(TypeID) { return false; } 839 840 OperationName::ParseAssemblyFn 841 OperationName::UnregisteredOpModel::getParseAssemblyFn() { 842 llvm::report_fatal_error("getParseAssemblyFn hook called on unregistered op"); 843 } 844 void OperationName::UnregisteredOpModel::populateDefaultAttrs( 845 const OperationName &, NamedAttrList &) {} 846 void OperationName::UnregisteredOpModel::printAssembly( 847 Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { 848 p.printGenericOp(op); 849 } 850 LogicalResult 851 OperationName::UnregisteredOpModel::verifyInvariants(Operation *) { 852 return success(); 853 } 854 LogicalResult 855 OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) { 856 return success(); 857 } 858 859 std::optional<Attribute> 860 OperationName::UnregisteredOpModel::getInherentAttr(Operation *op, 861 StringRef name) { 862 auto dict = dyn_cast_or_null<DictionaryAttr>(getPropertiesAsAttr(op)); 863 if (!dict) 864 return std::nullopt; 865 if (Attribute attr = dict.get(name)) 866 return attr; 867 return std::nullopt; 868 } 869 void OperationName::UnregisteredOpModel::setInherentAttr(Operation *op, 870 StringAttr name, 871 Attribute value) { 872 auto dict = dyn_cast_or_null<DictionaryAttr>(getPropertiesAsAttr(op)); 873 assert(dict); 874 NamedAttrList attrs(dict); 875 attrs.set(name, value); 876 *op->getPropertiesStorage().as<Attribute *>() = 877 attrs.getDictionary(op->getContext()); 878 } 879 void OperationName::UnregisteredOpModel::populateInherentAttrs( 880 Operation *op, NamedAttrList &attrs) {} 881 LogicalResult OperationName::UnregisteredOpModel::verifyInherentAttrs( 882 OperationName opName, NamedAttrList &attributes, 883 function_ref<InFlightDiagnostic()> emitError) { 884 return success(); 885 } 886 int OperationName::UnregisteredOpModel::getOpPropertyByteSize() { 887 return sizeof(Attribute); 888 } 889 void OperationName::UnregisteredOpModel::initProperties( 890 OperationName opName, OpaqueProperties storage, OpaqueProperties init) { 891 new (storage.as<Attribute *>()) Attribute(); 892 } 893 void OperationName::UnregisteredOpModel::deleteProperties( 894 OpaqueProperties prop) { 895 prop.as<Attribute *>()->~Attribute(); 896 } 897 void OperationName::UnregisteredOpModel::populateDefaultProperties( 898 OperationName opName, OpaqueProperties properties) {} 899 LogicalResult OperationName::UnregisteredOpModel::setPropertiesFromAttr( 900 OperationName opName, OpaqueProperties properties, Attribute attr, 901 function_ref<InFlightDiagnostic()> emitError) { 902 *properties.as<Attribute *>() = attr; 903 return success(); 904 } 905 Attribute 906 OperationName::UnregisteredOpModel::getPropertiesAsAttr(Operation *op) { 907 return *op->getPropertiesStorage().as<Attribute *>(); 908 } 909 void OperationName::UnregisteredOpModel::copyProperties(OpaqueProperties lhs, 910 OpaqueProperties rhs) { 911 *lhs.as<Attribute *>() = *rhs.as<Attribute *>(); 912 } 913 bool OperationName::UnregisteredOpModel::compareProperties( 914 OpaqueProperties lhs, OpaqueProperties rhs) { 915 return *lhs.as<Attribute *>() == *rhs.as<Attribute *>(); 916 } 917 llvm::hash_code 918 OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) { 919 return llvm::hash_combine(*prop.as<Attribute *>()); 920 } 921 922 //===----------------------------------------------------------------------===// 923 // RegisteredOperationName 924 //===----------------------------------------------------------------------===// 925 926 std::optional<RegisteredOperationName> 927 RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) { 928 auto &impl = ctx->getImpl(); 929 auto it = impl.registeredOperations.find(typeID); 930 if (it != impl.registeredOperations.end()) 931 return it->second; 932 return std::nullopt; 933 } 934 935 std::optional<RegisteredOperationName> 936 RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) { 937 auto &impl = ctx->getImpl(); 938 auto it = impl.registeredOperationsByName.find(name); 939 if (it != impl.registeredOperationsByName.end()) 940 return it->getValue(); 941 return std::nullopt; 942 } 943 944 void RegisteredOperationName::insert( 945 std::unique_ptr<RegisteredOperationName::Impl> ownedImpl, 946 ArrayRef<StringRef> attrNames) { 947 RegisteredOperationName::Impl *impl = ownedImpl.get(); 948 MLIRContext *ctx = impl->getDialect()->getContext(); 949 auto &ctxImpl = ctx->getImpl(); 950 assert(ctxImpl.multiThreadedExecutionContext == 0 && 951 "registering a new operation kind while in a multi-threaded execution " 952 "context"); 953 954 // Register the attribute names of this operation. 955 MutableArrayRef<StringAttr> cachedAttrNames; 956 if (!attrNames.empty()) { 957 cachedAttrNames = MutableArrayRef<StringAttr>( 958 ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>( 959 attrNames.size()), 960 attrNames.size()); 961 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size())) 962 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i])); 963 impl->attributeNames = cachedAttrNames; 964 } 965 StringRef name = impl->getName().strref(); 966 // Insert the operation info if it doesn't exist yet. 967 ctxImpl.operations[name] = std::move(ownedImpl); 968 969 // Update the registered info for this operation. 970 auto emplaced = ctxImpl.registeredOperations.try_emplace( 971 impl->getTypeID(), RegisteredOperationName(impl)); 972 assert(emplaced.second && "operation name registration must be successful"); 973 auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace( 974 name, RegisteredOperationName(impl)); 975 (void)emplacedByName; 976 assert(emplacedByName.second && 977 "operation name registration must be successful"); 978 979 // Add emplaced operation name to the sorted operations container. 980 RegisteredOperationName &value = emplaced.first->second; 981 ctxImpl.sortedRegisteredOperations.insert( 982 llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value, 983 [](auto &lhs, auto &rhs) { 984 return lhs.getIdentifier().compare( 985 rhs.getIdentifier()); 986 }), 987 value); 988 } 989 990 //===----------------------------------------------------------------------===// 991 // AbstractType 992 //===----------------------------------------------------------------------===// 993 994 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { 995 const AbstractType *type = lookupMutable(typeID, context); 996 if (!type) 997 llvm::report_fatal_error( 998 "Trying to create a Type that was not registered in this MLIRContext."); 999 return *type; 1000 } 1001 1002 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { 1003 auto &impl = context->getImpl(); 1004 return impl.registeredTypes.lookup(typeID); 1005 } 1006 1007 std::optional<std::reference_wrapper<const AbstractType>> 1008 AbstractType::lookup(StringRef name, MLIRContext *context) { 1009 MLIRContextImpl &impl = context->getImpl(); 1010 const AbstractType *type = impl.nameToType.lookup(name); 1011 1012 if (!type) 1013 return std::nullopt; 1014 return {*type}; 1015 } 1016 1017 //===----------------------------------------------------------------------===// 1018 // Type uniquing 1019 //===----------------------------------------------------------------------===// 1020 1021 /// Returns the storage uniquer used for constructing type storage instances. 1022 /// This should not be used directly. 1023 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } 1024 1025 BFloat16Type BFloat16Type::get(MLIRContext *context) { 1026 return context->getImpl().bf16Ty; 1027 } 1028 Float16Type Float16Type::get(MLIRContext *context) { 1029 return context->getImpl().f16Ty; 1030 } 1031 FloatTF32Type FloatTF32Type::get(MLIRContext *context) { 1032 return context->getImpl().tf32Ty; 1033 } 1034 Float32Type Float32Type::get(MLIRContext *context) { 1035 return context->getImpl().f32Ty; 1036 } 1037 Float64Type Float64Type::get(MLIRContext *context) { 1038 return context->getImpl().f64Ty; 1039 } 1040 Float80Type Float80Type::get(MLIRContext *context) { 1041 return context->getImpl().f80Ty; 1042 } 1043 Float128Type Float128Type::get(MLIRContext *context) { 1044 return context->getImpl().f128Ty; 1045 } 1046 1047 /// Get an instance of the IndexType. 1048 IndexType IndexType::get(MLIRContext *context) { 1049 return context->getImpl().indexTy; 1050 } 1051 1052 /// Return an existing integer type instance if one is cached within the 1053 /// context. 1054 static IntegerType 1055 getCachedIntegerType(unsigned width, 1056 IntegerType::SignednessSemantics signedness, 1057 MLIRContext *context) { 1058 if (signedness != IntegerType::Signless) 1059 return IntegerType(); 1060 1061 switch (width) { 1062 case 1: 1063 return context->getImpl().int1Ty; 1064 case 8: 1065 return context->getImpl().int8Ty; 1066 case 16: 1067 return context->getImpl().int16Ty; 1068 case 32: 1069 return context->getImpl().int32Ty; 1070 case 64: 1071 return context->getImpl().int64Ty; 1072 case 128: 1073 return context->getImpl().int128Ty; 1074 default: 1075 return IntegerType(); 1076 } 1077 } 1078 1079 IntegerType IntegerType::get(MLIRContext *context, unsigned width, 1080 IntegerType::SignednessSemantics signedness) { 1081 if (auto cached = getCachedIntegerType(width, signedness, context)) 1082 return cached; 1083 return Base::get(context, width, signedness); 1084 } 1085 1086 IntegerType 1087 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError, 1088 MLIRContext *context, unsigned width, 1089 SignednessSemantics signedness) { 1090 if (auto cached = getCachedIntegerType(width, signedness, context)) 1091 return cached; 1092 return Base::getChecked(emitError, context, width, signedness); 1093 } 1094 1095 /// Get an instance of the NoneType. 1096 NoneType NoneType::get(MLIRContext *context) { 1097 if (NoneType cachedInst = context->getImpl().noneType) 1098 return cachedInst; 1099 // Note: May happen when initializing the singleton attributes of the builtin 1100 // dialect. 1101 return Base::get(context); 1102 } 1103 1104 //===----------------------------------------------------------------------===// 1105 // Attribute uniquing 1106 //===----------------------------------------------------------------------===// 1107 1108 /// Returns the storage uniquer used for constructing attribute storage 1109 /// instances. This should not be used directly. 1110 StorageUniquer &MLIRContext::getAttributeUniquer() { 1111 return getImpl().attributeUniquer; 1112 } 1113 1114 /// Initialize the given attribute storage instance. 1115 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, 1116 MLIRContext *ctx, 1117 TypeID attrID) { 1118 storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); 1119 } 1120 1121 BoolAttr BoolAttr::get(MLIRContext *context, bool value) { 1122 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; 1123 } 1124 1125 UnitAttr UnitAttr::get(MLIRContext *context) { 1126 return context->getImpl().unitAttr; 1127 } 1128 1129 UnknownLoc UnknownLoc::get(MLIRContext *context) { 1130 return context->getImpl().unknownLocAttr; 1131 } 1132 1133 DistinctAttrStorage * 1134 detail::DistinctAttributeUniquer::allocateStorage(MLIRContext *context, 1135 Attribute referencedAttr) { 1136 return context->getImpl().distinctAttributeAllocator.allocate(referencedAttr); 1137 } 1138 1139 /// Return empty dictionary. 1140 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { 1141 return context->getImpl().emptyDictionaryAttr; 1142 } 1143 1144 void StringAttrStorage::initialize(MLIRContext *context) { 1145 // Check for a dialect namespace prefix, if there isn't one we don't need to 1146 // do any additional initialization. 1147 auto dialectNamePair = value.split('.'); 1148 if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) 1149 return; 1150 1151 // If one exists, we check to see if this dialect is loaded. If it is, we set 1152 // the dialect now, if it isn't we record this storage for initialization 1153 // later if the dialect ever gets loaded. 1154 if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) 1155 return; 1156 1157 MLIRContextImpl &impl = context->getImpl(); 1158 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex); 1159 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this); 1160 } 1161 1162 /// Return an empty string. 1163 StringAttr StringAttr::get(MLIRContext *context) { 1164 return context->getImpl().emptyStringAttr; 1165 } 1166 1167 //===----------------------------------------------------------------------===// 1168 // AffineMap uniquing 1169 //===----------------------------------------------------------------------===// 1170 1171 StorageUniquer &MLIRContext::getAffineUniquer() { 1172 return getImpl().affineUniquer; 1173 } 1174 1175 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, 1176 ArrayRef<AffineExpr> results, 1177 MLIRContext *context) { 1178 auto &impl = context->getImpl(); 1179 auto *storage = impl.affineUniquer.get<AffineMapStorage>( 1180 [&](AffineMapStorage *storage) { storage->context = context; }, dimCount, 1181 symbolCount, results); 1182 return AffineMap(storage); 1183 } 1184 1185 /// Check whether the arguments passed to the AffineMap::get() are consistent. 1186 /// This method checks whether the highest index of dimensional identifier 1187 /// present in result expressions is less than `dimCount` and the highest index 1188 /// of symbolic identifier present in result expressions is less than 1189 /// `symbolCount`. 1190 LLVM_ATTRIBUTE_UNUSED static bool 1191 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, 1192 ArrayRef<AffineExpr> results) { 1193 int64_t maxDimPosition = -1; 1194 int64_t maxSymbolPosition = -1; 1195 getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition, 1196 maxSymbolPosition); 1197 if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { 1198 LLVM_DEBUG( 1199 llvm::dbgs() 1200 << "maximum dimensional identifier position in result expression must " 1201 "be less than `dimCount` and maximum symbolic identifier position " 1202 "in result expression must be less than `symbolCount`\n"); 1203 return false; 1204 } 1205 return true; 1206 } 1207 1208 AffineMap AffineMap::get(MLIRContext *context) { 1209 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); 1210 } 1211 1212 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1213 MLIRContext *context) { 1214 return getImpl(dimCount, symbolCount, /*results=*/{}, context); 1215 } 1216 1217 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1218 AffineExpr result) { 1219 assert(willBeValidAffineMap(dimCount, symbolCount, {result})); 1220 return getImpl(dimCount, symbolCount, {result}, result.getContext()); 1221 } 1222 1223 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1224 ArrayRef<AffineExpr> results, MLIRContext *context) { 1225 assert(willBeValidAffineMap(dimCount, symbolCount, results)); 1226 return getImpl(dimCount, symbolCount, results, context); 1227 } 1228 1229 //===----------------------------------------------------------------------===// 1230 // Integer Sets: these are allocated into the bump pointer, and are immutable. 1231 // Unlike AffineMap's, these are uniqued only if they are small. 1232 //===----------------------------------------------------------------------===// 1233 1234 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, 1235 ArrayRef<AffineExpr> constraints, 1236 ArrayRef<bool> eqFlags) { 1237 // The number of constraints can't be zero. 1238 assert(!constraints.empty()); 1239 assert(constraints.size() == eqFlags.size()); 1240 1241 auto &impl = constraints[0].getContext()->getImpl(); 1242 auto *storage = impl.affineUniquer.get<IntegerSetStorage>( 1243 [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags); 1244 return IntegerSet(storage); 1245 } 1246 1247 //===----------------------------------------------------------------------===// 1248 // StorageUniquerSupport 1249 //===----------------------------------------------------------------------===// 1250 1251 /// Utility method to generate a callback that can be used to generate a 1252 /// diagnostic when checking the construction invariants of a storage object. 1253 /// This is defined out-of-line to avoid the need to include Location.h. 1254 llvm::unique_function<InFlightDiagnostic()> 1255 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { 1256 return [ctx] { return emitError(UnknownLoc::get(ctx)); }; 1257 } 1258 llvm::unique_function<InFlightDiagnostic()> 1259 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { 1260 return [=] { return emitError(loc); }; 1261 } 1262