1 //===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Verifiers for objects declared by IRDL.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/Block.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/IR/ExtensibleDialect.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/Region.h"
21 #include "mlir/IR/Value.h"
22 #include "llvm/Support/FormatVariadic.h"
23
24 using namespace mlir;
25 using namespace mlir::irdl;
26
ConstraintVerifier(ArrayRef<std::unique_ptr<Constraint>> constraints)27 ConstraintVerifier::ConstraintVerifier(
28 ArrayRef<std::unique_ptr<Constraint>> constraints)
29 : constraints(constraints), assigned() {
30 assigned.resize(this->constraints.size());
31 }
32
33 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,unsigned variable)34 ConstraintVerifier::verify(function_ref<InFlightDiagnostic()> emitError,
35 Attribute attr, unsigned variable) {
36
37 assert(variable < constraints.size() && "invalid constraint variable");
38
39 // If the variable is already assigned, check that the attribute is the same.
40 if (assigned[variable].has_value()) {
41 if (attr == assigned[variable].value()) {
42 return success();
43 }
44 if (emitError)
45 return emitError() << "expected '" << assigned[variable].value()
46 << "' but got '" << attr << "'";
47 return failure();
48 }
49
50 // Otherwise, check the constraint and assign the attribute to the variable.
51 LogicalResult result = constraints[variable]->verify(emitError, attr, *this);
52 if (succeeded(result))
53 assigned[variable] = attr;
54
55 return result;
56 }
57
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const58 LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
59 Attribute attr,
60 ConstraintVerifier &context) const {
61 if (attr == expectedAttribute)
62 return success();
63
64 if (emitError)
65 return emitError() << "expected '" << expectedAttribute << "' but got '"
66 << attr << "'";
67 return failure();
68 }
69
70 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const71 BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
72 Attribute attr, ConstraintVerifier &context) const {
73 if (attr.getTypeID() == baseTypeID)
74 return success();
75
76 if (emitError)
77 return emitError() << "expected base attribute '" << baseName
78 << "' but got '" << attr.getAbstractAttribute().getName()
79 << "'";
80 return failure();
81 }
82
83 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const84 BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
85 Attribute attr, ConstraintVerifier &context) const {
86 auto typeAttr = dyn_cast<TypeAttr>(attr);
87 if (!typeAttr) {
88 if (emitError)
89 return emitError() << "expected type, got attribute '" << attr;
90 return failure();
91 }
92
93 Type type = typeAttr.getValue();
94 if (type.getTypeID() == baseTypeID)
95 return success();
96
97 if (emitError)
98 return emitError() << "expected base type '" << baseName << "' but got '"
99 << type.getAbstractType().getName() << "'";
100 return failure();
101 }
102
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const103 LogicalResult DynParametricAttrConstraint::verify(
104 function_ref<InFlightDiagnostic()> emitError, Attribute attr,
105 ConstraintVerifier &context) const {
106
107 // Check that the base is the expected one.
108 auto dynAttr = dyn_cast<DynamicAttr>(attr);
109 if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
110 if (emitError) {
111 StringRef dialectName = attrDef->getDialect()->getNamespace();
112 StringRef attrName = attrDef->getName();
113 return emitError() << "expected base attribute '" << attrName << '.'
114 << dialectName << "' but got '" << attr << "'";
115 }
116 return failure();
117 }
118
119 // Check that the parameters satisfy the constraints.
120 ArrayRef<Attribute> params = dynAttr.getParams();
121 if (params.size() != constraints.size()) {
122 if (emitError) {
123 StringRef dialectName = attrDef->getDialect()->getNamespace();
124 StringRef attrName = attrDef->getName();
125 emitError() << "attribute '" << dialectName << "." << attrName
126 << "' expects " << params.size() << " parameters but got "
127 << constraints.size();
128 }
129 return failure();
130 }
131
132 for (size_t i = 0, s = params.size(); i < s; i++)
133 if (failed(context.verify(emitError, params[i], constraints[i])))
134 return failure();
135
136 return success();
137 }
138
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const139 LogicalResult DynParametricTypeConstraint::verify(
140 function_ref<InFlightDiagnostic()> emitError, Attribute attr,
141 ConstraintVerifier &context) const {
142 // Check that the base is a TypeAttr.
143 auto typeAttr = dyn_cast<TypeAttr>(attr);
144 if (!typeAttr) {
145 if (emitError)
146 return emitError() << "expected type, got attribute '" << attr;
147 return failure();
148 }
149
150 // Check that the type base is the expected one.
151 auto dynType = dyn_cast<DynamicType>(typeAttr.getValue());
152 if (!dynType || dynType.getTypeDef() != typeDef) {
153 if (emitError) {
154 StringRef dialectName = typeDef->getDialect()->getNamespace();
155 StringRef attrName = typeDef->getName();
156 return emitError() << "expected base type '" << dialectName << '.'
157 << attrName << "' but got '" << attr << "'";
158 }
159 return failure();
160 }
161
162 // Check that the parameters satisfy the constraints.
163 ArrayRef<Attribute> params = dynType.getParams();
164 if (params.size() != constraints.size()) {
165 if (emitError) {
166 StringRef dialectName = typeDef->getDialect()->getNamespace();
167 StringRef attrName = typeDef->getName();
168 emitError() << "attribute '" << dialectName << "." << attrName
169 << "' expects " << params.size() << " parameters but got "
170 << constraints.size();
171 }
172 return failure();
173 }
174
175 for (size_t i = 0, s = params.size(); i < s; i++)
176 if (failed(context.verify(emitError, params[i], constraints[i])))
177 return failure();
178
179 return success();
180 }
181
182 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const183 AnyOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
184 Attribute attr, ConstraintVerifier &context) const {
185 for (unsigned constr : constraints) {
186 // We do not pass the `emitError` here, since we want to emit an error
187 // only if none of the constraints are satisfied.
188 if (succeeded(context.verify({}, attr, constr))) {
189 return success();
190 }
191 }
192
193 if (emitError)
194 return emitError() << "'" << attr << "' does not satisfy the constraint";
195 return failure();
196 }
197
198 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const199 AllOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
200 Attribute attr, ConstraintVerifier &context) const {
201 for (unsigned constr : constraints) {
202 if (failed(context.verify(emitError, attr, constr))) {
203 return failure();
204 }
205 }
206
207 return success();
208 }
209
210 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Attribute attr,ConstraintVerifier & context) const211 AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
212 Attribute attr,
213 ConstraintVerifier &context) const {
214 return success();
215 }
216
verify(mlir::Region & region,ConstraintVerifier & constraintContext)217 LogicalResult RegionConstraint::verify(mlir::Region ®ion,
218 ConstraintVerifier &constraintContext) {
219 const auto emitError = [parentOp = region.getParentOp()](mlir::Location loc) {
220 return [loc, parentOp] {
221 InFlightDiagnostic diag = mlir::emitError(loc);
222 // If we already have been given location of the parent operation, which
223 // might happen when the region location is passed, we do not want to
224 // produce the note on the same location
225 if (loc != parentOp->getLoc())
226 diag.attachNote(parentOp->getLoc()).append("see the operation");
227 return diag;
228 };
229 };
230
231 if (blockCount.has_value() && *blockCount != region.getBlocks().size()) {
232 return emitError(region.getLoc())()
233 << "expected region " << region.getRegionNumber() << " to have "
234 << *blockCount << " block(s) but got " << region.getBlocks().size();
235 }
236
237 if (argumentConstraints.has_value()) {
238 auto actualArgs = region.getArguments();
239 if (actualArgs.size() != argumentConstraints->size()) {
240 const mlir::Location firstArgLoc =
241 actualArgs.empty() ? region.getLoc() : actualArgs.front().getLoc();
242 return emitError(firstArgLoc)()
243 << "expected region " << region.getRegionNumber() << " to have "
244 << argumentConstraints->size() << " arguments but got "
245 << actualArgs.size();
246 }
247
248 for (auto [arg, constraint] : llvm::zip(actualArgs, *argumentConstraints)) {
249 mlir::Attribute type = TypeAttr::get(arg.getType());
250 if (failed(constraintContext.verify(emitError(arg.getLoc()), type,
251 constraint))) {
252 return failure();
253 }
254 }
255 }
256 return success();
257 }
258