1 //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to generate NVVMIR operations for higher-level 10 // GPU operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" 15 16 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 18 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 20 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 21 #include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" 22 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 23 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 24 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 25 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 26 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 27 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 28 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 29 #include "mlir/Dialect/Func/IR/FuncOps.h" 30 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 31 #include "mlir/Dialect/GPU/Transforms/Passes.h" 32 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 33 #include "mlir/Dialect/Math/IR/Math.h" 34 #include "mlir/Dialect/MemRef/IR/MemRef.h" 35 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 36 #include "mlir/Transforms/DialectConversion.h" 37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 38 39 #include "../GPUCommon/GPUOpsLowering.h" 40 #include "../GPUCommon/IndexIntrinsicsOpLowering.h" 41 #include "../GPUCommon/OpToFuncCallLowering.h" 42 #include <optional> 43 44 namespace mlir { 45 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS 46 #include "mlir/Conversion/Passes.h.inc" 47 } // namespace mlir 48 49 using namespace mlir; 50 51 namespace { 52 53 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one. 54 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) { 55 switch (mode) { 56 case gpu::ShuffleMode::XOR: 57 return NVVM::ShflKind::bfly; 58 case gpu::ShuffleMode::UP: 59 return NVVM::ShflKind::up; 60 case gpu::ShuffleMode::DOWN: 61 return NVVM::ShflKind::down; 62 case gpu::ShuffleMode::IDX: 63 return NVVM::ShflKind::idx; 64 } 65 llvm_unreachable("unknown shuffle mode"); 66 } 67 68 static std::optional<NVVM::ReduxKind> 69 convertReduxKind(gpu::AllReduceOperation mode) { 70 switch (mode) { 71 case gpu::AllReduceOperation::ADD: 72 return NVVM::ReduxKind::ADD; 73 case gpu::AllReduceOperation::MUL: 74 return std::nullopt; 75 case gpu::AllReduceOperation::MINSI: 76 return NVVM::ReduxKind::MIN; 77 case gpu::AllReduceOperation::MINUI: 78 return std::nullopt; 79 case gpu::AllReduceOperation::MINNUMF: 80 return NVVM::ReduxKind::MIN; 81 case gpu::AllReduceOperation::MAXSI: 82 return NVVM::ReduxKind::MAX; 83 case gpu::AllReduceOperation::MAXUI: 84 return std::nullopt; 85 case gpu::AllReduceOperation::MAXNUMF: 86 return NVVM::ReduxKind::MAX; 87 case gpu::AllReduceOperation::AND: 88 return NVVM::ReduxKind::AND; 89 case gpu::AllReduceOperation::OR: 90 return NVVM::ReduxKind::OR; 91 case gpu::AllReduceOperation::XOR: 92 return NVVM::ReduxKind::XOR; 93 case gpu::AllReduceOperation::MINIMUMF: 94 case gpu::AllReduceOperation::MAXIMUMF: 95 return std::nullopt; 96 } 97 return std::nullopt; 98 } 99 100 /// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op 101 /// must be run by the entire subgroup, otherwise it is undefined behaviour. 102 struct GPUSubgroupReduceOpLowering 103 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> { 104 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern; 105 LogicalResult 106 107 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override { 109 if (op.getClusterSize()) 110 return rewriter.notifyMatchFailure( 111 op, "lowering for clustered reduce not implemented"); 112 113 if (!op.getUniform()) 114 return rewriter.notifyMatchFailure( 115 op, "cannot be lowered to redux as the op must be run " 116 "uniformly (entire subgroup)."); 117 if (!op.getValue().getType().isInteger(32)) 118 return rewriter.notifyMatchFailure(op, "unsupported data type"); 119 120 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp()); 121 if (!mode.has_value()) 122 return rewriter.notifyMatchFailure( 123 op, "unsupported reduction mode for redux"); 124 125 Location loc = op->getLoc(); 126 auto int32Type = IntegerType::get(rewriter.getContext(), 32); 127 Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1); 128 129 auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(), 130 mode.value(), offset); 131 132 rewriter.replaceOp(op, reduxOp->getResult(0)); 133 return success(); 134 } 135 }; 136 137 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { 138 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern; 139 140 /// Lowers a shuffle to the corresponding NVVM op. 141 /// 142 /// Convert the `width` argument into an activeMask (a bitmask which specifies 143 /// which threads participate in the shuffle) and a maskAndClamp (specifying 144 /// the highest lane which participates in the shuffle). 145 /// 146 /// %one = llvm.constant(1 : i32) : i32 147 /// %minus_one = llvm.constant(-1 : i32) : i32 148 /// %thirty_two = llvm.constant(32 : i32) : i32 149 /// %num_lanes = llvm.sub %thirty_two, %width : i32 150 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32 151 /// %mask_and_clamp = llvm.sub %width, %one : i32 152 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset, 153 /// %mask_and_clamp : !llvm<"{ float, i1 }"> 154 /// %shfl_value = llvm.extractvalue %shfl[0] : 155 /// !llvm<"{ float, i1 }"> 156 /// %shfl_pred = llvm.extractvalue %shfl[1] : 157 /// !llvm<"{ float, i1 }"> 158 LogicalResult 159 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, 160 ConversionPatternRewriter &rewriter) const override { 161 Location loc = op->getLoc(); 162 163 auto valueTy = adaptor.getValue().getType(); 164 auto int32Type = IntegerType::get(rewriter.getContext(), 32); 165 auto predTy = IntegerType::get(rewriter.getContext(), 1); 166 167 Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1); 168 Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1); 169 Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32); 170 Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>( 171 loc, int32Type, thirtyTwo, adaptor.getWidth()); 172 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. 173 Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne, 174 numLeadInactiveLane); 175 Value maskAndClamp; 176 if (op.getMode() == gpu::ShuffleMode::UP) { 177 // Clamp lane: `32 - activeWidth` 178 maskAndClamp = numLeadInactiveLane; 179 } else { 180 // Clamp lane: `activeWidth - 1` 181 maskAndClamp = 182 rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one); 183 } 184 185 bool predIsUsed = !op->getResult(1).use_empty(); 186 UnitAttr returnValueAndIsValidAttr = nullptr; 187 Type resultTy = valueTy; 188 if (predIsUsed) { 189 returnValueAndIsValidAttr = rewriter.getUnitAttr(); 190 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 191 {valueTy, predTy}); 192 } 193 Value shfl = rewriter.create<NVVM::ShflOp>( 194 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(), 195 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr); 196 if (predIsUsed) { 197 Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0); 198 Value isActiveSrcLane = 199 rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1); 200 rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); 201 } else { 202 rewriter.replaceOp(op, {shfl, nullptr}); 203 } 204 return success(); 205 } 206 }; 207 208 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> { 209 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern; 210 211 LogicalResult 212 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, 213 ConversionPatternRewriter &rewriter) const override { 214 auto loc = op->getLoc(); 215 MLIRContext *context = rewriter.getContext(); 216 LLVM::ConstantRangeAttr bounds = nullptr; 217 if (std::optional<APInt> upperBound = op.getUpperBound()) 218 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>( 219 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue()); 220 else 221 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>( 222 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize); 223 Value newOp = 224 rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds); 225 // Truncate or extend the result depending on the index bitwidth specified 226 // by the LLVMTypeConverter options. 227 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); 228 if (indexBitwidth > 32) { 229 newOp = rewriter.create<LLVM::SExtOp>( 230 loc, IntegerType::get(context, indexBitwidth), newOp); 231 } else if (indexBitwidth < 32) { 232 newOp = rewriter.create<LLVM::TruncOp>( 233 loc, IntegerType::get(context, indexBitwidth), newOp); 234 } 235 rewriter.replaceOp(op, {newOp}); 236 return success(); 237 } 238 }; 239 240 /// Lowering of cf.assert into a conditional __assertfail. 241 struct AssertOpToAssertfailLowering 242 : public ConvertOpToLLVMPattern<cf::AssertOp> { 243 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern; 244 245 LogicalResult 246 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor, 247 ConversionPatternRewriter &rewriter) const override { 248 MLIRContext *ctx = rewriter.getContext(); 249 Location loc = assertOp.getLoc(); 250 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8)); 251 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 252 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64)); 253 Type ptrType = LLVM::LLVMPointerType::get(ctx); 254 Type voidType = LLVM::LLVMVoidType::get(ctx); 255 256 // Find or create __assertfail function declaration. 257 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>(); 258 auto assertfailType = LLVM::LLVMFunctionType::get( 259 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type}); 260 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction( 261 moduleOp, loc, rewriter, "__assertfail", assertfailType); 262 assertfailDecl.setPassthroughAttr( 263 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn"))); 264 265 // Split blocks and insert conditional branch. 266 // ^before: 267 // ... 268 // cf.cond_br %condition, ^after, ^assert 269 // ^assert: 270 // cf.assert 271 // cf.br ^after 272 // ^after: 273 // ... 274 Block *beforeBlock = assertOp->getBlock(); 275 Block *assertBlock = 276 rewriter.splitBlock(beforeBlock, assertOp->getIterator()); 277 Block *afterBlock = 278 rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); 279 rewriter.setInsertionPointToEnd(beforeBlock); 280 rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock, 281 assertBlock); 282 rewriter.setInsertionPointToEnd(assertBlock); 283 rewriter.create<cf::BranchOp>(loc, afterBlock); 284 285 // Continue cf.assert lowering. 286 rewriter.setInsertionPoint(assertOp); 287 288 // Populate file name, file number and function name from the location of 289 // the AssertOp. 290 StringRef fileName = "(unknown)"; 291 StringRef funcName = "(unknown)"; 292 int32_t fileLine = 0; 293 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc)) 294 loc = callSiteLoc.getCallee(); 295 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) { 296 fileName = fileLineColLoc.getFilename().strref(); 297 fileLine = fileLineColLoc.getStartLine(); 298 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) { 299 funcName = nameLoc.getName().strref(); 300 if (auto fileLineColLoc = 301 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) { 302 fileName = fileLineColLoc.getFilename().strref(); 303 fileLine = fileLineColLoc.getStartLine(); 304 } 305 } 306 307 // Create constants. 308 auto getGlobal = [&](LLVM::GlobalOp global) { 309 // Get a pointer to the format string's first element. 310 Value globalPtr = rewriter.create<LLVM::AddressOfOp>( 311 loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), 312 global.getSymNameAttr()); 313 Value start = 314 rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), 315 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); 316 return start; 317 }; 318 Value assertMessage = getGlobal(getOrCreateStringConstant( 319 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg())); 320 Value assertFile = getGlobal(getOrCreateStringConstant( 321 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName)); 322 Value assertFunc = getGlobal(getOrCreateStringConstant( 323 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName)); 324 Value assertLine = 325 rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine); 326 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1); 327 328 // Insert function call to __assertfail. 329 SmallVector<Value> arguments{assertMessage, assertFile, assertLine, 330 assertFunc, c1}; 331 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl, 332 arguments); 333 return success(); 334 } 335 }; 336 337 /// Import the GPU Ops to NVVM Patterns. 338 #include "GPUToNVVM.cpp.inc" 339 340 /// A pass that replaces all occurrences of GPU device operations with their 341 /// corresponding NVVM equivalent. 342 /// 343 /// This pass only handles device code and is not meant to be run on GPU host 344 /// code. 345 struct LowerGpuOpsToNVVMOpsPass 346 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> { 347 using Base::Base; 348 349 void runOnOperation() override { 350 gpu::GPUModuleOp m = getOperation(); 351 352 // Request C wrapper emission. 353 for (auto func : m.getOps<func::FuncOp>()) { 354 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 355 UnitAttr::get(&getContext())); 356 } 357 358 // Customize the bitwidth used for the device side index computations. 359 LowerToLLVMOptions options( 360 m.getContext(), 361 DataLayout(cast<DataLayoutOpInterface>(m.getOperation()))); 362 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 363 options.overrideIndexBitwidth(indexBitwidth); 364 options.useBarePtrCallConv = useBarePtrCallConv; 365 366 // Apply in-dialect lowering. In-dialect lowering will replace 367 // ops which need to be lowered further, which is not supported by a 368 // single conversion pass. 369 { 370 RewritePatternSet patterns(m.getContext()); 371 populateGpuRewritePatterns(patterns); 372 if (failed(applyPatternsGreedily(m, std::move(patterns)))) 373 return signalPassFailure(); 374 } 375 376 LLVMTypeConverter converter(m.getContext(), options); 377 configureGpuToNVVMTypeConverter(converter); 378 RewritePatternSet llvmPatterns(m.getContext()); 379 380 arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); 381 cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); 382 populateFuncToLLVMConversionPatterns(converter, llvmPatterns); 383 populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); 384 populateGpuToNVVMConversionPatterns(converter, llvmPatterns); 385 populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); 386 populateVectorToLLVMConversionPatterns(converter, llvmPatterns); 387 if (this->hasRedux) 388 populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns); 389 LLVMConversionTarget target(getContext()); 390 configureGpuToNVVMConversionLegality(target); 391 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) 392 signalPassFailure(); 393 } 394 }; 395 396 } // namespace 397 398 void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) { 399 target.addIllegalOp<func::FuncOp>(); 400 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 401 target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); 402 target.addIllegalDialect<gpu::GPUDialect>(); 403 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, 404 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp, 405 LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, 406 LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp, 407 LLVM::SinOp, LLVM::SqrtOp>(); 408 409 // TODO: Remove once we support replacing non-root ops. 410 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>(); 411 } 412 413 void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) { 414 // NVVM uses alloca in the default address space to represent private 415 // memory allocations, so drop private annotations. NVVM uses address 416 // space 3 for shared memory. NVVM uses the default address space to 417 // represent global memory. 418 populateGpuMemorySpaceAttributeConversions( 419 converter, [](gpu::AddressSpace space) -> unsigned { 420 switch (space) { 421 case gpu::AddressSpace::Global: 422 return static_cast<unsigned>( 423 NVVM::NVVMMemorySpace::kGlobalMemorySpace); 424 case gpu::AddressSpace::Workgroup: 425 return static_cast<unsigned>( 426 NVVM::NVVMMemorySpace::kSharedMemorySpace); 427 case gpu::AddressSpace::Private: 428 return 0; 429 } 430 llvm_unreachable("unknown address space enum value"); 431 return 0; 432 }); 433 // Lowering for MMAMatrixType. 434 converter.addConversion([&](gpu::MMAMatrixType type) -> Type { 435 return convertMMAToLLVMType(type); 436 }); 437 } 438 439 template <typename OpTy> 440 static void populateOpPatterns(const LLVMTypeConverter &converter, 441 RewritePatternSet &patterns, StringRef f32Func, 442 StringRef f64Func, StringRef f32ApproxFunc = "", 443 StringRef f16Func = "") { 444 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter); 445 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, 446 f32ApproxFunc, f16Func); 447 } 448 449 template <typename OpTy> 450 static void populateIntOpPatterns(const LLVMTypeConverter &converter, 451 RewritePatternSet &patterns, 452 StringRef i32Func) { 453 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter); 454 patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func); 455 } 456 457 template <typename OpTy> 458 static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, 459 RewritePatternSet &patterns, 460 StringRef f32Func, StringRef f64Func) { 461 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter); 462 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", ""); 463 } 464 465 void mlir::populateGpuSubgroupReduceOpLoweringPattern( 466 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 467 patterns.add<GPUSubgroupReduceOpLowering>(converter); 468 } 469 470 void mlir::populateGpuToNVVMConversionPatterns( 471 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 472 using gpu::index_lowering::IndexKind; 473 using gpu::index_lowering::IntrType; 474 populateWithGenerated(patterns); 475 patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>( 476 converter); 477 patterns.add< 478 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp, 479 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>( 480 converter, IndexKind::Block, IntrType::Id); 481 patterns.add< 482 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp, 483 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>( 484 converter, IndexKind::Block, IntrType::Dim); 485 patterns.add< 486 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp, 487 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>( 488 converter, IndexKind::Other, IntrType::Id); 489 patterns.add<gpu::index_lowering::OpLowering< 490 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp, 491 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim); 492 patterns.add<gpu::index_lowering::OpLowering< 493 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp, 494 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>( 495 converter, IndexKind::Other, IntrType::Id); 496 patterns.add<gpu::index_lowering::OpLowering< 497 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp, 498 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>( 499 converter, IndexKind::Other, IntrType::Dim); 500 patterns.add<gpu::index_lowering::OpLowering< 501 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>( 502 converter, IndexKind::Grid, IntrType::Id); 503 patterns.add<gpu::index_lowering::OpLowering< 504 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>( 505 converter, IndexKind::Grid, IntrType::Dim); 506 patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>( 507 converter); 508 509 patterns.add<GPUDynamicSharedMemoryOpLowering>( 510 converter, NVVM::kSharedMemoryAlignmentBit); 511 512 // Explicitly drop memory space when lowering private memory 513 // attributions since NVVM models it as `alloca`s in the default 514 // memory space and does not support `alloca`s with addrspace(5). 515 patterns.add<GPUFuncOpLowering>( 516 converter, 517 GPUFuncOpLoweringOptions{ 518 /*allocaAddrSpace=*/0, 519 /*workgroupAddrSpace=*/ 520 static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace), 521 StringAttr::get(&converter.getContext(), 522 NVVM::NVVMDialect::getKernelFuncAttrName()), 523 StringAttr::get(&converter.getContext(), 524 NVVM::NVVMDialect::getMaxntidAttrName())}); 525 526 populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf", 527 "__nv_fmod"); 528 populateIntOpPatterns<math::AbsIOp>(converter, patterns, "__nv_abs"); 529 populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf", 530 "__nv_fabs"); 531 populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf", 532 "__nv_acos"); 533 populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf", 534 "__nv_acosh"); 535 populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf", 536 "__nv_asin"); 537 populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf", 538 "__nv_asinh"); 539 populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf", 540 "__nv_atan"); 541 populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f", 542 "__nv_atan2"); 543 populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf", 544 "__nv_atanh"); 545 populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf", 546 "__nv_cbrt"); 547 populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf", 548 "__nv_ceil"); 549 populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf", 550 "__nv_copysign"); 551 populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos", 552 "__nv_fast_cosf"); 553 populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf", 554 "__nv_cosh"); 555 populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf"); 556 populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp", 557 "__nv_fast_expf"); 558 populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f", 559 "__nv_exp2"); 560 populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f", 561 "__nv_expm1"); 562 populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf", 563 "__nv_floor"); 564 populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma"); 565 populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log", 566 "__nv_fast_logf"); 567 populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f", 568 "__nv_log10", "__nv_fast_log10f"); 569 populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf", 570 "__nv_log1p"); 571 populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f", 572 "__nv_log2", "__nv_fast_log2f"); 573 populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow", 574 "__nv_fast_powf"); 575 populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, "__nv_powif", 576 "__nv_powi"); 577 populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf", 578 "__nv_round"); 579 populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf", 580 "__nv_rint"); 581 populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf", 582 "__nv_rsqrt"); 583 populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin", 584 "__nv_fast_sinf"); 585 populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf", 586 "__nv_sinh"); 587 populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf", 588 "__nv_sqrt"); 589 populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan", 590 "__nv_fast_tanf"); 591 populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf", 592 "__nv_tanh"); 593 } 594 595 //===----------------------------------------------------------------------===// 596 // NVVMTargetAttr convert to LLVM attr interface 597 //===----------------------------------------------------------------------===// 598 599 namespace { 600 struct NVVMTargetConvertToLLVMAttrInterface 601 : public ConvertToLLVMAttrInterface::ExternalModel< 602 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> { 603 /// Configure GPU to NVVM. 604 void populateConvertToLLVMConversionPatterns( 605 Attribute attr, ConversionTarget &target, 606 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const; 607 }; 608 } // namespace 609 610 void NVVMTargetConvertToLLVMAttrInterface:: 611 populateConvertToLLVMConversionPatterns(Attribute attr, 612 ConversionTarget &target, 613 LLVMTypeConverter &typeConverter, 614 RewritePatternSet &patterns) const { 615 configureGpuToNVVMConversionLegality(target); 616 configureGpuToNVVMTypeConverter(typeConverter); 617 populateGpuToNVVMConversionPatterns(typeConverter, patterns); 618 } 619 620 void mlir::NVVM::registerConvertGpuToNVVMInterface(DialectRegistry ®istry) { 621 registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) { 622 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx); 623 }); 624 } 625