xref: /llvm-project/mlir/lib/Dialect/IRDL/IRDLLoading.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Manages the loading of MLIR objects from IRDL operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/IRDL/IRDLLoading.h"
14 #include "mlir/Dialect/IRDL/IR/IRDL.h"
15 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
16 #include "mlir/Dialect/IRDL/IRDLSymbols.h"
17 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/ExtensibleDialect.h"
21 #include "mlir/IR/OperationSupport.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/Support/SMLoc.h"
25 #include <numeric>
26 
27 using namespace mlir;
28 using namespace mlir::irdl;
29 
30 /// Verify that the given list of parameters satisfy the given constraints.
31 /// This encodes the logic of the verification method for attributes and types
32 /// defined with IRDL.
33 static LogicalResult
irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<Attribute> params,ArrayRef<std::unique_ptr<Constraint>> constraints,ArrayRef<size_t> paramConstraints)34 irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic()> emitError,
35                        ArrayRef<Attribute> params,
36                        ArrayRef<std::unique_ptr<Constraint>> constraints,
37                        ArrayRef<size_t> paramConstraints) {
38   if (params.size() != paramConstraints.size()) {
39     emitError() << "expected " << paramConstraints.size()
40                 << " type arguments, but had " << params.size();
41     return failure();
42   }
43 
44   ConstraintVerifier verifier(constraints);
45 
46   // Check that each parameter satisfies its constraint.
47   for (auto [i, param] : enumerate(params))
48     if (failed(verifier.verify(emitError, param, paramConstraints[i])))
49       return failure();
50 
51   return success();
52 }
53 
54 /// Get the operand segment sizes from the attribute dictionary.
getSegmentSizesFromAttr(Operation * op,StringRef elemName,StringRef attrName,unsigned numElements,ArrayRef<Variadicity> variadicities,SmallVectorImpl<int> & segmentSizes)55 LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName,
56                                       StringRef attrName, unsigned numElements,
57                                       ArrayRef<Variadicity> variadicities,
58                                       SmallVectorImpl<int> &segmentSizes) {
59   // Get the segment sizes attribute, and check that it is of the right type.
60   Attribute segmentSizesAttr = op->getAttr(attrName);
61   if (!segmentSizesAttr) {
62     return op->emitError() << "'" << attrName
63                            << "' attribute is expected but not provided";
64   }
65 
66   auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr);
67   if (!denseSegmentSizes) {
68     return op->emitError() << "'" << attrName
69                            << "' attribute is expected to be a dense i32 array";
70   }
71 
72   if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {
73     return op->emitError() << "'" << attrName << "' attribute for specifying "
74                            << elemName << " segments must have "
75                            << variadicities.size() << " elements, but got "
76                            << denseSegmentSizes.size();
77   }
78 
79   // Check that the segment sizes are corresponding to the given variadicities,
80   for (auto [i, segmentSize, variadicity] :
81        enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {
82     if (segmentSize < 0)
83       return op->emitError()
84              << "'" << attrName << "' attribute for specifying " << elemName
85              << " segments must have non-negative values";
86     if (variadicity == Variadicity::single && segmentSize != 1)
87       return op->emitError() << "element " << i << " in '" << attrName
88                              << "' attribute must be equal to 1";
89 
90     if (variadicity == Variadicity::optional && segmentSize > 1)
91       return op->emitError() << "element " << i << " in '" << attrName
92                              << "' attribute must be equal to 0 or 1";
93 
94     segmentSizes.push_back(segmentSize);
95   }
96 
97   // Check that the sum of the segment sizes is equal to the number of elements.
98   int32_t sum = 0;
99   for (int32_t segmentSize : denseSegmentSizes.asArrayRef())
100     sum += segmentSize;
101   if (sum != static_cast<int32_t>(numElements))
102     return op->emitError() << "sum of elements in '" << attrName
103                            << "' attribute must be equal to the number of "
104                            << elemName << "s";
105 
106   return success();
107 }
108 
109 /// Compute the segment sizes of the given element (operands, results).
110 /// If the operation has more than two non-single elements (optional or
111 /// variadic), then get the segment sizes from the attribute dictionary.
112 /// Otherwise, compute the segment sizes from the number of elements.
113 /// `elemName` should be either `"operand"` or `"result"`.
getSegmentSizes(Operation * op,StringRef elemName,StringRef attrName,unsigned numElements,ArrayRef<Variadicity> variadicities,SmallVectorImpl<int> & segmentSizes)114 LogicalResult getSegmentSizes(Operation *op, StringRef elemName,
115                               StringRef attrName, unsigned numElements,
116                               ArrayRef<Variadicity> variadicities,
117                               SmallVectorImpl<int> &segmentSizes) {
118   // If we have more than one non-single variadicity, we need to get the
119   // segment sizes from the attribute dictionary.
120   int numberNonSingle = count_if(
121       variadicities, [](Variadicity v) { return v != Variadicity::single; });
122   if (numberNonSingle > 1)
123     return getSegmentSizesFromAttr(op, elemName, attrName, numElements,
124                                    variadicities, segmentSizes);
125 
126   // If we only have single variadicities, the segments sizes are all 1.
127   if (numberNonSingle == 0) {
128     if (numElements != variadicities.size()) {
129       return op->emitError() << "op expects exactly " << variadicities.size()
130                              << " " << elemName << "s, but got " << numElements;
131     }
132     for (size_t i = 0, e = variadicities.size(); i < e; ++i)
133       segmentSizes.push_back(1);
134     return success();
135   }
136 
137   assert(numberNonSingle == 1);
138 
139   // There is exactly one non-single element, so we can
140   // compute its size and check that it is valid.
141   int nonSingleSegmentSize = static_cast<int>(numElements) -
142                              static_cast<int>(variadicities.size()) + 1;
143 
144   if (nonSingleSegmentSize < 0) {
145     return op->emitError() << "op expects at least " << variadicities.size() - 1
146                            << " " << elemName << "s, but got " << numElements;
147   }
148 
149   // Add the segment sizes.
150   for (Variadicity variadicity : variadicities) {
151     if (variadicity == Variadicity::single) {
152       segmentSizes.push_back(1);
153       continue;
154     }
155 
156     // If we have an optional element, we should check that it represents
157     // zero or one elements.
158     if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)
159       return op->emitError() << "op expects at most " << variadicities.size()
160                              << " " << elemName << "s, but got " << numElements;
161 
162     segmentSizes.push_back(nonSingleSegmentSize);
163   }
164 
165   return success();
166 }
167 
168 /// Compute the segment sizes of the given operands.
169 /// If the operation has more than two non-single operands (optional or
170 /// variadic), then get the segment sizes from the attribute dictionary.
171 /// Otherwise, compute the segment sizes from the number of operands.
getOperandSegmentSizes(Operation * op,ArrayRef<Variadicity> variadicities,SmallVectorImpl<int> & segmentSizes)172 LogicalResult getOperandSegmentSizes(Operation *op,
173                                      ArrayRef<Variadicity> variadicities,
174                                      SmallVectorImpl<int> &segmentSizes) {
175   return getSegmentSizes(op, "operand", "operand_segment_sizes",
176                          op->getNumOperands(), variadicities, segmentSizes);
177 }
178 
179 /// Compute the segment sizes of the given results.
180 /// If the operation has more than two non-single results (optional or
181 /// variadic), then get the segment sizes from the attribute dictionary.
182 /// Otherwise, compute the segment sizes from the number of results.
getResultSegmentSizes(Operation * op,ArrayRef<Variadicity> variadicities,SmallVectorImpl<int> & segmentSizes)183 LogicalResult getResultSegmentSizes(Operation *op,
184                                     ArrayRef<Variadicity> variadicities,
185                                     SmallVectorImpl<int> &segmentSizes) {
186   return getSegmentSizes(op, "result", "result_segment_sizes",
187                          op->getNumResults(), variadicities, segmentSizes);
188 }
189 
190 /// Verify that the given operation satisfies the given constraints.
191 /// This encodes the logic of the verification method for operations defined
192 /// with IRDL.
irdlOpVerifier(Operation * op,ConstraintVerifier & verifier,ArrayRef<size_t> operandConstrs,ArrayRef<Variadicity> operandVariadicity,ArrayRef<size_t> resultConstrs,ArrayRef<Variadicity> resultVariadicity,const DenseMap<StringAttr,size_t> & attributeConstrs)193 static LogicalResult irdlOpVerifier(
194     Operation *op, ConstraintVerifier &verifier,
195     ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity,
196     ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity,
197     const DenseMap<StringAttr, size_t> &attributeConstrs) {
198   // Get the segment sizes for the operands.
199   // This will check that the number of operands is correct.
200   SmallVector<int> operandSegmentSizes;
201   if (failed(
202           getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes)))
203     return failure();
204 
205   // Get the segment sizes for the results.
206   // This will check that the number of results is correct.
207   SmallVector<int> resultSegmentSizes;
208   if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes)))
209     return failure();
210 
211   auto emitError = [op] { return op->emitError(); };
212 
213   /// Сheck that we have all needed attributes passed
214   /// and they satisfy the constraints.
215   DictionaryAttr actualAttrs = op->getAttrDictionary();
216 
217   for (auto [name, constraint] : attributeConstrs) {
218     /// First, check if the attribute actually passed.
219     std::optional<NamedAttribute> actual = actualAttrs.getNamed(name);
220     if (!actual.has_value())
221       return op->emitOpError()
222              << "attribute " << name << " is expected but not provided";
223 
224     /// Then, check if the attribute value satisfies the constraint.
225     if (failed(verifier.verify({emitError}, actual->getValue(), constraint)))
226       return failure();
227   }
228 
229   // Check that all operands satisfy the constraints
230   int operandIdx = 0;
231   for (auto [defIndex, segmentSize] : enumerate(operandSegmentSizes)) {
232     for (int i = 0; i < segmentSize; i++) {
233       if (failed(verifier.verify(
234               {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]),
235               operandConstrs[defIndex])))
236         return failure();
237       ++operandIdx;
238     }
239   }
240 
241   // Check that all results satisfy the constraints
242   int resultIdx = 0;
243   for (auto [defIndex, segmentSize] : enumerate(resultSegmentSizes)) {
244     for (int i = 0; i < segmentSize; i++) {
245       if (failed(verifier.verify({emitError},
246                                  TypeAttr::get(op->getResultTypes()[resultIdx]),
247                                  resultConstrs[defIndex])))
248         return failure();
249       ++resultIdx;
250     }
251   }
252 
253   return success();
254 }
255 
irdlRegionVerifier(Operation * op,ConstraintVerifier & verifier,ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints)256 static LogicalResult irdlRegionVerifier(
257     Operation *op, ConstraintVerifier &verifier,
258     ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) {
259   if (op->getNumRegions() != regionsConstraints.size()) {
260     return op->emitOpError()
261            << "unexpected number of regions: expected "
262            << regionsConstraints.size() << " but got " << op->getNumRegions();
263   }
264 
265   for (auto [constraint, region] :
266        llvm::zip(regionsConstraints, op->getRegions()))
267     if (failed(constraint->verify(region, verifier)))
268       return failure();
269 
270   return success();
271 }
272 
273 llvm::unique_function<LogicalResult(Operation *) const>
createVerifier(OperationOp op,const DenseMap<irdl::TypeOp,std::unique_ptr<DynamicTypeDefinition>> & types,const DenseMap<irdl::AttributeOp,std::unique_ptr<DynamicAttrDefinition>> & attrs)274 mlir::irdl::createVerifier(
275     OperationOp op,
276     const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
277     const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
278         &attrs) {
279   // Resolve SSA values to verifier constraint slots
280   SmallVector<Value> constrToValue;
281   SmallVector<Value> regionToValue;
282   for (Operation &op : op->getRegion(0).getOps()) {
283     if (isa<VerifyConstraintInterface>(op)) {
284       if (op.getNumResults() != 1) {
285         op.emitError()
286             << "IRDL constraint operations must have exactly one result";
287         return nullptr;
288       }
289       constrToValue.push_back(op.getResult(0));
290     }
291     if (isa<VerifyRegionInterface>(op)) {
292       if (op.getNumResults() != 1) {
293         op.emitError()
294             << "IRDL constraint operations must have exactly one result";
295         return nullptr;
296       }
297       regionToValue.push_back(op.getResult(0));
298     }
299   }
300 
301   // Build the verifiers for each constraint slot
302   SmallVector<std::unique_ptr<Constraint>> constraints;
303   for (Value v : constrToValue) {
304     VerifyConstraintInterface op =
305         cast<VerifyConstraintInterface>(v.getDefiningOp());
306     std::unique_ptr<Constraint> verifier =
307         op.getVerifier(constrToValue, types, attrs);
308     if (!verifier)
309       return nullptr;
310     constraints.push_back(std::move(verifier));
311   }
312 
313   // Build region constraints
314   SmallVector<std::unique_ptr<RegionConstraint>> regionConstraints;
315   for (Value v : regionToValue) {
316     VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp());
317     std::unique_ptr<RegionConstraint> verifier =
318         op.getVerifier(constrToValue, types, attrs);
319     regionConstraints.push_back(std::move(verifier));
320   }
321 
322   SmallVector<size_t> operandConstraints;
323   SmallVector<Variadicity> operandVariadicity;
324 
325   // Gather which constraint slots correspond to operand constraints
326   auto operandsOp = op.getOp<OperandsOp>();
327   if (operandsOp.has_value()) {
328     operandConstraints.reserve(operandsOp->getArgs().size());
329     for (Value operand : operandsOp->getArgs()) {
330       for (auto [i, constr] : enumerate(constrToValue)) {
331         if (constr == operand) {
332           operandConstraints.push_back(i);
333           break;
334         }
335       }
336     }
337 
338     // Gather the variadicities of each operand
339     for (VariadicityAttr attr : operandsOp->getVariadicity())
340       operandVariadicity.push_back(attr.getValue());
341   }
342 
343   SmallVector<size_t> resultConstraints;
344   SmallVector<Variadicity> resultVariadicity;
345 
346   // Gather which constraint slots correspond to result constraints
347   auto resultsOp = op.getOp<ResultsOp>();
348   if (resultsOp.has_value()) {
349     resultConstraints.reserve(resultsOp->getArgs().size());
350     for (Value result : resultsOp->getArgs()) {
351       for (auto [i, constr] : enumerate(constrToValue)) {
352         if (constr == result) {
353           resultConstraints.push_back(i);
354           break;
355         }
356       }
357     }
358 
359     // Gather the variadicities of each result
360     for (Attribute attr : resultsOp->getVariadicity())
361       resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
362   }
363 
364   // Gather which constraint slots correspond to attributes constraints
365   DenseMap<StringAttr, size_t> attributeConstraints;
366   auto attributesOp = op.getOp<AttributesOp>();
367   if (attributesOp.has_value()) {
368     const Operation::operand_range values = attributesOp->getAttributeValues();
369     const ArrayAttr names = attributesOp->getAttributeValueNames();
370 
371     for (const auto &[name, value] : llvm::zip(names, values)) {
372       for (auto [i, constr] : enumerate(constrToValue)) {
373         if (constr == value) {
374           attributeConstraints[cast<StringAttr>(name)] = i;
375           break;
376         }
377       }
378     }
379   }
380 
381   return
382       [constraints{std::move(constraints)},
383        regionConstraints{std::move(regionConstraints)},
384        operandConstraints{std::move(operandConstraints)},
385        operandVariadicity{std::move(operandVariadicity)},
386        resultConstraints{std::move(resultConstraints)},
387        resultVariadicity{std::move(resultVariadicity)},
388        attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
389         ConstraintVerifier verifier(constraints);
390         const LogicalResult opVerifierResult = irdlOpVerifier(
391             op, verifier, operandConstraints, operandVariadicity,
392             resultConstraints, resultVariadicity, attributeConstraints);
393         const LogicalResult opRegionVerifierResult =
394             irdlRegionVerifier(op, verifier, regionConstraints);
395         return LogicalResult::success(opVerifierResult.succeeded() &&
396                                       opRegionVerifierResult.succeeded());
397       };
398 }
399 
400 /// Define and load an operation represented by a `irdl.operation`
401 /// operation.
loadOperation(OperationOp op,ExtensibleDialect * dialect,const DenseMap<TypeOp,std::unique_ptr<DynamicTypeDefinition>> & types,const DenseMap<AttributeOp,std::unique_ptr<DynamicAttrDefinition>> & attrs)402 static WalkResult loadOperation(
403     OperationOp op, ExtensibleDialect *dialect,
404     const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
405     const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
406         &attrs) {
407 
408   // IRDL does not support defining custom parsers or printers.
409   auto parser = [](OpAsmParser &parser, OperationState &result) {
410     return failure();
411   };
412   auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
413     printer.printGenericOp(op);
414   };
415 
416   auto verifier = createVerifier(op, types, attrs);
417   if (!verifier)
418     return WalkResult::interrupt();
419 
420   // IRDL supports only checking number of blocks and argument constraints
421   // It is done in the main verifier to reuse `ConstraintVerifier` context
422   auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
423 
424   auto opDef = DynamicOpDefinition::get(
425       op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
426       std::move(parser), std::move(printer));
427   dialect->registerDynamicOp(std::move(opDef));
428 
429   return WalkResult::advance();
430 }
431 
432 /// Get the verifier of a type or attribute definition.
433 /// Return nullptr if the definition is invalid.
getAttrOrTypeVerifier(Operation * attrOrTypeDef,ExtensibleDialect * dialect,DenseMap<TypeOp,std::unique_ptr<DynamicTypeDefinition>> & types,DenseMap<AttributeOp,std::unique_ptr<DynamicAttrDefinition>> & attrs)434 static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(
435     Operation *attrOrTypeDef, ExtensibleDialect *dialect,
436     DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
437     DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
438   assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
439          "Expected an attribute or type definition");
440 
441   // Resolve SSA values to verifier constraint slots
442   SmallVector<Value> constrToValue;
443   for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) {
444     if (isa<VerifyConstraintInterface>(op)) {
445       assert(op.getNumResults() == 1 &&
446              "IRDL constraint operations must have exactly one result");
447       constrToValue.push_back(op.getResult(0));
448     }
449   }
450 
451   // Build the verifiers for each constraint slot
452   SmallVector<std::unique_ptr<Constraint>> constraints;
453   for (Value v : constrToValue) {
454     VerifyConstraintInterface op =
455         cast<VerifyConstraintInterface>(v.getDefiningOp());
456     std::unique_ptr<Constraint> verifier =
457         op.getVerifier(constrToValue, types, attrs);
458     if (!verifier)
459       return {};
460     constraints.push_back(std::move(verifier));
461   }
462 
463   // Get the parameter definitions.
464   std::optional<ParametersOp> params;
465   if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
466     params = attr.getOp<ParametersOp>();
467   else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef))
468     params = type.getOp<ParametersOp>();
469 
470   // Gather which constraint slots correspond to parameter constraints
471   SmallVector<size_t> paramConstraints;
472   if (params.has_value()) {
473     paramConstraints.reserve(params->getArgs().size());
474     for (Value param : params->getArgs()) {
475       for (auto [i, constr] : enumerate(constrToValue)) {
476         if (constr == param) {
477           paramConstraints.push_back(i);
478           break;
479         }
480       }
481     }
482   }
483 
484   auto verifier = [paramConstraints{std::move(paramConstraints)},
485                    constraints{std::move(constraints)}](
486                       function_ref<InFlightDiagnostic()> emitError,
487                       ArrayRef<Attribute> params) {
488     return irdlAttrOrTypeVerifier(emitError, params, constraints,
489                                   paramConstraints);
490   };
491 
492   // While the `std::move` is not required, not adding it triggers a bug in
493   // clang-10.
494   return std::move(verifier);
495 }
496 
497 /// Get the possible bases of a constraint. Return `true` if all bases can
498 /// potentially be matched.
499 /// A base is a type or an attribute definition. For instance, the base of
500 /// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`.
501 /// This function returns the following information through arguments:
502 /// - `paramIds`: the set of type or attribute IDs that are used as bases.
503 /// - `paramIrdlOps`: the set of IRDL operations that are used as bases.
504 /// - `isIds`: the set of type or attribute IDs that are used in `irdl.is`
505 ///   constraints.
getBases(Operation * op,SmallPtrSet<TypeID,4> & paramIds,SmallPtrSet<Operation *,4> & paramIrdlOps,SmallPtrSet<TypeID,4> & isIds)506 static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
507                      SmallPtrSet<Operation *, 4> &paramIrdlOps,
508                      SmallPtrSet<TypeID, 4> &isIds) {
509   // For `irdl.any_of`, we get the bases from all its arguments.
510   if (auto anyOf = dyn_cast<AnyOfOp>(op)) {
511     bool hasAny = false;
512     for (Value arg : anyOf.getArgs())
513       hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
514     return hasAny;
515   }
516 
517   // For `irdl.all_of`, we get the bases from the first argument.
518   // This is restrictive, but we can relax it later if needed.
519   if (auto allOf = dyn_cast<AllOfOp>(op))
520     return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
521                     isIds);
522 
523   // For `irdl.parametric`, we get directly the base from the operation.
524   if (auto params = dyn_cast<ParametricOp>(op)) {
525     SymbolRefAttr symRef = params.getBaseType();
526     Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
527     assert(defOp && "symbol reference should refer to an existing operation");
528     paramIrdlOps.insert(defOp);
529     return false;
530   }
531 
532   // For `irdl.is`, we get the base TypeID directly.
533   if (auto is = dyn_cast<IsOp>(op)) {
534     Attribute expected = is.getExpected();
535     isIds.insert(expected.getTypeID());
536     return false;
537   }
538 
539   // For `irdl.any`, we return `false` since we can match any type or attribute
540   // base.
541   if (auto isA = dyn_cast<AnyOp>(op))
542     return true;
543 
544   llvm_unreachable("unknown IRDL constraint");
545 }
546 
547 /// Check that an any_of is in the subset IRDL can handle.
548 /// IRDL uses a greedy algorithm to match constraints. This means that if we
549 /// encounter an `any_of` with multiple constraints, we will match the first
550 /// constraint that is satisfied. Thus, the order of constraints matter in
551 /// `any_of` with our current algorithm.
552 /// In order to make the order of constraints irrelevant, we require that
553 /// all `any_of` constraint parameters are disjoint. For this, we check that
554 /// the base parameters are all disjoints between `parametric` operations, and
555 /// that they are disjoint between `parametric` and `is` operations.
556 /// This restriction will be relaxed in the future, when we will change our
557 /// algorithm to be non-greedy.
checkCorrectAnyOf(AnyOfOp anyOf)558 static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) {
559   SmallPtrSet<TypeID, 4> paramIds;
560   SmallPtrSet<Operation *, 4> paramIrdlOps;
561   SmallPtrSet<TypeID, 4> isIds;
562 
563   for (Value arg : anyOf.getArgs()) {
564     Operation *argOp = arg.getDefiningOp();
565     SmallPtrSet<TypeID, 4> argParamIds;
566     SmallPtrSet<Operation *, 4> argParamIrdlOps;
567     SmallPtrSet<TypeID, 4> argIsIds;
568 
569     // Get the bases of this argument. If it can match any type or attribute,
570     // then our `any_of` should not be allowed.
571     if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
572       return failure();
573 
574     // We check that the base parameters are all disjoints between `parametric`
575     // operations, and that they are disjoint between `parametric` and `is`
576     // operations.
577     for (TypeID id : argParamIds) {
578       if (isIds.count(id))
579         return failure();
580       bool inserted = paramIds.insert(id).second;
581       if (!inserted)
582         return failure();
583     }
584 
585     // We check that the base parameters are all disjoints with `irdl.is`
586     // operations.
587     for (TypeID id : isIds) {
588       if (paramIds.count(id))
589         return failure();
590       isIds.insert(id);
591     }
592 
593     // We check that all `parametric` operations are disjoint. We do not
594     // need to check that they are disjoint with `is` operations, since
595     // `is` operations cannot refer to attributes defined with `irdl.parametric`
596     // operations.
597     for (Operation *op : argParamIrdlOps) {
598       bool inserted = paramIrdlOps.insert(op).second;
599       if (!inserted)
600         return failure();
601     }
602   }
603 
604   return success();
605 }
606 
607 /// Load all dialects in the given module, without loading any operation, type
608 /// or attribute definitions.
loadEmptyDialects(ModuleOp op)609 static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) {
610   DenseMap<DialectOp, ExtensibleDialect *> dialects;
611   op.walk([&](DialectOp dialectOp) {
612     MLIRContext *ctx = dialectOp.getContext();
613     StringRef dialectName = dialectOp.getName();
614 
615     DynamicDialect *dialect = ctx->getOrLoadDynamicDialect(
616         dialectName, [](DynamicDialect *dialect) {});
617 
618     dialects.insert({dialectOp, dialect});
619   });
620   return dialects;
621 }
622 
623 /// Preallocate type definitions objects with empty verifiers.
624 /// This in particular allocates a TypeID for each type definition.
625 static DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>>
preallocateTypeDefs(ModuleOp op,DenseMap<DialectOp,ExtensibleDialect * > dialects)626 preallocateTypeDefs(ModuleOp op,
627                     DenseMap<DialectOp, ExtensibleDialect *> dialects) {
628   DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> typeDefs;
629   op.walk([&](TypeOp typeOp) {
630     ExtensibleDialect *dialect = dialects[typeOp.getParentOp()];
631     auto typeDef = DynamicTypeDefinition::get(
632         typeOp.getName(), dialect,
633         [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) {
634           return success();
635         });
636     typeDefs.try_emplace(typeOp, std::move(typeDef));
637   });
638   return typeDefs;
639 }
640 
641 /// Preallocate attribute definitions objects with empty verifiers.
642 /// This in particular allocates a TypeID for each attribute definition.
643 static DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
preallocateAttrDefs(ModuleOp op,DenseMap<DialectOp,ExtensibleDialect * > dialects)644 preallocateAttrDefs(ModuleOp op,
645                     DenseMap<DialectOp, ExtensibleDialect *> dialects) {
646   DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrDefs;
647   op.walk([&](AttributeOp attrOp) {
648     ExtensibleDialect *dialect = dialects[attrOp.getParentOp()];
649     auto attrDef = DynamicAttrDefinition::get(
650         attrOp.getName(), dialect,
651         [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) {
652           return success();
653         });
654     attrDefs.try_emplace(attrOp, std::move(attrDef));
655   });
656   return attrDefs;
657 }
658 
loadDialects(ModuleOp op)659 LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
660   // First, check that all any_of constraints are in a correct form.
661   // This is to ensure we can do the verification correctly.
662   WalkResult anyOfCorrects = op.walk(
663       [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); });
664   if (anyOfCorrects.wasInterrupted())
665     return op.emitError("any_of constraints are not in the correct form");
666 
667   // Preallocate all dialects, and type and attribute definitions.
668   // In particular, this allocates TypeIDs so type and attributes can have
669   // verifiers that refer to each other.
670   DenseMap<DialectOp, ExtensibleDialect *> dialects = loadEmptyDialects(op);
671   DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> types =
672       preallocateTypeDefs(op, dialects);
673   DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs =
674       preallocateAttrDefs(op, dialects);
675 
676   // Set the verifier for types.
677   WalkResult res = op.walk([&](TypeOp typeOp) {
678     DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
679         typeOp, dialects[typeOp.getParentOp()], types, attrs);
680     if (!verifier)
681       return WalkResult::interrupt();
682     types[typeOp]->setVerifyFn(std::move(verifier));
683     return WalkResult::advance();
684   });
685   if (res.wasInterrupted())
686     return failure();
687 
688   // Set the verifier for attributes.
689   res = op.walk([&](AttributeOp attrOp) {
690     DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
691         attrOp, dialects[attrOp.getParentOp()], types, attrs);
692     if (!verifier)
693       return WalkResult::interrupt();
694     attrs[attrOp]->setVerifyFn(std::move(verifier));
695     return WalkResult::advance();
696   });
697   if (res.wasInterrupted())
698     return failure();
699 
700   // Define and load all operations.
701   res = op.walk([&](OperationOp opOp) {
702     return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
703   });
704   if (res.wasInterrupted())
705     return failure();
706 
707   // Load all types in their dialects.
708   for (auto &pair : types) {
709     ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
710     dialect->registerDynamicType(std::move(pair.second));
711   }
712 
713   // Load all attributes in their dialects.
714   for (auto &pair : attrs) {
715     ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
716     dialect->registerDynamicAttr(std::move(pair.second));
717   }
718 
719   return success();
720 }
721