1 //===- AsmParserState.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/AsmParser/AsmParserState.h" 10 #include "mlir/IR/Attributes.h" 11 #include "mlir/IR/Operation.h" 12 #include "mlir/IR/SymbolTable.h" 13 #include "mlir/IR/Types.h" 14 #include "mlir/IR/Value.h" 15 #include "mlir/Support/LLVM.h" 16 #include "llvm/ADT/ArrayRef.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/ADT/StringMap.h" 20 #include "llvm/ADT/iterator.h" 21 #include "llvm/Support/ErrorHandling.h" 22 #include <cassert> 23 #include <cctype> 24 #include <memory> 25 #include <utility> 26 27 using namespace mlir; 28 29 //===----------------------------------------------------------------------===// 30 // AsmParserState::Impl 31 //===----------------------------------------------------------------------===// 32 33 struct AsmParserState::Impl { 34 /// A map from a SymbolRefAttr to a range of uses. 35 using SymbolUseMap = 36 DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>; 37 38 struct PartialOpDef { 39 explicit PartialOpDef(const OperationName &opName) { 40 if (opName.hasTrait<OpTrait::SymbolTable>()) 41 symbolTable = std::make_unique<SymbolUseMap>(); 42 } 43 44 /// Return if this operation is a symbol table. 45 bool isSymbolTable() const { return symbolTable.get(); } 46 47 /// If this operation is a symbol table, the following contains symbol uses 48 /// within this operation. 49 std::unique_ptr<SymbolUseMap> symbolTable; 50 }; 51 52 /// Resolve any symbol table uses in the IR. 53 void resolveSymbolUses(); 54 55 /// A mapping from operations in the input source file to their parser state. 56 SmallVector<std::unique_ptr<OperationDefinition>> operations; 57 DenseMap<Operation *, unsigned> operationToIdx; 58 59 /// A mapping from blocks in the input source file to their parser state. 60 SmallVector<std::unique_ptr<BlockDefinition>> blocks; 61 DenseMap<Block *, unsigned> blocksToIdx; 62 63 /// A mapping from aliases in the input source file to their parser state. 64 SmallVector<std::unique_ptr<AttributeAliasDefinition>> attrAliases; 65 SmallVector<std::unique_ptr<TypeAliasDefinition>> typeAliases; 66 llvm::StringMap<unsigned> attrAliasToIdx; 67 llvm::StringMap<unsigned> typeAliasToIdx; 68 69 /// A set of value definitions that are placeholders for forward references. 70 /// This map should be empty if the parser finishes successfully. 71 DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses; 72 73 /// The symbol table operations within the IR. 74 SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>> 75 symbolTableOperations; 76 77 /// A stack of partial operation definitions that have been started but not 78 /// yet finalized. 79 SmallVector<PartialOpDef> partialOperations; 80 81 /// A stack of symbol use scopes. This is used when collecting symbol table 82 /// uses during parsing. 83 SmallVector<SymbolUseMap *> symbolUseScopes; 84 85 /// A symbol table containing all of the symbol table operations in the IR. 86 SymbolTableCollection symbolTable; 87 }; 88 89 void AsmParserState::Impl::resolveSymbolUses() { 90 SmallVector<Operation *> symbolOps; 91 for (auto &opAndUseMapIt : symbolTableOperations) { 92 for (auto &it : *opAndUseMapIt.second) { 93 symbolOps.clear(); 94 if (failed(symbolTable.lookupSymbolIn( 95 opAndUseMapIt.first, cast<SymbolRefAttr>(it.first), symbolOps))) 96 continue; 97 98 for (ArrayRef<SMRange> useRange : it.second) { 99 for (const auto &symIt : llvm::zip(symbolOps, useRange)) { 100 auto opIt = operationToIdx.find(std::get<0>(symIt)); 101 if (opIt != operationToIdx.end()) 102 operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt)); 103 } 104 } 105 } 106 } 107 } 108 109 //===----------------------------------------------------------------------===// 110 // AsmParserState 111 //===----------------------------------------------------------------------===// 112 113 AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {} 114 AsmParserState::~AsmParserState() = default; 115 AsmParserState &AsmParserState::operator=(AsmParserState &&other) { 116 impl = std::move(other.impl); 117 return *this; 118 } 119 120 //===----------------------------------------------------------------------===// 121 // Access State 122 123 auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> { 124 return llvm::make_pointee_range(llvm::ArrayRef(impl->blocks)); 125 } 126 127 auto AsmParserState::getBlockDef(Block *block) const 128 -> const BlockDefinition * { 129 auto it = impl->blocksToIdx.find(block); 130 return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second]; 131 } 132 133 auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> { 134 return llvm::make_pointee_range(llvm::ArrayRef(impl->operations)); 135 } 136 137 auto AsmParserState::getOpDef(Operation *op) const 138 -> const OperationDefinition * { 139 auto it = impl->operationToIdx.find(op); 140 return it == impl->operationToIdx.end() ? nullptr 141 : &*impl->operations[it->second]; 142 } 143 144 auto AsmParserState::getAttributeAliasDefs() const 145 -> iterator_range<AttributeDefIterator> { 146 return llvm::make_pointee_range(ArrayRef(impl->attrAliases)); 147 } 148 149 auto AsmParserState::getAttributeAliasDef(StringRef name) const 150 -> const AttributeAliasDefinition * { 151 auto it = impl->attrAliasToIdx.find(name); 152 return it == impl->attrAliasToIdx.end() ? nullptr 153 : &*impl->attrAliases[it->second]; 154 } 155 156 auto AsmParserState::getTypeAliasDefs() const 157 -> iterator_range<TypeDefIterator> { 158 return llvm::make_pointee_range(ArrayRef(impl->typeAliases)); 159 } 160 161 auto AsmParserState::getTypeAliasDef(StringRef name) const 162 -> const TypeAliasDefinition * { 163 auto it = impl->typeAliasToIdx.find(name); 164 return it == impl->typeAliasToIdx.end() ? nullptr 165 : &*impl->typeAliases[it->second]; 166 } 167 168 /// Lex a string token whose contents start at the given `curPtr`. Returns the 169 /// position at the end of the string, after a terminal or invalid character 170 /// (e.g. `"` or `\0`). 171 static const char *lexLocStringTok(const char *curPtr) { 172 while (char c = *curPtr++) { 173 // Check for various terminal characters. 174 if (StringRef("\"\n\v\f").contains(c)) 175 return curPtr; 176 177 // Check for escape sequences. 178 if (c == '\\') { 179 // Check a few known escapes and \xx hex digits. 180 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't') 181 ++curPtr; 182 else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) 183 curPtr += 2; 184 else 185 return curPtr; 186 } 187 } 188 189 // If we hit this point, we've reached the end of the buffer. Update the end 190 // pointer to not point past the buffer. 191 return curPtr - 1; 192 } 193 194 SMRange AsmParserState::convertIdLocToRange(SMLoc loc) { 195 if (!loc.isValid()) 196 return SMRange(); 197 const char *curPtr = loc.getPointer(); 198 199 // Check if this is a string token. 200 if (*curPtr == '"') { 201 curPtr = lexLocStringTok(curPtr + 1); 202 203 // Otherwise, default to handling an identifier. 204 } else { 205 // Return if the given character is a valid identifier character. 206 auto isIdentifierChar = [](char c) { 207 return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-'; 208 }; 209 210 while (*curPtr && isIdentifierChar(*(++curPtr))) 211 continue; 212 } 213 214 return SMRange(loc, SMLoc::getFromPointer(curPtr)); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // Populate State 219 220 void AsmParserState::initialize(Operation *topLevelOp) { 221 startOperationDefinition(topLevelOp->getName()); 222 223 // If the top-level operation is a symbol table, push a new symbol scope. 224 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back(); 225 if (partialOpDef.isSymbolTable()) 226 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get()); 227 } 228 229 void AsmParserState::finalize(Operation *topLevelOp) { 230 assert(!impl->partialOperations.empty() && 231 "expected valid partial operation definition"); 232 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val(); 233 234 // If this operation is a symbol table, resolve any symbol uses. 235 if (partialOpDef.isSymbolTable()) { 236 impl->symbolTableOperations.emplace_back( 237 topLevelOp, std::move(partialOpDef.symbolTable)); 238 } 239 impl->resolveSymbolUses(); 240 } 241 242 void AsmParserState::startOperationDefinition(const OperationName &opName) { 243 impl->partialOperations.emplace_back(opName); 244 } 245 246 void AsmParserState::finalizeOperationDefinition( 247 Operation *op, SMRange nameLoc, SMLoc endLoc, 248 ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) { 249 assert(!impl->partialOperations.empty() && 250 "expected valid partial operation definition"); 251 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val(); 252 253 // Build the full operation definition. 254 std::unique_ptr<OperationDefinition> def = 255 std::make_unique<OperationDefinition>(op, nameLoc, endLoc); 256 for (auto &resultGroup : resultGroups) 257 def->resultGroups.emplace_back(resultGroup.first, 258 convertIdLocToRange(resultGroup.second)); 259 impl->operationToIdx.try_emplace(op, impl->operations.size()); 260 impl->operations.emplace_back(std::move(def)); 261 262 // If this operation is a symbol table, resolve any symbol uses. 263 if (partialOpDef.isSymbolTable()) { 264 impl->symbolTableOperations.emplace_back( 265 op, std::move(partialOpDef.symbolTable)); 266 } 267 } 268 269 void AsmParserState::startRegionDefinition() { 270 assert(!impl->partialOperations.empty() && 271 "expected valid partial operation definition"); 272 273 // If the parent operation of this region is a symbol table, we also push a 274 // new symbol scope. 275 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back(); 276 if (partialOpDef.isSymbolTable()) 277 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get()); 278 } 279 280 void AsmParserState::finalizeRegionDefinition() { 281 assert(!impl->partialOperations.empty() && 282 "expected valid partial operation definition"); 283 284 // If the parent operation of this region is a symbol table, pop the symbol 285 // scope for this region. 286 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back(); 287 if (partialOpDef.isSymbolTable()) 288 impl->symbolUseScopes.pop_back(); 289 } 290 291 void AsmParserState::addDefinition(Block *block, SMLoc location) { 292 auto [it, inserted] = 293 impl->blocksToIdx.try_emplace(block, impl->blocks.size()); 294 if (inserted) { 295 impl->blocks.emplace_back(std::make_unique<BlockDefinition>( 296 block, convertIdLocToRange(location))); 297 return; 298 } 299 300 // If an entry already exists, this was a forward declaration that now has a 301 // proper definition. 302 impl->blocks[it->second]->definition.loc = convertIdLocToRange(location); 303 } 304 305 void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) { 306 auto it = impl->blocksToIdx.find(blockArg.getOwner()); 307 assert(it != impl->blocksToIdx.end() && 308 "expected owner block to have an entry"); 309 BlockDefinition &def = *impl->blocks[it->second]; 310 unsigned argIdx = blockArg.getArgNumber(); 311 312 if (def.arguments.size() <= argIdx) 313 def.arguments.resize(argIdx + 1); 314 def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location)); 315 } 316 317 void AsmParserState::addAttrAliasDefinition(StringRef name, SMRange location, 318 Attribute value) { 319 auto [it, inserted] = 320 impl->attrAliasToIdx.try_emplace(name, impl->attrAliases.size()); 321 // Location aliases may be referenced before they are defined. 322 if (inserted) { 323 impl->attrAliases.push_back( 324 std::make_unique<AttributeAliasDefinition>(name, location, value)); 325 } else { 326 AttributeAliasDefinition &attr = *impl->attrAliases[it->second]; 327 attr.definition.loc = location; 328 attr.value = value; 329 } 330 } 331 332 void AsmParserState::addTypeAliasDefinition(StringRef name, SMRange location, 333 Type value) { 334 [[maybe_unused]] auto [it, inserted] = 335 impl->typeAliasToIdx.try_emplace(name, impl->typeAliases.size()); 336 assert(inserted && "unexpected attribute alias redefinition"); 337 impl->typeAliases.push_back( 338 std::make_unique<TypeAliasDefinition>(name, location, value)); 339 } 340 341 void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) { 342 // Handle the case where the value is an operation result. 343 if (OpResult result = dyn_cast<OpResult>(value)) { 344 // Check to see if a definition for the parent operation has been recorded. 345 // If one hasn't, we treat the provided value as a placeholder value that 346 // will be refined further later. 347 Operation *parentOp = result.getOwner(); 348 auto existingIt = impl->operationToIdx.find(parentOp); 349 if (existingIt == impl->operationToIdx.end()) { 350 impl->placeholderValueUses[value].append(locations.begin(), 351 locations.end()); 352 return; 353 } 354 355 // If a definition does exist, locate the value's result group and add the 356 // use. The result groups are ordered by increasing start index, so we just 357 // need to find the last group that has a smaller/equal start index. 358 unsigned resultNo = result.getResultNumber(); 359 OperationDefinition &def = *impl->operations[existingIt->second]; 360 for (auto &resultGroup : llvm::reverse(def.resultGroups)) { 361 if (resultNo >= resultGroup.startIndex) { 362 for (SMLoc loc : locations) 363 resultGroup.definition.uses.push_back(convertIdLocToRange(loc)); 364 return; 365 } 366 } 367 llvm_unreachable("expected valid result group for value use"); 368 } 369 370 // Otherwise, this is a block argument. 371 BlockArgument arg = cast<BlockArgument>(value); 372 auto existingIt = impl->blocksToIdx.find(arg.getOwner()); 373 assert(existingIt != impl->blocksToIdx.end() && 374 "expected valid block definition for block argument"); 375 BlockDefinition &blockDef = *impl->blocks[existingIt->second]; 376 SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()]; 377 for (SMLoc loc : locations) 378 argDef.uses.emplace_back(convertIdLocToRange(loc)); 379 } 380 381 void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) { 382 auto [it, inserted] = 383 impl->blocksToIdx.try_emplace(block, impl->blocks.size()); 384 if (inserted) 385 impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block)); 386 387 BlockDefinition &def = *impl->blocks[it->second]; 388 for (SMLoc loc : locations) 389 def.definition.uses.push_back(convertIdLocToRange(loc)); 390 } 391 392 void AsmParserState::addUses(SymbolRefAttr refAttr, 393 ArrayRef<SMRange> locations) { 394 // Ignore this symbol if no scopes are active. 395 if (impl->symbolUseScopes.empty()) 396 return; 397 398 assert((refAttr.getNestedReferences().size() + 1) == locations.size() && 399 "expected the same number of references as provided locations"); 400 (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(), 401 locations.end()); 402 } 403 404 void AsmParserState::addAttrAliasUses(StringRef name, SMRange location) { 405 auto it = impl->attrAliasToIdx.find(name); 406 // Location aliases may be referenced before they are defined. 407 if (it == impl->attrAliasToIdx.end()) { 408 it = impl->attrAliasToIdx.try_emplace(name, impl->attrAliases.size()).first; 409 impl->attrAliases.push_back( 410 std::make_unique<AttributeAliasDefinition>(name)); 411 } 412 AttributeAliasDefinition &def = *impl->attrAliases[it->second]; 413 def.definition.uses.push_back(location); 414 } 415 416 void AsmParserState::addTypeAliasUses(StringRef name, SMRange location) { 417 auto it = impl->typeAliasToIdx.find(name); 418 // Location aliases may be referenced before they are defined. 419 assert(it != impl->typeAliasToIdx.end() && 420 "expected valid type alias definition"); 421 TypeAliasDefinition &def = *impl->typeAliases[it->second]; 422 def.definition.uses.push_back(location); 423 } 424 425 void AsmParserState::refineDefinition(Value oldValue, Value newValue) { 426 auto it = impl->placeholderValueUses.find(oldValue); 427 assert(it != impl->placeholderValueUses.end() && 428 "expected `oldValue` to be a placeholder"); 429 addUses(newValue, it->second); 430 impl->placeholderValueUses.erase(oldValue); 431 } 432