xref: /llvm-project/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp (revision 2bb252852c72a4563fd7cd36604a1698c34d22a8)
1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionImplementation.h"
12 
13 using namespace mlir;
14 using namespace mlir::ml_program;
15 
16 //===----------------------------------------------------------------------===//
17 // Custom asm helpers
18 //===----------------------------------------------------------------------===//
19 
20 /// some.op custom<TypeOrAttr>($type, $attr)
21 ///
22 /// Uninitialized:
23 ///   some.op : tensor<3xi32>
24 /// Initialized to narrower type than op:
25 ///   some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
26 static ParseResult parseTypedInitialValue(OpAsmParser &parser,
27                                           TypeAttr &typeAttr, Attribute &attr) {
28   if (succeeded(parser.parseOptionalLParen())) {
29     if (failed(parser.parseAttribute(attr)))
30       return failure();
31     if (failed(parser.parseRParen()))
32       return failure();
33   }
34 
35   Type type;
36   if (failed(parser.parseColonType(type)))
37     return failure();
38   typeAttr = TypeAttr::get(type);
39   return success();
40 }
41 
42 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
43                                    TypeAttr type, Attribute attr) {
44   if (attr) {
45     p << "(";
46     p.printAttribute(attr);
47     p << ")";
48   }
49 
50   p << " : ";
51   p.printAttribute(type);
52 }
53 
54 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
55 /// ->
56 /// some.op public @foo
57 /// some.op private @foo
58 static ParseResult parseSymbolVisibility(OpAsmParser &parser,
59                                          StringAttr &symVisibilityAttr) {
60   StringRef symVisibility;
61   (void)parser.parseOptionalKeyword(&symVisibility,
62                                     {"public", "private", "nested"});
63   if (symVisibility.empty())
64     return parser.emitError(parser.getCurrentLocation())
65            << "expected 'public', 'private', or 'nested'";
66   if (!symVisibility.empty())
67     symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
68   return success();
69 }
70 
71 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
72                                   StringAttr symVisibilityAttr) {
73   if (!symVisibilityAttr)
74     p << "public";
75   else
76     p << symVisibilityAttr.getValue();
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // TableGen'd op method definitions
81 //===----------------------------------------------------------------------===//
82 
83 #define GET_OP_CLASSES
84 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
85 
86 //===----------------------------------------------------------------------===//
87 // FuncOp
88 //===----------------------------------------------------------------------===//
89 
90 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
91   auto buildFuncType =
92       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
93          function_interface_impl::VariadicFlag,
94          std::string &) { return builder.getFunctionType(argTypes, results); };
95 
96   return function_interface_impl::parseFunctionOp(
97       parser, result, /*allowVariadic=*/false, buildFuncType);
98 }
99 
100 void FuncOp::print(OpAsmPrinter &p) {
101   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // GlobalOp
106 //===----------------------------------------------------------------------===//
107 
108 LogicalResult GlobalOp::verify() {
109   if (!getIsMutable() && !getValue())
110     return emitOpError() << "immutable global must have an initial value";
111   return success();
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // GlobalLoadConstOp
116 //===----------------------------------------------------------------------===//
117 
118 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
119   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
120       getOperation()->getParentOp(), getGlobalAttr());
121 }
122 
123 LogicalResult
124 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
125   GlobalOp referrent = getGlobalOp(symbolTable);
126   if (!referrent)
127     return emitOpError() << "undefined global: " << getGlobal();
128 
129   if (referrent.getIsMutable())
130     return emitOpError() << "cannot load as const from mutable global "
131                          << getGlobal();
132 
133   if (referrent.getType() != getResult().getType())
134     return emitOpError() << "cannot load from global typed "
135                          << referrent.getType() << " as "
136                          << getResult().getType();
137 
138   return success();
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // SubgraphOp
143 //===----------------------------------------------------------------------===//
144 
145 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
146   auto buildFuncType =
147       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
148          function_interface_impl::VariadicFlag,
149          std::string &) { return builder.getFunctionType(argTypes, results); };
150 
151   return function_interface_impl::parseFunctionOp(
152       parser, result, /*allowVariadic=*/false, buildFuncType);
153 }
154 
155 void SubgraphOp::print(OpAsmPrinter &p) {
156   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // OutputOp
161 //===----------------------------------------------------------------------===//
162 
163 LogicalResult OutputOp::verify() {
164   auto function = cast<SubgraphOp>((*this)->getParentOp());
165 
166   // The operand number and types must match the function signature.
167   const auto &results = function.getFunctionType().getResults();
168   if (getNumOperands() != results.size())
169     return emitOpError("has ")
170            << getNumOperands() << " operands, but enclosing function (@"
171            << function.getName() << ") outputs " << results.size();
172 
173   for (unsigned i = 0, e = results.size(); i != e; ++i)
174     if (getOperand(i).getType() != results[i])
175       return emitError() << "type of output operand " << i << " ("
176                          << getOperand(i).getType()
177                          << ") doesn't match function result type ("
178                          << results[i] << ")"
179                          << " in function @" << function.getName();
180 
181   return success();
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // ReturnOp
186 //===----------------------------------------------------------------------===//
187 
188 LogicalResult ReturnOp::verify() {
189   auto function = cast<FuncOp>((*this)->getParentOp());
190 
191   // The operand number and types must match the function signature.
192   const auto &results = function.getFunctionType().getResults();
193   if (getNumOperands() != results.size())
194     return emitOpError("has ")
195            << getNumOperands() << " operands, but enclosing function (@"
196            << function.getName() << ") returns " << results.size();
197 
198   for (unsigned i = 0, e = results.size(); i != e; ++i)
199     if (getOperand(i).getType() != results[i])
200       return emitError() << "type of return operand " << i << " ("
201                          << getOperand(i).getType()
202                          << ") doesn't match function result type ("
203                          << results[i] << ")"
204                          << " in function @" << function.getName();
205 
206   return success();
207 }
208