xref: /llvm-project/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp (revision 53406427cdf4290986d1a48ea0d582ad195bff15)
1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionImplementation.h"
12 
13 using namespace mlir;
14 using namespace mlir::ml_program;
15 
16 //===----------------------------------------------------------------------===//
17 // Custom asm helpers
18 //===----------------------------------------------------------------------===//
19 
20 /// Parse and print an ordering clause for a variadic of consuming tokens
21 /// and an producing token.
22 ///
23 /// Syntax:
24 ///   ordering(%0, %1 -> !ml_program.token)
25 ///   ordering(() -> !ml_program.token)
26 ///
27 /// If both the consuming and producing token are not present on the op, then
28 /// the clause prints nothing.
29 static ParseResult parseTokenOrdering(
30     OpAsmParser &parser,
31     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
32     Type &produceTokenType) {
33   if (failed(parser.parseOptionalKeyword("ordering")) ||
34       failed(parser.parseLParen()))
35     return success();
36 
37   // Parse consuming token list. If there are no consuming tokens, the
38   // '()' null list represents this.
39   if (succeeded(parser.parseOptionalLParen())) {
40     if (failed(parser.parseRParen()))
41       return failure();
42   } else {
43     if (failed(parser.parseOperandList(consumeTokens,
44                                        /*requiredOperandCount=*/-1)))
45       return failure();
46   }
47 
48   // Parse producer token.
49   if (failed(parser.parseArrow()))
50     return failure();
51   if (failed(parser.parseType(produceTokenType)))
52     return failure();
53 
54   if (failed(parser.parseRParen()))
55     return failure();
56 
57   return success();
58 }
59 
60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61                                OperandRange consumeTokens,
62                                Type produceTokenType) {
63   if (consumeTokens.empty() && !produceTokenType)
64     return;
65 
66   p << " ordering(";
67   if (consumeTokens.empty())
68     p << "()";
69   else
70     p.printOperands(consumeTokens);
71   if (produceTokenType) {
72     p << " -> ";
73     p.printType(produceTokenType);
74   }
75   p << ")";
76 }
77 
78 /// some.op custom<TypeOrAttr>($type, $attr)
79 ///
80 /// Uninitialized:
81 ///   some.op : tensor<3xi32>
82 /// Initialized to narrower type than op:
83 ///   some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
84 static ParseResult parseTypedInitialValue(OpAsmParser &parser,
85                                           TypeAttr &typeAttr, Attribute &attr) {
86   if (succeeded(parser.parseOptionalLParen())) {
87     if (failed(parser.parseAttribute(attr)))
88       return failure();
89     if (failed(parser.parseRParen()))
90       return failure();
91   }
92 
93   Type type;
94   if (failed(parser.parseColonType(type)))
95     return failure();
96   typeAttr = TypeAttr::get(type);
97   return success();
98 }
99 
100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
101                                    TypeAttr type, Attribute attr) {
102   if (attr) {
103     p << "(";
104     p.printAttribute(attr);
105     p << ")";
106   }
107 
108   p << " : ";
109   p.printAttribute(type);
110 }
111 
112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113 /// ->
114 /// some.op public @foo
115 /// some.op private @foo
116 static ParseResult parseSymbolVisibility(OpAsmParser &parser,
117                                          StringAttr &symVisibilityAttr) {
118   StringRef symVisibility;
119   (void)parser.parseOptionalKeyword(&symVisibility,
120                                     {"public", "private", "nested"});
121   if (symVisibility.empty())
122     return parser.emitError(parser.getCurrentLocation())
123            << "expected 'public', 'private', or 'nested'";
124   if (!symVisibility.empty())
125     symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126   return success();
127 }
128 
129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
130                                   StringAttr symVisibilityAttr) {
131   if (!symVisibilityAttr)
132     p << "public";
133   else
134     p << symVisibilityAttr.getValue();
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // TableGen'd op method definitions
139 //===----------------------------------------------------------------------===//
140 
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143 
144 //===----------------------------------------------------------------------===//
145 // FuncOp
146 //===----------------------------------------------------------------------===//
147 
148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
149   auto buildFuncType =
150       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
151          function_interface_impl::VariadicFlag,
152          std::string &) { return builder.getFunctionType(argTypes, results); };
153 
154   return function_interface_impl::parseFunctionOp(
155       parser, result, /*allowVariadic=*/false,
156       getFunctionTypeAttrName(result.name), buildFuncType,
157       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
158 }
159 
160 void FuncOp::print(OpAsmPrinter &p) {
161   function_interface_impl::printFunctionOp(
162       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
163       getArgAttrsAttrName(), getResAttrsAttrName());
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // GlobalOp
168 //===----------------------------------------------------------------------===//
169 
170 LogicalResult GlobalOp::verify() {
171   if (!getIsMutable() && !getValue())
172     return emitOpError() << "immutable global must have an initial value";
173   return success();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // GlobalLoadOp
178 //===----------------------------------------------------------------------===//
179 
180 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
181   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
182       getOperation()->getParentOp(), getGlobalAttr());
183 }
184 
185 LogicalResult
186 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
187   GlobalOp referrent = getGlobalOp(symbolTable);
188   if (!referrent)
189     return emitOpError() << "undefined global: " << getGlobal();
190 
191   if (referrent.getType() != getResult().getType()) {
192     return emitOpError() << "cannot load from global typed "
193                          << referrent.getType() << " as "
194                          << getResult().getType();
195   }
196 
197   return success();
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // GlobalLoadConstOp
202 //===----------------------------------------------------------------------===//
203 
204 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
205   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
206       getOperation()->getParentOp(), getGlobalAttr());
207 }
208 
209 LogicalResult
210 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
211   GlobalOp referrent = getGlobalOp(symbolTable);
212   if (!referrent)
213     return emitOpError() << "undefined global: " << getGlobal();
214 
215   if (referrent.getIsMutable())
216     return emitOpError() << "cannot load as const from mutable global "
217                          << getGlobal();
218 
219   if (referrent.getType() != getResult().getType())
220     return emitOpError() << "cannot load from global typed "
221                          << referrent.getType() << " as "
222                          << getResult().getType();
223 
224   return success();
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // GlobalLoadGraphOp
229 //===----------------------------------------------------------------------===//
230 
231 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
232   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
233       getOperation()->getParentOp(), getGlobalAttr());
234 }
235 
236 LogicalResult
237 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
238   GlobalOp referrent = getGlobalOp(symbolTable);
239   if (!referrent)
240     return emitOpError() << "undefined global: " << getGlobal();
241 
242   if (referrent.getType() != getResult().getType()) {
243     return emitOpError() << "cannot load from global typed "
244                          << referrent.getType() << " as "
245                          << getResult().getType();
246   }
247 
248   return success();
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // GlobalStoreOp
253 //===----------------------------------------------------------------------===//
254 
255 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
256   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
257       getOperation()->getParentOp(), getGlobalAttr());
258 }
259 
260 LogicalResult
261 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
262   GlobalOp referrent = getGlobalOp(symbolTable);
263   if (!referrent)
264     return emitOpError() << "undefined global: " << getGlobal();
265 
266   if (!referrent.getIsMutable()) {
267     return emitOpError() << "cannot store to an immutable global "
268                          << getGlobal();
269   }
270 
271   if (referrent.getType() != getValue().getType()) {
272     return emitOpError() << "cannot store to a global typed "
273                          << referrent.getType() << " from "
274                          << getValue().getType();
275   }
276 
277   return success();
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // GlobalStoreGraphOp
282 //===----------------------------------------------------------------------===//
283 
284 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
285   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
286       getOperation()->getParentOp(), getGlobalAttr());
287 }
288 
289 LogicalResult
290 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
291   GlobalOp referrent = getGlobalOp(symbolTable);
292   if (!referrent)
293     return emitOpError() << "undefined global: " << getGlobal();
294 
295   if (!referrent.getIsMutable()) {
296     return emitOpError() << "cannot store to an immutable global "
297                          << getGlobal();
298   }
299 
300   if (referrent.getType() != getValue().getType()) {
301     return emitOpError() << "cannot store to a global typed "
302                          << referrent.getType() << " from "
303                          << getValue().getType();
304   }
305 
306   return success();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // SubgraphOp
311 //===----------------------------------------------------------------------===//
312 
313 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
314   auto buildFuncType =
315       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
316          function_interface_impl::VariadicFlag,
317          std::string &) { return builder.getFunctionType(argTypes, results); };
318 
319   return function_interface_impl::parseFunctionOp(
320       parser, result, /*allowVariadic=*/false,
321       getFunctionTypeAttrName(result.name), buildFuncType,
322       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
323 }
324 
325 void SubgraphOp::print(OpAsmPrinter &p) {
326   function_interface_impl::printFunctionOp(
327       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
328       getArgAttrsAttrName(), getResAttrsAttrName());
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // OutputOp
333 //===----------------------------------------------------------------------===//
334 
335 LogicalResult OutputOp::verify() {
336   auto function = cast<SubgraphOp>((*this)->getParentOp());
337 
338   // The operand number and types must match the function signature.
339   const auto &results = function.getFunctionType().getResults();
340   if (getNumOperands() != results.size())
341     return emitOpError("has ")
342            << getNumOperands() << " operands, but enclosing function (@"
343            << function.getName() << ") outputs " << results.size();
344 
345   for (unsigned i = 0, e = results.size(); i != e; ++i)
346     if (getOperand(i).getType() != results[i])
347       return emitError() << "type of output operand " << i << " ("
348                          << getOperand(i).getType()
349                          << ") doesn't match function result type ("
350                          << results[i] << ")"
351                          << " in function @" << function.getName();
352 
353   return success();
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // ReturnOp
358 //===----------------------------------------------------------------------===//
359 
360 LogicalResult ReturnOp::verify() {
361   auto function = cast<FuncOp>((*this)->getParentOp());
362 
363   // The operand number and types must match the function signature.
364   const auto &results = function.getFunctionType().getResults();
365   if (getNumOperands() != results.size())
366     return emitOpError("has ")
367            << getNumOperands() << " operands, but enclosing function (@"
368            << function.getName() << ") returns " << results.size();
369 
370   for (unsigned i = 0, e = results.size(); i != e; ++i)
371     if (getOperand(i).getType() != results[i])
372       return emitError() << "type of return operand " << i << " ("
373                          << getOperand(i).getType()
374                          << ") doesn't match function result type ("
375                          << results[i] << ")"
376                          << " in function @" << function.getName();
377 
378   return success();
379 }
380