1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/Interfaces/FunctionImplementation.h"
12
13 using namespace mlir;
14 using namespace mlir::ml_program;
15
16 //===----------------------------------------------------------------------===//
17 // Custom asm helpers
18 //===----------------------------------------------------------------------===//
19
20 /// Parse and print an ordering clause for a variadic of consuming tokens
21 /// and an producing token.
22 ///
23 /// Syntax:
24 /// ordering(%0, %1 -> !ml_program.token)
25 /// ordering(() -> !ml_program.token)
26 ///
27 /// If both the consuming and producing token are not present on the op, then
28 /// the clause prints nothing.
parseTokenOrdering(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & consumeTokens,Type & produceTokenType)29 static ParseResult parseTokenOrdering(
30 OpAsmParser &parser,
31 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
32 Type &produceTokenType) {
33 if (failed(parser.parseOptionalKeyword("ordering")) ||
34 failed(parser.parseLParen()))
35 return success();
36
37 // Parse consuming token list. If there are no consuming tokens, the
38 // '()' null list represents this.
39 if (succeeded(parser.parseOptionalLParen())) {
40 if (failed(parser.parseRParen()))
41 return failure();
42 } else {
43 if (failed(parser.parseOperandList(consumeTokens,
44 /*requiredOperandCount=*/-1)))
45 return failure();
46 }
47
48 // Parse producer token.
49 if (failed(parser.parseArrow()))
50 return failure();
51 if (failed(parser.parseType(produceTokenType)))
52 return failure();
53
54 if (failed(parser.parseRParen()))
55 return failure();
56
57 return success();
58 }
59
printTokenOrdering(OpAsmPrinter & p,Operation * op,OperandRange consumeTokens,Type produceTokenType)60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61 OperandRange consumeTokens,
62 Type produceTokenType) {
63 if (consumeTokens.empty() && !produceTokenType)
64 return;
65
66 p << " ordering(";
67 if (consumeTokens.empty())
68 p << "()";
69 else
70 p.printOperands(consumeTokens);
71 if (produceTokenType) {
72 p << " -> ";
73 p.printType(produceTokenType);
74 }
75 p << ")";
76 }
77
78 /// some.op custom<TypeOrAttr>($type, $attr)
79 ///
80 /// Uninitialized:
81 /// some.op : tensor<3xi32>
82 /// Initialized to narrower type than op:
83 /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
parseTypedInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & attr)84 static ParseResult parseTypedInitialValue(OpAsmParser &parser,
85 TypeAttr &typeAttr, Attribute &attr) {
86 if (succeeded(parser.parseOptionalLParen())) {
87 if (failed(parser.parseAttribute(attr)))
88 return failure();
89 if (failed(parser.parseRParen()))
90 return failure();
91 }
92
93 Type type;
94 if (failed(parser.parseColonType(type)))
95 return failure();
96 typeAttr = TypeAttr::get(type);
97 return success();
98 }
99
printTypedInitialValue(OpAsmPrinter & p,Operation * op,TypeAttr type,Attribute attr)100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
101 TypeAttr type, Attribute attr) {
102 if (attr) {
103 p << "(";
104 p.printAttribute(attr);
105 p << ")";
106 }
107
108 p << " : ";
109 p.printAttribute(type);
110 }
111
112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113 /// ->
114 /// some.op public @foo
115 /// some.op private @foo
parseSymbolVisibility(OpAsmParser & parser,StringAttr & symVisibilityAttr)116 static ParseResult parseSymbolVisibility(OpAsmParser &parser,
117 StringAttr &symVisibilityAttr) {
118 StringRef symVisibility;
119 (void)parser.parseOptionalKeyword(&symVisibility,
120 {"public", "private", "nested"});
121 if (symVisibility.empty())
122 return parser.emitError(parser.getCurrentLocation())
123 << "expected 'public', 'private', or 'nested'";
124 if (!symVisibility.empty())
125 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126 return success();
127 }
128
printSymbolVisibility(OpAsmPrinter & p,Operation * op,StringAttr symVisibilityAttr)129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
130 StringAttr symVisibilityAttr) {
131 if (!symVisibilityAttr)
132 p << "public";
133 else
134 p << symVisibilityAttr.getValue();
135 }
136
137 //===----------------------------------------------------------------------===//
138 // TableGen'd op method definitions
139 //===----------------------------------------------------------------------===//
140
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143
144 //===----------------------------------------------------------------------===//
145 // FuncOp
146 //===----------------------------------------------------------------------===//
147
parse(OpAsmParser & parser,OperationState & result)148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
149 auto buildFuncType =
150 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
151 function_interface_impl::VariadicFlag,
152 std::string &) { return builder.getFunctionType(argTypes, results); };
153
154 return function_interface_impl::parseFunctionOp(
155 parser, result, /*allowVariadic=*/false,
156 getFunctionTypeAttrName(result.name), buildFuncType,
157 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
158 }
159
print(OpAsmPrinter & p)160 void FuncOp::print(OpAsmPrinter &p) {
161 function_interface_impl::printFunctionOp(
162 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
163 getArgAttrsAttrName(), getResAttrsAttrName());
164 }
165
166 //===----------------------------------------------------------------------===//
167 // GlobalOp
168 //===----------------------------------------------------------------------===//
169
verify()170 LogicalResult GlobalOp::verify() {
171 if (!getIsMutable() && !getValue())
172 return emitOpError() << "immutable global must have an initial value";
173 return success();
174 }
175
176 //===----------------------------------------------------------------------===//
177 // GlobalLoadOp
178 //===----------------------------------------------------------------------===//
179
getGlobalOp(SymbolTableCollection & symbolTable)180 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
181 for (auto *parent = getOperation()->getParentOp(); parent;
182 parent = parent->getParentOp()) {
183 if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
184 parent, getGlobalAttr())) {
185 return nearest;
186 }
187 }
188 return {};
189 }
190
191 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)192 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
193 GlobalOp referrent = getGlobalOp(symbolTable);
194 if (!referrent)
195 return emitOpError() << "undefined global: " << getGlobal();
196
197 if (referrent.getType() != getResult().getType()) {
198 return emitOpError() << "cannot load from global typed "
199 << referrent.getType() << " as "
200 << getResult().getType();
201 }
202
203 return success();
204 }
205
206 //===----------------------------------------------------------------------===//
207 // GlobalLoadConstOp
208 //===----------------------------------------------------------------------===//
209
getGlobalOp(SymbolTableCollection & symbolTable)210 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
211 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
212 getOperation()->getParentOp(), getGlobalAttr());
213 }
214
215 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)216 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
217 GlobalOp referrent = getGlobalOp(symbolTable);
218 if (!referrent)
219 return emitOpError() << "undefined global: " << getGlobal();
220
221 if (referrent.getIsMutable())
222 return emitOpError() << "cannot load as const from mutable global "
223 << getGlobal();
224
225 if (referrent.getType() != getResult().getType())
226 return emitOpError() << "cannot load from global typed "
227 << referrent.getType() << " as "
228 << getResult().getType();
229
230 return success();
231 }
232
233 //===----------------------------------------------------------------------===//
234 // GlobalLoadGraphOp
235 //===----------------------------------------------------------------------===//
236
getGlobalOp(SymbolTableCollection & symbolTable)237 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
238 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
239 getOperation()->getParentOp(), getGlobalAttr());
240 }
241
242 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)243 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
244 GlobalOp referrent = getGlobalOp(symbolTable);
245 if (!referrent)
246 return emitOpError() << "undefined global: " << getGlobal();
247
248 if (referrent.getType() != getResult().getType()) {
249 return emitOpError() << "cannot load from global typed "
250 << referrent.getType() << " as "
251 << getResult().getType();
252 }
253
254 return success();
255 }
256
257 //===----------------------------------------------------------------------===//
258 // GlobalStoreOp
259 //===----------------------------------------------------------------------===//
260
getGlobalOp(SymbolTableCollection & symbolTable)261 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
262 for (auto *parent = getOperation()->getParentOp(); parent;) {
263 if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
264 parent, getGlobalAttr())) {
265 return nearest;
266 }
267 parent = parent->getParentOp();
268 }
269 return {};
270 }
271
272 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)273 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
274 GlobalOp referrent = getGlobalOp(symbolTable);
275 if (!referrent)
276 return emitOpError() << "undefined global: " << getGlobal();
277
278 if (!referrent.getIsMutable()) {
279 return emitOpError() << "cannot store to an immutable global "
280 << getGlobal();
281 }
282
283 if (referrent.getType() != getValue().getType()) {
284 return emitOpError() << "cannot store to a global typed "
285 << referrent.getType() << " from "
286 << getValue().getType();
287 }
288
289 return success();
290 }
291
292 //===----------------------------------------------------------------------===//
293 // GlobalStoreGraphOp
294 //===----------------------------------------------------------------------===//
295
getGlobalOp(SymbolTableCollection & symbolTable)296 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
297 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
298 getOperation()->getParentOp(), getGlobalAttr());
299 }
300
301 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)302 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
303 GlobalOp referrent = getGlobalOp(symbolTable);
304 if (!referrent)
305 return emitOpError() << "undefined global: " << getGlobal();
306
307 if (!referrent.getIsMutable()) {
308 return emitOpError() << "cannot store to an immutable global "
309 << getGlobal();
310 }
311
312 if (referrent.getType() != getValue().getType()) {
313 return emitOpError() << "cannot store to a global typed "
314 << referrent.getType() << " from "
315 << getValue().getType();
316 }
317
318 return success();
319 }
320
321 //===----------------------------------------------------------------------===//
322 // SubgraphOp
323 //===----------------------------------------------------------------------===//
324
parse(OpAsmParser & parser,OperationState & result)325 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
326 auto buildFuncType =
327 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
328 function_interface_impl::VariadicFlag,
329 std::string &) { return builder.getFunctionType(argTypes, results); };
330
331 return function_interface_impl::parseFunctionOp(
332 parser, result, /*allowVariadic=*/false,
333 getFunctionTypeAttrName(result.name), buildFuncType,
334 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
335 }
336
print(OpAsmPrinter & p)337 void SubgraphOp::print(OpAsmPrinter &p) {
338 function_interface_impl::printFunctionOp(
339 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
340 getArgAttrsAttrName(), getResAttrsAttrName());
341 }
342
343 //===----------------------------------------------------------------------===//
344 // OutputOp
345 //===----------------------------------------------------------------------===//
346
verify()347 LogicalResult OutputOp::verify() {
348 auto function = cast<SubgraphOp>((*this)->getParentOp());
349
350 // The operand number and types must match the function signature.
351 const auto &results = function.getFunctionType().getResults();
352 if (getNumOperands() != results.size())
353 return emitOpError("has ")
354 << getNumOperands() << " operands, but enclosing function (@"
355 << function.getName() << ") outputs " << results.size();
356
357 for (unsigned i = 0, e = results.size(); i != e; ++i)
358 if (getOperand(i).getType() != results[i])
359 return emitError() << "type of output operand " << i << " ("
360 << getOperand(i).getType()
361 << ") doesn't match function result type ("
362 << results[i] << ")"
363 << " in function @" << function.getName();
364
365 return success();
366 }
367
368 //===----------------------------------------------------------------------===//
369 // ReturnOp
370 //===----------------------------------------------------------------------===//
371
verify()372 LogicalResult ReturnOp::verify() {
373 auto function = cast<FuncOp>((*this)->getParentOp());
374
375 // The operand number and types must match the function signature.
376 const auto &results = function.getFunctionType().getResults();
377 if (getNumOperands() != results.size())
378 return emitOpError("has ")
379 << getNumOperands() << " operands, but enclosing function (@"
380 << function.getName() << ") returns " << results.size();
381
382 for (unsigned i = 0, e = results.size(); i != e; ++i)
383 if (getOperand(i).getType() != results[i])
384 return emitError() << "type of return operand " << i << " ("
385 << getOperand(i).getType()
386 << ") doesn't match function result type ("
387 << results[i] << ")"
388 << " in function @" << function.getName();
389
390 return success();
391 }
392