xref: /llvm-project/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/Interfaces/FunctionImplementation.h"
14 
15 using namespace mlir;
16 using namespace mlir::pdl_interp;
17 
18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
19 
20 //===----------------------------------------------------------------------===//
21 // PDLInterp Dialect
22 //===----------------------------------------------------------------------===//
23 
initialize()24 void PDLInterpDialect::initialize() {
25   addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
28       >();
29 }
30 
31 template <typename OpT>
verifySwitchOp(OpT op)32 static LogicalResult verifySwitchOp(OpT op) {
33   // Verify that the number of case destinations matches the number of case
34   // values.
35   size_t numDests = op.getCases().size();
36   size_t numValues = op.getCaseValues().size();
37   if (numDests != numValues) {
38     return op.emitOpError(
39                "expected number of cases to match the number of case "
40                "values, got ")
41            << numDests << " but expected " << numValues;
42   }
43   return success();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // pdl_interp::CreateOperationOp
48 //===----------------------------------------------------------------------===//
49 
verify()50 LogicalResult CreateOperationOp::verify() {
51   if (!getInferredResultTypes())
52     return success();
53   if (!getInputResultTypes().empty()) {
54     return emitOpError("with inferred results cannot also have "
55                        "explicit result types");
56   }
57   OperationName opName(getName(), getContext());
58   if (!opName.hasInterface<InferTypeOpInterface>()) {
59     return emitOpError()
60            << "has inferred results, but the created operation '" << opName
61            << "' does not support result type inference (or is not "
62               "registered)";
63   }
64   return success();
65 }
66 
parseCreateOperationOpAttributes(OpAsmParser & p,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & attrOperands,ArrayAttr & attrNamesAttr)67 static ParseResult parseCreateOperationOpAttributes(
68     OpAsmParser &p,
69     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
70     ArrayAttr &attrNamesAttr) {
71   Builder &builder = p.getBuilder();
72   SmallVector<Attribute, 4> attrNames;
73   if (succeeded(p.parseOptionalLBrace())) {
74     auto parseOperands = [&]() {
75       StringAttr nameAttr;
76       OpAsmParser::UnresolvedOperand operand;
77       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
78           p.parseOperand(operand))
79         return failure();
80       attrNames.push_back(nameAttr);
81       attrOperands.push_back(operand);
82       return success();
83     };
84     if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
85       return failure();
86   }
87   attrNamesAttr = builder.getArrayAttr(attrNames);
88   return success();
89 }
90 
printCreateOperationOpAttributes(OpAsmPrinter & p,CreateOperationOp op,OperandRange attrArgs,ArrayAttr attrNames)91 static void printCreateOperationOpAttributes(OpAsmPrinter &p,
92                                              CreateOperationOp op,
93                                              OperandRange attrArgs,
94                                              ArrayAttr attrNames) {
95   if (attrNames.empty())
96     return;
97   p << " {";
98   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
99                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
100   p << '}';
101 }
102 
parseCreateOperationOpResults(OpAsmParser & p,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & resultOperands,SmallVectorImpl<Type> & resultTypes,UnitAttr & inferredResultTypes)103 static ParseResult parseCreateOperationOpResults(
104     OpAsmParser &p,
105     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
106     SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
107   if (failed(p.parseOptionalArrow()))
108     return success();
109 
110   // Handle the case of inferred results.
111   if (succeeded(p.parseOptionalLess())) {
112     if (p.parseKeyword("inferred") || p.parseGreater())
113       return failure();
114     inferredResultTypes = p.getBuilder().getUnitAttr();
115     return success();
116   }
117 
118   // Otherwise, parse the explicit results.
119   return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
120                  p.parseColonTypeList(resultTypes) || p.parseRParen());
121 }
122 
printCreateOperationOpResults(OpAsmPrinter & p,CreateOperationOp op,OperandRange resultOperands,TypeRange resultTypes,UnitAttr inferredResultTypes)123 static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
124                                           OperandRange resultOperands,
125                                           TypeRange resultTypes,
126                                           UnitAttr inferredResultTypes) {
127   // Handle the case of inferred results.
128   if (inferredResultTypes) {
129     p << " -> <inferred>";
130     return;
131   }
132 
133   // Otherwise, handle the explicit results.
134   if (!resultTypes.empty())
135     p << " -> (" << resultOperands << " : " << resultTypes << ")";
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // pdl_interp::ForEachOp
140 //===----------------------------------------------------------------------===//
141 
build(::mlir::OpBuilder & builder,::mlir::OperationState & state,Value range,Block * successor,bool initLoop)142 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
143                       Value range, Block *successor, bool initLoop) {
144   build(builder, state, range, successor);
145   if (initLoop) {
146     // Create the block and the loop variable.
147     // FIXME: Allow passing in a proper location for the loop variable.
148     auto rangeType = llvm::cast<pdl::RangeType>(range.getType());
149     state.regions.front()->emplaceBlock();
150     state.regions.front()->addArgument(rangeType.getElementType(),
151                                        state.location);
152   }
153 }
154 
parse(OpAsmParser & parser,OperationState & result)155 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
156   // Parse the loop variable followed by type.
157   OpAsmParser::Argument loopVariable;
158   OpAsmParser::UnresolvedOperand operandInfo;
159   if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
160       parser.parseKeyword("in", " after loop variable") ||
161       // Parse the operand (value range).
162       parser.parseOperand(operandInfo))
163     return failure();
164 
165   // Resolve the operand.
166   Type rangeType = pdl::RangeType::get(loopVariable.type);
167   if (parser.resolveOperand(operandInfo, rangeType, result.operands))
168     return failure();
169 
170   // Parse the body region.
171   Region *body = result.addRegion();
172   Block *successor;
173   if (parser.parseRegion(*body, loopVariable) ||
174       parser.parseOptionalAttrDict(result.attributes) ||
175       // Parse the successor.
176       parser.parseArrow() || parser.parseSuccessor(successor))
177     return failure();
178 
179   result.addSuccessors(successor);
180   return success();
181 }
182 
print(OpAsmPrinter & p)183 void ForEachOp::print(OpAsmPrinter &p) {
184   BlockArgument arg = getLoopVariable();
185   p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
186   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
187   p.printOptionalAttrDict((*this)->getAttrs());
188   p << " -> ";
189   p.printSuccessor(getSuccessor());
190 }
191 
verify()192 LogicalResult ForEachOp::verify() {
193   // Verify that the operation has exactly one argument.
194   if (getRegion().getNumArguments() != 1)
195     return emitOpError("requires exactly one argument");
196 
197   // Verify that the loop variable and the operand (value range)
198   // have compatible types.
199   BlockArgument arg = getLoopVariable();
200   Type rangeType = pdl::RangeType::get(arg.getType());
201   if (rangeType != getValues().getType())
202     return emitOpError("operand must be a range of loop variable type");
203 
204   return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // pdl_interp::FuncOp
209 //===----------------------------------------------------------------------===//
210 
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs)211 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
212                    FunctionType type, ArrayRef<NamedAttribute> attrs) {
213   buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
214 }
215 
parse(OpAsmParser & parser,OperationState & result)216 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
217   auto buildFuncType =
218       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
219          function_interface_impl::VariadicFlag,
220          std::string &) { return builder.getFunctionType(argTypes, results); };
221 
222   return function_interface_impl::parseFunctionOp(
223       parser, result, /*allowVariadic=*/false,
224       getFunctionTypeAttrName(result.name), buildFuncType,
225       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
226 }
227 
print(OpAsmPrinter & p)228 void FuncOp::print(OpAsmPrinter &p) {
229   function_interface_impl::printFunctionOp(
230       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
231       getArgAttrsAttrName(), getResAttrsAttrName());
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // pdl_interp::GetValueTypeOp
236 //===----------------------------------------------------------------------===//
237 
238 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
getGetValueTypeOpValueType(Type type)239 static Type getGetValueTypeOpValueType(Type type) {
240   Type valueTy = pdl::ValueType::get(type.getContext());
241   return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy)
242                                          : valueTy;
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // pdl::CreateRangeOp
247 //===----------------------------------------------------------------------===//
248 
parseRangeType(OpAsmParser & p,TypeRange argumentTypes,Type & resultType)249 static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
250                                   Type &resultType) {
251   // If arguments were provided, infer the result type from the argument list.
252   if (!argumentTypes.empty()) {
253     resultType =
254         pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
255     return success();
256   }
257   // Otherwise, parse the type as a trailing type.
258   return p.parseColonType(resultType);
259 }
260 
printRangeType(OpAsmPrinter & p,CreateRangeOp op,TypeRange argumentTypes,Type resultType)261 static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
262                            TypeRange argumentTypes, Type resultType) {
263   if (argumentTypes.empty())
264     p << ": " << resultType;
265 }
266 
verify()267 LogicalResult CreateRangeOp::verify() {
268   Type elementType = getType().getElementType();
269   for (Type operandType : getOperandTypes()) {
270     Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
271     if (operandElementType != elementType) {
272       return emitOpError("expected operand to have element type ")
273              << elementType << ", but got " << operandElementType;
274     }
275   }
276   return success();
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // pdl_interp::SwitchAttributeOp
281 //===----------------------------------------------------------------------===//
282 
verify()283 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
284 
285 //===----------------------------------------------------------------------===//
286 // pdl_interp::SwitchOperandCountOp
287 //===----------------------------------------------------------------------===//
288 
verify()289 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
290 
291 //===----------------------------------------------------------------------===//
292 // pdl_interp::SwitchOperationNameOp
293 //===----------------------------------------------------------------------===//
294 
verify()295 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
296 
297 //===----------------------------------------------------------------------===//
298 // pdl_interp::SwitchResultCountOp
299 //===----------------------------------------------------------------------===//
300 
verify()301 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
302 
303 //===----------------------------------------------------------------------===//
304 // pdl_interp::SwitchTypeOp
305 //===----------------------------------------------------------------------===//
306 
verify()307 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
308 
309 //===----------------------------------------------------------------------===//
310 // pdl_interp::SwitchTypesOp
311 //===----------------------------------------------------------------------===//
312 
verify()313 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
314 
315 //===----------------------------------------------------------------------===//
316 // TableGen Auto-Generated Op and Interface Definitions
317 //===----------------------------------------------------------------------===//
318 
319 #define GET_OP_CLASSES
320 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
321