//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionImplementation.h" using namespace mlir; using namespace mlir::ml_program; //===----------------------------------------------------------------------===// // Custom asm helpers //===----------------------------------------------------------------------===// /// Parse and print an ordering clause for a variadic of consuming tokens /// and an producing token. /// /// Syntax: /// ordering(%0, %1 -> !ml_program.token) /// ordering(() -> !ml_program.token) /// /// If both the consuming and producing token are not present on the op, then /// the clause prints nothing. static ParseResult parseTokenOrdering( OpAsmParser &parser, SmallVectorImpl &consumeTokens, Type &produceTokenType) { if (failed(parser.parseOptionalKeyword("ordering")) || failed(parser.parseLParen())) return success(); // Parse consuming token list. If there are no consuming tokens, the // '()' null list represents this. if (succeeded(parser.parseOptionalLParen())) { if (failed(parser.parseRParen())) return failure(); } else { if (failed(parser.parseOperandList(consumeTokens, /*requiredOperandCount=*/-1))) return failure(); } // Parse producer token. if (failed(parser.parseArrow())) return failure(); if (failed(parser.parseType(produceTokenType))) return failure(); if (failed(parser.parseRParen())) return failure(); return success(); } static void printTokenOrdering(OpAsmPrinter &p, Operation *op, OperandRange consumeTokens, Type produceTokenType) { if (consumeTokens.empty() && !produceTokenType) return; p << " ordering("; if (consumeTokens.empty()) p << "()"; else p.printOperands(consumeTokens); if (produceTokenType) { p << " -> "; p.printType(produceTokenType); } p << ")"; } /// some.op custom($type, $attr) /// /// Uninitialized: /// some.op : tensor<3xi32> /// Initialized to narrower type than op: /// some.op (dense<0> : tensor<3xi32>) : tensor static ParseResult parseTypedInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr) { if (succeeded(parser.parseOptionalLParen())) { if (failed(parser.parseAttribute(attr))) return failure(); if (failed(parser.parseRParen())) return failure(); } Type type; if (failed(parser.parseColonType(type))) return failure(); typeAttr = TypeAttr::get(type); return success(); } static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr) { if (attr) { p << "("; p.printAttribute(attr); p << ")"; } p << " : "; p.printAttribute(type); } /// some.op custom($sym_visibility) $sym_name /// -> /// some.op public @foo /// some.op private @foo static ParseResult parseSymbolVisibility(OpAsmParser &parser, StringAttr &symVisibilityAttr) { StringRef symVisibility; (void)parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"}); if (symVisibility.empty()) return parser.emitError(parser.getCurrentLocation()) << "expected 'public', 'private', or 'nested'"; if (!symVisibility.empty()) symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); return success(); } static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, StringAttr symVisibilityAttr) { if (!symVisibilityAttr) p << "public"; else p << symVisibilityAttr.getValue(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// LogicalResult GlobalOp::verify() { if (!getIsMutable() && !getValue()) return emitOpError() << "immutable global must have an initial value"; return success(); } //===----------------------------------------------------------------------===// // GlobalLoadOp //===----------------------------------------------------------------------===// GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { for (auto *parent = getOperation()->getParentOp(); parent; parent = parent->getParentOp()) { if (auto nearest = symbolTable.lookupNearestSymbolFrom( parent, getGlobalAttr())) { return nearest; } } return {}; } LogicalResult GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { GlobalOp referrent = getGlobalOp(symbolTable); if (!referrent) return emitOpError() << "undefined global: " << getGlobal(); if (referrent.getType() != getResult().getType()) { return emitOpError() << "cannot load from global typed " << referrent.getType() << " as " << getResult().getType(); } return success(); } //===----------------------------------------------------------------------===// // GlobalLoadConstOp //===----------------------------------------------------------------------===// GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { return symbolTable.lookupNearestSymbolFrom( getOperation()->getParentOp(), getGlobalAttr()); } LogicalResult GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { GlobalOp referrent = getGlobalOp(symbolTable); if (!referrent) return emitOpError() << "undefined global: " << getGlobal(); if (referrent.getIsMutable()) return emitOpError() << "cannot load as const from mutable global " << getGlobal(); if (referrent.getType() != getResult().getType()) return emitOpError() << "cannot load from global typed " << referrent.getType() << " as " << getResult().getType(); return success(); } //===----------------------------------------------------------------------===// // GlobalLoadGraphOp //===----------------------------------------------------------------------===// GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { return symbolTable.lookupNearestSymbolFrom( getOperation()->getParentOp(), getGlobalAttr()); } LogicalResult GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { GlobalOp referrent = getGlobalOp(symbolTable); if (!referrent) return emitOpError() << "undefined global: " << getGlobal(); if (referrent.getType() != getResult().getType()) { return emitOpError() << "cannot load from global typed " << referrent.getType() << " as " << getResult().getType(); } return success(); } //===----------------------------------------------------------------------===// // GlobalStoreOp //===----------------------------------------------------------------------===// GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { for (auto *parent = getOperation()->getParentOp(); parent;) { if (auto nearest = symbolTable.lookupNearestSymbolFrom( parent, getGlobalAttr())) { return nearest; } parent = parent->getParentOp(); } return {}; } LogicalResult GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { GlobalOp referrent = getGlobalOp(symbolTable); if (!referrent) return emitOpError() << "undefined global: " << getGlobal(); if (!referrent.getIsMutable()) { return emitOpError() << "cannot store to an immutable global " << getGlobal(); } if (referrent.getType() != getValue().getType()) { return emitOpError() << "cannot store to a global typed " << referrent.getType() << " from " << getValue().getType(); } return success(); } //===----------------------------------------------------------------------===// // GlobalStoreGraphOp //===----------------------------------------------------------------------===// GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { return symbolTable.lookupNearestSymbolFrom( getOperation()->getParentOp(), getGlobalAttr()); } LogicalResult GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { GlobalOp referrent = getGlobalOp(symbolTable); if (!referrent) return emitOpError() << "undefined global: " << getGlobal(); if (!referrent.getIsMutable()) { return emitOpError() << "cannot store to an immutable global " << getGlobal(); } if (referrent.getType() != getValue().getType()) { return emitOpError() << "cannot store to a global typed " << referrent.getType() << " from " << getValue().getType(); } return success(); } //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void SubgraphOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// // OutputOp //===----------------------------------------------------------------------===// LogicalResult OutputOp::verify() { auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") outputs " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of output operand " << i << " (" << getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << function.getName(); return success(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOp::verify() { auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of return operand " << i << " (" << getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << function.getName(); return success(); }