1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Async/IR/Async.h"
10
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/Interfaces/FunctionImplementation.h"
14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/TypeSwitch.h"
16
17 using namespace mlir;
18 using namespace mlir::async;
19
20 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
21
22 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
23
initialize()24 void AsyncDialect::initialize() {
25 addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
28 >();
29 addTypes<
30 #define GET_TYPEDEF_LIST
31 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
32 >();
33 }
34
35 //===----------------------------------------------------------------------===//
36 /// ExecuteOp
37 //===----------------------------------------------------------------------===//
38
39 constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
40
getEntrySuccessorOperands(RegionBranchPoint point)41 OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
42 assert(point == getBodyRegion() && "invalid region index");
43 return getBodyOperands();
44 }
45
areTypesCompatible(Type lhs,Type rhs)46 bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
47 const auto getValueOrTokenType = [](Type type) {
48 if (auto value = llvm::dyn_cast<ValueType>(type))
49 return value.getValueType();
50 return type;
51 };
52 return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
53 }
54
getSuccessorRegions(RegionBranchPoint point,SmallVectorImpl<RegionSuccessor> & regions)55 void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
56 SmallVectorImpl<RegionSuccessor> ®ions) {
57 // The `body` region branch back to the parent operation.
58 if (point == getBodyRegion()) {
59 regions.push_back(RegionSuccessor(getBodyResults()));
60 return;
61 }
62
63 // Otherwise the successor is the body region.
64 regions.push_back(
65 RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
66 }
67
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)68 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
69 TypeRange resultTypes, ValueRange dependencies,
70 ValueRange operands, BodyBuilderFn bodyBuilder) {
71 OpBuilder::InsertionGuard guard(builder);
72 result.addOperands(dependencies);
73 result.addOperands(operands);
74
75 // Add derived `operandSegmentSizes` attribute based on parsed operands.
76 int32_t numDependencies = dependencies.size();
77 int32_t numOperands = operands.size();
78 auto operandSegmentSizes =
79 builder.getDenseI32ArrayAttr({numDependencies, numOperands});
80 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
81
82 // First result is always a token, and then `resultTypes` wrapped into
83 // `async.value`.
84 result.addTypes({TokenType::get(result.getContext())});
85 for (Type type : resultTypes)
86 result.addTypes(ValueType::get(type));
87
88 // Add a body region with block arguments as unwrapped async value operands.
89 Region *bodyRegion = result.addRegion();
90 Block *bodyBlock = builder.createBlock(bodyRegion);
91 for (Value operand : operands) {
92 auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
93 bodyBlock->addArgument(valueType ? valueType.getValueType()
94 : operand.getType(),
95 operand.getLoc());
96 }
97
98 // Create the default terminator if the builder is not provided and if the
99 // expected result is empty. Otherwise, leave this to the caller
100 // because we don't know which values to return from the execute op.
101 if (resultTypes.empty() && !bodyBuilder) {
102 builder.create<async::YieldOp>(result.location, ValueRange());
103 } else if (bodyBuilder) {
104 bodyBuilder(builder, result.location, bodyBlock->getArguments());
105 }
106 }
107
print(OpAsmPrinter & p)108 void ExecuteOp::print(OpAsmPrinter &p) {
109 // [%tokens,...]
110 if (!getDependencies().empty())
111 p << " [" << getDependencies() << "]";
112
113 // (%value as %unwrapped: !async.value<!arg.type>, ...)
114 if (!getBodyOperands().empty()) {
115 p << " (";
116 Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
117 llvm::interleaveComma(
118 getBodyOperands(), p, [&, n = 0](Value operand) mutable {
119 Value argument = entry ? entry->getArgument(n++) : Value();
120 p << operand << " as " << argument << ": " << operand.getType();
121 });
122 p << ")";
123 }
124
125 // -> (!async.value<!return.type>, ...)
126 p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
127 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
128 {kOperandSegmentSizesAttr});
129 p << ' ';
130 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
131 }
132
parse(OpAsmParser & parser,OperationState & result)133 ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
134 MLIRContext *ctx = result.getContext();
135
136 // Sizes of parsed variadic operands, will be updated below after parsing.
137 int32_t numDependencies = 0;
138
139 auto tokenTy = TokenType::get(ctx);
140
141 // Parse dependency tokens.
142 if (succeeded(parser.parseOptionalLSquare())) {
143 SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
144 if (parser.parseOperandList(tokenArgs) ||
145 parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
146 parser.parseRSquare())
147 return failure();
148
149 numDependencies = tokenArgs.size();
150 }
151
152 // Parse async value operands (%value as %unwrapped : !async.value<!type>).
153 SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
154 SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
155 SmallVector<Type, 4> valueTypes;
156
157 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
158 auto parseAsyncValueArg = [&]() -> ParseResult {
159 if (parser.parseOperand(valueArgs.emplace_back()) ||
160 parser.parseKeyword("as") ||
161 parser.parseArgument(unwrappedArgs.emplace_back()) ||
162 parser.parseColonType(valueTypes.emplace_back()))
163 return failure();
164
165 auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
166 unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
167 return success();
168 };
169
170 auto argsLoc = parser.getCurrentLocation();
171 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
172 parseAsyncValueArg) ||
173 parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
174 return failure();
175
176 int32_t numOperands = valueArgs.size();
177
178 // Add derived `operandSegmentSizes` attribute based on parsed operands.
179 auto operandSegmentSizes =
180 parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands});
181 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
182
183 // Parse the types of results returned from the async execute op.
184 SmallVector<Type, 4> resultTypes;
185 NamedAttrList attrs;
186 if (parser.parseOptionalArrowTypeList(resultTypes) ||
187 // Async execute first result is always a completion token.
188 parser.addTypeToList(tokenTy, result.types) ||
189 parser.addTypesToList(resultTypes, result.types) ||
190 // Parse operation attributes.
191 parser.parseOptionalAttrDictWithKeyword(attrs))
192 return failure();
193
194 result.addAttributes(attrs);
195
196 // Parse asynchronous region.
197 Region *body = result.addRegion();
198 return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
199 }
200
verifyRegions()201 LogicalResult ExecuteOp::verifyRegions() {
202 // Unwrap async.execute value operands types.
203 auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
204 return llvm::cast<ValueType>(operand.getType()).getValueType();
205 });
206
207 // Verify that unwrapped argument types matches the body region arguments.
208 if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
209 return emitOpError("async body region argument types do not match the "
210 "execute operation arguments types");
211
212 return success();
213 }
214
215 //===----------------------------------------------------------------------===//
216 /// CreateGroupOp
217 //===----------------------------------------------------------------------===//
218
canonicalize(CreateGroupOp op,PatternRewriter & rewriter)219 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
220 PatternRewriter &rewriter) {
221 // Find all `await_all` users of the group.
222 llvm::SmallVector<AwaitAllOp> awaitAllUsers;
223
224 auto isAwaitAll = [&](Operation *op) -> bool {
225 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
226 awaitAllUsers.push_back(awaitAll);
227 return true;
228 }
229 return false;
230 };
231
232 // Check if all users of the group are `await_all` operations.
233 if (!llvm::all_of(op->getUsers(), isAwaitAll))
234 return failure();
235
236 // If group is only awaited without adding anything to it, we can safely erase
237 // the create operation and all users.
238 for (AwaitAllOp awaitAll : awaitAllUsers)
239 rewriter.eraseOp(awaitAll);
240 rewriter.eraseOp(op);
241
242 return success();
243 }
244
245 //===----------------------------------------------------------------------===//
246 /// AwaitOp
247 //===----------------------------------------------------------------------===//
248
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)249 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
250 ArrayRef<NamedAttribute> attrs) {
251 result.addOperands({operand});
252 result.attributes.append(attrs.begin(), attrs.end());
253
254 // Add unwrapped async.value type to the returned values types.
255 if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType()))
256 result.addTypes(valueType.getValueType());
257 }
258
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)259 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
260 Type &resultType) {
261 if (parser.parseType(operandType))
262 return failure();
263
264 // Add unwrapped async.value type to the returned values types.
265 if (auto valueType = llvm::dyn_cast<ValueType>(operandType))
266 resultType = valueType.getValueType();
267
268 return success();
269 }
270
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)271 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
272 Type operandType, Type resultType) {
273 p << operandType;
274 }
275
verify()276 LogicalResult AwaitOp::verify() {
277 Type argType = getOperand().getType();
278
279 // Awaiting on a token does not have any results.
280 if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
281 return emitOpError("awaiting on a token must have empty result");
282
283 // Awaiting on a value unwraps the async value type.
284 if (auto value = llvm::dyn_cast<ValueType>(argType)) {
285 if (*getResultType() != value.getValueType())
286 return emitOpError() << "result type " << *getResultType()
287 << " does not match async value type "
288 << value.getValueType();
289 }
290
291 return success();
292 }
293
294 //===----------------------------------------------------------------------===//
295 // FuncOp
296 //===----------------------------------------------------------------------===//
297
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)298 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
299 FunctionType type, ArrayRef<NamedAttribute> attrs,
300 ArrayRef<DictionaryAttr> argAttrs) {
301 state.addAttribute(SymbolTable::getSymbolAttrName(),
302 builder.getStringAttr(name));
303 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
304
305 state.attributes.append(attrs.begin(), attrs.end());
306 state.addRegion();
307
308 if (argAttrs.empty())
309 return;
310 assert(type.getNumInputs() == argAttrs.size());
311 function_interface_impl::addArgAndResultAttrs(
312 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
313 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
314 }
315
parse(OpAsmParser & parser,OperationState & result)316 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
317 auto buildFuncType =
318 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
319 function_interface_impl::VariadicFlag,
320 std::string &) { return builder.getFunctionType(argTypes, results); };
321
322 return function_interface_impl::parseFunctionOp(
323 parser, result, /*allowVariadic=*/false,
324 getFunctionTypeAttrName(result.name), buildFuncType,
325 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
326 }
327
print(OpAsmPrinter & p)328 void FuncOp::print(OpAsmPrinter &p) {
329 function_interface_impl::printFunctionOp(
330 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
331 getArgAttrsAttrName(), getResAttrsAttrName());
332 }
333
334 /// Check that the result type of async.func is not void and must be
335 /// some async token or async values.
verify()336 LogicalResult FuncOp::verify() {
337 auto resultTypes = getResultTypes();
338 if (resultTypes.empty())
339 return emitOpError()
340 << "result is expected to be at least of size 1, but got "
341 << resultTypes.size();
342
343 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
344 auto type = resultTypes[i];
345 if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
346 return emitOpError() << "result type must be async value type or async "
347 "token type, but got "
348 << type;
349 // We only allow AsyncToken appear as the first return value
350 if (llvm::isa<TokenType>(type) && i != 0) {
351 return emitOpError()
352 << " results' (optional) async token type is expected "
353 "to appear as the 1st return value, but got "
354 << i + 1;
355 }
356 }
357
358 return success();
359 }
360
361 //===----------------------------------------------------------------------===//
362 /// CallOp
363 //===----------------------------------------------------------------------===//
364
verifySymbolUses(SymbolTableCollection & symbolTable)365 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
366 // Check that the callee attribute was specified.
367 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
368 if (!fnAttr)
369 return emitOpError("requires a 'callee' symbol reference attribute");
370 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
371 if (!fn)
372 return emitOpError() << "'" << fnAttr.getValue()
373 << "' does not reference a valid async function";
374
375 // Verify that the operand and result types match the callee.
376 auto fnType = fn.getFunctionType();
377 if (fnType.getNumInputs() != getNumOperands())
378 return emitOpError("incorrect number of operands for callee");
379
380 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
381 if (getOperand(i).getType() != fnType.getInput(i))
382 return emitOpError("operand type mismatch: expected operand type ")
383 << fnType.getInput(i) << ", but provided "
384 << getOperand(i).getType() << " for operand number " << i;
385
386 if (fnType.getNumResults() != getNumResults())
387 return emitOpError("incorrect number of results for callee");
388
389 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
390 if (getResult(i).getType() != fnType.getResult(i)) {
391 auto diag = emitOpError("result type mismatch at index ") << i;
392 diag.attachNote() << " op result types: " << getResultTypes();
393 diag.attachNote() << "function result types: " << fnType.getResults();
394 return diag;
395 }
396
397 return success();
398 }
399
getCalleeType()400 FunctionType CallOp::getCalleeType() {
401 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
402 }
403
404 //===----------------------------------------------------------------------===//
405 /// ReturnOp
406 //===----------------------------------------------------------------------===//
407
verify()408 LogicalResult ReturnOp::verify() {
409 auto funcOp = (*this)->getParentOfType<FuncOp>();
410 ArrayRef<Type> resultTypes = funcOp.isStateful()
411 ? funcOp.getResultTypes().drop_front()
412 : funcOp.getResultTypes();
413 // Get the underlying value types from async types returned from the
414 // parent `async.func` operation.
415 auto types = llvm::map_range(resultTypes, [](const Type &result) {
416 return llvm::cast<ValueType>(result).getValueType();
417 });
418
419 if (getOperandTypes() != types)
420 return emitOpError("operand types do not match the types returned from "
421 "the parent FuncOp");
422
423 return success();
424 }
425
426 //===----------------------------------------------------------------------===//
427 // TableGen'd op method definitions
428 //===----------------------------------------------------------------------===//
429
430 #define GET_OP_CLASSES
431 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
432
433 //===----------------------------------------------------------------------===//
434 // TableGen'd type method definitions
435 //===----------------------------------------------------------------------===//
436
437 #define GET_TYPEDEF_CLASSES
438 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
439
print(AsmPrinter & printer) const440 void ValueType::print(AsmPrinter &printer) const {
441 printer << "<";
442 printer.printType(getValueType());
443 printer << '>';
444 }
445
parse(mlir::AsmParser & parser)446 Type ValueType::parse(mlir::AsmParser &parser) {
447 Type ty;
448 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
449 parser.emitError(parser.getNameLoc(), "failed to parse async value type");
450 return Type();
451 }
452 return ValueType::get(ty);
453 }
454