xref: /llvm-project/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Verifiers for objects declared by IRDL.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/IR/ExtensibleDialect.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/Region.h"
21 #include "mlir/IR/Value.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 using namespace mlir;
25 using namespace mlir::irdl;
26 
ConstraintVerifier(ArrayRef<std::unique_ptr<Constraint>> constraints)27 ConstraintVerifier::ConstraintVerifier(
28     ArrayRef<std::unique_ptr<Constraint>> constraints)
29     : constraints(constraints), assigned() {
30   assigned.resize(this->constraints.size());
31 }
32 
33 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,unsigned variable)34 ConstraintVerifier::verify(function_ref<InFlightDiagnostic()> emitError,
35                            Attribute attr, unsigned variable) {
36 
37   assert(variable < constraints.size() && "invalid constraint variable");
38 
39   // If the variable is already assigned, check that the attribute is the same.
40   if (assigned[variable].has_value()) {
41     if (attr == assigned[variable].value()) {
42       return success();
43     }
44     if (emitError)
45       return emitError() << "expected '" << assigned[variable].value()
46                          << "' but got '" << attr << "'";
47     return failure();
48   }
49 
50   // Otherwise, check the constraint and assign the attribute to the variable.
51   LogicalResult result = constraints[variable]->verify(emitError, attr, *this);
52   if (succeeded(result))
53     assigned[variable] = attr;
54 
55   return result;
56 }
57 
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const58 LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
59                                    Attribute attr,
60                                    ConstraintVerifier &context) const {
61   if (attr == expectedAttribute)
62     return success();
63 
64   if (emitError)
65     return emitError() << "expected '" << expectedAttribute << "' but got '"
66                        << attr << "'";
67   return failure();
68 }
69 
70 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const71 BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
72                            Attribute attr, ConstraintVerifier &context) const {
73   if (attr.getTypeID() == baseTypeID)
74     return success();
75 
76   if (emitError)
77     return emitError() << "expected base attribute '" << baseName
78                        << "' but got '" << attr.getAbstractAttribute().getName()
79                        << "'";
80   return failure();
81 }
82 
83 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const84 BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
85                            Attribute attr, ConstraintVerifier &context) const {
86   auto typeAttr = dyn_cast<TypeAttr>(attr);
87   if (!typeAttr) {
88     if (emitError)
89       return emitError() << "expected type, got attribute '" << attr;
90     return failure();
91   }
92 
93   Type type = typeAttr.getValue();
94   if (type.getTypeID() == baseTypeID)
95     return success();
96 
97   if (emitError)
98     return emitError() << "expected base type '" << baseName << "' but got '"
99                        << type.getAbstractType().getName() << "'";
100   return failure();
101 }
102 
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const103 LogicalResult DynParametricAttrConstraint::verify(
104     function_ref<InFlightDiagnostic()> emitError, Attribute attr,
105     ConstraintVerifier &context) const {
106 
107   // Check that the base is the expected one.
108   auto dynAttr = dyn_cast<DynamicAttr>(attr);
109   if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
110     if (emitError) {
111       StringRef dialectName = attrDef->getDialect()->getNamespace();
112       StringRef attrName = attrDef->getName();
113       return emitError() << "expected base attribute '" << attrName << '.'
114                          << dialectName << "' but got '" << attr << "'";
115     }
116     return failure();
117   }
118 
119   // Check that the parameters satisfy the constraints.
120   ArrayRef<Attribute> params = dynAttr.getParams();
121   if (params.size() != constraints.size()) {
122     if (emitError) {
123       StringRef dialectName = attrDef->getDialect()->getNamespace();
124       StringRef attrName = attrDef->getName();
125       emitError() << "attribute '" << dialectName << "." << attrName
126                   << "' expects " << params.size() << " parameters but got "
127                   << constraints.size();
128     }
129     return failure();
130   }
131 
132   for (size_t i = 0, s = params.size(); i < s; i++)
133     if (failed(context.verify(emitError, params[i], constraints[i])))
134       return failure();
135 
136   return success();
137 }
138 
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const139 LogicalResult DynParametricTypeConstraint::verify(
140     function_ref<InFlightDiagnostic()> emitError, Attribute attr,
141     ConstraintVerifier &context) const {
142   // Check that the base is a TypeAttr.
143   auto typeAttr = dyn_cast<TypeAttr>(attr);
144   if (!typeAttr) {
145     if (emitError)
146       return emitError() << "expected type, got attribute '" << attr;
147     return failure();
148   }
149 
150   // Check that the type base is the expected one.
151   auto dynType = dyn_cast<DynamicType>(typeAttr.getValue());
152   if (!dynType || dynType.getTypeDef() != typeDef) {
153     if (emitError) {
154       StringRef dialectName = typeDef->getDialect()->getNamespace();
155       StringRef attrName = typeDef->getName();
156       return emitError() << "expected base type '" << dialectName << '.'
157                          << attrName << "' but got '" << attr << "'";
158     }
159     return failure();
160   }
161 
162   // Check that the parameters satisfy the constraints.
163   ArrayRef<Attribute> params = dynType.getParams();
164   if (params.size() != constraints.size()) {
165     if (emitError) {
166       StringRef dialectName = typeDef->getDialect()->getNamespace();
167       StringRef attrName = typeDef->getName();
168       emitError() << "attribute '" << dialectName << "." << attrName
169                   << "' expects " << params.size() << " parameters but got "
170                   << constraints.size();
171     }
172     return failure();
173   }
174 
175   for (size_t i = 0, s = params.size(); i < s; i++)
176     if (failed(context.verify(emitError, params[i], constraints[i])))
177       return failure();
178 
179   return success();
180 }
181 
182 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const183 AnyOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
184                         Attribute attr, ConstraintVerifier &context) const {
185   for (unsigned constr : constraints) {
186     // We do not pass the `emitError` here, since we want to emit an error
187     // only if none of the constraints are satisfied.
188     if (succeeded(context.verify({}, attr, constr))) {
189       return success();
190     }
191   }
192 
193   if (emitError)
194     return emitError() << "'" << attr << "' does not satisfy the constraint";
195   return failure();
196 }
197 
198 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const199 AllOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
200                         Attribute attr, ConstraintVerifier &context) const {
201   for (unsigned constr : constraints) {
202     if (failed(context.verify(emitError, attr, constr))) {
203       return failure();
204     }
205   }
206 
207   return success();
208 }
209 
210 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const211 AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
212                                Attribute attr,
213                                ConstraintVerifier &context) const {
214   return success();
215 }
216 
verify(mlir::Region & region,ConstraintVerifier & constraintContext)217 LogicalResult RegionConstraint::verify(mlir::Region &region,
218                                        ConstraintVerifier &constraintContext) {
219   const auto emitError = [parentOp = region.getParentOp()](mlir::Location loc) {
220     return [loc, parentOp] {
221       InFlightDiagnostic diag = mlir::emitError(loc);
222       // If we already have been given location of the parent operation, which
223       // might happen when the region location is passed, we do not want to
224       // produce the note on the same location
225       if (loc != parentOp->getLoc())
226         diag.attachNote(parentOp->getLoc()).append("see the operation");
227       return diag;
228     };
229   };
230 
231   if (blockCount.has_value() && *blockCount != region.getBlocks().size()) {
232     return emitError(region.getLoc())()
233            << "expected region " << region.getRegionNumber() << " to have "
234            << *blockCount << " block(s) but got " << region.getBlocks().size();
235   }
236 
237   if (argumentConstraints.has_value()) {
238     auto actualArgs = region.getArguments();
239     if (actualArgs.size() != argumentConstraints->size()) {
240       const mlir::Location firstArgLoc =
241           actualArgs.empty() ? region.getLoc() : actualArgs.front().getLoc();
242       return emitError(firstArgLoc)()
243              << "expected region " << region.getRegionNumber() << " to have "
244              << argumentConstraints->size() << " arguments but got "
245              << actualArgs.size();
246     }
247 
248     for (auto [arg, constraint] : llvm::zip(actualArgs, *argumentConstraints)) {
249       mlir::Attribute type = TypeAttr::get(arg.getType());
250       if (failed(constraintContext.verify(emitError(arg.getLoc()), type,
251                                           constraint))) {
252         return failure();
253       }
254     }
255   }
256   return success();
257 }
258