1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/FunctionImplementation.h" 12 13 using namespace mlir; 14 using namespace mlir::ml_program; 15 16 //===----------------------------------------------------------------------===// 17 // Custom asm helpers 18 //===----------------------------------------------------------------------===// 19 20 /// Parse and print an ordering clause for a variadic of consuming tokens 21 /// and an producing token. 22 /// 23 /// Syntax: 24 /// ordering(%0, %1 -> !ml_program.token) 25 /// ordering(() -> !ml_program.token) 26 /// 27 /// If both the consuming and producing token are not present on the op, then 28 /// the clause prints nothing. 29 static ParseResult parseTokenOrdering( 30 OpAsmParser &parser, 31 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens, 32 Type &produceTokenType) { 33 if (failed(parser.parseOptionalKeyword("ordering")) || 34 failed(parser.parseLParen())) 35 return success(); 36 37 // Parse consuming token list. If there are no consuming tokens, the 38 // '()' null list represents this. 39 if (succeeded(parser.parseOptionalLParen())) { 40 if (failed(parser.parseRParen())) 41 return failure(); 42 } else { 43 if (failed(parser.parseOperandList(consumeTokens, 44 /*requiredOperandCount=*/-1))) 45 return failure(); 46 } 47 48 // Parse producer token. 49 if (failed(parser.parseArrow())) 50 return failure(); 51 if (failed(parser.parseType(produceTokenType))) 52 return failure(); 53 54 if (failed(parser.parseRParen())) 55 return failure(); 56 57 return success(); 58 } 59 60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op, 61 OperandRange consumeTokens, 62 Type produceTokenType) { 63 if (consumeTokens.empty() && !produceTokenType) 64 return; 65 66 p << " ordering("; 67 if (consumeTokens.empty()) 68 p << "()"; 69 else 70 p.printOperands(consumeTokens); 71 if (produceTokenType) { 72 p << " -> "; 73 p.printType(produceTokenType); 74 } 75 p << ")"; 76 } 77 78 /// some.op custom<TypeOrAttr>($type, $attr) 79 /// 80 /// Uninitialized: 81 /// some.op : tensor<3xi32> 82 /// Initialized to narrower type than op: 83 /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32> 84 static ParseResult parseTypedInitialValue(OpAsmParser &parser, 85 TypeAttr &typeAttr, Attribute &attr) { 86 if (succeeded(parser.parseOptionalLParen())) { 87 if (failed(parser.parseAttribute(attr))) 88 return failure(); 89 if (failed(parser.parseRParen())) 90 return failure(); 91 } 92 93 Type type; 94 if (failed(parser.parseColonType(type))) 95 return failure(); 96 typeAttr = TypeAttr::get(type); 97 return success(); 98 } 99 100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, 101 TypeAttr type, Attribute attr) { 102 if (attr) { 103 p << "("; 104 p.printAttribute(attr); 105 p << ")"; 106 } 107 108 p << " : "; 109 p.printAttribute(type); 110 } 111 112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name 113 /// -> 114 /// some.op public @foo 115 /// some.op private @foo 116 static ParseResult parseSymbolVisibility(OpAsmParser &parser, 117 StringAttr &symVisibilityAttr) { 118 StringRef symVisibility; 119 (void)parser.parseOptionalKeyword(&symVisibility, 120 {"public", "private", "nested"}); 121 if (symVisibility.empty()) 122 return parser.emitError(parser.getCurrentLocation()) 123 << "expected 'public', 'private', or 'nested'"; 124 if (!symVisibility.empty()) 125 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); 126 return success(); 127 } 128 129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, 130 StringAttr symVisibilityAttr) { 131 if (!symVisibilityAttr) 132 p << "public"; 133 else 134 p << symVisibilityAttr.getValue(); 135 } 136 137 //===----------------------------------------------------------------------===// 138 // TableGen'd op method definitions 139 //===----------------------------------------------------------------------===// 140 141 #define GET_OP_CLASSES 142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 143 144 //===----------------------------------------------------------------------===// 145 // FuncOp 146 //===----------------------------------------------------------------------===// 147 148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 149 auto buildFuncType = 150 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 151 function_interface_impl::VariadicFlag, 152 std::string &) { return builder.getFunctionType(argTypes, results); }; 153 154 return function_interface_impl::parseFunctionOp( 155 parser, result, /*allowVariadic=*/false, 156 getFunctionTypeAttrName(result.name), buildFuncType, 157 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 158 } 159 160 void FuncOp::print(OpAsmPrinter &p) { 161 function_interface_impl::printFunctionOp( 162 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 163 getArgAttrsAttrName(), getResAttrsAttrName()); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // GlobalOp 168 //===----------------------------------------------------------------------===// 169 170 LogicalResult GlobalOp::verify() { 171 if (!getIsMutable() && !getValue()) 172 return emitOpError() << "immutable global must have an initial value"; 173 return success(); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // GlobalLoadOp 178 //===----------------------------------------------------------------------===// 179 180 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { 181 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 182 getOperation()->getParentOp(), getGlobalAttr()); 183 } 184 185 LogicalResult 186 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 187 GlobalOp referrent = getGlobalOp(symbolTable); 188 if (!referrent) 189 return emitOpError() << "undefined global: " << getGlobal(); 190 191 if (referrent.getType() != getResult().getType()) { 192 return emitOpError() << "cannot load from global typed " 193 << referrent.getType() << " as " 194 << getResult().getType(); 195 } 196 197 return success(); 198 } 199 200 //===----------------------------------------------------------------------===// 201 // GlobalLoadConstOp 202 //===----------------------------------------------------------------------===// 203 204 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { 205 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 206 getOperation()->getParentOp(), getGlobalAttr()); 207 } 208 209 LogicalResult 210 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 211 GlobalOp referrent = getGlobalOp(symbolTable); 212 if (!referrent) 213 return emitOpError() << "undefined global: " << getGlobal(); 214 215 if (referrent.getIsMutable()) 216 return emitOpError() << "cannot load as const from mutable global " 217 << getGlobal(); 218 219 if (referrent.getType() != getResult().getType()) 220 return emitOpError() << "cannot load from global typed " 221 << referrent.getType() << " as " 222 << getResult().getType(); 223 224 return success(); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // GlobalLoadGraphOp 229 //===----------------------------------------------------------------------===// 230 231 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { 232 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 233 getOperation()->getParentOp(), getGlobalAttr()); 234 } 235 236 LogicalResult 237 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 238 GlobalOp referrent = getGlobalOp(symbolTable); 239 if (!referrent) 240 return emitOpError() << "undefined global: " << getGlobal(); 241 242 if (referrent.getType() != getResult().getType()) { 243 return emitOpError() << "cannot load from global typed " 244 << referrent.getType() << " as " 245 << getResult().getType(); 246 } 247 248 return success(); 249 } 250 251 //===----------------------------------------------------------------------===// 252 // GlobalStoreOp 253 //===----------------------------------------------------------------------===// 254 255 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { 256 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 257 getOperation()->getParentOp(), getGlobalAttr()); 258 } 259 260 LogicalResult 261 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 262 GlobalOp referrent = getGlobalOp(symbolTable); 263 if (!referrent) 264 return emitOpError() << "undefined global: " << getGlobal(); 265 266 if (!referrent.getIsMutable()) { 267 return emitOpError() << "cannot store to an immutable global " 268 << getGlobal(); 269 } 270 271 if (referrent.getType() != getValue().getType()) { 272 return emitOpError() << "cannot store to a global typed " 273 << referrent.getType() << " from " 274 << getValue().getType(); 275 } 276 277 return success(); 278 } 279 280 //===----------------------------------------------------------------------===// 281 // GlobalStoreGraphOp 282 //===----------------------------------------------------------------------===// 283 284 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { 285 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 286 getOperation()->getParentOp(), getGlobalAttr()); 287 } 288 289 LogicalResult 290 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 291 GlobalOp referrent = getGlobalOp(symbolTable); 292 if (!referrent) 293 return emitOpError() << "undefined global: " << getGlobal(); 294 295 if (!referrent.getIsMutable()) { 296 return emitOpError() << "cannot store to an immutable global " 297 << getGlobal(); 298 } 299 300 if (referrent.getType() != getValue().getType()) { 301 return emitOpError() << "cannot store to a global typed " 302 << referrent.getType() << " from " 303 << getValue().getType(); 304 } 305 306 return success(); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // SubgraphOp 311 //===----------------------------------------------------------------------===// 312 313 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 314 auto buildFuncType = 315 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 316 function_interface_impl::VariadicFlag, 317 std::string &) { return builder.getFunctionType(argTypes, results); }; 318 319 return function_interface_impl::parseFunctionOp( 320 parser, result, /*allowVariadic=*/false, 321 getFunctionTypeAttrName(result.name), buildFuncType, 322 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 323 } 324 325 void SubgraphOp::print(OpAsmPrinter &p) { 326 function_interface_impl::printFunctionOp( 327 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 328 getArgAttrsAttrName(), getResAttrsAttrName()); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // OutputOp 333 //===----------------------------------------------------------------------===// 334 335 LogicalResult OutputOp::verify() { 336 auto function = cast<SubgraphOp>((*this)->getParentOp()); 337 338 // The operand number and types must match the function signature. 339 const auto &results = function.getFunctionType().getResults(); 340 if (getNumOperands() != results.size()) 341 return emitOpError("has ") 342 << getNumOperands() << " operands, but enclosing function (@" 343 << function.getName() << ") outputs " << results.size(); 344 345 for (unsigned i = 0, e = results.size(); i != e; ++i) 346 if (getOperand(i).getType() != results[i]) 347 return emitError() << "type of output operand " << i << " (" 348 << getOperand(i).getType() 349 << ") doesn't match function result type (" 350 << results[i] << ")" 351 << " in function @" << function.getName(); 352 353 return success(); 354 } 355 356 //===----------------------------------------------------------------------===// 357 // ReturnOp 358 //===----------------------------------------------------------------------===// 359 360 LogicalResult ReturnOp::verify() { 361 auto function = cast<FuncOp>((*this)->getParentOp()); 362 363 // The operand number and types must match the function signature. 364 const auto &results = function.getFunctionType().getResults(); 365 if (getNumOperands() != results.size()) 366 return emitOpError("has ") 367 << getNumOperands() << " operands, but enclosing function (@" 368 << function.getName() << ") returns " << results.size(); 369 370 for (unsigned i = 0, e = results.size(); i != e; ++i) 371 if (getOperand(i).getType() != results[i]) 372 return emitError() << "type of return operand " << i << " (" 373 << getOperand(i).getType() 374 << ") doesn't match function result type (" 375 << results[i] << ")" 376 << " in function @" << function.getName(); 377 378 return success(); 379 } 380