xref: /llvm-project/mlir/lib/Dialect/Async/IR/Async.cpp (revision 91d5653e3ae9742f7fb847f809b534ee128501b0)
105a3b4feSEugene Zhulenev //===- Async.cpp - MLIR Async Operations ----------------------------------===//
205a3b4feSEugene Zhulenev //
305a3b4feSEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
405a3b4feSEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
505a3b4feSEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
605a3b4feSEugene Zhulenev //
705a3b4feSEugene Zhulenev //===----------------------------------------------------------------------===//
805a3b4feSEugene Zhulenev 
905a3b4feSEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
1005a3b4feSEugene Zhulenev 
1105a3b4feSEugene Zhulenev #include "mlir/IR/DialectImplementation.h"
124d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
1334a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionImplementation.h"
14145d2a50Syijiagu #include "llvm/ADT/MapVector.h"
1505a3b4feSEugene Zhulenev #include "llvm/ADT/TypeSwitch.h"
1605a3b4feSEugene Zhulenev 
174e69a529SEugene Zhulenev using namespace mlir;
184e69a529SEugene Zhulenev using namespace mlir::async;
1905a3b4feSEugene Zhulenev 
20485cc55eSStella Laurenzo #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
21485cc55eSStella Laurenzo 
229a5bc836Sbakhtiyar constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
239a5bc836Sbakhtiyar 
initialize()2405a3b4feSEugene Zhulenev void AsyncDialect::initialize() {
2505a3b4feSEugene Zhulenev   addOperations<
2605a3b4feSEugene Zhulenev #define GET_OP_LIST
2705a3b4feSEugene Zhulenev #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
2805a3b4feSEugene Zhulenev       >();
292f7baffdSEugene Zhulenev   addTypes<
302f7baffdSEugene Zhulenev #define GET_TYPEDEF_LIST
312f7baffdSEugene Zhulenev #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
322f7baffdSEugene Zhulenev       >();
3305a3b4feSEugene Zhulenev }
3405a3b4feSEugene Zhulenev 
35655af658SEugene Zhulenev //===----------------------------------------------------------------------===//
36655af658SEugene Zhulenev /// ExecuteOp
37655af658SEugene Zhulenev //===----------------------------------------------------------------------===//
38655af658SEugene Zhulenev 
39363b6559SMehdi Amini constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
404e69a529SEugene Zhulenev 
getEntrySuccessorOperands(RegionBranchPoint point)414dd744acSMarkus Böck OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
424dd744acSMarkus Böck   assert(point == getBodyRegion() && "invalid region index");
43a5aa7836SRiver Riddle   return getBodyOperands();
449775c0c9SVladislav Vinogradov }
459775c0c9SVladislav Vinogradov 
areTypesCompatible(Type lhs,Type rhs)46e7c7b16aSMogball bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
47e7c7b16aSMogball   const auto getValueOrTokenType = [](Type type) {
48c1fa60b4STres Popp     if (auto value = llvm::dyn_cast<ValueType>(type))
49e7c7b16aSMogball       return value.getValueType();
50e7c7b16aSMogball     return type;
51e7c7b16aSMogball   };
52e7c7b16aSMogball   return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
53e7c7b16aSMogball }
54e7c7b16aSMogball 
getSuccessorRegions(RegionBranchPoint point,SmallVectorImpl<RegionSuccessor> & regions)554dd744acSMarkus Böck void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
56bb0d5f76SEugene Zhulenev                                     SmallVectorImpl<RegionSuccessor> &regions) {
57bb0d5f76SEugene Zhulenev   // The `body` region branch back to the parent operation.
584dd744acSMarkus Böck   if (point == getBodyRegion()) {
59a5aa7836SRiver Riddle     regions.push_back(RegionSuccessor(getBodyResults()));
60bb0d5f76SEugene Zhulenev     return;
61bb0d5f76SEugene Zhulenev   }
62bb0d5f76SEugene Zhulenev 
63bb0d5f76SEugene Zhulenev   // Otherwise the successor is the body region.
64986b5c56SRiver Riddle   regions.push_back(
65a5aa7836SRiver Riddle       RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
66bb0d5f76SEugene Zhulenev }
67bb0d5f76SEugene Zhulenev 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)68c30ab6c2SEugene Zhulenev void ExecuteOp::build(OpBuilder &builder, OperationState &result,
69c30ab6c2SEugene Zhulenev                       TypeRange resultTypes, ValueRange dependencies,
70c30ab6c2SEugene Zhulenev                       ValueRange operands, BodyBuilderFn bodyBuilder) {
71*91d5653eSMatthias Springer   OpBuilder::InsertionGuard guard(builder);
72c30ab6c2SEugene Zhulenev   result.addOperands(dependencies);
73c30ab6c2SEugene Zhulenev   result.addOperands(operands);
74c30ab6c2SEugene Zhulenev 
75363b6559SMehdi Amini   // Add derived `operandSegmentSizes` attribute based on parsed operands.
76c30ab6c2SEugene Zhulenev   int32_t numDependencies = dependencies.size();
77c30ab6c2SEugene Zhulenev   int32_t numOperands = operands.size();
7858a47508SJeff Niu   auto operandSegmentSizes =
7958a47508SJeff Niu       builder.getDenseI32ArrayAttr({numDependencies, numOperands});
80c30ab6c2SEugene Zhulenev   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
81c30ab6c2SEugene Zhulenev 
82c30ab6c2SEugene Zhulenev   // First result is always a token, and then `resultTypes` wrapped into
83c30ab6c2SEugene Zhulenev   // `async.value`.
84c30ab6c2SEugene Zhulenev   result.addTypes({TokenType::get(result.getContext())});
85c30ab6c2SEugene Zhulenev   for (Type type : resultTypes)
86c30ab6c2SEugene Zhulenev     result.addTypes(ValueType::get(type));
87c30ab6c2SEugene Zhulenev 
88c30ab6c2SEugene Zhulenev   // Add a body region with block arguments as unwrapped async value operands.
89c30ab6c2SEugene Zhulenev   Region *bodyRegion = result.addRegion();
90*91d5653eSMatthias Springer   Block *bodyBlock = builder.createBlock(bodyRegion);
91c30ab6c2SEugene Zhulenev   for (Value operand : operands) {
92c1fa60b4STres Popp     auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
93*91d5653eSMatthias Springer     bodyBlock->addArgument(valueType ? valueType.getValueType()
94e084679fSRiver Riddle                                      : operand.getType(),
95e084679fSRiver Riddle                            operand.getLoc());
96c30ab6c2SEugene Zhulenev   }
97c30ab6c2SEugene Zhulenev 
98c30ab6c2SEugene Zhulenev   // Create the default terminator if the builder is not provided and if the
99c30ab6c2SEugene Zhulenev   // expected result is empty. Otherwise, leave this to the caller
100c30ab6c2SEugene Zhulenev   // because we don't know which values to return from the execute op.
101c30ab6c2SEugene Zhulenev   if (resultTypes.empty() && !bodyBuilder) {
102c30ab6c2SEugene Zhulenev     builder.create<async::YieldOp>(result.location, ValueRange());
103c30ab6c2SEugene Zhulenev   } else if (bodyBuilder) {
104*91d5653eSMatthias Springer     bodyBuilder(builder, result.location, bodyBlock->getArguments());
105c30ab6c2SEugene Zhulenev   }
106c30ab6c2SEugene Zhulenev }
107c30ab6c2SEugene Zhulenev 
print(OpAsmPrinter & p)1082418cd92SRiver Riddle void ExecuteOp::print(OpAsmPrinter &p) {
1094e69a529SEugene Zhulenev   // [%tokens,...]
110a5aa7836SRiver Riddle   if (!getDependencies().empty())
111a5aa7836SRiver Riddle     p << " [" << getDependencies() << "]";
1124e69a529SEugene Zhulenev 
1134e69a529SEugene Zhulenev   // (%value as %unwrapped: !async.value<!arg.type>, ...)
114a5aa7836SRiver Riddle   if (!getBodyOperands().empty()) {
1154e69a529SEugene Zhulenev     p << " (";
116a5aa7836SRiver Riddle     Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
117a5aa7836SRiver Riddle     llvm::interleaveComma(
118a5aa7836SRiver Riddle         getBodyOperands(), p, [&, n = 0](Value operand) mutable {
119674dd9d0SChristian Sigg           Value argument = entry ? entry->getArgument(n++) : Value();
120674dd9d0SChristian Sigg           p << operand << " as " << argument << ": " << operand.getType();
121655af658SEugene Zhulenev         });
1224e69a529SEugene Zhulenev     p << ")";
1234e69a529SEugene Zhulenev   }
1244e69a529SEugene Zhulenev 
1254e69a529SEugene Zhulenev   // -> (!async.value<!return.type>, ...)
1262418cd92SRiver Riddle   p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
1272418cd92SRiver Riddle   p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
1288c074cb0SChristian Sigg                                      {kOperandSegmentSizesAttr});
1295c36ee8dSMogball   p << ' ';
130a5aa7836SRiver Riddle   p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
131655af658SEugene Zhulenev }
132655af658SEugene Zhulenev 
parse(OpAsmParser & parser,OperationState & result)1332418cd92SRiver Riddle ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
134655af658SEugene Zhulenev   MLIRContext *ctx = result.getContext();
135655af658SEugene Zhulenev 
1364e69a529SEugene Zhulenev   // Sizes of parsed variadic operands, will be updated below after parsing.
1374e69a529SEugene Zhulenev   int32_t numDependencies = 0;
1384e69a529SEugene Zhulenev 
1394e69a529SEugene Zhulenev   auto tokenTy = TokenType::get(ctx);
1404e69a529SEugene Zhulenev 
1414e69a529SEugene Zhulenev   // Parse dependency tokens.
1424e69a529SEugene Zhulenev   if (succeeded(parser.parseOptionalLSquare())) {
143e13d23bcSMarkus Böck     SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
1444e69a529SEugene Zhulenev     if (parser.parseOperandList(tokenArgs) ||
1454e69a529SEugene Zhulenev         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
1464e69a529SEugene Zhulenev         parser.parseRSquare())
147655af658SEugene Zhulenev       return failure();
148655af658SEugene Zhulenev 
1494e69a529SEugene Zhulenev     numDependencies = tokenArgs.size();
1504e69a529SEugene Zhulenev   }
1514e69a529SEugene Zhulenev 
1524e69a529SEugene Zhulenev   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
153e13d23bcSMarkus Böck   SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
154d85eb4e2SChris Lattner   SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
1554e69a529SEugene Zhulenev   SmallVector<Type, 4> valueTypes;
1564e69a529SEugene Zhulenev 
1574e69a529SEugene Zhulenev   // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
1584e69a529SEugene Zhulenev   auto parseAsyncValueArg = [&]() -> ParseResult {
1594e69a529SEugene Zhulenev     if (parser.parseOperand(valueArgs.emplace_back()) ||
1604e69a529SEugene Zhulenev         parser.parseKeyword("as") ||
161d85eb4e2SChris Lattner         parser.parseArgument(unwrappedArgs.emplace_back()) ||
1624e69a529SEugene Zhulenev         parser.parseColonType(valueTypes.emplace_back()))
1634e69a529SEugene Zhulenev       return failure();
1644e69a529SEugene Zhulenev 
165c1fa60b4STres Popp     auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
166d85eb4e2SChris Lattner     unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
1674e69a529SEugene Zhulenev     return success();
1684e69a529SEugene Zhulenev   };
1694e69a529SEugene Zhulenev 
17058abc8c3SChris Lattner   auto argsLoc = parser.getCurrentLocation();
17158abc8c3SChris Lattner   if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
17258abc8c3SChris Lattner                                      parseAsyncValueArg) ||
17358abc8c3SChris Lattner       parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
1744e69a529SEugene Zhulenev     return failure();
1754e69a529SEugene Zhulenev 
17658abc8c3SChris Lattner   int32_t numOperands = valueArgs.size();
1774e69a529SEugene Zhulenev 
178363b6559SMehdi Amini   // Add derived `operandSegmentSizes` attribute based on parsed operands.
17958a47508SJeff Niu   auto operandSegmentSizes =
18058a47508SJeff Niu       parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands});
1814e69a529SEugene Zhulenev   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
1824e69a529SEugene Zhulenev 
1834e69a529SEugene Zhulenev   // Parse the types of results returned from the async execute op.
1844e69a529SEugene Zhulenev   SmallVector<Type, 4> resultTypes;
185655af658SEugene Zhulenev   NamedAttrList attrs;
1861d7b5cd5SChris Lattner   if (parser.parseOptionalArrowTypeList(resultTypes) ||
1871d7b5cd5SChris Lattner       // Async execute first result is always a completion token.
1881d7b5cd5SChris Lattner       parser.addTypeToList(tokenTy, result.types) ||
1891d7b5cd5SChris Lattner       parser.addTypesToList(resultTypes, result.types) ||
1901d7b5cd5SChris Lattner       // Parse operation attributes.
1911d7b5cd5SChris Lattner       parser.parseOptionalAttrDictWithKeyword(attrs))
192655af658SEugene Zhulenev     return failure();
1931d7b5cd5SChris Lattner 
194655af658SEugene Zhulenev   result.addAttributes(attrs);
195655af658SEugene Zhulenev 
1964e69a529SEugene Zhulenev   // Parse asynchronous region.
1974e69a529SEugene Zhulenev   Region *body = result.addRegion();
198d85eb4e2SChris Lattner   return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
199655af658SEugene Zhulenev }
200655af658SEugene Zhulenev 
verifyRegions()201ed645f63SChia-hung Duan LogicalResult ExecuteOp::verifyRegions() {
2024e69a529SEugene Zhulenev   // Unwrap async.execute value operands types.
203a5aa7836SRiver Riddle   auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
204c1fa60b4STres Popp     return llvm::cast<ValueType>(operand.getType()).getValueType();
2054e69a529SEugene Zhulenev   });
2064e69a529SEugene Zhulenev 
2074e69a529SEugene Zhulenev   // Verify that unwrapped argument types matches the body region arguments.
208a5aa7836SRiver Riddle   if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
2091be88f5aSRiver Riddle     return emitOpError("async body region argument types do not match the "
2104e69a529SEugene Zhulenev                        "execute operation arguments types");
2114e69a529SEugene Zhulenev 
2124e69a529SEugene Zhulenev   return success();
2134e69a529SEugene Zhulenev }
214655af658SEugene Zhulenev 
21561dce0f3SEugene Zhulenev //===----------------------------------------------------------------------===//
216a8f819c6SEugene Zhulenev /// CreateGroupOp
217a8f819c6SEugene Zhulenev //===----------------------------------------------------------------------===//
218a8f819c6SEugene Zhulenev 
canonicalize(CreateGroupOp op,PatternRewriter & rewriter)219a8f819c6SEugene Zhulenev LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
220a8f819c6SEugene Zhulenev                                           PatternRewriter &rewriter) {
221a8f819c6SEugene Zhulenev   // Find all `await_all` users of the group.
222a8f819c6SEugene Zhulenev   llvm::SmallVector<AwaitAllOp> awaitAllUsers;
223a8f819c6SEugene Zhulenev 
224a8f819c6SEugene Zhulenev   auto isAwaitAll = [&](Operation *op) -> bool {
225a8f819c6SEugene Zhulenev     if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
226a8f819c6SEugene Zhulenev       awaitAllUsers.push_back(awaitAll);
227a8f819c6SEugene Zhulenev       return true;
228a8f819c6SEugene Zhulenev     }
229a8f819c6SEugene Zhulenev     return false;
230a8f819c6SEugene Zhulenev   };
231a8f819c6SEugene Zhulenev 
232a8f819c6SEugene Zhulenev   // Check if all users of the group are `await_all` operations.
233a8f819c6SEugene Zhulenev   if (!llvm::all_of(op->getUsers(), isAwaitAll))
234a8f819c6SEugene Zhulenev     return failure();
235a8f819c6SEugene Zhulenev 
236a8f819c6SEugene Zhulenev   // If group is only awaited without adding anything to it, we can safely erase
237a8f819c6SEugene Zhulenev   // the create operation and all users.
238a8f819c6SEugene Zhulenev   for (AwaitAllOp awaitAll : awaitAllUsers)
239a8f819c6SEugene Zhulenev     rewriter.eraseOp(awaitAll);
240a8f819c6SEugene Zhulenev   rewriter.eraseOp(op);
241a8f819c6SEugene Zhulenev 
242a8f819c6SEugene Zhulenev   return success();
243a8f819c6SEugene Zhulenev }
244a8f819c6SEugene Zhulenev 
245a8f819c6SEugene Zhulenev //===----------------------------------------------------------------------===//
24661dce0f3SEugene Zhulenev /// AwaitOp
24761dce0f3SEugene Zhulenev //===----------------------------------------------------------------------===//
24861dce0f3SEugene Zhulenev 
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)24961dce0f3SEugene Zhulenev void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
25061dce0f3SEugene Zhulenev                     ArrayRef<NamedAttribute> attrs) {
25161dce0f3SEugene Zhulenev   result.addOperands({operand});
25261dce0f3SEugene Zhulenev   result.attributes.append(attrs.begin(), attrs.end());
25361dce0f3SEugene Zhulenev 
25461dce0f3SEugene Zhulenev   // Add unwrapped async.value type to the returned values types.
255c1fa60b4STres Popp   if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType()))
25661dce0f3SEugene Zhulenev     result.addTypes(valueType.getValueType());
25761dce0f3SEugene Zhulenev }
25861dce0f3SEugene Zhulenev 
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)25961dce0f3SEugene Zhulenev static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
26061dce0f3SEugene Zhulenev                                         Type &resultType) {
26161dce0f3SEugene Zhulenev   if (parser.parseType(operandType))
26261dce0f3SEugene Zhulenev     return failure();
26361dce0f3SEugene Zhulenev 
26461dce0f3SEugene Zhulenev   // Add unwrapped async.value type to the returned values types.
265c1fa60b4STres Popp   if (auto valueType = llvm::dyn_cast<ValueType>(operandType))
26661dce0f3SEugene Zhulenev     resultType = valueType.getValueType();
26761dce0f3SEugene Zhulenev 
26861dce0f3SEugene Zhulenev   return success();
26961dce0f3SEugene Zhulenev }
27061dce0f3SEugene Zhulenev 
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)271035e12e6SJohn Demme static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
272035e12e6SJohn Demme                                  Type operandType, Type resultType) {
27361dce0f3SEugene Zhulenev   p << operandType;
27461dce0f3SEugene Zhulenev }
27561dce0f3SEugene Zhulenev 
verify()2761be88f5aSRiver Riddle LogicalResult AwaitOp::verify() {
277a5aa7836SRiver Riddle   Type argType = getOperand().getType();
27861dce0f3SEugene Zhulenev 
27961dce0f3SEugene Zhulenev   // Awaiting on a token does not have any results.
280c1fa60b4STres Popp   if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
2811be88f5aSRiver Riddle     return emitOpError("awaiting on a token must have empty result");
28261dce0f3SEugene Zhulenev 
28361dce0f3SEugene Zhulenev   // Awaiting on a value unwraps the async value type.
284c1fa60b4STres Popp   if (auto value = llvm::dyn_cast<ValueType>(argType)) {
2851be88f5aSRiver Riddle     if (*getResultType() != value.getValueType())
2861be88f5aSRiver Riddle       return emitOpError() << "result type " << *getResultType()
2871be88f5aSRiver Riddle                            << " does not match async value type "
2881be88f5aSRiver Riddle                            << value.getValueType();
28961dce0f3SEugene Zhulenev   }
29061dce0f3SEugene Zhulenev 
29161dce0f3SEugene Zhulenev   return success();
29261dce0f3SEugene Zhulenev }
29361dce0f3SEugene Zhulenev 
2942f7baffdSEugene Zhulenev //===----------------------------------------------------------------------===//
295145d2a50Syijiagu // FuncOp
296145d2a50Syijiagu //===----------------------------------------------------------------------===//
297145d2a50Syijiagu 
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)298145d2a50Syijiagu void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
299145d2a50Syijiagu                    FunctionType type, ArrayRef<NamedAttribute> attrs,
300145d2a50Syijiagu                    ArrayRef<DictionaryAttr> argAttrs) {
301145d2a50Syijiagu   state.addAttribute(SymbolTable::getSymbolAttrName(),
302145d2a50Syijiagu                      builder.getStringAttr(name));
30353406427SJeff Niu   state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
304145d2a50Syijiagu 
305145d2a50Syijiagu   state.attributes.append(attrs.begin(), attrs.end());
306145d2a50Syijiagu   state.addRegion();
307145d2a50Syijiagu 
308145d2a50Syijiagu   if (argAttrs.empty())
309145d2a50Syijiagu     return;
310145d2a50Syijiagu   assert(type.getNumInputs() == argAttrs.size());
31153406427SJeff Niu   function_interface_impl::addArgAndResultAttrs(
31253406427SJeff Niu       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
31353406427SJeff Niu       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
314145d2a50Syijiagu }
315145d2a50Syijiagu 
parse(OpAsmParser & parser,OperationState & result)316145d2a50Syijiagu ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
317145d2a50Syijiagu   auto buildFuncType =
318145d2a50Syijiagu       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
319145d2a50Syijiagu          function_interface_impl::VariadicFlag,
320145d2a50Syijiagu          std::string &) { return builder.getFunctionType(argTypes, results); };
321145d2a50Syijiagu 
322145d2a50Syijiagu   return function_interface_impl::parseFunctionOp(
32353406427SJeff Niu       parser, result, /*allowVariadic=*/false,
32453406427SJeff Niu       getFunctionTypeAttrName(result.name), buildFuncType,
32553406427SJeff Niu       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
326145d2a50Syijiagu }
327145d2a50Syijiagu 
print(OpAsmPrinter & p)328145d2a50Syijiagu void FuncOp::print(OpAsmPrinter &p) {
32953406427SJeff Niu   function_interface_impl::printFunctionOp(
33053406427SJeff Niu       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
33153406427SJeff Niu       getArgAttrsAttrName(), getResAttrsAttrName());
332145d2a50Syijiagu }
333145d2a50Syijiagu 
334145d2a50Syijiagu /// Check that the result type of async.func is not void and must be
335145d2a50Syijiagu /// some async token or async values.
verify()336145d2a50Syijiagu LogicalResult FuncOp::verify() {
337145d2a50Syijiagu   auto resultTypes = getResultTypes();
338145d2a50Syijiagu   if (resultTypes.empty())
339145d2a50Syijiagu     return emitOpError()
340145d2a50Syijiagu            << "result is expected to be at least of size 1, but got "
341145d2a50Syijiagu            << resultTypes.size();
342145d2a50Syijiagu 
343145d2a50Syijiagu   for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
344145d2a50Syijiagu     auto type = resultTypes[i];
345c1fa60b4STres Popp     if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
346145d2a50Syijiagu       return emitOpError() << "result type must be async value type or async "
347145d2a50Syijiagu                               "token type, but got "
348145d2a50Syijiagu                            << type;
349145d2a50Syijiagu     // We only allow AsyncToken appear as the first return value
350c1fa60b4STres Popp     if (llvm::isa<TokenType>(type) && i != 0) {
351145d2a50Syijiagu       return emitOpError()
352145d2a50Syijiagu              << " results' (optional) async token type is expected "
353145d2a50Syijiagu                 "to appear as the 1st return value, but got "
354145d2a50Syijiagu              << i + 1;
355145d2a50Syijiagu     }
356145d2a50Syijiagu   }
357145d2a50Syijiagu 
358145d2a50Syijiagu   return success();
359145d2a50Syijiagu }
360145d2a50Syijiagu 
361145d2a50Syijiagu //===----------------------------------------------------------------------===//
362145d2a50Syijiagu /// CallOp
363145d2a50Syijiagu //===----------------------------------------------------------------------===//
364145d2a50Syijiagu 
verifySymbolUses(SymbolTableCollection & symbolTable)365145d2a50Syijiagu LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
366145d2a50Syijiagu   // Check that the callee attribute was specified.
367145d2a50Syijiagu   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
368145d2a50Syijiagu   if (!fnAttr)
369145d2a50Syijiagu     return emitOpError("requires a 'callee' symbol reference attribute");
370145d2a50Syijiagu   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
371145d2a50Syijiagu   if (!fn)
372145d2a50Syijiagu     return emitOpError() << "'" << fnAttr.getValue()
373145d2a50Syijiagu                          << "' does not reference a valid async function";
374145d2a50Syijiagu 
375145d2a50Syijiagu   // Verify that the operand and result types match the callee.
376145d2a50Syijiagu   auto fnType = fn.getFunctionType();
377145d2a50Syijiagu   if (fnType.getNumInputs() != getNumOperands())
378145d2a50Syijiagu     return emitOpError("incorrect number of operands for callee");
379145d2a50Syijiagu 
380145d2a50Syijiagu   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
381145d2a50Syijiagu     if (getOperand(i).getType() != fnType.getInput(i))
382145d2a50Syijiagu       return emitOpError("operand type mismatch: expected operand type ")
383145d2a50Syijiagu              << fnType.getInput(i) << ", but provided "
384145d2a50Syijiagu              << getOperand(i).getType() << " for operand number " << i;
385145d2a50Syijiagu 
386145d2a50Syijiagu   if (fnType.getNumResults() != getNumResults())
387145d2a50Syijiagu     return emitOpError("incorrect number of results for callee");
388145d2a50Syijiagu 
389145d2a50Syijiagu   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
390145d2a50Syijiagu     if (getResult(i).getType() != fnType.getResult(i)) {
391145d2a50Syijiagu       auto diag = emitOpError("result type mismatch at index ") << i;
392145d2a50Syijiagu       diag.attachNote() << "      op result types: " << getResultTypes();
393145d2a50Syijiagu       diag.attachNote() << "function result types: " << fnType.getResults();
394145d2a50Syijiagu       return diag;
395145d2a50Syijiagu     }
396145d2a50Syijiagu 
397145d2a50Syijiagu   return success();
398145d2a50Syijiagu }
399145d2a50Syijiagu 
getCalleeType()400145d2a50Syijiagu FunctionType CallOp::getCalleeType() {
401145d2a50Syijiagu   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
402145d2a50Syijiagu }
403145d2a50Syijiagu 
404145d2a50Syijiagu //===----------------------------------------------------------------------===//
405145d2a50Syijiagu /// ReturnOp
406145d2a50Syijiagu //===----------------------------------------------------------------------===//
407145d2a50Syijiagu 
verify()408145d2a50Syijiagu LogicalResult ReturnOp::verify() {
409145d2a50Syijiagu   auto funcOp = (*this)->getParentOfType<FuncOp>();
410145d2a50Syijiagu   ArrayRef<Type> resultTypes = funcOp.isStateful()
411145d2a50Syijiagu                                    ? funcOp.getResultTypes().drop_front()
412145d2a50Syijiagu                                    : funcOp.getResultTypes();
413145d2a50Syijiagu   // Get the underlying value types from async types returned from the
414145d2a50Syijiagu   // parent `async.func` operation.
415145d2a50Syijiagu   auto types = llvm::map_range(resultTypes, [](const Type &result) {
416c1fa60b4STres Popp     return llvm::cast<ValueType>(result).getValueType();
417145d2a50Syijiagu   });
418145d2a50Syijiagu 
419145d2a50Syijiagu   if (getOperandTypes() != types)
420145d2a50Syijiagu     return emitOpError("operand types do not match the types returned from "
421145d2a50Syijiagu                        "the parent FuncOp");
422145d2a50Syijiagu 
423145d2a50Syijiagu   return success();
424145d2a50Syijiagu }
425145d2a50Syijiagu 
426145d2a50Syijiagu //===----------------------------------------------------------------------===//
4272f7baffdSEugene Zhulenev // TableGen'd op method definitions
4282f7baffdSEugene Zhulenev //===----------------------------------------------------------------------===//
4292f7baffdSEugene Zhulenev 
43005a3b4feSEugene Zhulenev #define GET_OP_CLASSES
43105a3b4feSEugene Zhulenev #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
4322f7baffdSEugene Zhulenev 
4332f7baffdSEugene Zhulenev //===----------------------------------------------------------------------===//
4342f7baffdSEugene Zhulenev // TableGen'd type method definitions
4352f7baffdSEugene Zhulenev //===----------------------------------------------------------------------===//
4362f7baffdSEugene Zhulenev 
4372f7baffdSEugene Zhulenev #define GET_TYPEDEF_CLASSES
4382f7baffdSEugene Zhulenev #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
4392f7baffdSEugene Zhulenev 
print(AsmPrinter & printer) const440f97e72aaSMehdi Amini void ValueType::print(AsmPrinter &printer) const {
4412f7baffdSEugene Zhulenev   printer << "<";
4422f7baffdSEugene Zhulenev   printer.printType(getValueType());
4432f7baffdSEugene Zhulenev   printer << '>';
4442f7baffdSEugene Zhulenev }
4452f7baffdSEugene Zhulenev 
parse(mlir::AsmParser & parser)446f97e72aaSMehdi Amini Type ValueType::parse(mlir::AsmParser &parser) {
4472f7baffdSEugene Zhulenev   Type ty;
4482f7baffdSEugene Zhulenev   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
4492f7baffdSEugene Zhulenev     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
4502f7baffdSEugene Zhulenev     return Type();
4512f7baffdSEugene Zhulenev   }
4522f7baffdSEugene Zhulenev   return ValueType::get(ty);
4532f7baffdSEugene Zhulenev }
454