xref: /llvm-project/mlir/lib/AsmParser/AsmParserState.cpp (revision 59b7461c139d30ea57db4211decebe43117676fa)
1 //===- AsmParserState.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/AsmParser/AsmParserState.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Types.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/ADT/iterator.h"
21 #include "llvm/Support/ErrorHandling.h"
22 #include <cassert>
23 #include <cctype>
24 #include <memory>
25 #include <utility>
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // AsmParserState::Impl
31 //===----------------------------------------------------------------------===//
32 
33 struct AsmParserState::Impl {
34   /// A map from a SymbolRefAttr to a range of uses.
35   using SymbolUseMap =
36       DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>;
37 
38   struct PartialOpDef {
39     explicit PartialOpDef(const OperationName &opName) {
40       if (opName.hasTrait<OpTrait::SymbolTable>())
41         symbolTable = std::make_unique<SymbolUseMap>();
42     }
43 
44     /// Return if this operation is a symbol table.
45     bool isSymbolTable() const { return symbolTable.get(); }
46 
47     /// If this operation is a symbol table, the following contains symbol uses
48     /// within this operation.
49     std::unique_ptr<SymbolUseMap> symbolTable;
50   };
51 
52   /// Resolve any symbol table uses in the IR.
53   void resolveSymbolUses();
54 
55   /// A mapping from operations in the input source file to their parser state.
56   SmallVector<std::unique_ptr<OperationDefinition>> operations;
57   DenseMap<Operation *, unsigned> operationToIdx;
58 
59   /// A mapping from blocks in the input source file to their parser state.
60   SmallVector<std::unique_ptr<BlockDefinition>> blocks;
61   DenseMap<Block *, unsigned> blocksToIdx;
62 
63   /// A mapping from aliases in the input source file to their parser state.
64   SmallVector<std::unique_ptr<AttributeAliasDefinition>> attrAliases;
65   SmallVector<std::unique_ptr<TypeAliasDefinition>> typeAliases;
66   llvm::StringMap<unsigned> attrAliasToIdx;
67   llvm::StringMap<unsigned> typeAliasToIdx;
68 
69   /// A set of value definitions that are placeholders for forward references.
70   /// This map should be empty if the parser finishes successfully.
71   DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses;
72 
73   /// The symbol table operations within the IR.
74   SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
75       symbolTableOperations;
76 
77   /// A stack of partial operation definitions that have been started but not
78   /// yet finalized.
79   SmallVector<PartialOpDef> partialOperations;
80 
81   /// A stack of symbol use scopes. This is used when collecting symbol table
82   /// uses during parsing.
83   SmallVector<SymbolUseMap *> symbolUseScopes;
84 
85   /// A symbol table containing all of the symbol table operations in the IR.
86   SymbolTableCollection symbolTable;
87 };
88 
89 void AsmParserState::Impl::resolveSymbolUses() {
90   SmallVector<Operation *> symbolOps;
91   for (auto &opAndUseMapIt : symbolTableOperations) {
92     for (auto &it : *opAndUseMapIt.second) {
93       symbolOps.clear();
94       if (failed(symbolTable.lookupSymbolIn(
95               opAndUseMapIt.first, cast<SymbolRefAttr>(it.first), symbolOps)))
96         continue;
97 
98       for (ArrayRef<SMRange> useRange : it.second) {
99         for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
100           auto opIt = operationToIdx.find(std::get<0>(symIt));
101           if (opIt != operationToIdx.end())
102             operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
103         }
104       }
105     }
106   }
107 }
108 
109 //===----------------------------------------------------------------------===//
110 // AsmParserState
111 //===----------------------------------------------------------------------===//
112 
113 AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
114 AsmParserState::~AsmParserState() = default;
115 AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
116   impl = std::move(other.impl);
117   return *this;
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Access State
122 
123 auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
124   return llvm::make_pointee_range(llvm::ArrayRef(impl->blocks));
125 }
126 
127 auto AsmParserState::getBlockDef(Block *block) const
128     -> const BlockDefinition * {
129   auto it = impl->blocksToIdx.find(block);
130   return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
131 }
132 
133 auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
134   return llvm::make_pointee_range(llvm::ArrayRef(impl->operations));
135 }
136 
137 auto AsmParserState::getOpDef(Operation *op) const
138     -> const OperationDefinition * {
139   auto it = impl->operationToIdx.find(op);
140   return it == impl->operationToIdx.end() ? nullptr
141                                           : &*impl->operations[it->second];
142 }
143 
144 auto AsmParserState::getAttributeAliasDefs() const
145     -> iterator_range<AttributeDefIterator> {
146   return llvm::make_pointee_range(ArrayRef(impl->attrAliases));
147 }
148 
149 auto AsmParserState::getAttributeAliasDef(StringRef name) const
150     -> const AttributeAliasDefinition * {
151   auto it = impl->attrAliasToIdx.find(name);
152   return it == impl->attrAliasToIdx.end() ? nullptr
153                                           : &*impl->attrAliases[it->second];
154 }
155 
156 auto AsmParserState::getTypeAliasDefs() const
157     -> iterator_range<TypeDefIterator> {
158   return llvm::make_pointee_range(ArrayRef(impl->typeAliases));
159 }
160 
161 auto AsmParserState::getTypeAliasDef(StringRef name) const
162     -> const TypeAliasDefinition * {
163   auto it = impl->typeAliasToIdx.find(name);
164   return it == impl->typeAliasToIdx.end() ? nullptr
165                                           : &*impl->typeAliases[it->second];
166 }
167 
168 /// Lex a string token whose contents start at the given `curPtr`. Returns the
169 /// position at the end of the string, after a terminal or invalid character
170 /// (e.g. `"` or `\0`).
171 static const char *lexLocStringTok(const char *curPtr) {
172   while (char c = *curPtr++) {
173     // Check for various terminal characters.
174     if (StringRef("\"\n\v\f").contains(c))
175       return curPtr;
176 
177     // Check for escape sequences.
178     if (c == '\\') {
179       // Check a few known escapes and \xx hex digits.
180       if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
181         ++curPtr;
182       else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
183         curPtr += 2;
184       else
185         return curPtr;
186     }
187   }
188 
189   // If we hit this point, we've reached the end of the buffer. Update the end
190   // pointer to not point past the buffer.
191   return curPtr - 1;
192 }
193 
194 SMRange AsmParserState::convertIdLocToRange(SMLoc loc) {
195   if (!loc.isValid())
196     return SMRange();
197   const char *curPtr = loc.getPointer();
198 
199   // Check if this is a string token.
200   if (*curPtr == '"') {
201     curPtr = lexLocStringTok(curPtr + 1);
202 
203     // Otherwise, default to handling an identifier.
204   } else {
205     // Return if the given character is a valid identifier character.
206     auto isIdentifierChar = [](char c) {
207       return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
208     };
209 
210     while (*curPtr && isIdentifierChar(*(++curPtr)))
211       continue;
212   }
213 
214   return SMRange(loc, SMLoc::getFromPointer(curPtr));
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // Populate State
219 
220 void AsmParserState::initialize(Operation *topLevelOp) {
221   startOperationDefinition(topLevelOp->getName());
222 
223   // If the top-level operation is a symbol table, push a new symbol scope.
224   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
225   if (partialOpDef.isSymbolTable())
226     impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
227 }
228 
229 void AsmParserState::finalize(Operation *topLevelOp) {
230   assert(!impl->partialOperations.empty() &&
231          "expected valid partial operation definition");
232   Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
233 
234   // If this operation is a symbol table, resolve any symbol uses.
235   if (partialOpDef.isSymbolTable()) {
236     impl->symbolTableOperations.emplace_back(
237         topLevelOp, std::move(partialOpDef.symbolTable));
238   }
239   impl->resolveSymbolUses();
240 }
241 
242 void AsmParserState::startOperationDefinition(const OperationName &opName) {
243   impl->partialOperations.emplace_back(opName);
244 }
245 
246 void AsmParserState::finalizeOperationDefinition(
247     Operation *op, SMRange nameLoc, SMLoc endLoc,
248     ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) {
249   assert(!impl->partialOperations.empty() &&
250          "expected valid partial operation definition");
251   Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
252 
253   // Build the full operation definition.
254   std::unique_ptr<OperationDefinition> def =
255       std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
256   for (auto &resultGroup : resultGroups)
257     def->resultGroups.emplace_back(resultGroup.first,
258                                    convertIdLocToRange(resultGroup.second));
259   impl->operationToIdx.try_emplace(op, impl->operations.size());
260   impl->operations.emplace_back(std::move(def));
261 
262   // If this operation is a symbol table, resolve any symbol uses.
263   if (partialOpDef.isSymbolTable()) {
264     impl->symbolTableOperations.emplace_back(
265         op, std::move(partialOpDef.symbolTable));
266   }
267 }
268 
269 void AsmParserState::startRegionDefinition() {
270   assert(!impl->partialOperations.empty() &&
271          "expected valid partial operation definition");
272 
273   // If the parent operation of this region is a symbol table, we also push a
274   // new symbol scope.
275   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
276   if (partialOpDef.isSymbolTable())
277     impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
278 }
279 
280 void AsmParserState::finalizeRegionDefinition() {
281   assert(!impl->partialOperations.empty() &&
282          "expected valid partial operation definition");
283 
284   // If the parent operation of this region is a symbol table, pop the symbol
285   // scope for this region.
286   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
287   if (partialOpDef.isSymbolTable())
288     impl->symbolUseScopes.pop_back();
289 }
290 
291 void AsmParserState::addDefinition(Block *block, SMLoc location) {
292   auto [it, inserted] =
293       impl->blocksToIdx.try_emplace(block, impl->blocks.size());
294   if (inserted) {
295     impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
296         block, convertIdLocToRange(location)));
297     return;
298   }
299 
300   // If an entry already exists, this was a forward declaration that now has a
301   // proper definition.
302   impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
303 }
304 
305 void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) {
306   auto it = impl->blocksToIdx.find(blockArg.getOwner());
307   assert(it != impl->blocksToIdx.end() &&
308          "expected owner block to have an entry");
309   BlockDefinition &def = *impl->blocks[it->second];
310   unsigned argIdx = blockArg.getArgNumber();
311 
312   if (def.arguments.size() <= argIdx)
313     def.arguments.resize(argIdx + 1);
314   def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
315 }
316 
317 void AsmParserState::addAttrAliasDefinition(StringRef name, SMRange location,
318                                             Attribute value) {
319   auto [it, inserted] =
320       impl->attrAliasToIdx.try_emplace(name, impl->attrAliases.size());
321   // Location aliases may be referenced before they are defined.
322   if (inserted) {
323     impl->attrAliases.push_back(
324         std::make_unique<AttributeAliasDefinition>(name, location, value));
325   } else {
326     AttributeAliasDefinition &attr = *impl->attrAliases[it->second];
327     attr.definition.loc = location;
328     attr.value = value;
329   }
330 }
331 
332 void AsmParserState::addTypeAliasDefinition(StringRef name, SMRange location,
333                                             Type value) {
334   [[maybe_unused]] auto [it, inserted] =
335       impl->typeAliasToIdx.try_emplace(name, impl->typeAliases.size());
336   assert(inserted && "unexpected attribute alias redefinition");
337   impl->typeAliases.push_back(
338       std::make_unique<TypeAliasDefinition>(name, location, value));
339 }
340 
341 void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
342   // Handle the case where the value is an operation result.
343   if (OpResult result = dyn_cast<OpResult>(value)) {
344     // Check to see if a definition for the parent operation has been recorded.
345     // If one hasn't, we treat the provided value as a placeholder value that
346     // will be refined further later.
347     Operation *parentOp = result.getOwner();
348     auto existingIt = impl->operationToIdx.find(parentOp);
349     if (existingIt == impl->operationToIdx.end()) {
350       impl->placeholderValueUses[value].append(locations.begin(),
351                                                locations.end());
352       return;
353     }
354 
355     // If a definition does exist, locate the value's result group and add the
356     // use. The result groups are ordered by increasing start index, so we just
357     // need to find the last group that has a smaller/equal start index.
358     unsigned resultNo = result.getResultNumber();
359     OperationDefinition &def = *impl->operations[existingIt->second];
360     for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
361       if (resultNo >= resultGroup.startIndex) {
362         for (SMLoc loc : locations)
363           resultGroup.definition.uses.push_back(convertIdLocToRange(loc));
364         return;
365       }
366     }
367     llvm_unreachable("expected valid result group for value use");
368   }
369 
370   // Otherwise, this is a block argument.
371   BlockArgument arg = cast<BlockArgument>(value);
372   auto existingIt = impl->blocksToIdx.find(arg.getOwner());
373   assert(existingIt != impl->blocksToIdx.end() &&
374          "expected valid block definition for block argument");
375   BlockDefinition &blockDef = *impl->blocks[existingIt->second];
376   SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
377   for (SMLoc loc : locations)
378     argDef.uses.emplace_back(convertIdLocToRange(loc));
379 }
380 
381 void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) {
382   auto [it, inserted] =
383       impl->blocksToIdx.try_emplace(block, impl->blocks.size());
384   if (inserted)
385     impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
386 
387   BlockDefinition &def = *impl->blocks[it->second];
388   for (SMLoc loc : locations)
389     def.definition.uses.push_back(convertIdLocToRange(loc));
390 }
391 
392 void AsmParserState::addUses(SymbolRefAttr refAttr,
393                              ArrayRef<SMRange> locations) {
394   // Ignore this symbol if no scopes are active.
395   if (impl->symbolUseScopes.empty())
396     return;
397 
398   assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
399          "expected the same number of references as provided locations");
400   (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
401                                                         locations.end());
402 }
403 
404 void AsmParserState::addAttrAliasUses(StringRef name, SMRange location) {
405   auto it = impl->attrAliasToIdx.find(name);
406   // Location aliases may be referenced before they are defined.
407   if (it == impl->attrAliasToIdx.end()) {
408     it = impl->attrAliasToIdx.try_emplace(name, impl->attrAliases.size()).first;
409     impl->attrAliases.push_back(
410         std::make_unique<AttributeAliasDefinition>(name));
411   }
412   AttributeAliasDefinition &def = *impl->attrAliases[it->second];
413   def.definition.uses.push_back(location);
414 }
415 
416 void AsmParserState::addTypeAliasUses(StringRef name, SMRange location) {
417   auto it = impl->typeAliasToIdx.find(name);
418   // Location aliases may be referenced before they are defined.
419   assert(it != impl->typeAliasToIdx.end() &&
420          "expected valid type alias definition");
421   TypeAliasDefinition &def = *impl->typeAliases[it->second];
422   def.definition.uses.push_back(location);
423 }
424 
425 void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
426   auto it = impl->placeholderValueUses.find(oldValue);
427   assert(it != impl->placeholderValueUses.end() &&
428          "expected `oldValue` to be a placeholder");
429   addUses(newValue, it->second);
430   impl->placeholderValueUses.erase(oldValue);
431 }
432