1 //===- TransformDialect.cpp - Transform Dialect Definition ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 10 #include "mlir/Analysis/CallGraph.h" 11 #include "mlir/Dialect/Transform/IR/TransformOps.h" 12 #include "mlir/Dialect/Transform/IR/TransformTypes.h" 13 #include "mlir/Dialect/Transform/IR/Utils.h" 14 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 15 #include "mlir/IR/DialectImplementation.h" 16 #include "llvm/ADT/SCCIterator.h" 17 18 using namespace mlir; 19 20 #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" 21 22 #ifndef NDEBUG 23 void transform::detail::checkImplementsTransformOpInterface( 24 StringRef name, MLIRContext *context) { 25 // Since the operation is being inserted into the Transform dialect and the 26 // dialect does not implement the interface fallback, only check for the op 27 // itself having the interface implementation. 28 RegisteredOperationName opName = 29 *RegisteredOperationName::lookup(name, context); 30 assert((opName.hasInterface<TransformOpInterface>() || 31 opName.hasInterface<PatternDescriptorOpInterface>() || 32 opName.hasInterface<ConversionPatternDescriptorOpInterface>() || 33 opName.hasInterface<TypeConverterBuilderOpInterface>() || 34 opName.hasTrait<OpTrait::IsTerminator>()) && 35 "non-terminator ops injected into the transform dialect must " 36 "implement TransformOpInterface or PatternDescriptorOpInterface or " 37 "ConversionPatternDescriptorOpInterface"); 38 if (!opName.hasInterface<PatternDescriptorOpInterface>() && 39 !opName.hasInterface<ConversionPatternDescriptorOpInterface>() && 40 !opName.hasInterface<TypeConverterBuilderOpInterface>()) { 41 assert(opName.hasInterface<MemoryEffectOpInterface>() && 42 "ops injected into the transform dialect must implement " 43 "MemoryEffectsOpInterface"); 44 } 45 } 46 47 void transform::detail::checkImplementsTransformHandleTypeInterface( 48 TypeID typeID, MLIRContext *context) { 49 const auto &abstractType = AbstractType::lookup(typeID, context); 50 assert((abstractType.hasInterface( 51 TransformHandleTypeInterface::getInterfaceID()) || 52 abstractType.hasInterface( 53 TransformParamTypeInterface::getInterfaceID()) || 54 abstractType.hasInterface( 55 TransformValueHandleTypeInterface::getInterfaceID())) && 56 "expected Transform dialect type to implement one of the three " 57 "interfaces"); 58 } 59 #endif // NDEBUG 60 61 void transform::TransformDialect::initialize() { 62 // Using the checked versions to enable the same assertions as for the ops 63 // from extensions. 64 addOperationsChecked< 65 #define GET_OP_LIST 66 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 67 >(); 68 initializeTypes(); 69 initializeLibraryModule(); 70 } 71 72 Type transform::TransformDialect::parseType(DialectAsmParser &parser) const { 73 StringRef keyword; 74 SMLoc loc = parser.getCurrentLocation(); 75 if (failed(parser.parseKeyword(&keyword))) 76 return nullptr; 77 78 auto it = typeParsingHooks.find(keyword); 79 if (it == typeParsingHooks.end()) { 80 parser.emitError(loc) << "unknown type mnemonic: " << keyword; 81 return nullptr; 82 } 83 84 return it->getValue()(parser); 85 } 86 87 void transform::TransformDialect::printType(Type type, 88 DialectAsmPrinter &printer) const { 89 auto it = typePrintingHooks.find(type.getTypeID()); 90 assert(it != typePrintingHooks.end() && "printing unknown type"); 91 it->getSecond()(type, printer); 92 } 93 94 LogicalResult transform::TransformDialect::loadIntoLibraryModule( 95 ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) { 96 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library)); 97 } 98 99 void transform::TransformDialect::initializeLibraryModule() { 100 MLIRContext *context = getContext(); 101 auto loc = 102 FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0); 103 libraryModule = ModuleOp::create(loc, "__transform_library"); 104 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName, 105 UnitAttr::get(context)); 106 } 107 108 void transform::TransformDialect::reportDuplicateTypeRegistration( 109 StringRef mnemonic) { 110 std::string buffer; 111 llvm::raw_string_ostream msg(buffer); 112 msg << "extensible dialect type '" << mnemonic 113 << "' is already registered with a different implementation"; 114 llvm::report_fatal_error(StringRef(buffer)); 115 } 116 117 void transform::TransformDialect::reportDuplicateOpRegistration( 118 StringRef opName) { 119 std::string buffer; 120 llvm::raw_string_ostream msg(buffer); 121 msg << "extensible dialect operation '" << opName 122 << "' is already registered with a mismatching TypeID"; 123 llvm::report_fatal_error(StringRef(buffer)); 124 } 125 126 LogicalResult transform::TransformDialect::verifyOperationAttribute( 127 Operation *op, NamedAttribute attribute) { 128 if (attribute.getName().getValue() == kWithNamedSequenceAttrName) { 129 if (!op->hasTrait<OpTrait::SymbolTable>()) { 130 return emitError(op->getLoc()) << attribute.getName() 131 << " attribute can only be attached to " 132 "operations with symbol tables"; 133 } 134 135 const mlir::CallGraph callgraph(op); 136 for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) { 137 if (!scc.hasCycle()) 138 continue; 139 140 // Need to check this here additionally because this verification may run 141 // before we check the nested operations. 142 if ((*scc->begin())->isExternal()) 143 return op->emitOpError() << "contains a call to an external operation, " 144 "which is not allowed"; 145 146 Operation *first = (*scc->begin())->getCallableRegion()->getParentOp(); 147 InFlightDiagnostic diag = emitError(first->getLoc()) 148 << "recursion not allowed in named sequences"; 149 for (auto it = std::next(scc->begin()); it != scc->end(); ++it) { 150 // Need to check this here additionally because this verification may 151 // run before we check the nested operations. 152 if ((*it)->isExternal()) { 153 return op->emitOpError() << "contains a call to an external " 154 "operation, which is not allowed"; 155 } 156 157 Operation *current = (*it)->getCallableRegion()->getParentOp(); 158 diag.attachNote(current->getLoc()) << "operation on recursion stack"; 159 } 160 return diag; 161 } 162 return success(); 163 } 164 if (attribute.getName().getValue() == kTargetTagAttrName) { 165 if (!llvm::isa<StringAttr>(attribute.getValue())) { 166 return op->emitError() 167 << attribute.getName() << " attribute must be a string"; 168 } 169 return success(); 170 } 171 if (attribute.getName().getValue() == kArgConsumedAttrName || 172 attribute.getName().getValue() == kArgReadOnlyAttrName) { 173 if (!llvm::isa<UnitAttr>(attribute.getValue())) { 174 return op->emitError() 175 << attribute.getName() << " must be a unit attribute"; 176 } 177 return success(); 178 } 179 if (attribute.getName().getValue() == 180 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) { 181 if (!llvm::isa<UnitAttr>(attribute.getValue())) { 182 return op->emitError() 183 << attribute.getName() << " must be a unit attribute"; 184 } 185 return success(); 186 } 187 return emitError(op->getLoc()) 188 << "unknown attribute: " << attribute.getName(); 189 } 190