1d064c480SAlex Zinenko //===- TransformDialect.cpp - Transform Dialect Definition ----------------===// 2d064c480SAlex Zinenko // 3d064c480SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d064c480SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5d064c480SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d064c480SAlex Zinenko // 7d064c480SAlex Zinenko //===----------------------------------------------------------------------===// 8d064c480SAlex Zinenko 9d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h" 10fb409a28SAlex Zinenko #include "mlir/Analysis/CallGraph.h" 110eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.h" 12bba85ebdSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformTypes.h" 1399c15eb4SIngo Müller #include "mlir/Dialect/Transform/IR/Utils.h" 14*5a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 15bba85ebdSAlex Zinenko #include "mlir/IR/DialectImplementation.h" 16fb409a28SAlex Zinenko #include "llvm/ADT/SCCIterator.h" 17d064c480SAlex Zinenko 18d064c480SAlex Zinenko using namespace mlir; 19d064c480SAlex Zinenko 200eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" 210eb403adSAlex Zinenko 22bba85ebdSAlex Zinenko #ifndef NDEBUG 23b586d56cSAlex Zinenko void transform::detail::checkImplementsTransformOpInterface( 24b586d56cSAlex Zinenko StringRef name, MLIRContext *context) { 25b586d56cSAlex Zinenko // Since the operation is being inserted into the Transform dialect and the 26b586d56cSAlex Zinenko // dialect does not implement the interface fallback, only check for the op 27b586d56cSAlex Zinenko // itself having the interface implementation. 28b586d56cSAlex Zinenko RegisteredOperationName opName = 29b586d56cSAlex Zinenko *RegisteredOperationName::lookup(name, context); 30b586d56cSAlex Zinenko assert((opName.hasInterface<TransformOpInterface>() || 315a10f207SMatthias Springer opName.hasInterface<PatternDescriptorOpInterface>() || 32bcfdb3e4SMatthias Springer opName.hasInterface<ConversionPatternDescriptorOpInterface>() || 33bcfdb3e4SMatthias Springer opName.hasInterface<TypeConverterBuilderOpInterface>() || 34b586d56cSAlex Zinenko opName.hasTrait<OpTrait::IsTerminator>()) && 35b586d56cSAlex Zinenko "non-terminator ops injected into the transform dialect must " 36bcfdb3e4SMatthias Springer "implement TransformOpInterface or PatternDescriptorOpInterface or " 37bcfdb3e4SMatthias Springer "ConversionPatternDescriptorOpInterface"); 38bcfdb3e4SMatthias Springer if (!opName.hasInterface<PatternDescriptorOpInterface>() && 39bcfdb3e4SMatthias Springer !opName.hasInterface<ConversionPatternDescriptorOpInterface>() && 40bcfdb3e4SMatthias Springer !opName.hasInterface<TypeConverterBuilderOpInterface>()) { 41b586d56cSAlex Zinenko assert(opName.hasInterface<MemoryEffectOpInterface>() && 42b586d56cSAlex Zinenko "ops injected into the transform dialect must implement " 43b586d56cSAlex Zinenko "MemoryEffectsOpInterface"); 44b586d56cSAlex Zinenko } 455a10f207SMatthias Springer } 46b586d56cSAlex Zinenko 4797c05062SAlex Zinenko void transform::detail::checkImplementsTransformHandleTypeInterface( 48bba85ebdSAlex Zinenko TypeID typeID, MLIRContext *context) { 49bba85ebdSAlex Zinenko const auto &abstractType = AbstractType::lookup(typeID, context); 50a7026288SAlex Zinenko assert((abstractType.hasInterface( 5197c05062SAlex Zinenko TransformHandleTypeInterface::getInterfaceID()) || 52ed02fa81SAlex Zinenko abstractType.hasInterface( 53a7026288SAlex Zinenko TransformParamTypeInterface::getInterfaceID()) || 54a7026288SAlex Zinenko abstractType.hasInterface( 55a7026288SAlex Zinenko TransformValueHandleTypeInterface::getInterfaceID())) && 56a7026288SAlex Zinenko "expected Transform dialect type to implement one of the three " 57a7026288SAlex Zinenko "interfaces"); 58bba85ebdSAlex Zinenko } 59bba85ebdSAlex Zinenko #endif // NDEBUG 60bba85ebdSAlex Zinenko 610eb403adSAlex Zinenko void transform::TransformDialect::initialize() { 62bba85ebdSAlex Zinenko // Using the checked versions to enable the same assertions as for the ops 63bba85ebdSAlex Zinenko // from extensions. 6440a8bd63SAlex Zinenko addOperationsChecked< 650eb403adSAlex Zinenko #define GET_OP_LIST 660eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 670eb403adSAlex Zinenko >(); 683e1f6d02SAlex Zinenko initializeTypes(); 6999c15eb4SIngo Müller initializeLibraryModule(); 7030f22429SAlex Zinenko } 71a60ed954SAlex Zinenko 72bba85ebdSAlex Zinenko Type transform::TransformDialect::parseType(DialectAsmParser &parser) const { 73bba85ebdSAlex Zinenko StringRef keyword; 74bba85ebdSAlex Zinenko SMLoc loc = parser.getCurrentLocation(); 75bba85ebdSAlex Zinenko if (failed(parser.parseKeyword(&keyword))) 76bba85ebdSAlex Zinenko return nullptr; 77bba85ebdSAlex Zinenko 78bba85ebdSAlex Zinenko auto it = typeParsingHooks.find(keyword); 79bba85ebdSAlex Zinenko if (it == typeParsingHooks.end()) { 80bba85ebdSAlex Zinenko parser.emitError(loc) << "unknown type mnemonic: " << keyword; 81bba85ebdSAlex Zinenko return nullptr; 82bba85ebdSAlex Zinenko } 83bba85ebdSAlex Zinenko 84bba85ebdSAlex Zinenko return it->getValue()(parser); 85bba85ebdSAlex Zinenko } 86bba85ebdSAlex Zinenko 87bba85ebdSAlex Zinenko void transform::TransformDialect::printType(Type type, 88bba85ebdSAlex Zinenko DialectAsmPrinter &printer) const { 89bba85ebdSAlex Zinenko auto it = typePrintingHooks.find(type.getTypeID()); 90bba85ebdSAlex Zinenko assert(it != typePrintingHooks.end() && "printing unknown type"); 91bba85ebdSAlex Zinenko it->getSecond()(type, printer); 92bba85ebdSAlex Zinenko } 93bba85ebdSAlex Zinenko 9499c15eb4SIngo Müller LogicalResult transform::TransformDialect::loadIntoLibraryModule( 9599c15eb4SIngo Müller ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) { 9699c15eb4SIngo Müller return detail::mergeSymbolsInto(getLibraryModule(), std::move(library)); 9799c15eb4SIngo Müller } 9899c15eb4SIngo Müller 9999c15eb4SIngo Müller void transform::TransformDialect::initializeLibraryModule() { 10099c15eb4SIngo Müller MLIRContext *context = getContext(); 10199c15eb4SIngo Müller auto loc = 10299c15eb4SIngo Müller FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0); 10399c15eb4SIngo Müller libraryModule = ModuleOp::create(loc, "__transform_library"); 10499c15eb4SIngo Müller libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName, 10599c15eb4SIngo Müller UnitAttr::get(context)); 10699c15eb4SIngo Müller } 10799c15eb4SIngo Müller 108bba85ebdSAlex Zinenko void transform::TransformDialect::reportDuplicateTypeRegistration( 109bba85ebdSAlex Zinenko StringRef mnemonic) { 110bba85ebdSAlex Zinenko std::string buffer; 111bba85ebdSAlex Zinenko llvm::raw_string_ostream msg(buffer); 112b586d56cSAlex Zinenko msg << "extensible dialect type '" << mnemonic 113bba85ebdSAlex Zinenko << "' is already registered with a different implementation"; 114bba85ebdSAlex Zinenko llvm::report_fatal_error(StringRef(buffer)); 115bba85ebdSAlex Zinenko } 116bba85ebdSAlex Zinenko 117b586d56cSAlex Zinenko void transform::TransformDialect::reportDuplicateOpRegistration( 118b586d56cSAlex Zinenko StringRef opName) { 119b586d56cSAlex Zinenko std::string buffer; 120b586d56cSAlex Zinenko llvm::raw_string_ostream msg(buffer); 121b586d56cSAlex Zinenko msg << "extensible dialect operation '" << opName 122b586d56cSAlex Zinenko << "' is already registered with a mismatching TypeID"; 123b586d56cSAlex Zinenko llvm::report_fatal_error(StringRef(buffer)); 124b586d56cSAlex Zinenko } 125b586d56cSAlex Zinenko 126fb409a28SAlex Zinenko LogicalResult transform::TransformDialect::verifyOperationAttribute( 127fb409a28SAlex Zinenko Operation *op, NamedAttribute attribute) { 128fb409a28SAlex Zinenko if (attribute.getName().getValue() == kWithNamedSequenceAttrName) { 129fb409a28SAlex Zinenko if (!op->hasTrait<OpTrait::SymbolTable>()) { 130fb409a28SAlex Zinenko return emitError(op->getLoc()) << attribute.getName() 131fb409a28SAlex Zinenko << " attribute can only be attached to " 132fb409a28SAlex Zinenko "operations with symbol tables"; 133fb409a28SAlex Zinenko } 134fb409a28SAlex Zinenko 135fb409a28SAlex Zinenko const mlir::CallGraph callgraph(op); 136fb409a28SAlex Zinenko for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) { 137fb409a28SAlex Zinenko if (!scc.hasCycle()) 138fb409a28SAlex Zinenko continue; 139fb409a28SAlex Zinenko 140fb409a28SAlex Zinenko // Need to check this here additionally because this verification may run 141fb409a28SAlex Zinenko // before we check the nested operations. 142fb409a28SAlex Zinenko if ((*scc->begin())->isExternal()) 143fb409a28SAlex Zinenko return op->emitOpError() << "contains a call to an external operation, " 144fb409a28SAlex Zinenko "which is not allowed"; 145fb409a28SAlex Zinenko 146fb409a28SAlex Zinenko Operation *first = (*scc->begin())->getCallableRegion()->getParentOp(); 147fb409a28SAlex Zinenko InFlightDiagnostic diag = emitError(first->getLoc()) 148fb409a28SAlex Zinenko << "recursion not allowed in named sequences"; 149fb409a28SAlex Zinenko for (auto it = std::next(scc->begin()); it != scc->end(); ++it) { 150fb409a28SAlex Zinenko // Need to check this here additionally because this verification may 151fb409a28SAlex Zinenko // run before we check the nested operations. 152fb409a28SAlex Zinenko if ((*it)->isExternal()) { 153fb409a28SAlex Zinenko return op->emitOpError() << "contains a call to an external " 154fb409a28SAlex Zinenko "operation, which is not allowed"; 155fb409a28SAlex Zinenko } 156fb409a28SAlex Zinenko 157fb409a28SAlex Zinenko Operation *current = (*it)->getCallableRegion()->getParentOp(); 158fb409a28SAlex Zinenko diag.attachNote(current->getLoc()) << "operation on recursion stack"; 159fb409a28SAlex Zinenko } 160fb409a28SAlex Zinenko return diag; 161fb409a28SAlex Zinenko } 162fb409a28SAlex Zinenko return success(); 163fb409a28SAlex Zinenko } 164fb409a28SAlex Zinenko if (attribute.getName().getValue() == kTargetTagAttrName) { 165c1fa60b4STres Popp if (!llvm::isa<StringAttr>(attribute.getValue())) { 166fb409a28SAlex Zinenko return op->emitError() 167fb409a28SAlex Zinenko << attribute.getName() << " attribute must be a string"; 168fb409a28SAlex Zinenko } 169fb409a28SAlex Zinenko return success(); 170fb409a28SAlex Zinenko } 17141109341SAlex Zinenko if (attribute.getName().getValue() == kArgConsumedAttrName || 17241109341SAlex Zinenko attribute.getName().getValue() == kArgReadOnlyAttrName) { 173c1fa60b4STres Popp if (!llvm::isa<UnitAttr>(attribute.getValue())) { 17441109341SAlex Zinenko return op->emitError() 17541109341SAlex Zinenko << attribute.getName() << " must be a unit attribute"; 17641109341SAlex Zinenko } 17741109341SAlex Zinenko return success(); 17841109341SAlex Zinenko } 179*5a9bdd85SOleksandr "Alex" Zinenko if (attribute.getName().getValue() == 180*5a9bdd85SOleksandr "Alex" Zinenko FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) { 181c63d2b2cSMatthias Springer if (!llvm::isa<UnitAttr>(attribute.getValue())) { 182c63d2b2cSMatthias Springer return op->emitError() 183c63d2b2cSMatthias Springer << attribute.getName() << " must be a unit attribute"; 184c63d2b2cSMatthias Springer } 185c63d2b2cSMatthias Springer return success(); 186c63d2b2cSMatthias Springer } 187fb409a28SAlex Zinenko return emitError(op->getLoc()) 188fb409a28SAlex Zinenko << "unknown attribute: " << attribute.getName(); 189fb409a28SAlex Zinenko } 190