1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Async/IR/Async.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/IR/ImplicitLocOpBuilder.h" 23 #include "mlir/IR/TypeUtilities.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 28 namespace mlir { 29 #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS 30 #include "mlir/Conversion/Passes.h.inc" 31 } // namespace mlir 32 33 #define DEBUG_TYPE "convert-async-to-llvm" 34 35 using namespace mlir; 36 using namespace mlir::async; 37 38 //===----------------------------------------------------------------------===// 39 // Async Runtime C API declaration. 40 //===----------------------------------------------------------------------===// 41 42 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 43 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 44 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 45 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 46 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 47 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 48 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 49 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 50 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 51 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 52 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 53 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 54 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 55 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 56 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 57 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 58 static constexpr const char *kGetValueStorage = 59 "mlirAsyncRuntimeGetValueStorage"; 60 static constexpr const char *kAddTokenToGroup = 61 "mlirAsyncRuntimeAddTokenToGroup"; 62 static constexpr const char *kAwaitTokenAndExecute = 63 "mlirAsyncRuntimeAwaitTokenAndExecute"; 64 static constexpr const char *kAwaitValueAndExecute = 65 "mlirAsyncRuntimeAwaitValueAndExecute"; 66 static constexpr const char *kAwaitAllAndExecute = 67 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 68 static constexpr const char *kGetNumWorkerThreads = 69 "mlirAsyncRuntimGetNumWorkerThreads"; 70 71 namespace { 72 /// Async Runtime API function types. 73 /// 74 /// Because we can't create API function signature for type parametrized 75 /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After 76 /// lowering all async data types become opaque pointers at runtime. 77 struct AsyncAPI { 78 // All async types are lowered to opaque LLVM pointers at runtime. 79 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 80 return LLVM::LLVMPointerType::get(ctx); 81 } 82 83 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 84 return LLVM::LLVMTokenType::get(ctx); 85 } 86 87 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 88 auto ref = opaquePointerType(ctx); 89 auto count = IntegerType::get(ctx, 64); 90 return FunctionType::get(ctx, {ref, count}, {}); 91 } 92 93 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 94 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 95 } 96 97 static FunctionType createValueFunctionType(MLIRContext *ctx) { 98 auto i64 = IntegerType::get(ctx, 64); 99 auto value = opaquePointerType(ctx); 100 return FunctionType::get(ctx, {i64}, {value}); 101 } 102 103 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 104 auto i64 = IntegerType::get(ctx, 64); 105 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 106 } 107 108 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 109 auto ptrType = opaquePointerType(ctx); 110 return FunctionType::get(ctx, {ptrType}, {ptrType}); 111 } 112 113 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 114 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 115 } 116 117 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 118 auto value = opaquePointerType(ctx); 119 return FunctionType::get(ctx, {value}, {}); 120 } 121 122 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 123 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 124 } 125 126 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 127 auto value = opaquePointerType(ctx); 128 return FunctionType::get(ctx, {value}, {}); 129 } 130 131 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 132 auto i1 = IntegerType::get(ctx, 1); 133 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 134 } 135 136 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 137 auto value = opaquePointerType(ctx); 138 auto i1 = IntegerType::get(ctx, 1); 139 return FunctionType::get(ctx, {value}, {i1}); 140 } 141 142 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 143 auto i1 = IntegerType::get(ctx, 1); 144 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 145 } 146 147 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 148 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 149 } 150 151 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 152 auto value = opaquePointerType(ctx); 153 return FunctionType::get(ctx, {value}, {}); 154 } 155 156 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 157 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 158 } 159 160 static FunctionType executeFunctionType(MLIRContext *ctx) { 161 auto ptrType = opaquePointerType(ctx); 162 return FunctionType::get(ctx, {ptrType, ptrType}, {}); 163 } 164 165 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 166 auto i64 = IntegerType::get(ctx, 64); 167 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 168 {i64}); 169 } 170 171 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 172 auto ptrType = opaquePointerType(ctx); 173 return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {}); 174 } 175 176 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 177 auto ptrType = opaquePointerType(ctx); 178 return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {}); 179 } 180 181 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 182 auto ptrType = opaquePointerType(ctx); 183 return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {}); 184 } 185 186 static FunctionType getNumWorkerThreads(MLIRContext *ctx) { 187 return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); 188 } 189 190 // Auxiliary coroutine resume intrinsic wrapper. 191 static Type resumeFunctionType(MLIRContext *ctx) { 192 auto voidTy = LLVM::LLVMVoidType::get(ctx); 193 auto ptrType = opaquePointerType(ctx); 194 return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false); 195 } 196 }; 197 } // namespace 198 199 /// Adds Async Runtime C API declarations to the module. 200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 201 auto builder = 202 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 203 204 auto addFuncDecl = [&](StringRef name, FunctionType type) { 205 if (module.lookupSymbol(name)) 206 return; 207 builder.create<func::FuncOp>(name, type).setPrivate(); 208 }; 209 210 MLIRContext *ctx = module.getContext(); 211 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 212 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 213 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 214 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 215 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 216 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 217 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 218 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 219 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 220 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 221 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 222 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 223 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 224 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 225 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 226 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 227 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 228 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 229 addFuncDecl(kAwaitTokenAndExecute, 230 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 231 addFuncDecl(kAwaitValueAndExecute, 232 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 233 addFuncDecl(kAwaitAllAndExecute, 234 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 235 addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); 236 } 237 238 //===----------------------------------------------------------------------===// 239 // Coroutine resume function wrapper. 240 //===----------------------------------------------------------------------===// 241 242 static constexpr const char *kResume = "__resume"; 243 244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 245 /// intrinsics. We need this function to be able to pass it to the async 246 /// runtime execute API. 247 static void addResumeFunction(ModuleOp module) { 248 if (module.lookupSymbol(kResume)) 249 return; 250 251 MLIRContext *ctx = module.getContext(); 252 auto loc = module.getLoc(); 253 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 254 255 auto voidTy = LLVM::LLVMVoidType::get(ctx); 256 Type ptrType = AsyncAPI::opaquePointerType(ctx); 257 258 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 259 kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); 260 resumeOp.setPrivate(); 261 262 auto *block = resumeOp.addEntryBlock(moduleBuilder); 263 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 264 265 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 266 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 267 } 268 269 //===----------------------------------------------------------------------===// 270 // Convert Async dialect types to LLVM types. 271 //===----------------------------------------------------------------------===// 272 273 namespace { 274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 275 /// their runtime type (opaque pointers) and does not convert any other types. 276 class AsyncRuntimeTypeConverter : public TypeConverter { 277 public: 278 AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) { 279 addConversion([](Type type) { return type; }); 280 addConversion([](Type type) { return convertAsyncTypes(type); }); 281 282 // Use UnrealizedConversionCast as the bridge so that we don't need to pull 283 // in patterns for other dialects. 284 auto addUnrealizedCast = [](OpBuilder &builder, Type type, 285 ValueRange inputs, Location loc) -> Value { 286 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 287 return cast.getResult(0); 288 }; 289 290 addSourceMaterialization(addUnrealizedCast); 291 addTargetMaterialization(addUnrealizedCast); 292 } 293 294 static std::optional<Type> convertAsyncTypes(Type type) { 295 if (isa<TokenType, GroupType, ValueType>(type)) 296 return AsyncAPI::opaquePointerType(type.getContext()); 297 298 if (isa<CoroIdType, CoroStateType>(type)) 299 return AsyncAPI::tokenType(type.getContext()); 300 if (isa<CoroHandleType>(type)) 301 return AsyncAPI::opaquePointerType(type.getContext()); 302 303 return std::nullopt; 304 } 305 }; 306 307 /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter 308 /// as type converter. Allows access to it via the 'getTypeConverter' 309 /// convenience method. 310 template <typename SourceOp> 311 class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> { 312 313 using Base = OpConversionPattern<SourceOp>; 314 315 public: 316 AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter, 317 MLIRContext *context) 318 : Base(typeConverter, context) {} 319 320 /// Returns the 'AsyncRuntimeTypeConverter' of the pattern. 321 const AsyncRuntimeTypeConverter *getTypeConverter() const { 322 return static_cast<const AsyncRuntimeTypeConverter *>( 323 Base::getTypeConverter()); 324 } 325 }; 326 327 } // namespace 328 329 //===----------------------------------------------------------------------===// 330 // Convert async.coro.id to @llvm.coro.id intrinsic. 331 //===----------------------------------------------------------------------===// 332 333 namespace { 334 class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> { 335 public: 336 using AsyncOpConversionPattern::AsyncOpConversionPattern; 337 338 LogicalResult 339 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const override { 341 auto token = AsyncAPI::tokenType(op->getContext()); 342 auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); 343 auto loc = op->getLoc(); 344 345 // Constants for initializing coroutine frame. 346 auto constZero = 347 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); 348 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType); 349 350 // Get coroutine id: @llvm.coro.id. 351 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 352 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 353 354 return success(); 355 } 356 }; 357 } // namespace 358 359 //===----------------------------------------------------------------------===// 360 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 361 //===----------------------------------------------------------------------===// 362 363 namespace { 364 class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> { 365 public: 366 using AsyncOpConversionPattern::AsyncOpConversionPattern; 367 368 LogicalResult 369 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 370 ConversionPatternRewriter &rewriter) const override { 371 auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); 372 auto loc = op->getLoc(); 373 374 // Get coroutine frame size: @llvm.coro.size.i64. 375 Value coroSize = 376 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 377 // Get coroutine frame alignment: @llvm.coro.align.i64. 378 Value coroAlign = 379 rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); 380 381 // Round up the size to be multiple of the alignment. Since aligned_alloc 382 // requires the size parameter be an integral multiple of the alignment 383 // parameter. 384 auto makeConstant = [&](uint64_t c) { 385 return rewriter.create<LLVM::ConstantOp>(op->getLoc(), 386 rewriter.getI64Type(), c); 387 }; 388 coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); 389 coroSize = 390 rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); 391 Value negCoroAlign = 392 rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); 393 coroSize = 394 rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign); 395 396 // Allocate memory for the coroutine frame. 397 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 398 op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); 399 if (failed(allocFuncOp)) 400 return failure(); 401 auto coroAlloc = rewriter.create<LLVM::CallOp>( 402 loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); 403 404 // Begin a coroutine: @llvm.coro.begin. 405 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); 406 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 407 op, ptrType, ValueRange({coroId, coroAlloc.getResult()})); 408 409 return success(); 410 } 411 }; 412 } // namespace 413 414 //===----------------------------------------------------------------------===// 415 // Convert async.coro.free to @llvm.coro.free intrinsic. 416 //===----------------------------------------------------------------------===// 417 418 namespace { 419 class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> { 420 public: 421 using AsyncOpConversionPattern::AsyncOpConversionPattern; 422 423 LogicalResult 424 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 425 ConversionPatternRewriter &rewriter) const override { 426 auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); 427 auto loc = op->getLoc(); 428 429 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 430 auto coroMem = 431 rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands()); 432 433 // Free the memory. 434 auto freeFuncOp = 435 LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 436 if (failed(freeFuncOp)) 437 return failure(); 438 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(), 439 ValueRange(coroMem.getResult())); 440 441 return success(); 442 } 443 }; 444 } // namespace 445 446 //===----------------------------------------------------------------------===// 447 // Convert async.coro.end to @llvm.coro.end intrinsic. 448 //===----------------------------------------------------------------------===// 449 450 namespace { 451 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 452 public: 453 using OpConversionPattern::OpConversionPattern; 454 455 LogicalResult 456 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 457 ConversionPatternRewriter &rewriter) const override { 458 // We are not in the block that is part of the unwind sequence. 459 auto constFalse = rewriter.create<LLVM::ConstantOp>( 460 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 461 auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc()); 462 463 // Mark the end of a coroutine: @llvm.coro.end. 464 auto coroHdl = adaptor.getHandle(); 465 rewriter.create<LLVM::CoroEndOp>( 466 op->getLoc(), rewriter.getI1Type(), 467 ValueRange({coroHdl, constFalse, noneToken})); 468 rewriter.eraseOp(op); 469 470 return success(); 471 } 472 }; 473 } // namespace 474 475 //===----------------------------------------------------------------------===// 476 // Convert async.coro.save to @llvm.coro.save intrinsic. 477 //===----------------------------------------------------------------------===// 478 479 namespace { 480 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 481 public: 482 using OpConversionPattern::OpConversionPattern; 483 484 LogicalResult 485 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 486 ConversionPatternRewriter &rewriter) const override { 487 // Save the coroutine state: @llvm.coro.save 488 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 489 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 490 491 return success(); 492 } 493 }; 494 } // namespace 495 496 //===----------------------------------------------------------------------===// 497 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 498 //===----------------------------------------------------------------------===// 499 500 namespace { 501 502 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 503 /// branch to the appropriate block based on the return code. 504 /// 505 /// Before: 506 /// 507 /// ^suspended: 508 /// "opBefore"(...) 509 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 510 /// ^resume: 511 /// "op"(...) 512 /// ^cleanup: ... 513 /// ^suspend: ... 514 /// 515 /// After: 516 /// 517 /// ^suspended: 518 /// "opBefore"(...) 519 /// %suspend = llmv.intr.coro.suspend ... 520 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 521 /// ^resume: 522 /// "op"(...) 523 /// ^cleanup: ... 524 /// ^suspend: ... 525 /// 526 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 527 public: 528 using OpConversionPattern::OpConversionPattern; 529 530 LogicalResult 531 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 532 ConversionPatternRewriter &rewriter) const override { 533 auto i8 = rewriter.getIntegerType(8); 534 auto i32 = rewriter.getI32Type(); 535 auto loc = op->getLoc(); 536 537 // This is not a final suspension point. 538 auto constFalse = rewriter.create<LLVM::ConstantOp>( 539 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 540 541 // Suspend a coroutine: @llvm.coro.suspend 542 auto coroState = adaptor.getState(); 543 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 544 loc, i8, ValueRange({coroState, constFalse})); 545 546 // Cast return code to i32. 547 548 // After a suspension point decide if we should branch into resume, cleanup 549 // or suspend block of the coroutine (see @llvm.coro.suspend return code 550 // documentation). 551 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 552 llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(), 553 op.getCleanupDest()}; 554 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 555 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 556 /*defaultDestination=*/op.getSuspendDest(), 557 /*defaultOperands=*/ValueRange(), 558 /*caseValues=*/caseValues, 559 /*caseDestinations=*/caseDest, 560 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 561 /*branchWeights=*/ArrayRef<int32_t>()); 562 563 return success(); 564 } 565 }; 566 } // namespace 567 568 //===----------------------------------------------------------------------===// 569 // Convert async.runtime.create to the corresponding runtime API call. 570 // 571 // To allocate storage for the async values we use getelementptr trick: 572 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 573 //===----------------------------------------------------------------------===// 574 575 namespace { 576 class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> { 577 public: 578 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 579 580 LogicalResult 581 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 582 ConversionPatternRewriter &rewriter) const override { 583 const TypeConverter *converter = getTypeConverter(); 584 Type resultType = op->getResultTypes()[0]; 585 586 // Tokens creation maps to a simple function call. 587 if (isa<TokenType>(resultType)) { 588 rewriter.replaceOpWithNewOp<func::CallOp>( 589 op, kCreateToken, converter->convertType(resultType)); 590 return success(); 591 } 592 593 // To create a value we need to compute the storage requirement. 594 if (auto value = dyn_cast<ValueType>(resultType)) { 595 // Returns the size requirements for the async value storage. 596 auto sizeOf = [&](ValueType valueType) -> Value { 597 auto loc = op->getLoc(); 598 auto i64 = rewriter.getI64Type(); 599 600 auto storedType = converter->convertType(valueType.getValueType()); 601 auto storagePtrType = 602 AsyncAPI::opaquePointerType(rewriter.getContext()); 603 604 // %Size = getelementptr %T* null, int 1 605 // %SizeI = ptrtoint %T* %Size to i64 606 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType); 607 auto gep = 608 rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType, 609 nullPtr, ArrayRef<LLVM::GEPArg>{1}); 610 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 611 }; 612 613 rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType, 614 sizeOf(value)); 615 616 return success(); 617 } 618 619 return rewriter.notifyMatchFailure(op, "unsupported async type"); 620 } 621 }; 622 } // namespace 623 624 //===----------------------------------------------------------------------===// 625 // Convert async.runtime.create_group to the corresponding runtime API call. 626 //===----------------------------------------------------------------------===// 627 628 namespace { 629 class RuntimeCreateGroupOpLowering 630 : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> { 631 public: 632 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 633 634 LogicalResult 635 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 636 ConversionPatternRewriter &rewriter) const override { 637 const TypeConverter *converter = getTypeConverter(); 638 Type resultType = op.getResult().getType(); 639 640 rewriter.replaceOpWithNewOp<func::CallOp>( 641 op, kCreateGroup, converter->convertType(resultType), 642 adaptor.getOperands()); 643 return success(); 644 } 645 }; 646 } // namespace 647 648 //===----------------------------------------------------------------------===// 649 // Convert async.runtime.set_available to the corresponding runtime API call. 650 //===----------------------------------------------------------------------===// 651 652 namespace { 653 class RuntimeSetAvailableOpLowering 654 : public OpConversionPattern<RuntimeSetAvailableOp> { 655 public: 656 using OpConversionPattern::OpConversionPattern; 657 658 LogicalResult 659 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 660 ConversionPatternRewriter &rewriter) const override { 661 StringRef apiFuncName = 662 TypeSwitch<Type, StringRef>(op.getOperand().getType()) 663 .Case<TokenType>([](Type) { return kEmplaceToken; }) 664 .Case<ValueType>([](Type) { return kEmplaceValue; }); 665 666 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), 667 adaptor.getOperands()); 668 669 return success(); 670 } 671 }; 672 } // namespace 673 674 //===----------------------------------------------------------------------===// 675 // Convert async.runtime.set_error to the corresponding runtime API call. 676 //===----------------------------------------------------------------------===// 677 678 namespace { 679 class RuntimeSetErrorOpLowering 680 : public OpConversionPattern<RuntimeSetErrorOp> { 681 public: 682 using OpConversionPattern::OpConversionPattern; 683 684 LogicalResult 685 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 686 ConversionPatternRewriter &rewriter) const override { 687 StringRef apiFuncName = 688 TypeSwitch<Type, StringRef>(op.getOperand().getType()) 689 .Case<TokenType>([](Type) { return kSetTokenError; }) 690 .Case<ValueType>([](Type) { return kSetValueError; }); 691 692 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), 693 adaptor.getOperands()); 694 695 return success(); 696 } 697 }; 698 } // namespace 699 700 //===----------------------------------------------------------------------===// 701 // Convert async.runtime.is_error to the corresponding runtime API call. 702 //===----------------------------------------------------------------------===// 703 704 namespace { 705 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 706 public: 707 using OpConversionPattern::OpConversionPattern; 708 709 LogicalResult 710 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 711 ConversionPatternRewriter &rewriter) const override { 712 StringRef apiFuncName = 713 TypeSwitch<Type, StringRef>(op.getOperand().getType()) 714 .Case<TokenType>([](Type) { return kIsTokenError; }) 715 .Case<GroupType>([](Type) { return kIsGroupError; }) 716 .Case<ValueType>([](Type) { return kIsValueError; }); 717 718 rewriter.replaceOpWithNewOp<func::CallOp>( 719 op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands()); 720 return success(); 721 } 722 }; 723 } // namespace 724 725 //===----------------------------------------------------------------------===// 726 // Convert async.runtime.await to the corresponding runtime API call. 727 //===----------------------------------------------------------------------===// 728 729 namespace { 730 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 731 public: 732 using OpConversionPattern::OpConversionPattern; 733 734 LogicalResult 735 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 736 ConversionPatternRewriter &rewriter) const override { 737 StringRef apiFuncName = 738 TypeSwitch<Type, StringRef>(op.getOperand().getType()) 739 .Case<TokenType>([](Type) { return kAwaitToken; }) 740 .Case<ValueType>([](Type) { return kAwaitValue; }) 741 .Case<GroupType>([](Type) { return kAwaitGroup; }); 742 743 rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(), 744 adaptor.getOperands()); 745 rewriter.eraseOp(op); 746 747 return success(); 748 } 749 }; 750 } // namespace 751 752 //===----------------------------------------------------------------------===// 753 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 754 //===----------------------------------------------------------------------===// 755 756 namespace { 757 class RuntimeAwaitAndResumeOpLowering 758 : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> { 759 public: 760 using AsyncOpConversionPattern::AsyncOpConversionPattern; 761 762 LogicalResult 763 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 764 ConversionPatternRewriter &rewriter) const override { 765 StringRef apiFuncName = 766 TypeSwitch<Type, StringRef>(op.getOperand().getType()) 767 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 768 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 769 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 770 771 Value operand = adaptor.getOperand(); 772 Value handle = adaptor.getHandle(); 773 774 // A pointer to coroutine resume intrinsic wrapper. 775 addResumeFunction(op->getParentOfType<ModuleOp>()); 776 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 777 op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), 778 kResume); 779 780 rewriter.create<func::CallOp>( 781 op->getLoc(), apiFuncName, TypeRange(), 782 ValueRange({operand, handle, resumePtr.getRes()})); 783 rewriter.eraseOp(op); 784 785 return success(); 786 } 787 }; 788 } // namespace 789 790 //===----------------------------------------------------------------------===// 791 // Convert async.runtime.resume to the corresponding runtime API call. 792 //===----------------------------------------------------------------------===// 793 794 namespace { 795 class RuntimeResumeOpLowering 796 : public AsyncOpConversionPattern<RuntimeResumeOp> { 797 public: 798 using AsyncOpConversionPattern::AsyncOpConversionPattern; 799 800 LogicalResult 801 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 802 ConversionPatternRewriter &rewriter) const override { 803 // A pointer to coroutine resume intrinsic wrapper. 804 addResumeFunction(op->getParentOfType<ModuleOp>()); 805 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 806 op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), 807 kResume); 808 809 // Call async runtime API to execute a coroutine in the managed thread. 810 auto coroHdl = adaptor.getHandle(); 811 rewriter.replaceOpWithNewOp<func::CallOp>( 812 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); 813 814 return success(); 815 } 816 }; 817 } // namespace 818 819 //===----------------------------------------------------------------------===// 820 // Convert async.runtime.store to the corresponding runtime API call. 821 //===----------------------------------------------------------------------===// 822 823 namespace { 824 class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> { 825 public: 826 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 827 828 LogicalResult 829 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 830 ConversionPatternRewriter &rewriter) const override { 831 Location loc = op->getLoc(); 832 833 // Get a pointer to the async value storage from the runtime. 834 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); 835 auto storage = adaptor.getStorage(); 836 auto storagePtr = rewriter.create<func::CallOp>( 837 loc, kGetValueStorage, TypeRange(ptrType), storage); 838 839 // Cast from i8* to the LLVM pointer type. 840 auto valueType = op.getValue().getType(); 841 auto llvmValueType = getTypeConverter()->convertType(valueType); 842 if (!llvmValueType) 843 return rewriter.notifyMatchFailure( 844 op, "failed to convert stored value type to LLVM type"); 845 846 Value castedStoragePtr = storagePtr.getResult(0); 847 // Store the yielded value into the async value storage. 848 auto value = adaptor.getValue(); 849 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr); 850 851 // Erase the original runtime store operation. 852 rewriter.eraseOp(op); 853 854 return success(); 855 } 856 }; 857 } // namespace 858 859 //===----------------------------------------------------------------------===// 860 // Convert async.runtime.load to the corresponding runtime API call. 861 //===----------------------------------------------------------------------===// 862 863 namespace { 864 class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> { 865 public: 866 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 867 868 LogicalResult 869 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 870 ConversionPatternRewriter &rewriter) const override { 871 Location loc = op->getLoc(); 872 873 // Get a pointer to the async value storage from the runtime. 874 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); 875 auto storage = adaptor.getStorage(); 876 auto storagePtr = rewriter.create<func::CallOp>( 877 loc, kGetValueStorage, TypeRange(ptrType), storage); 878 879 // Cast from i8* to the LLVM pointer type. 880 auto valueType = op.getResult().getType(); 881 auto llvmValueType = getTypeConverter()->convertType(valueType); 882 if (!llvmValueType) 883 return rewriter.notifyMatchFailure( 884 op, "failed to convert loaded value type to LLVM type"); 885 886 Value castedStoragePtr = storagePtr.getResult(0); 887 888 // Load from the casted pointer. 889 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType, 890 castedStoragePtr); 891 892 return success(); 893 } 894 }; 895 } // namespace 896 897 //===----------------------------------------------------------------------===// 898 // Convert async.runtime.add_to_group to the corresponding runtime API call. 899 //===----------------------------------------------------------------------===// 900 901 namespace { 902 class RuntimeAddToGroupOpLowering 903 : public OpConversionPattern<RuntimeAddToGroupOp> { 904 public: 905 using OpConversionPattern::OpConversionPattern; 906 907 LogicalResult 908 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 909 ConversionPatternRewriter &rewriter) const override { 910 // Currently we can only add tokens to the group. 911 if (!isa<TokenType>(op.getOperand().getType())) 912 return rewriter.notifyMatchFailure(op, "only token type is supported"); 913 914 // Replace with a runtime API function call. 915 rewriter.replaceOpWithNewOp<func::CallOp>( 916 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 917 918 return success(); 919 } 920 }; 921 } // namespace 922 923 //===----------------------------------------------------------------------===// 924 // Convert async.runtime.num_worker_threads to the corresponding runtime API 925 // call. 926 //===----------------------------------------------------------------------===// 927 928 namespace { 929 class RuntimeNumWorkerThreadsOpLowering 930 : public OpConversionPattern<RuntimeNumWorkerThreadsOp> { 931 public: 932 using OpConversionPattern::OpConversionPattern; 933 934 LogicalResult 935 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, 936 ConversionPatternRewriter &rewriter) const override { 937 938 // Replace with a runtime API function call. 939 rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads, 940 rewriter.getIndexType()); 941 942 return success(); 943 } 944 }; 945 } // namespace 946 947 //===----------------------------------------------------------------------===// 948 // Async reference counting ops lowering (`async.runtime.add_ref` and 949 // `async.runtime.drop_ref` to the corresponding API calls). 950 //===----------------------------------------------------------------------===// 951 952 namespace { 953 template <typename RefCountingOp> 954 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 955 public: 956 explicit RefCountingOpLowering(const TypeConverter &converter, 957 MLIRContext *ctx, StringRef apiFunctionName) 958 : OpConversionPattern<RefCountingOp>(converter, ctx), 959 apiFunctionName(apiFunctionName) {} 960 961 LogicalResult 962 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 963 ConversionPatternRewriter &rewriter) const override { 964 auto count = rewriter.create<arith::ConstantOp>( 965 op->getLoc(), rewriter.getI64Type(), 966 rewriter.getI64IntegerAttr(op.getCount())); 967 968 auto operand = adaptor.getOperand(); 969 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName, 970 ValueRange({operand, count})); 971 972 return success(); 973 } 974 975 private: 976 StringRef apiFunctionName; 977 }; 978 979 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 980 public: 981 explicit RuntimeAddRefOpLowering(const TypeConverter &converter, 982 MLIRContext *ctx) 983 : RefCountingOpLowering(converter, ctx, kAddRef) {} 984 }; 985 986 class RuntimeDropRefOpLowering 987 : public RefCountingOpLowering<RuntimeDropRefOp> { 988 public: 989 explicit RuntimeDropRefOpLowering(const TypeConverter &converter, 990 MLIRContext *ctx) 991 : RefCountingOpLowering(converter, ctx, kDropRef) {} 992 }; 993 } // namespace 994 995 //===----------------------------------------------------------------------===// 996 // Convert return operations that return async values from async regions. 997 //===----------------------------------------------------------------------===// 998 999 namespace { 1000 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> { 1001 public: 1002 using OpConversionPattern::OpConversionPattern; 1003 1004 LogicalResult 1005 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 1006 ConversionPatternRewriter &rewriter) const override { 1007 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 1008 return success(); 1009 } 1010 }; 1011 } // namespace 1012 1013 //===----------------------------------------------------------------------===// 1014 1015 namespace { 1016 struct ConvertAsyncToLLVMPass 1017 : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> { 1018 using Base::Base; 1019 1020 void runOnOperation() override; 1021 }; 1022 } // namespace 1023 1024 void ConvertAsyncToLLVMPass::runOnOperation() { 1025 ModuleOp module = getOperation(); 1026 MLIRContext *ctx = module->getContext(); 1027 1028 LowerToLLVMOptions options(ctx); 1029 1030 // Add declarations for most functions required by the coroutines lowering. 1031 // We delay adding the resume function until it's needed because it currently 1032 // fails to compile unless '-O0' is specified. 1033 addAsyncRuntimeApiDeclarations(module); 1034 1035 // Lower async.runtime and async.coro operations to Async Runtime API and 1036 // LLVM coroutine intrinsics. 1037 1038 // Convert async dialect types and operations to LLVM dialect. 1039 AsyncRuntimeTypeConverter converter(options); 1040 RewritePatternSet patterns(ctx); 1041 1042 // We use conversion to LLVM type to lower async.runtime load and store 1043 // operations. 1044 LLVMTypeConverter llvmConverter(ctx, options); 1045 llvmConverter.addConversion([&](Type type) { 1046 return AsyncRuntimeTypeConverter::convertAsyncTypes(type); 1047 }); 1048 1049 // Convert async types in function signatures and function calls. 1050 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 1051 converter); 1052 populateCallOpTypeConversionPattern(patterns, converter); 1053 1054 // Convert return operations inside async.execute regions. 1055 patterns.add<ReturnOpOpConversion>(converter, ctx); 1056 1057 // Lower async.runtime operations to the async runtime API calls. 1058 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 1059 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 1060 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 1061 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering, 1062 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter, 1063 ctx); 1064 1065 // Lower async.runtime operations that rely on LLVM type converter to convert 1066 // from async value payload type to the LLVM type. 1067 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 1068 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter); 1069 1070 // Lower async coroutine operations to LLVM coroutine intrinsics. 1071 patterns 1072 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 1073 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1074 converter, ctx); 1075 1076 ConversionTarget target(*ctx); 1077 target.addLegalOp<arith::ConstantOp, func::ConstantOp, 1078 UnrealizedConversionCastOp>(); 1079 target.addLegalDialect<LLVM::LLVMDialect>(); 1080 1081 // All operations from Async dialect must be lowered to the runtime API and 1082 // LLVM intrinsics calls. 1083 target.addIllegalDialect<AsyncDialect>(); 1084 1085 // Add dynamic legality constraints to apply conversions defined above. 1086 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1087 return converter.isSignatureLegal(op.getFunctionType()); 1088 }); 1089 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 1090 return converter.isLegal(op.getOperandTypes()); 1091 }); 1092 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 1093 return converter.isSignatureLegal(op.getCalleeType()); 1094 }); 1095 1096 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1097 signalPassFailure(); 1098 } 1099 1100 //===----------------------------------------------------------------------===// 1101 // Patterns for structural type conversions for the Async dialect operations. 1102 //===----------------------------------------------------------------------===// 1103 1104 namespace { 1105 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1106 public: 1107 using OpConversionPattern::OpConversionPattern; 1108 LogicalResult 1109 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1110 ConversionPatternRewriter &rewriter) const override { 1111 ExecuteOp newOp = 1112 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1113 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1114 newOp.getRegion().end()); 1115 1116 // Set operands and update block argument and result types. 1117 newOp->setOperands(adaptor.getOperands()); 1118 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1119 return failure(); 1120 for (auto result : newOp.getResults()) 1121 result.setType(typeConverter->convertType(result.getType())); 1122 1123 rewriter.replaceOp(op, newOp.getResults()); 1124 return success(); 1125 } 1126 }; 1127 1128 // Dummy pattern to trigger the appropriate type conversion / materialization. 1129 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1130 public: 1131 using OpConversionPattern::OpConversionPattern; 1132 LogicalResult 1133 matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1134 ConversionPatternRewriter &rewriter) const override { 1135 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1136 return success(); 1137 } 1138 }; 1139 1140 // Dummy pattern to trigger the appropriate type conversion / materialization. 1141 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1142 public: 1143 using OpConversionPattern::OpConversionPattern; 1144 LogicalResult 1145 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1146 ConversionPatternRewriter &rewriter) const override { 1147 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1148 return success(); 1149 } 1150 }; 1151 } // namespace 1152 1153 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1154 TypeConverter &typeConverter, RewritePatternSet &patterns, 1155 ConversionTarget &target) { 1156 typeConverter.addConversion([&](TokenType type) { return type; }); 1157 typeConverter.addConversion([&](ValueType type) { 1158 Type converted = typeConverter.convertType(type.getValueType()); 1159 return converted ? ValueType::get(converted) : converted; 1160 }); 1161 1162 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1163 typeConverter, patterns.getContext()); 1164 1165 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1166 [&](Operation *op) { return typeConverter.isLegal(op); }); 1167 } 1168