105a3b4feSEugene Zhulenev //===- Async.cpp - MLIR Async Operations ----------------------------------===//
205a3b4feSEugene Zhulenev //
305a3b4feSEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
405a3b4feSEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
505a3b4feSEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
605a3b4feSEugene Zhulenev //
705a3b4feSEugene Zhulenev //===----------------------------------------------------------------------===//
805a3b4feSEugene Zhulenev
905a3b4feSEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
1005a3b4feSEugene Zhulenev
1105a3b4feSEugene Zhulenev #include "mlir/IR/DialectImplementation.h"
124d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
1334a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionImplementation.h"
14145d2a50Syijiagu #include "llvm/ADT/MapVector.h"
1505a3b4feSEugene Zhulenev #include "llvm/ADT/TypeSwitch.h"
1605a3b4feSEugene Zhulenev
174e69a529SEugene Zhulenev using namespace mlir;
184e69a529SEugene Zhulenev using namespace mlir::async;
1905a3b4feSEugene Zhulenev
20485cc55eSStella Laurenzo #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
21485cc55eSStella Laurenzo
229a5bc836Sbakhtiyar constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
239a5bc836Sbakhtiyar
initialize()2405a3b4feSEugene Zhulenev void AsyncDialect::initialize() {
2505a3b4feSEugene Zhulenev addOperations<
2605a3b4feSEugene Zhulenev #define GET_OP_LIST
2705a3b4feSEugene Zhulenev #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
2805a3b4feSEugene Zhulenev >();
292f7baffdSEugene Zhulenev addTypes<
302f7baffdSEugene Zhulenev #define GET_TYPEDEF_LIST
312f7baffdSEugene Zhulenev #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
322f7baffdSEugene Zhulenev >();
3305a3b4feSEugene Zhulenev }
3405a3b4feSEugene Zhulenev
35655af658SEugene Zhulenev //===----------------------------------------------------------------------===//
36655af658SEugene Zhulenev /// ExecuteOp
37655af658SEugene Zhulenev //===----------------------------------------------------------------------===//
38655af658SEugene Zhulenev
39363b6559SMehdi Amini constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
404e69a529SEugene Zhulenev
getEntrySuccessorOperands(RegionBranchPoint point)414dd744acSMarkus Böck OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
424dd744acSMarkus Böck assert(point == getBodyRegion() && "invalid region index");
43a5aa7836SRiver Riddle return getBodyOperands();
449775c0c9SVladislav Vinogradov }
459775c0c9SVladislav Vinogradov
areTypesCompatible(Type lhs,Type rhs)46e7c7b16aSMogball bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
47e7c7b16aSMogball const auto getValueOrTokenType = [](Type type) {
48c1fa60b4STres Popp if (auto value = llvm::dyn_cast<ValueType>(type))
49e7c7b16aSMogball return value.getValueType();
50e7c7b16aSMogball return type;
51e7c7b16aSMogball };
52e7c7b16aSMogball return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
53e7c7b16aSMogball }
54e7c7b16aSMogball
getSuccessorRegions(RegionBranchPoint point,SmallVectorImpl<RegionSuccessor> & regions)554dd744acSMarkus Böck void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
56bb0d5f76SEugene Zhulenev SmallVectorImpl<RegionSuccessor> ®ions) {
57bb0d5f76SEugene Zhulenev // The `body` region branch back to the parent operation.
584dd744acSMarkus Böck if (point == getBodyRegion()) {
59a5aa7836SRiver Riddle regions.push_back(RegionSuccessor(getBodyResults()));
60bb0d5f76SEugene Zhulenev return;
61bb0d5f76SEugene Zhulenev }
62bb0d5f76SEugene Zhulenev
63bb0d5f76SEugene Zhulenev // Otherwise the successor is the body region.
64986b5c56SRiver Riddle regions.push_back(
65a5aa7836SRiver Riddle RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
66bb0d5f76SEugene Zhulenev }
67bb0d5f76SEugene Zhulenev
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)68c30ab6c2SEugene Zhulenev void ExecuteOp::build(OpBuilder &builder, OperationState &result,
69c30ab6c2SEugene Zhulenev TypeRange resultTypes, ValueRange dependencies,
70c30ab6c2SEugene Zhulenev ValueRange operands, BodyBuilderFn bodyBuilder) {
71*91d5653eSMatthias Springer OpBuilder::InsertionGuard guard(builder);
72c30ab6c2SEugene Zhulenev result.addOperands(dependencies);
73c30ab6c2SEugene Zhulenev result.addOperands(operands);
74c30ab6c2SEugene Zhulenev
75363b6559SMehdi Amini // Add derived `operandSegmentSizes` attribute based on parsed operands.
76c30ab6c2SEugene Zhulenev int32_t numDependencies = dependencies.size();
77c30ab6c2SEugene Zhulenev int32_t numOperands = operands.size();
7858a47508SJeff Niu auto operandSegmentSizes =
7958a47508SJeff Niu builder.getDenseI32ArrayAttr({numDependencies, numOperands});
80c30ab6c2SEugene Zhulenev result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
81c30ab6c2SEugene Zhulenev
82c30ab6c2SEugene Zhulenev // First result is always a token, and then `resultTypes` wrapped into
83c30ab6c2SEugene Zhulenev // `async.value`.
84c30ab6c2SEugene Zhulenev result.addTypes({TokenType::get(result.getContext())});
85c30ab6c2SEugene Zhulenev for (Type type : resultTypes)
86c30ab6c2SEugene Zhulenev result.addTypes(ValueType::get(type));
87c30ab6c2SEugene Zhulenev
88c30ab6c2SEugene Zhulenev // Add a body region with block arguments as unwrapped async value operands.
89c30ab6c2SEugene Zhulenev Region *bodyRegion = result.addRegion();
90*91d5653eSMatthias Springer Block *bodyBlock = builder.createBlock(bodyRegion);
91c30ab6c2SEugene Zhulenev for (Value operand : operands) {
92c1fa60b4STres Popp auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
93*91d5653eSMatthias Springer bodyBlock->addArgument(valueType ? valueType.getValueType()
94e084679fSRiver Riddle : operand.getType(),
95e084679fSRiver Riddle operand.getLoc());
96c30ab6c2SEugene Zhulenev }
97c30ab6c2SEugene Zhulenev
98c30ab6c2SEugene Zhulenev // Create the default terminator if the builder is not provided and if the
99c30ab6c2SEugene Zhulenev // expected result is empty. Otherwise, leave this to the caller
100c30ab6c2SEugene Zhulenev // because we don't know which values to return from the execute op.
101c30ab6c2SEugene Zhulenev if (resultTypes.empty() && !bodyBuilder) {
102c30ab6c2SEugene Zhulenev builder.create<async::YieldOp>(result.location, ValueRange());
103c30ab6c2SEugene Zhulenev } else if (bodyBuilder) {
104*91d5653eSMatthias Springer bodyBuilder(builder, result.location, bodyBlock->getArguments());
105c30ab6c2SEugene Zhulenev }
106c30ab6c2SEugene Zhulenev }
107c30ab6c2SEugene Zhulenev
print(OpAsmPrinter & p)1082418cd92SRiver Riddle void ExecuteOp::print(OpAsmPrinter &p) {
1094e69a529SEugene Zhulenev // [%tokens,...]
110a5aa7836SRiver Riddle if (!getDependencies().empty())
111a5aa7836SRiver Riddle p << " [" << getDependencies() << "]";
1124e69a529SEugene Zhulenev
1134e69a529SEugene Zhulenev // (%value as %unwrapped: !async.value<!arg.type>, ...)
114a5aa7836SRiver Riddle if (!getBodyOperands().empty()) {
1154e69a529SEugene Zhulenev p << " (";
116a5aa7836SRiver Riddle Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
117a5aa7836SRiver Riddle llvm::interleaveComma(
118a5aa7836SRiver Riddle getBodyOperands(), p, [&, n = 0](Value operand) mutable {
119674dd9d0SChristian Sigg Value argument = entry ? entry->getArgument(n++) : Value();
120674dd9d0SChristian Sigg p << operand << " as " << argument << ": " << operand.getType();
121655af658SEugene Zhulenev });
1224e69a529SEugene Zhulenev p << ")";
1234e69a529SEugene Zhulenev }
1244e69a529SEugene Zhulenev
1254e69a529SEugene Zhulenev // -> (!async.value<!return.type>, ...)
1262418cd92SRiver Riddle p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
1272418cd92SRiver Riddle p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
1288c074cb0SChristian Sigg {kOperandSegmentSizesAttr});
1295c36ee8dSMogball p << ' ';
130a5aa7836SRiver Riddle p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
131655af658SEugene Zhulenev }
132655af658SEugene Zhulenev
parse(OpAsmParser & parser,OperationState & result)1332418cd92SRiver Riddle ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
134655af658SEugene Zhulenev MLIRContext *ctx = result.getContext();
135655af658SEugene Zhulenev
1364e69a529SEugene Zhulenev // Sizes of parsed variadic operands, will be updated below after parsing.
1374e69a529SEugene Zhulenev int32_t numDependencies = 0;
1384e69a529SEugene Zhulenev
1394e69a529SEugene Zhulenev auto tokenTy = TokenType::get(ctx);
1404e69a529SEugene Zhulenev
1414e69a529SEugene Zhulenev // Parse dependency tokens.
1424e69a529SEugene Zhulenev if (succeeded(parser.parseOptionalLSquare())) {
143e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
1444e69a529SEugene Zhulenev if (parser.parseOperandList(tokenArgs) ||
1454e69a529SEugene Zhulenev parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
1464e69a529SEugene Zhulenev parser.parseRSquare())
147655af658SEugene Zhulenev return failure();
148655af658SEugene Zhulenev
1494e69a529SEugene Zhulenev numDependencies = tokenArgs.size();
1504e69a529SEugene Zhulenev }
1514e69a529SEugene Zhulenev
1524e69a529SEugene Zhulenev // Parse async value operands (%value as %unwrapped : !async.value<!type>).
153e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
154d85eb4e2SChris Lattner SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
1554e69a529SEugene Zhulenev SmallVector<Type, 4> valueTypes;
1564e69a529SEugene Zhulenev
1574e69a529SEugene Zhulenev // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
1584e69a529SEugene Zhulenev auto parseAsyncValueArg = [&]() -> ParseResult {
1594e69a529SEugene Zhulenev if (parser.parseOperand(valueArgs.emplace_back()) ||
1604e69a529SEugene Zhulenev parser.parseKeyword("as") ||
161d85eb4e2SChris Lattner parser.parseArgument(unwrappedArgs.emplace_back()) ||
1624e69a529SEugene Zhulenev parser.parseColonType(valueTypes.emplace_back()))
1634e69a529SEugene Zhulenev return failure();
1644e69a529SEugene Zhulenev
165c1fa60b4STres Popp auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
166d85eb4e2SChris Lattner unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
1674e69a529SEugene Zhulenev return success();
1684e69a529SEugene Zhulenev };
1694e69a529SEugene Zhulenev
17058abc8c3SChris Lattner auto argsLoc = parser.getCurrentLocation();
17158abc8c3SChris Lattner if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
17258abc8c3SChris Lattner parseAsyncValueArg) ||
17358abc8c3SChris Lattner parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
1744e69a529SEugene Zhulenev return failure();
1754e69a529SEugene Zhulenev
17658abc8c3SChris Lattner int32_t numOperands = valueArgs.size();
1774e69a529SEugene Zhulenev
178363b6559SMehdi Amini // Add derived `operandSegmentSizes` attribute based on parsed operands.
17958a47508SJeff Niu auto operandSegmentSizes =
18058a47508SJeff Niu parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands});
1814e69a529SEugene Zhulenev result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
1824e69a529SEugene Zhulenev
1834e69a529SEugene Zhulenev // Parse the types of results returned from the async execute op.
1844e69a529SEugene Zhulenev SmallVector<Type, 4> resultTypes;
185655af658SEugene Zhulenev NamedAttrList attrs;
1861d7b5cd5SChris Lattner if (parser.parseOptionalArrowTypeList(resultTypes) ||
1871d7b5cd5SChris Lattner // Async execute first result is always a completion token.
1881d7b5cd5SChris Lattner parser.addTypeToList(tokenTy, result.types) ||
1891d7b5cd5SChris Lattner parser.addTypesToList(resultTypes, result.types) ||
1901d7b5cd5SChris Lattner // Parse operation attributes.
1911d7b5cd5SChris Lattner parser.parseOptionalAttrDictWithKeyword(attrs))
192655af658SEugene Zhulenev return failure();
1931d7b5cd5SChris Lattner
194655af658SEugene Zhulenev result.addAttributes(attrs);
195655af658SEugene Zhulenev
1964e69a529SEugene Zhulenev // Parse asynchronous region.
1974e69a529SEugene Zhulenev Region *body = result.addRegion();
198d85eb4e2SChris Lattner return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
199655af658SEugene Zhulenev }
200655af658SEugene Zhulenev
verifyRegions()201ed645f63SChia-hung Duan LogicalResult ExecuteOp::verifyRegions() {
2024e69a529SEugene Zhulenev // Unwrap async.execute value operands types.
203a5aa7836SRiver Riddle auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
204c1fa60b4STres Popp return llvm::cast<ValueType>(operand.getType()).getValueType();
2054e69a529SEugene Zhulenev });
2064e69a529SEugene Zhulenev
2074e69a529SEugene Zhulenev // Verify that unwrapped argument types matches the body region arguments.
208a5aa7836SRiver Riddle if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
2091be88f5aSRiver Riddle return emitOpError("async body region argument types do not match the "
2104e69a529SEugene Zhulenev "execute operation arguments types");
2114e69a529SEugene Zhulenev
2124e69a529SEugene Zhulenev return success();
2134e69a529SEugene Zhulenev }
214655af658SEugene Zhulenev
21561dce0f3SEugene Zhulenev //===----------------------------------------------------------------------===//
216a8f819c6SEugene Zhulenev /// CreateGroupOp
217a8f819c6SEugene Zhulenev //===----------------------------------------------------------------------===//
218a8f819c6SEugene Zhulenev
canonicalize(CreateGroupOp op,PatternRewriter & rewriter)219a8f819c6SEugene Zhulenev LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
220a8f819c6SEugene Zhulenev PatternRewriter &rewriter) {
221a8f819c6SEugene Zhulenev // Find all `await_all` users of the group.
222a8f819c6SEugene Zhulenev llvm::SmallVector<AwaitAllOp> awaitAllUsers;
223a8f819c6SEugene Zhulenev
224a8f819c6SEugene Zhulenev auto isAwaitAll = [&](Operation *op) -> bool {
225a8f819c6SEugene Zhulenev if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
226a8f819c6SEugene Zhulenev awaitAllUsers.push_back(awaitAll);
227a8f819c6SEugene Zhulenev return true;
228a8f819c6SEugene Zhulenev }
229a8f819c6SEugene Zhulenev return false;
230a8f819c6SEugene Zhulenev };
231a8f819c6SEugene Zhulenev
232a8f819c6SEugene Zhulenev // Check if all users of the group are `await_all` operations.
233a8f819c6SEugene Zhulenev if (!llvm::all_of(op->getUsers(), isAwaitAll))
234a8f819c6SEugene Zhulenev return failure();
235a8f819c6SEugene Zhulenev
236a8f819c6SEugene Zhulenev // If group is only awaited without adding anything to it, we can safely erase
237a8f819c6SEugene Zhulenev // the create operation and all users.
238a8f819c6SEugene Zhulenev for (AwaitAllOp awaitAll : awaitAllUsers)
239a8f819c6SEugene Zhulenev rewriter.eraseOp(awaitAll);
240a8f819c6SEugene Zhulenev rewriter.eraseOp(op);
241a8f819c6SEugene Zhulenev
242a8f819c6SEugene Zhulenev return success();
243a8f819c6SEugene Zhulenev }
244a8f819c6SEugene Zhulenev
245a8f819c6SEugene Zhulenev //===----------------------------------------------------------------------===//
24661dce0f3SEugene Zhulenev /// AwaitOp
24761dce0f3SEugene Zhulenev //===----------------------------------------------------------------------===//
24861dce0f3SEugene Zhulenev
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)24961dce0f3SEugene Zhulenev void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
25061dce0f3SEugene Zhulenev ArrayRef<NamedAttribute> attrs) {
25161dce0f3SEugene Zhulenev result.addOperands({operand});
25261dce0f3SEugene Zhulenev result.attributes.append(attrs.begin(), attrs.end());
25361dce0f3SEugene Zhulenev
25461dce0f3SEugene Zhulenev // Add unwrapped async.value type to the returned values types.
255c1fa60b4STres Popp if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType()))
25661dce0f3SEugene Zhulenev result.addTypes(valueType.getValueType());
25761dce0f3SEugene Zhulenev }
25861dce0f3SEugene Zhulenev
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)25961dce0f3SEugene Zhulenev static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
26061dce0f3SEugene Zhulenev Type &resultType) {
26161dce0f3SEugene Zhulenev if (parser.parseType(operandType))
26261dce0f3SEugene Zhulenev return failure();
26361dce0f3SEugene Zhulenev
26461dce0f3SEugene Zhulenev // Add unwrapped async.value type to the returned values types.
265c1fa60b4STres Popp if (auto valueType = llvm::dyn_cast<ValueType>(operandType))
26661dce0f3SEugene Zhulenev resultType = valueType.getValueType();
26761dce0f3SEugene Zhulenev
26861dce0f3SEugene Zhulenev return success();
26961dce0f3SEugene Zhulenev }
27061dce0f3SEugene Zhulenev
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)271035e12e6SJohn Demme static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
272035e12e6SJohn Demme Type operandType, Type resultType) {
27361dce0f3SEugene Zhulenev p << operandType;
27461dce0f3SEugene Zhulenev }
27561dce0f3SEugene Zhulenev
verify()2761be88f5aSRiver Riddle LogicalResult AwaitOp::verify() {
277a5aa7836SRiver Riddle Type argType = getOperand().getType();
27861dce0f3SEugene Zhulenev
27961dce0f3SEugene Zhulenev // Awaiting on a token does not have any results.
280c1fa60b4STres Popp if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
2811be88f5aSRiver Riddle return emitOpError("awaiting on a token must have empty result");
28261dce0f3SEugene Zhulenev
28361dce0f3SEugene Zhulenev // Awaiting on a value unwraps the async value type.
284c1fa60b4STres Popp if (auto value = llvm::dyn_cast<ValueType>(argType)) {
2851be88f5aSRiver Riddle if (*getResultType() != value.getValueType())
2861be88f5aSRiver Riddle return emitOpError() << "result type " << *getResultType()
2871be88f5aSRiver Riddle << " does not match async value type "
2881be88f5aSRiver Riddle << value.getValueType();
28961dce0f3SEugene Zhulenev }
29061dce0f3SEugene Zhulenev
29161dce0f3SEugene Zhulenev return success();
29261dce0f3SEugene Zhulenev }
29361dce0f3SEugene Zhulenev
2942f7baffdSEugene Zhulenev //===----------------------------------------------------------------------===//
295145d2a50Syijiagu // FuncOp
296145d2a50Syijiagu //===----------------------------------------------------------------------===//
297145d2a50Syijiagu
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)298145d2a50Syijiagu void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
299145d2a50Syijiagu FunctionType type, ArrayRef<NamedAttribute> attrs,
300145d2a50Syijiagu ArrayRef<DictionaryAttr> argAttrs) {
301145d2a50Syijiagu state.addAttribute(SymbolTable::getSymbolAttrName(),
302145d2a50Syijiagu builder.getStringAttr(name));
30353406427SJeff Niu state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
304145d2a50Syijiagu
305145d2a50Syijiagu state.attributes.append(attrs.begin(), attrs.end());
306145d2a50Syijiagu state.addRegion();
307145d2a50Syijiagu
308145d2a50Syijiagu if (argAttrs.empty())
309145d2a50Syijiagu return;
310145d2a50Syijiagu assert(type.getNumInputs() == argAttrs.size());
31153406427SJeff Niu function_interface_impl::addArgAndResultAttrs(
31253406427SJeff Niu builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
31353406427SJeff Niu getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
314145d2a50Syijiagu }
315145d2a50Syijiagu
parse(OpAsmParser & parser,OperationState & result)316145d2a50Syijiagu ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
317145d2a50Syijiagu auto buildFuncType =
318145d2a50Syijiagu [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
319145d2a50Syijiagu function_interface_impl::VariadicFlag,
320145d2a50Syijiagu std::string &) { return builder.getFunctionType(argTypes, results); };
321145d2a50Syijiagu
322145d2a50Syijiagu return function_interface_impl::parseFunctionOp(
32353406427SJeff Niu parser, result, /*allowVariadic=*/false,
32453406427SJeff Niu getFunctionTypeAttrName(result.name), buildFuncType,
32553406427SJeff Niu getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
326145d2a50Syijiagu }
327145d2a50Syijiagu
print(OpAsmPrinter & p)328145d2a50Syijiagu void FuncOp::print(OpAsmPrinter &p) {
32953406427SJeff Niu function_interface_impl::printFunctionOp(
33053406427SJeff Niu p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
33153406427SJeff Niu getArgAttrsAttrName(), getResAttrsAttrName());
332145d2a50Syijiagu }
333145d2a50Syijiagu
334145d2a50Syijiagu /// Check that the result type of async.func is not void and must be
335145d2a50Syijiagu /// some async token or async values.
verify()336145d2a50Syijiagu LogicalResult FuncOp::verify() {
337145d2a50Syijiagu auto resultTypes = getResultTypes();
338145d2a50Syijiagu if (resultTypes.empty())
339145d2a50Syijiagu return emitOpError()
340145d2a50Syijiagu << "result is expected to be at least of size 1, but got "
341145d2a50Syijiagu << resultTypes.size();
342145d2a50Syijiagu
343145d2a50Syijiagu for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
344145d2a50Syijiagu auto type = resultTypes[i];
345c1fa60b4STres Popp if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
346145d2a50Syijiagu return emitOpError() << "result type must be async value type or async "
347145d2a50Syijiagu "token type, but got "
348145d2a50Syijiagu << type;
349145d2a50Syijiagu // We only allow AsyncToken appear as the first return value
350c1fa60b4STres Popp if (llvm::isa<TokenType>(type) && i != 0) {
351145d2a50Syijiagu return emitOpError()
352145d2a50Syijiagu << " results' (optional) async token type is expected "
353145d2a50Syijiagu "to appear as the 1st return value, but got "
354145d2a50Syijiagu << i + 1;
355145d2a50Syijiagu }
356145d2a50Syijiagu }
357145d2a50Syijiagu
358145d2a50Syijiagu return success();
359145d2a50Syijiagu }
360145d2a50Syijiagu
361145d2a50Syijiagu //===----------------------------------------------------------------------===//
362145d2a50Syijiagu /// CallOp
363145d2a50Syijiagu //===----------------------------------------------------------------------===//
364145d2a50Syijiagu
verifySymbolUses(SymbolTableCollection & symbolTable)365145d2a50Syijiagu LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
366145d2a50Syijiagu // Check that the callee attribute was specified.
367145d2a50Syijiagu auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
368145d2a50Syijiagu if (!fnAttr)
369145d2a50Syijiagu return emitOpError("requires a 'callee' symbol reference attribute");
370145d2a50Syijiagu FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
371145d2a50Syijiagu if (!fn)
372145d2a50Syijiagu return emitOpError() << "'" << fnAttr.getValue()
373145d2a50Syijiagu << "' does not reference a valid async function";
374145d2a50Syijiagu
375145d2a50Syijiagu // Verify that the operand and result types match the callee.
376145d2a50Syijiagu auto fnType = fn.getFunctionType();
377145d2a50Syijiagu if (fnType.getNumInputs() != getNumOperands())
378145d2a50Syijiagu return emitOpError("incorrect number of operands for callee");
379145d2a50Syijiagu
380145d2a50Syijiagu for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
381145d2a50Syijiagu if (getOperand(i).getType() != fnType.getInput(i))
382145d2a50Syijiagu return emitOpError("operand type mismatch: expected operand type ")
383145d2a50Syijiagu << fnType.getInput(i) << ", but provided "
384145d2a50Syijiagu << getOperand(i).getType() << " for operand number " << i;
385145d2a50Syijiagu
386145d2a50Syijiagu if (fnType.getNumResults() != getNumResults())
387145d2a50Syijiagu return emitOpError("incorrect number of results for callee");
388145d2a50Syijiagu
389145d2a50Syijiagu for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
390145d2a50Syijiagu if (getResult(i).getType() != fnType.getResult(i)) {
391145d2a50Syijiagu auto diag = emitOpError("result type mismatch at index ") << i;
392145d2a50Syijiagu diag.attachNote() << " op result types: " << getResultTypes();
393145d2a50Syijiagu diag.attachNote() << "function result types: " << fnType.getResults();
394145d2a50Syijiagu return diag;
395145d2a50Syijiagu }
396145d2a50Syijiagu
397145d2a50Syijiagu return success();
398145d2a50Syijiagu }
399145d2a50Syijiagu
getCalleeType()400145d2a50Syijiagu FunctionType CallOp::getCalleeType() {
401145d2a50Syijiagu return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
402145d2a50Syijiagu }
403145d2a50Syijiagu
404145d2a50Syijiagu //===----------------------------------------------------------------------===//
405145d2a50Syijiagu /// ReturnOp
406145d2a50Syijiagu //===----------------------------------------------------------------------===//
407145d2a50Syijiagu
verify()408145d2a50Syijiagu LogicalResult ReturnOp::verify() {
409145d2a50Syijiagu auto funcOp = (*this)->getParentOfType<FuncOp>();
410145d2a50Syijiagu ArrayRef<Type> resultTypes = funcOp.isStateful()
411145d2a50Syijiagu ? funcOp.getResultTypes().drop_front()
412145d2a50Syijiagu : funcOp.getResultTypes();
413145d2a50Syijiagu // Get the underlying value types from async types returned from the
414145d2a50Syijiagu // parent `async.func` operation.
415145d2a50Syijiagu auto types = llvm::map_range(resultTypes, [](const Type &result) {
416c1fa60b4STres Popp return llvm::cast<ValueType>(result).getValueType();
417145d2a50Syijiagu });
418145d2a50Syijiagu
419145d2a50Syijiagu if (getOperandTypes() != types)
420145d2a50Syijiagu return emitOpError("operand types do not match the types returned from "
421145d2a50Syijiagu "the parent FuncOp");
422145d2a50Syijiagu
423145d2a50Syijiagu return success();
424145d2a50Syijiagu }
425145d2a50Syijiagu
426145d2a50Syijiagu //===----------------------------------------------------------------------===//
4272f7baffdSEugene Zhulenev // TableGen'd op method definitions
4282f7baffdSEugene Zhulenev //===----------------------------------------------------------------------===//
4292f7baffdSEugene Zhulenev
43005a3b4feSEugene Zhulenev #define GET_OP_CLASSES
43105a3b4feSEugene Zhulenev #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
4322f7baffdSEugene Zhulenev
4332f7baffdSEugene Zhulenev //===----------------------------------------------------------------------===//
4342f7baffdSEugene Zhulenev // TableGen'd type method definitions
4352f7baffdSEugene Zhulenev //===----------------------------------------------------------------------===//
4362f7baffdSEugene Zhulenev
4372f7baffdSEugene Zhulenev #define GET_TYPEDEF_CLASSES
4382f7baffdSEugene Zhulenev #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
4392f7baffdSEugene Zhulenev
print(AsmPrinter & printer) const440f97e72aaSMehdi Amini void ValueType::print(AsmPrinter &printer) const {
4412f7baffdSEugene Zhulenev printer << "<";
4422f7baffdSEugene Zhulenev printer.printType(getValueType());
4432f7baffdSEugene Zhulenev printer << '>';
4442f7baffdSEugene Zhulenev }
4452f7baffdSEugene Zhulenev
parse(mlir::AsmParser & parser)446f97e72aaSMehdi Amini Type ValueType::parse(mlir::AsmParser &parser) {
4472f7baffdSEugene Zhulenev Type ty;
4482f7baffdSEugene Zhulenev if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
4492f7baffdSEugene Zhulenev parser.emitError(parser.getNameLoc(), "failed to parse async value type");
4502f7baffdSEugene Zhulenev return Type();
4512f7baffdSEugene Zhulenev }
4522f7baffdSEugene Zhulenev return ValueType::get(ty);
4532f7baffdSEugene Zhulenev }
454