xref: /llvm-project/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp (revision b9f73714359a2839b651ee094cad481fb6a636ce)
1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/Interfaces/FunctionImplementation.h"
12 
13 using namespace mlir;
14 using namespace mlir::ml_program;
15 
16 //===----------------------------------------------------------------------===//
17 // Custom asm helpers
18 //===----------------------------------------------------------------------===//
19 
20 /// Parse and print an ordering clause for a variadic of consuming tokens
21 /// and an producing token.
22 ///
23 /// Syntax:
24 ///   ordering(%0, %1 -> !ml_program.token)
25 ///   ordering(() -> !ml_program.token)
26 ///
27 /// If both the consuming and producing token are not present on the op, then
28 /// the clause prints nothing.
parseTokenOrdering(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & consumeTokens,Type & produceTokenType)29 static ParseResult parseTokenOrdering(
30     OpAsmParser &parser,
31     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
32     Type &produceTokenType) {
33   if (failed(parser.parseOptionalKeyword("ordering")) ||
34       failed(parser.parseLParen()))
35     return success();
36 
37   // Parse consuming token list. If there are no consuming tokens, the
38   // '()' null list represents this.
39   if (succeeded(parser.parseOptionalLParen())) {
40     if (failed(parser.parseRParen()))
41       return failure();
42   } else {
43     if (failed(parser.parseOperandList(consumeTokens,
44                                        /*requiredOperandCount=*/-1)))
45       return failure();
46   }
47 
48   // Parse producer token.
49   if (failed(parser.parseArrow()))
50     return failure();
51   if (failed(parser.parseType(produceTokenType)))
52     return failure();
53 
54   if (failed(parser.parseRParen()))
55     return failure();
56 
57   return success();
58 }
59 
printTokenOrdering(OpAsmPrinter & p,Operation * op,OperandRange consumeTokens,Type produceTokenType)60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61                                OperandRange consumeTokens,
62                                Type produceTokenType) {
63   if (consumeTokens.empty() && !produceTokenType)
64     return;
65 
66   p << " ordering(";
67   if (consumeTokens.empty())
68     p << "()";
69   else
70     p.printOperands(consumeTokens);
71   if (produceTokenType) {
72     p << " -> ";
73     p.printType(produceTokenType);
74   }
75   p << ")";
76 }
77 
78 /// some.op custom<TypeOrAttr>($type, $attr)
79 ///
80 /// Uninitialized:
81 ///   some.op : tensor<3xi32>
82 /// Initialized to narrower type than op:
83 ///   some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
parseTypedInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & attr)84 static ParseResult parseTypedInitialValue(OpAsmParser &parser,
85                                           TypeAttr &typeAttr, Attribute &attr) {
86   if (succeeded(parser.parseOptionalLParen())) {
87     if (failed(parser.parseAttribute(attr)))
88       return failure();
89     if (failed(parser.parseRParen()))
90       return failure();
91   }
92 
93   Type type;
94   if (failed(parser.parseColonType(type)))
95     return failure();
96   typeAttr = TypeAttr::get(type);
97   return success();
98 }
99 
printTypedInitialValue(OpAsmPrinter & p,Operation * op,TypeAttr type,Attribute attr)100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
101                                    TypeAttr type, Attribute attr) {
102   if (attr) {
103     p << "(";
104     p.printAttribute(attr);
105     p << ")";
106   }
107 
108   p << " : ";
109   p.printAttribute(type);
110 }
111 
112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113 /// ->
114 /// some.op public @foo
115 /// some.op private @foo
parseSymbolVisibility(OpAsmParser & parser,StringAttr & symVisibilityAttr)116 static ParseResult parseSymbolVisibility(OpAsmParser &parser,
117                                          StringAttr &symVisibilityAttr) {
118   StringRef symVisibility;
119   (void)parser.parseOptionalKeyword(&symVisibility,
120                                     {"public", "private", "nested"});
121   if (symVisibility.empty())
122     return parser.emitError(parser.getCurrentLocation())
123            << "expected 'public', 'private', or 'nested'";
124   if (!symVisibility.empty())
125     symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126   return success();
127 }
128 
printSymbolVisibility(OpAsmPrinter & p,Operation * op,StringAttr symVisibilityAttr)129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
130                                   StringAttr symVisibilityAttr) {
131   if (!symVisibilityAttr)
132     p << "public";
133   else
134     p << symVisibilityAttr.getValue();
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // TableGen'd op method definitions
139 //===----------------------------------------------------------------------===//
140 
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143 
144 //===----------------------------------------------------------------------===//
145 // FuncOp
146 //===----------------------------------------------------------------------===//
147 
parse(OpAsmParser & parser,OperationState & result)148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
149   auto buildFuncType =
150       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
151          function_interface_impl::VariadicFlag,
152          std::string &) { return builder.getFunctionType(argTypes, results); };
153 
154   return function_interface_impl::parseFunctionOp(
155       parser, result, /*allowVariadic=*/false,
156       getFunctionTypeAttrName(result.name), buildFuncType,
157       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
158 }
159 
print(OpAsmPrinter & p)160 void FuncOp::print(OpAsmPrinter &p) {
161   function_interface_impl::printFunctionOp(
162       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
163       getArgAttrsAttrName(), getResAttrsAttrName());
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // GlobalOp
168 //===----------------------------------------------------------------------===//
169 
verify()170 LogicalResult GlobalOp::verify() {
171   if (!getIsMutable() && !getValue())
172     return emitOpError() << "immutable global must have an initial value";
173   return success();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // GlobalLoadOp
178 //===----------------------------------------------------------------------===//
179 
getGlobalOp(SymbolTableCollection & symbolTable)180 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
181   for (auto *parent = getOperation()->getParentOp(); parent;
182        parent = parent->getParentOp()) {
183     if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
184             parent, getGlobalAttr())) {
185       return nearest;
186     }
187   }
188   return {};
189 }
190 
191 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)192 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
193   GlobalOp referrent = getGlobalOp(symbolTable);
194   if (!referrent)
195     return emitOpError() << "undefined global: " << getGlobal();
196 
197   if (referrent.getType() != getResult().getType()) {
198     return emitOpError() << "cannot load from global typed "
199                          << referrent.getType() << " as "
200                          << getResult().getType();
201   }
202 
203   return success();
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // GlobalLoadConstOp
208 //===----------------------------------------------------------------------===//
209 
getGlobalOp(SymbolTableCollection & symbolTable)210 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
211   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
212       getOperation()->getParentOp(), getGlobalAttr());
213 }
214 
215 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)216 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
217   GlobalOp referrent = getGlobalOp(symbolTable);
218   if (!referrent)
219     return emitOpError() << "undefined global: " << getGlobal();
220 
221   if (referrent.getIsMutable())
222     return emitOpError() << "cannot load as const from mutable global "
223                          << getGlobal();
224 
225   if (referrent.getType() != getResult().getType())
226     return emitOpError() << "cannot load from global typed "
227                          << referrent.getType() << " as "
228                          << getResult().getType();
229 
230   return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // GlobalLoadGraphOp
235 //===----------------------------------------------------------------------===//
236 
getGlobalOp(SymbolTableCollection & symbolTable)237 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
238   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
239       getOperation()->getParentOp(), getGlobalAttr());
240 }
241 
242 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)243 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
244   GlobalOp referrent = getGlobalOp(symbolTable);
245   if (!referrent)
246     return emitOpError() << "undefined global: " << getGlobal();
247 
248   if (referrent.getType() != getResult().getType()) {
249     return emitOpError() << "cannot load from global typed "
250                          << referrent.getType() << " as "
251                          << getResult().getType();
252   }
253 
254   return success();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // GlobalStoreOp
259 //===----------------------------------------------------------------------===//
260 
getGlobalOp(SymbolTableCollection & symbolTable)261 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
262   for (auto *parent = getOperation()->getParentOp(); parent;) {
263     if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
264             parent, getGlobalAttr())) {
265       return nearest;
266     }
267     parent = parent->getParentOp();
268   }
269   return {};
270 }
271 
272 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)273 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
274   GlobalOp referrent = getGlobalOp(symbolTable);
275   if (!referrent)
276     return emitOpError() << "undefined global: " << getGlobal();
277 
278   if (!referrent.getIsMutable()) {
279     return emitOpError() << "cannot store to an immutable global "
280                          << getGlobal();
281   }
282 
283   if (referrent.getType() != getValue().getType()) {
284     return emitOpError() << "cannot store to a global typed "
285                          << referrent.getType() << " from "
286                          << getValue().getType();
287   }
288 
289   return success();
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // GlobalStoreGraphOp
294 //===----------------------------------------------------------------------===//
295 
getGlobalOp(SymbolTableCollection & symbolTable)296 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
297   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
298       getOperation()->getParentOp(), getGlobalAttr());
299 }
300 
301 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)302 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
303   GlobalOp referrent = getGlobalOp(symbolTable);
304   if (!referrent)
305     return emitOpError() << "undefined global: " << getGlobal();
306 
307   if (!referrent.getIsMutable()) {
308     return emitOpError() << "cannot store to an immutable global "
309                          << getGlobal();
310   }
311 
312   if (referrent.getType() != getValue().getType()) {
313     return emitOpError() << "cannot store to a global typed "
314                          << referrent.getType() << " from "
315                          << getValue().getType();
316   }
317 
318   return success();
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // SubgraphOp
323 //===----------------------------------------------------------------------===//
324 
parse(OpAsmParser & parser,OperationState & result)325 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
326   auto buildFuncType =
327       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
328          function_interface_impl::VariadicFlag,
329          std::string &) { return builder.getFunctionType(argTypes, results); };
330 
331   return function_interface_impl::parseFunctionOp(
332       parser, result, /*allowVariadic=*/false,
333       getFunctionTypeAttrName(result.name), buildFuncType,
334       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
335 }
336 
print(OpAsmPrinter & p)337 void SubgraphOp::print(OpAsmPrinter &p) {
338   function_interface_impl::printFunctionOp(
339       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
340       getArgAttrsAttrName(), getResAttrsAttrName());
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // OutputOp
345 //===----------------------------------------------------------------------===//
346 
verify()347 LogicalResult OutputOp::verify() {
348   auto function = cast<SubgraphOp>((*this)->getParentOp());
349 
350   // The operand number and types must match the function signature.
351   const auto &results = function.getFunctionType().getResults();
352   if (getNumOperands() != results.size())
353     return emitOpError("has ")
354            << getNumOperands() << " operands, but enclosing function (@"
355            << function.getName() << ") outputs " << results.size();
356 
357   for (unsigned i = 0, e = results.size(); i != e; ++i)
358     if (getOperand(i).getType() != results[i])
359       return emitError() << "type of output operand " << i << " ("
360                          << getOperand(i).getType()
361                          << ") doesn't match function result type ("
362                          << results[i] << ")"
363                          << " in function @" << function.getName();
364 
365   return success();
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // ReturnOp
370 //===----------------------------------------------------------------------===//
371 
verify()372 LogicalResult ReturnOp::verify() {
373   auto function = cast<FuncOp>((*this)->getParentOp());
374 
375   // The operand number and types must match the function signature.
376   const auto &results = function.getFunctionType().getResults();
377   if (getNumOperands() != results.size())
378     return emitOpError("has ")
379            << getNumOperands() << " operands, but enclosing function (@"
380            << function.getName() << ") returns " << results.size();
381 
382   for (unsigned i = 0, e = results.size(); i != e; ++i)
383     if (getOperand(i).getType() != results[i])
384       return emitError() << "type of return operand " << i << " ("
385                          << getOperand(i).getType()
386                          << ") doesn't match function result type ("
387                          << results[i] << ")"
388                          << " in function @" << function.getName();
389 
390   return success();
391 }
392