1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/Interfaces/FunctionImplementation.h" 12 13 using namespace mlir; 14 using namespace mlir::ml_program; 15 16 //===----------------------------------------------------------------------===// 17 // Custom asm helpers 18 //===----------------------------------------------------------------------===// 19 20 /// Parse and print an ordering clause for a variadic of consuming tokens 21 /// and an producing token. 22 /// 23 /// Syntax: 24 /// ordering(%0, %1 -> !ml_program.token) 25 /// ordering(() -> !ml_program.token) 26 /// 27 /// If both the consuming and producing token are not present on the op, then 28 /// the clause prints nothing. 29 static ParseResult parseTokenOrdering( 30 OpAsmParser &parser, 31 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens, 32 Type &produceTokenType) { 33 if (failed(parser.parseOptionalKeyword("ordering")) || 34 failed(parser.parseLParen())) 35 return success(); 36 37 // Parse consuming token list. If there are no consuming tokens, the 38 // '()' null list represents this. 39 if (succeeded(parser.parseOptionalLParen())) { 40 if (failed(parser.parseRParen())) 41 return failure(); 42 } else { 43 if (failed(parser.parseOperandList(consumeTokens, 44 /*requiredOperandCount=*/-1))) 45 return failure(); 46 } 47 48 // Parse producer token. 49 if (failed(parser.parseArrow())) 50 return failure(); 51 if (failed(parser.parseType(produceTokenType))) 52 return failure(); 53 54 if (failed(parser.parseRParen())) 55 return failure(); 56 57 return success(); 58 } 59 60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op, 61 OperandRange consumeTokens, 62 Type produceTokenType) { 63 if (consumeTokens.empty() && !produceTokenType) 64 return; 65 66 p << " ordering("; 67 if (consumeTokens.empty()) 68 p << "()"; 69 else 70 p.printOperands(consumeTokens); 71 if (produceTokenType) { 72 p << " -> "; 73 p.printType(produceTokenType); 74 } 75 p << ")"; 76 } 77 78 /// some.op custom<TypeOrAttr>($type, $attr) 79 /// 80 /// Uninitialized: 81 /// some.op : tensor<3xi32> 82 /// Initialized to narrower type than op: 83 /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32> 84 static ParseResult parseTypedInitialValue(OpAsmParser &parser, 85 TypeAttr &typeAttr, Attribute &attr) { 86 if (succeeded(parser.parseOptionalLParen())) { 87 if (failed(parser.parseAttribute(attr))) 88 return failure(); 89 if (failed(parser.parseRParen())) 90 return failure(); 91 } 92 93 Type type; 94 if (failed(parser.parseColonType(type))) 95 return failure(); 96 typeAttr = TypeAttr::get(type); 97 return success(); 98 } 99 100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, 101 TypeAttr type, Attribute attr) { 102 if (attr) { 103 p << "("; 104 p.printAttribute(attr); 105 p << ")"; 106 } 107 108 p << " : "; 109 p.printAttribute(type); 110 } 111 112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name 113 /// -> 114 /// some.op public @foo 115 /// some.op private @foo 116 static ParseResult parseSymbolVisibility(OpAsmParser &parser, 117 StringAttr &symVisibilityAttr) { 118 StringRef symVisibility; 119 (void)parser.parseOptionalKeyword(&symVisibility, 120 {"public", "private", "nested"}); 121 if (symVisibility.empty()) 122 return parser.emitError(parser.getCurrentLocation()) 123 << "expected 'public', 'private', or 'nested'"; 124 if (!symVisibility.empty()) 125 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); 126 return success(); 127 } 128 129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, 130 StringAttr symVisibilityAttr) { 131 if (!symVisibilityAttr) 132 p << "public"; 133 else 134 p << symVisibilityAttr.getValue(); 135 } 136 137 //===----------------------------------------------------------------------===// 138 // TableGen'd op method definitions 139 //===----------------------------------------------------------------------===// 140 141 #define GET_OP_CLASSES 142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 143 144 //===----------------------------------------------------------------------===// 145 // FuncOp 146 //===----------------------------------------------------------------------===// 147 148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 149 auto buildFuncType = 150 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 151 function_interface_impl::VariadicFlag, 152 std::string &) { return builder.getFunctionType(argTypes, results); }; 153 154 return function_interface_impl::parseFunctionOp( 155 parser, result, /*allowVariadic=*/false, 156 getFunctionTypeAttrName(result.name), buildFuncType, 157 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 158 } 159 160 void FuncOp::print(OpAsmPrinter &p) { 161 function_interface_impl::printFunctionOp( 162 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 163 getArgAttrsAttrName(), getResAttrsAttrName()); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // GlobalOp 168 //===----------------------------------------------------------------------===// 169 170 LogicalResult GlobalOp::verify() { 171 if (!getIsMutable() && !getValue()) 172 return emitOpError() << "immutable global must have an initial value"; 173 return success(); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // GlobalLoadOp 178 //===----------------------------------------------------------------------===// 179 180 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { 181 for (auto parent = getOperation()->getParentOp(); parent; 182 parent = parent->getParentOp()) { 183 if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>( 184 parent, getGlobalAttr())) { 185 return nearest; 186 } 187 } 188 return {}; 189 } 190 191 LogicalResult 192 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 193 GlobalOp referrent = getGlobalOp(symbolTable); 194 if (!referrent) 195 return emitOpError() << "undefined global: " << getGlobal(); 196 197 if (referrent.getType() != getResult().getType()) { 198 return emitOpError() << "cannot load from global typed " 199 << referrent.getType() << " as " 200 << getResult().getType(); 201 } 202 203 return success(); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // GlobalLoadConstOp 208 //===----------------------------------------------------------------------===// 209 210 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { 211 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 212 getOperation()->getParentOp(), getGlobalAttr()); 213 } 214 215 LogicalResult 216 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 217 GlobalOp referrent = getGlobalOp(symbolTable); 218 if (!referrent) 219 return emitOpError() << "undefined global: " << getGlobal(); 220 221 if (referrent.getIsMutable()) 222 return emitOpError() << "cannot load as const from mutable global " 223 << getGlobal(); 224 225 if (referrent.getType() != getResult().getType()) 226 return emitOpError() << "cannot load from global typed " 227 << referrent.getType() << " as " 228 << getResult().getType(); 229 230 return success(); 231 } 232 233 //===----------------------------------------------------------------------===// 234 // GlobalLoadGraphOp 235 //===----------------------------------------------------------------------===// 236 237 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { 238 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 239 getOperation()->getParentOp(), getGlobalAttr()); 240 } 241 242 LogicalResult 243 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 244 GlobalOp referrent = getGlobalOp(symbolTable); 245 if (!referrent) 246 return emitOpError() << "undefined global: " << getGlobal(); 247 248 if (referrent.getType() != getResult().getType()) { 249 return emitOpError() << "cannot load from global typed " 250 << referrent.getType() << " as " 251 << getResult().getType(); 252 } 253 254 return success(); 255 } 256 257 //===----------------------------------------------------------------------===// 258 // GlobalStoreOp 259 //===----------------------------------------------------------------------===// 260 261 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { 262 for (auto parent = getOperation()->getParentOp(); parent;) { 263 if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>( 264 parent, getGlobalAttr())) { 265 return nearest; 266 } 267 parent = parent->getParentOp(); 268 } 269 return {}; 270 } 271 272 LogicalResult 273 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 274 GlobalOp referrent = getGlobalOp(symbolTable); 275 if (!referrent) 276 return emitOpError() << "undefined global: " << getGlobal(); 277 278 if (!referrent.getIsMutable()) { 279 return emitOpError() << "cannot store to an immutable global " 280 << getGlobal(); 281 } 282 283 if (referrent.getType() != getValue().getType()) { 284 return emitOpError() << "cannot store to a global typed " 285 << referrent.getType() << " from " 286 << getValue().getType(); 287 } 288 289 return success(); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // GlobalStoreGraphOp 294 //===----------------------------------------------------------------------===// 295 296 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { 297 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 298 getOperation()->getParentOp(), getGlobalAttr()); 299 } 300 301 LogicalResult 302 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 303 GlobalOp referrent = getGlobalOp(symbolTable); 304 if (!referrent) 305 return emitOpError() << "undefined global: " << getGlobal(); 306 307 if (!referrent.getIsMutable()) { 308 return emitOpError() << "cannot store to an immutable global " 309 << getGlobal(); 310 } 311 312 if (referrent.getType() != getValue().getType()) { 313 return emitOpError() << "cannot store to a global typed " 314 << referrent.getType() << " from " 315 << getValue().getType(); 316 } 317 318 return success(); 319 } 320 321 //===----------------------------------------------------------------------===// 322 // SubgraphOp 323 //===----------------------------------------------------------------------===// 324 325 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 326 auto buildFuncType = 327 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 328 function_interface_impl::VariadicFlag, 329 std::string &) { return builder.getFunctionType(argTypes, results); }; 330 331 return function_interface_impl::parseFunctionOp( 332 parser, result, /*allowVariadic=*/false, 333 getFunctionTypeAttrName(result.name), buildFuncType, 334 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 335 } 336 337 void SubgraphOp::print(OpAsmPrinter &p) { 338 function_interface_impl::printFunctionOp( 339 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 340 getArgAttrsAttrName(), getResAttrsAttrName()); 341 } 342 343 //===----------------------------------------------------------------------===// 344 // OutputOp 345 //===----------------------------------------------------------------------===// 346 347 LogicalResult OutputOp::verify() { 348 auto function = cast<SubgraphOp>((*this)->getParentOp()); 349 350 // The operand number and types must match the function signature. 351 const auto &results = function.getFunctionType().getResults(); 352 if (getNumOperands() != results.size()) 353 return emitOpError("has ") 354 << getNumOperands() << " operands, but enclosing function (@" 355 << function.getName() << ") outputs " << results.size(); 356 357 for (unsigned i = 0, e = results.size(); i != e; ++i) 358 if (getOperand(i).getType() != results[i]) 359 return emitError() << "type of output operand " << i << " (" 360 << getOperand(i).getType() 361 << ") doesn't match function result type (" 362 << results[i] << ")" 363 << " in function @" << function.getName(); 364 365 return success(); 366 } 367 368 //===----------------------------------------------------------------------===// 369 // ReturnOp 370 //===----------------------------------------------------------------------===// 371 372 LogicalResult ReturnOp::verify() { 373 auto function = cast<FuncOp>((*this)->getParentOp()); 374 375 // The operand number and types must match the function signature. 376 const auto &results = function.getFunctionType().getResults(); 377 if (getNumOperands() != results.size()) 378 return emitOpError("has ") 379 << getNumOperands() << " operands, but enclosing function (@" 380 << function.getName() << ") returns " << results.size(); 381 382 for (unsigned i = 0, e = results.size(); i != e; ++i) 383 if (getOperand(i).getType() != results[i]) 384 return emitError() << "type of return operand " << i << " (" 385 << getOperand(i).getType() 386 << ") doesn't match function result type (" 387 << results[i] << ")" 388 << " in function @" << function.getName(); 389 390 return success(); 391 } 392