xref: /llvm-project/mlir/lib/Dialect/Async/IR/Async.cpp (revision 91d5653e3ae9742f7fb847f809b534ee128501b0)
1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/Interfaces/FunctionImplementation.h"
14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::async;
19 
20 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
21 
22 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
23 
initialize()24 void AsyncDialect::initialize() {
25   addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
28       >();
29   addTypes<
30 #define GET_TYPEDEF_LIST
31 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
32       >();
33 }
34 
35 //===----------------------------------------------------------------------===//
36 /// ExecuteOp
37 //===----------------------------------------------------------------------===//
38 
39 constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
40 
getEntrySuccessorOperands(RegionBranchPoint point)41 OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
42   assert(point == getBodyRegion() && "invalid region index");
43   return getBodyOperands();
44 }
45 
areTypesCompatible(Type lhs,Type rhs)46 bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
47   const auto getValueOrTokenType = [](Type type) {
48     if (auto value = llvm::dyn_cast<ValueType>(type))
49       return value.getValueType();
50     return type;
51   };
52   return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
53 }
54 
getSuccessorRegions(RegionBranchPoint point,SmallVectorImpl<RegionSuccessor> & regions)55 void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
56                                     SmallVectorImpl<RegionSuccessor> &regions) {
57   // The `body` region branch back to the parent operation.
58   if (point == getBodyRegion()) {
59     regions.push_back(RegionSuccessor(getBodyResults()));
60     return;
61   }
62 
63   // Otherwise the successor is the body region.
64   regions.push_back(
65       RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
66 }
67 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)68 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
69                       TypeRange resultTypes, ValueRange dependencies,
70                       ValueRange operands, BodyBuilderFn bodyBuilder) {
71   OpBuilder::InsertionGuard guard(builder);
72   result.addOperands(dependencies);
73   result.addOperands(operands);
74 
75   // Add derived `operandSegmentSizes` attribute based on parsed operands.
76   int32_t numDependencies = dependencies.size();
77   int32_t numOperands = operands.size();
78   auto operandSegmentSizes =
79       builder.getDenseI32ArrayAttr({numDependencies, numOperands});
80   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
81 
82   // First result is always a token, and then `resultTypes` wrapped into
83   // `async.value`.
84   result.addTypes({TokenType::get(result.getContext())});
85   for (Type type : resultTypes)
86     result.addTypes(ValueType::get(type));
87 
88   // Add a body region with block arguments as unwrapped async value operands.
89   Region *bodyRegion = result.addRegion();
90   Block *bodyBlock = builder.createBlock(bodyRegion);
91   for (Value operand : operands) {
92     auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
93     bodyBlock->addArgument(valueType ? valueType.getValueType()
94                                      : operand.getType(),
95                            operand.getLoc());
96   }
97 
98   // Create the default terminator if the builder is not provided and if the
99   // expected result is empty. Otherwise, leave this to the caller
100   // because we don't know which values to return from the execute op.
101   if (resultTypes.empty() && !bodyBuilder) {
102     builder.create<async::YieldOp>(result.location, ValueRange());
103   } else if (bodyBuilder) {
104     bodyBuilder(builder, result.location, bodyBlock->getArguments());
105   }
106 }
107 
print(OpAsmPrinter & p)108 void ExecuteOp::print(OpAsmPrinter &p) {
109   // [%tokens,...]
110   if (!getDependencies().empty())
111     p << " [" << getDependencies() << "]";
112 
113   // (%value as %unwrapped: !async.value<!arg.type>, ...)
114   if (!getBodyOperands().empty()) {
115     p << " (";
116     Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
117     llvm::interleaveComma(
118         getBodyOperands(), p, [&, n = 0](Value operand) mutable {
119           Value argument = entry ? entry->getArgument(n++) : Value();
120           p << operand << " as " << argument << ": " << operand.getType();
121         });
122     p << ")";
123   }
124 
125   // -> (!async.value<!return.type>, ...)
126   p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
127   p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
128                                      {kOperandSegmentSizesAttr});
129   p << ' ';
130   p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
131 }
132 
parse(OpAsmParser & parser,OperationState & result)133 ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
134   MLIRContext *ctx = result.getContext();
135 
136   // Sizes of parsed variadic operands, will be updated below after parsing.
137   int32_t numDependencies = 0;
138 
139   auto tokenTy = TokenType::get(ctx);
140 
141   // Parse dependency tokens.
142   if (succeeded(parser.parseOptionalLSquare())) {
143     SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
144     if (parser.parseOperandList(tokenArgs) ||
145         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
146         parser.parseRSquare())
147       return failure();
148 
149     numDependencies = tokenArgs.size();
150   }
151 
152   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
153   SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
154   SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
155   SmallVector<Type, 4> valueTypes;
156 
157   // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
158   auto parseAsyncValueArg = [&]() -> ParseResult {
159     if (parser.parseOperand(valueArgs.emplace_back()) ||
160         parser.parseKeyword("as") ||
161         parser.parseArgument(unwrappedArgs.emplace_back()) ||
162         parser.parseColonType(valueTypes.emplace_back()))
163       return failure();
164 
165     auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
166     unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
167     return success();
168   };
169 
170   auto argsLoc = parser.getCurrentLocation();
171   if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
172                                      parseAsyncValueArg) ||
173       parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
174     return failure();
175 
176   int32_t numOperands = valueArgs.size();
177 
178   // Add derived `operandSegmentSizes` attribute based on parsed operands.
179   auto operandSegmentSizes =
180       parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands});
181   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
182 
183   // Parse the types of results returned from the async execute op.
184   SmallVector<Type, 4> resultTypes;
185   NamedAttrList attrs;
186   if (parser.parseOptionalArrowTypeList(resultTypes) ||
187       // Async execute first result is always a completion token.
188       parser.addTypeToList(tokenTy, result.types) ||
189       parser.addTypesToList(resultTypes, result.types) ||
190       // Parse operation attributes.
191       parser.parseOptionalAttrDictWithKeyword(attrs))
192     return failure();
193 
194   result.addAttributes(attrs);
195 
196   // Parse asynchronous region.
197   Region *body = result.addRegion();
198   return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
199 }
200 
verifyRegions()201 LogicalResult ExecuteOp::verifyRegions() {
202   // Unwrap async.execute value operands types.
203   auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
204     return llvm::cast<ValueType>(operand.getType()).getValueType();
205   });
206 
207   // Verify that unwrapped argument types matches the body region arguments.
208   if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
209     return emitOpError("async body region argument types do not match the "
210                        "execute operation arguments types");
211 
212   return success();
213 }
214 
215 //===----------------------------------------------------------------------===//
216 /// CreateGroupOp
217 //===----------------------------------------------------------------------===//
218 
canonicalize(CreateGroupOp op,PatternRewriter & rewriter)219 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
220                                           PatternRewriter &rewriter) {
221   // Find all `await_all` users of the group.
222   llvm::SmallVector<AwaitAllOp> awaitAllUsers;
223 
224   auto isAwaitAll = [&](Operation *op) -> bool {
225     if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
226       awaitAllUsers.push_back(awaitAll);
227       return true;
228     }
229     return false;
230   };
231 
232   // Check if all users of the group are `await_all` operations.
233   if (!llvm::all_of(op->getUsers(), isAwaitAll))
234     return failure();
235 
236   // If group is only awaited without adding anything to it, we can safely erase
237   // the create operation and all users.
238   for (AwaitAllOp awaitAll : awaitAllUsers)
239     rewriter.eraseOp(awaitAll);
240   rewriter.eraseOp(op);
241 
242   return success();
243 }
244 
245 //===----------------------------------------------------------------------===//
246 /// AwaitOp
247 //===----------------------------------------------------------------------===//
248 
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)249 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
250                     ArrayRef<NamedAttribute> attrs) {
251   result.addOperands({operand});
252   result.attributes.append(attrs.begin(), attrs.end());
253 
254   // Add unwrapped async.value type to the returned values types.
255   if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType()))
256     result.addTypes(valueType.getValueType());
257 }
258 
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)259 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
260                                         Type &resultType) {
261   if (parser.parseType(operandType))
262     return failure();
263 
264   // Add unwrapped async.value type to the returned values types.
265   if (auto valueType = llvm::dyn_cast<ValueType>(operandType))
266     resultType = valueType.getValueType();
267 
268   return success();
269 }
270 
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)271 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
272                                  Type operandType, Type resultType) {
273   p << operandType;
274 }
275 
verify()276 LogicalResult AwaitOp::verify() {
277   Type argType = getOperand().getType();
278 
279   // Awaiting on a token does not have any results.
280   if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
281     return emitOpError("awaiting on a token must have empty result");
282 
283   // Awaiting on a value unwraps the async value type.
284   if (auto value = llvm::dyn_cast<ValueType>(argType)) {
285     if (*getResultType() != value.getValueType())
286       return emitOpError() << "result type " << *getResultType()
287                            << " does not match async value type "
288                            << value.getValueType();
289   }
290 
291   return success();
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // FuncOp
296 //===----------------------------------------------------------------------===//
297 
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)298 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
299                    FunctionType type, ArrayRef<NamedAttribute> attrs,
300                    ArrayRef<DictionaryAttr> argAttrs) {
301   state.addAttribute(SymbolTable::getSymbolAttrName(),
302                      builder.getStringAttr(name));
303   state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
304 
305   state.attributes.append(attrs.begin(), attrs.end());
306   state.addRegion();
307 
308   if (argAttrs.empty())
309     return;
310   assert(type.getNumInputs() == argAttrs.size());
311   function_interface_impl::addArgAndResultAttrs(
312       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
313       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
314 }
315 
parse(OpAsmParser & parser,OperationState & result)316 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
317   auto buildFuncType =
318       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
319          function_interface_impl::VariadicFlag,
320          std::string &) { return builder.getFunctionType(argTypes, results); };
321 
322   return function_interface_impl::parseFunctionOp(
323       parser, result, /*allowVariadic=*/false,
324       getFunctionTypeAttrName(result.name), buildFuncType,
325       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
326 }
327 
print(OpAsmPrinter & p)328 void FuncOp::print(OpAsmPrinter &p) {
329   function_interface_impl::printFunctionOp(
330       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
331       getArgAttrsAttrName(), getResAttrsAttrName());
332 }
333 
334 /// Check that the result type of async.func is not void and must be
335 /// some async token or async values.
verify()336 LogicalResult FuncOp::verify() {
337   auto resultTypes = getResultTypes();
338   if (resultTypes.empty())
339     return emitOpError()
340            << "result is expected to be at least of size 1, but got "
341            << resultTypes.size();
342 
343   for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
344     auto type = resultTypes[i];
345     if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
346       return emitOpError() << "result type must be async value type or async "
347                               "token type, but got "
348                            << type;
349     // We only allow AsyncToken appear as the first return value
350     if (llvm::isa<TokenType>(type) && i != 0) {
351       return emitOpError()
352              << " results' (optional) async token type is expected "
353                 "to appear as the 1st return value, but got "
354              << i + 1;
355     }
356   }
357 
358   return success();
359 }
360 
361 //===----------------------------------------------------------------------===//
362 /// CallOp
363 //===----------------------------------------------------------------------===//
364 
verifySymbolUses(SymbolTableCollection & symbolTable)365 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
366   // Check that the callee attribute was specified.
367   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
368   if (!fnAttr)
369     return emitOpError("requires a 'callee' symbol reference attribute");
370   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
371   if (!fn)
372     return emitOpError() << "'" << fnAttr.getValue()
373                          << "' does not reference a valid async function";
374 
375   // Verify that the operand and result types match the callee.
376   auto fnType = fn.getFunctionType();
377   if (fnType.getNumInputs() != getNumOperands())
378     return emitOpError("incorrect number of operands for callee");
379 
380   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
381     if (getOperand(i).getType() != fnType.getInput(i))
382       return emitOpError("operand type mismatch: expected operand type ")
383              << fnType.getInput(i) << ", but provided "
384              << getOperand(i).getType() << " for operand number " << i;
385 
386   if (fnType.getNumResults() != getNumResults())
387     return emitOpError("incorrect number of results for callee");
388 
389   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
390     if (getResult(i).getType() != fnType.getResult(i)) {
391       auto diag = emitOpError("result type mismatch at index ") << i;
392       diag.attachNote() << "      op result types: " << getResultTypes();
393       diag.attachNote() << "function result types: " << fnType.getResults();
394       return diag;
395     }
396 
397   return success();
398 }
399 
getCalleeType()400 FunctionType CallOp::getCalleeType() {
401   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
402 }
403 
404 //===----------------------------------------------------------------------===//
405 /// ReturnOp
406 //===----------------------------------------------------------------------===//
407 
verify()408 LogicalResult ReturnOp::verify() {
409   auto funcOp = (*this)->getParentOfType<FuncOp>();
410   ArrayRef<Type> resultTypes = funcOp.isStateful()
411                                    ? funcOp.getResultTypes().drop_front()
412                                    : funcOp.getResultTypes();
413   // Get the underlying value types from async types returned from the
414   // parent `async.func` operation.
415   auto types = llvm::map_range(resultTypes, [](const Type &result) {
416     return llvm::cast<ValueType>(result).getValueType();
417   });
418 
419   if (getOperandTypes() != types)
420     return emitOpError("operand types do not match the types returned from "
421                        "the parent FuncOp");
422 
423   return success();
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // TableGen'd op method definitions
428 //===----------------------------------------------------------------------===//
429 
430 #define GET_OP_CLASSES
431 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
432 
433 //===----------------------------------------------------------------------===//
434 // TableGen'd type method definitions
435 //===----------------------------------------------------------------------===//
436 
437 #define GET_TYPEDEF_CLASSES
438 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
439 
print(AsmPrinter & printer) const440 void ValueType::print(AsmPrinter &printer) const {
441   printer << "<";
442   printer.printType(getValueType());
443   printer << '>';
444 }
445 
parse(mlir::AsmParser & parser)446 Type ValueType::parse(mlir::AsmParser &parser) {
447   Type ty;
448   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
449     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
450     return Type();
451   }
452   return ValueType::get(ty);
453 }
454