1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/FunctionImplementation.h" 12 13 using namespace mlir; 14 using namespace mlir::ml_program; 15 16 //===----------------------------------------------------------------------===// 17 // Custom asm helpers 18 //===----------------------------------------------------------------------===// 19 20 /// some.op custom<TypeOrAttr>($type, $attr) 21 /// 22 /// Uninitialized: 23 /// some.op : tensor<3xi32> 24 /// Initialized to narrower type than op: 25 /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32> 26 static ParseResult parseTypedInitialValue(OpAsmParser &parser, 27 TypeAttr &typeAttr, Attribute &attr) { 28 if (succeeded(parser.parseOptionalLParen())) { 29 if (failed(parser.parseAttribute(attr))) 30 return failure(); 31 if (failed(parser.parseRParen())) 32 return failure(); 33 } 34 35 Type type; 36 if (failed(parser.parseColonType(type))) 37 return failure(); 38 typeAttr = TypeAttr::get(type); 39 return success(); 40 } 41 42 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, 43 TypeAttr type, Attribute attr) { 44 if (attr) { 45 p << "("; 46 p.printAttribute(attr); 47 p << ")"; 48 } 49 50 p << " : "; 51 p.printAttribute(type); 52 } 53 54 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name 55 /// -> 56 /// some.op public @foo 57 /// some.op private @foo 58 static ParseResult parseSymbolVisibility(OpAsmParser &parser, 59 StringAttr &symVisibilityAttr) { 60 StringRef symVisibility; 61 (void)parser.parseOptionalKeyword(&symVisibility, 62 {"public", "private", "nested"}); 63 if (symVisibility.empty()) 64 return parser.emitError(parser.getCurrentLocation()) 65 << "expected 'public', 'private', or 'nested'"; 66 if (!symVisibility.empty()) 67 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); 68 return success(); 69 } 70 71 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, 72 StringAttr symVisibilityAttr) { 73 if (!symVisibilityAttr) 74 p << "public"; 75 else 76 p << symVisibilityAttr.getValue(); 77 } 78 79 //===----------------------------------------------------------------------===// 80 // TableGen'd op method definitions 81 //===----------------------------------------------------------------------===// 82 83 #define GET_OP_CLASSES 84 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 85 86 //===----------------------------------------------------------------------===// 87 // FuncOp 88 //===----------------------------------------------------------------------===// 89 90 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 91 auto buildFuncType = 92 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 93 function_interface_impl::VariadicFlag, 94 std::string &) { return builder.getFunctionType(argTypes, results); }; 95 96 return function_interface_impl::parseFunctionOp( 97 parser, result, /*allowVariadic=*/false, buildFuncType); 98 } 99 100 void FuncOp::print(OpAsmPrinter &p) { 101 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 102 } 103 104 //===----------------------------------------------------------------------===// 105 // GlobalOp 106 //===----------------------------------------------------------------------===// 107 108 LogicalResult GlobalOp::verify() { 109 if (!getIsMutable() && !getValue()) 110 return emitOpError() << "immutable global must have an initial value"; 111 return success(); 112 } 113 114 //===----------------------------------------------------------------------===// 115 // GlobalLoadConstOp 116 //===----------------------------------------------------------------------===// 117 118 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { 119 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 120 getOperation()->getParentOp(), getGlobalAttr()); 121 } 122 123 LogicalResult 124 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 125 GlobalOp referrent = getGlobalOp(symbolTable); 126 if (!referrent) 127 return emitOpError() << "undefined global: " << getGlobal(); 128 129 if (referrent.getIsMutable()) 130 return emitOpError() << "cannot load as const from mutable global " 131 << getGlobal(); 132 133 if (referrent.getType() != getResult().getType()) 134 return emitOpError() << "cannot load from global typed " 135 << referrent.getType() << " as " 136 << getResult().getType(); 137 138 return success(); 139 } 140 141 //===----------------------------------------------------------------------===// 142 // SubgraphOp 143 //===----------------------------------------------------------------------===// 144 145 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 146 auto buildFuncType = 147 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 148 function_interface_impl::VariadicFlag, 149 std::string &) { return builder.getFunctionType(argTypes, results); }; 150 151 return function_interface_impl::parseFunctionOp( 152 parser, result, /*allowVariadic=*/false, buildFuncType); 153 } 154 155 void SubgraphOp::print(OpAsmPrinter &p) { 156 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 157 } 158 159 //===----------------------------------------------------------------------===// 160 // OutputOp 161 //===----------------------------------------------------------------------===// 162 163 LogicalResult OutputOp::verify() { 164 auto function = cast<SubgraphOp>((*this)->getParentOp()); 165 166 // The operand number and types must match the function signature. 167 const auto &results = function.getFunctionType().getResults(); 168 if (getNumOperands() != results.size()) 169 return emitOpError("has ") 170 << getNumOperands() << " operands, but enclosing function (@" 171 << function.getName() << ") outputs " << results.size(); 172 173 for (unsigned i = 0, e = results.size(); i != e; ++i) 174 if (getOperand(i).getType() != results[i]) 175 return emitError() << "type of output operand " << i << " (" 176 << getOperand(i).getType() 177 << ") doesn't match function result type (" 178 << results[i] << ")" 179 << " in function @" << function.getName(); 180 181 return success(); 182 } 183 184 //===----------------------------------------------------------------------===// 185 // ReturnOp 186 //===----------------------------------------------------------------------===// 187 188 LogicalResult ReturnOp::verify() { 189 auto function = cast<FuncOp>((*this)->getParentOp()); 190 191 // The operand number and types must match the function signature. 192 const auto &results = function.getFunctionType().getResults(); 193 if (getNumOperands() != results.size()) 194 return emitOpError("has ") 195 << getNumOperands() << " operands, but enclosing function (@" 196 << function.getName() << ") returns " << results.size(); 197 198 for (unsigned i = 0, e = results.size(); i != e; ++i) 199 if (getOperand(i).getType() != results[i]) 200 return emitError() << "type of return operand " << i << " (" 201 << getOperand(i).getType() 202 << ") doesn't match function result type (" 203 << results[i] << ")" 204 << " in function @" << function.getName(); 205 206 return success(); 207 } 208