1//===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This is the NVVM IR operation definition file. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef NVVMIR_OPS 14#define NVVMIR_OPS 15 16include "mlir/IR/EnumAttr.td" 17include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" 18include "mlir/Dialect/LLVMIR/LLVMOpBase.td" 19include "mlir/Interfaces/SideEffectInterfaces.td" 20include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" 21include "mlir/Interfaces/InferIntRangeInterface.td" 22 23def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>; 24def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; 25def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; 26 27//===----------------------------------------------------------------------===// 28// NVVM dialect definitions 29//===----------------------------------------------------------------------===// 30 31def NVVM_Dialect : Dialect { 32 let name = "nvvm"; 33 let cppNamespace = "::mlir::NVVM"; 34 let dependentDialects = ["LLVM::LLVMDialect"]; 35 let hasOperationAttrVerify = 1; 36 37 let extraClassDeclaration = [{ 38 /// Get the name of the attribute used to annotate external kernel 39 /// functions. 40 static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; } 41 /// Get the name of the attribute used to annotate max threads required 42 /// per CTA for kernel functions. 43 static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; } 44 /// Get the name of the metadata names for each dimension 45 static StringRef getMaxntidXName() { return "maxntidx"; } 46 static StringRef getMaxntidYName() { return "maxntidy"; } 47 static StringRef getMaxntidZName() { return "maxntidz"; } 48 49 /// Get the name of the attribute used to annotate exact threads required 50 /// per CTA for kernel functions. 51 static StringRef getReqntidAttrName() { return "nvvm.reqntid"; } 52 /// Get the name of the metadata names for each dimension 53 static StringRef getReqntidXName() { return "reqntidx"; } 54 static StringRef getReqntidYName() { return "reqntidy"; } 55 static StringRef getReqntidZName() { return "reqntidz"; } 56 57 /// Get the name of the attribute used to annotate exact CTAs required 58 /// per cluster for kernel functions. 59 static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; } 60 /// Get the name of the metadata names for each dimension 61 static StringRef getClusterDimXName() { return "cluster_dim_x"; } 62 static StringRef getClusterDimYName() { return "cluster_dim_y"; } 63 static StringRef getClusterDimZName() { return "cluster_dim_z"; } 64 65 /// Get the name of the attribute used to annotate maximum number of 66 /// CTAs per cluster for kernel functions. 67 static StringRef getClusterMaxBlocksAttrName() { return "nvvm.cluster_max_blocks"; } 68 69 /// Get the name of the attribute used to annotate min CTA required 70 /// per SM for kernel functions. 71 static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; } 72 73 /// Get the name of the attribute used to annotate max number of 74 /// registers that can be allocated per thread. 75 static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; } 76 77 /// Get the name of the attribute used to annotate kernel arguments that 78 /// are grid constants. 79 static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; } 80 81 /// Verify an attribute from this dialect on the argument at 'argIndex' for 82 /// the region at 'regionIndex' on the given operation. Returns failure if 83 /// the verification failed, success otherwise. This hook may optionally be 84 /// invoked from any operation containing a region. 85 LogicalResult verifyRegionArgAttribute(Operation *op, 86 unsigned regionIndex, 87 unsigned argIndex, 88 NamedAttribute argAttr) override; 89 }]; 90 91 let useDefaultAttributePrinterParser = 1; 92} 93 94//===----------------------------------------------------------------------===// 95// NVVM op definitions 96//===----------------------------------------------------------------------===// 97 98class NVVM_Op<string mnemonic, list<Trait> traits = []> : 99 LLVM_OpBase<NVVM_Dialect, mnemonic, traits> { 100} 101 102/// Base class that defines BasicPtxBuilderOpInterface. 103class NVVM_PTXBuilder_Op<string mnemonic, 104 list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> : 105 LLVM_OpBase<NVVM_Dialect, mnemonic, traits> { 106} 107 108//===----------------------------------------------------------------------===// 109// NVVM attribute definitions 110//===----------------------------------------------------------------------===// 111 112class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []> 113 : AttrDef<NVVM_Dialect, attrName, traits> { 114 let mnemonic = attrMnemonic; 115} 116 117//===----------------------------------------------------------------------===// 118// NVVM intrinsic operations 119//===----------------------------------------------------------------------===// 120 121class NVVM_IntrOp<string mnem, list<Trait> traits = [], 122 int numResults = 0> 123 : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem), 124 /*list<int> overloadedResults=*/[], 125 /*list<int> overloadedOperands=*/[], 126 traits, numResults>; 127 128//===----------------------------------------------------------------------===// 129// NVVM special register op definitions 130//===----------------------------------------------------------------------===// 131 132class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> : 133 NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> { 134 let arguments = (ins); 135 let assemblyFormat = "attr-dict `:` type($res)"; 136} 137 138class NVVM_SpecialRangeableRegisterOp<string mnemonic> : 139 NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> { 140 let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range); 141 let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; 142 let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; 143 let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda; 144 145 // Backwards-compatibility builder for an unspecified range. 146 let builders = [ 147 OpBuilder<(ins "Type":$resultType), [{ 148 build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{}); 149 }]> 150 ]; 151 152 // Define this method for the InferIntRangeInterface. 153 let extraClassDefinition = [{ 154 // Infer the result ranges based on the range attribute. 155 void $cppClass::inferResultRanges( 156 ArrayRef<::mlir::ConstantIntRanges> argRanges, 157 SetIntRangeFn setResultRanges) { 158 nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges); 159 } 160 }]; 161 162} 163 164//===----------------------------------------------------------------------===// 165// Lane, Warp, SM, Grid index and range 166def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">; 167def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">; 168def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">; 169def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">; 170def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">; 171def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">; 172def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">; 173 174//===----------------------------------------------------------------------===// 175// Lane Mask Comparison Ops 176def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">; 177def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">; 178def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">; 179def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">; 180def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">; 181 182//===----------------------------------------------------------------------===// 183// Thread index and range 184def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">; 185def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">; 186def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">; 187def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">; 188def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">; 189def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">; 190 191//===----------------------------------------------------------------------===// 192// Block index and range 193def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">; 194def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">; 195def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">; 196def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">; 197def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">; 198def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">; 199 200//===----------------------------------------------------------------------===// 201// CTA Cluster index and range 202def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">; 203def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">; 204def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">; 205def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">; 206def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">; 207def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">; 208 209 210//===----------------------------------------------------------------------===// 211// CTA index and range within Cluster 212def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">; 213def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">; 214def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">; 215def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">; 216def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">; 217def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">; 218 219//===----------------------------------------------------------------------===// 220// CTA index and across Cluster dimensions 221def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">; 222def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">; 223 224//===----------------------------------------------------------------------===// 225// Clock registers 226def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">; 227def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">; 228def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">; 229 230//===----------------------------------------------------------------------===// 231// envreg registers 232foreach index = !range(0, 32) in { 233 def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>; 234} 235 236//===----------------------------------------------------------------------===// 237// NVVM approximate op definitions 238//===----------------------------------------------------------------------===// 239 240def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [Pure], 1> { 241 let arguments = (ins F32:$arg); 242 let results = (outs F32:$res); 243 let assemblyFormat = "$arg attr-dict `:` type($res)"; 244} 245 246//===----------------------------------------------------------------------===// 247// NVVM redux op definitions 248//===----------------------------------------------------------------------===// 249 250def ReduxKindNone : I32EnumAttrCase<"NONE", 0, "none">; 251def ReduxKindAdd : I32EnumAttrCase<"ADD", 1, "add">; 252def ReduxKindAnd : I32EnumAttrCase<"AND", 2, "and">; 253def ReduxKindMax : I32EnumAttrCase<"MAX", 3, "max">; 254def ReduxKindMin : I32EnumAttrCase<"MIN", 4, "min">; 255def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">; 256def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">; 257def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">; 258def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">; 259 260/// Enum attribute of the different kinds. 261def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind", 262 [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, 263 ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> { 264 let genSpecializedAttr = 0; 265 let cppNamespace = "::mlir::NVVM"; 266} 267 268def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">; 269 270def NVVM_ReduxOp : 271 NVVM_Op<"redux.sync">, 272 Results<(outs LLVM_Type:$res)>, 273 Arguments<(ins LLVM_Type:$val, 274 ReduxKindAttr:$kind, 275 I32:$mask_and_clamp)> { 276 string llvmBuilder = [{ 277 auto intId = getReduxIntrinsicId($_resultType, $kind); 278 $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp}); 279 }]; 280 let assemblyFormat = [{ 281 $kind $val `,` $mask_and_clamp attr-dict `:` type($val) `->` type($res) 282 }]; 283} 284 285//===----------------------------------------------------------------------===// 286// NVVM Split arrive/wait barrier 287//===----------------------------------------------------------------------===// 288 289/// mbarrier.init instruction with generic pointer type 290def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">, 291 Arguments<(ins LLVM_AnyPointer:$addr, I32:$count, PtxPredicate:$predicate)> { 292 string llvmBuilder = [{ 293 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count}); 294 }]; 295 let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; 296 let extraClassDeclaration = [{ 297 bool hasIntrinsic() { if(getPredicate()) return false; return true; } 298 }]; 299 let extraClassDefinition = [{ 300 std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); } 301 }]; 302} 303 304/// mbarrier.init instruction with shared pointer type 305def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">, 306 Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> { 307 string llvmBuilder = [{ 308 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count}); 309 }]; 310 let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; 311 let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }"; 312 let extraClassDefinition = [{ 313 std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); } 314 }]; 315} 316 317def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">, 318 Arguments<(ins LLVM_AnyPointer:$addr)> { 319 string llvmBuilder = [{ 320 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_inval, {$addr}); 321 }]; 322 let assemblyFormat = "$addr attr-dict `:` type(operands)"; 323} 324 325def NVVM_MBarrierInvalSharedOp : NVVM_Op<"mbarrier.inval.shared">, 326 Arguments<(ins LLVM_PointerShared:$addr)> { 327 string llvmBuilder = [{ 328 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_inval_shared, {$addr}); 329 }]; 330 let assemblyFormat = "$addr attr-dict `:` type(operands)"; 331} 332 333def NVVM_MBarrierArriveOp : NVVM_Op<"mbarrier.arrive">, 334 Results<(outs LLVM_Type:$res)>, 335 Arguments<(ins LLVM_AnyPointer:$addr)> { 336 string llvmBuilder = [{ 337 $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_arrive, {$addr}); 338 }]; 339 let assemblyFormat = "$addr attr-dict `:` type($addr) `->` type($res)"; 340} 341 342def NVVM_MBarrierArriveSharedOp : NVVM_Op<"mbarrier.arrive.shared">, 343 Results<(outs LLVM_Type:$res)>, 344 Arguments<(ins LLVM_PointerShared:$addr)> { 345 string llvmBuilder = [{ 346 $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {$addr}); 347 }]; 348 let assemblyFormat = "$addr attr-dict `:` qualified(type($addr)) `->` type($res)"; 349} 350 351def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">, 352 Results<(outs LLVM_Type:$res)>, 353 Arguments<(ins LLVM_AnyPointer:$addr, I32:$count)> { 354 string llvmBuilder = [{ 355 $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete, {$addr, $count}); 356 }]; 357 let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)"; 358} 359 360def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.shared">, 361 Results<(outs LLVM_Type:$res)>, 362 Arguments<(ins LLVM_PointerShared:$addr, I32:$count)> { 363 string llvmBuilder = [{ 364 $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared, {$addr, $count}); 365 }]; 366 let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)"; 367} 368 369def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">, 370 Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> { 371 let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; 372 let extraClassDefinition = [{ 373 std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); } 374 }]; 375} 376 377def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">, 378 Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> { 379 let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; 380 let extraClassDefinition = [{ 381 std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); } 382 }]; 383} 384 385def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">, 386 Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> { 387 let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; 388 let extraClassDefinition = [{ 389 std::string $cppClass::getPtx() { 390 return std::string( 391 "{\n\t" 392 ".reg .pred P1; \n\t" 393 "LAB_WAIT: \n\t" 394 "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t" 395 "@P1 bra.uni DONE; \n\t" 396 "bra.uni LAB_WAIT; \n\t" 397 "DONE: \n\t" 398 "}" 399 ); 400 } 401 }]; 402} 403 404def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">, 405 Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> { 406 let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; 407 let extraClassDefinition = [{ 408 std::string $cppClass::getPtx() { 409 return std::string( 410 "{\n\t" 411 ".reg .pred P1; \n\t" 412 "LAB_WAIT: \n\t" 413 "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" 414 "@P1 bra.uni DONE; \n\t" 415 "bra.uni LAB_WAIT; \n\t" 416 "DONE: \n\t" 417 "}" 418 ); 419 } 420 }]; 421} 422 423def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, 424 Results<(outs LLVM_Type:$res)>, 425 Arguments<(ins LLVM_AnyPointer:$addr, LLVM_Type:$state)> { 426 string llvmBuilder = [{ 427 $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_test_wait, {$addr, $state}); 428 }]; 429 let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)"; 430} 431 432def NVVM_MBarrierTestWaitSharedOp : NVVM_Op<"mbarrier.test.wait.shared">, 433 Results<(outs LLVM_Type:$res)>, 434 Arguments<(ins LLVM_PointerShared:$addr, LLVM_Type:$state)> { 435 string llvmBuilder = [{ 436 $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_test_wait_shared, {$addr, $state}); 437 }]; 438 let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)"; 439} 440 441//===----------------------------------------------------------------------===// 442// NVVM synchronization op definitions 443//===----------------------------------------------------------------------===// 444 445def NVVM_Barrier0Op : NVVM_IntrOp<"barrier0"> { 446 let assemblyFormat = "attr-dict"; 447} 448 449def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> { 450 let arguments = (ins 451 Optional<I32>:$barrierId, 452 Optional<I32>:$numberOfThreads); 453 string llvmBuilder = [{ 454 if ($numberOfThreads && $barrierId) { 455 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier, 456 {$barrierId, $numberOfThreads}); 457 } else if($barrierId) { 458 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_n, 459 {$barrierId}); 460 } else { 461 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0); 462 } 463 }]; 464 let hasVerifier = 1; 465 let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict"; 466} 467 468def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive"> 469{ 470 let arguments = (ins Optional<I32>:$barrierId, I32:$numberOfThreads); 471 472 let description = [{ 473 Thread that executes this op announces their arrival at the barrier with 474 given id and continue their execution. 475 476 The default barrier id is 0 that is similar to `nvvm.barrier` Op. When 477 `barrierId` is not present, the default barrier id is used. 478 479 [For more information, see PTX ISA] 480 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar) 481 }]; 482 483 let assemblyFormat = "(`id` `=` $barrierId^)? `number_of_threads` `=` $numberOfThreads attr-dict"; 484 485 let extraClassDefinition = [{ 486 std::string $cppClass::getPtx() { 487 std::string ptx = "bar.arrive "; 488 if (getBarrierId()) { ptx += "%0, %1;"; } 489 else { ptx += "0, %0;"; } 490 return ptx; 491 } 492 }]; 493} 494 495def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> { 496 let arguments = (ins OptionalAttr<UnitAttr>:$aligned); 497 498 let summary = "Cluster Barrier Arrive Op"; 499 let description = [{ 500 The `cluster.arrive` can be used by the threads within the cluster for synchronization and 501 communication. The `cluster.arrive` instruction marks the warps' arrival at the barrier 502 without causing the executing thread to wait for other participating threads. 503 504 The `aligned` attribute, when provided, generates the .aligned version of the PTX instruction. 505 506 [For more information, see PTX ISA] 507 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) 508 }]; 509 510 string llvmBuilder = [{ 511 if ($aligned) 512 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_arrive_aligned); 513 else 514 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_arrive); 515 }]; 516 let assemblyFormat = "attr-dict"; 517} 518 519def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> { 520 let arguments = (ins OptionalAttr<UnitAttr>:$aligned); 521 522 let summary = "Cluster Barrier Relaxed Arrive Op"; 523 let description = [{ 524 The `cluster.arrive` can be used by the threads within the cluster for synchronization and 525 communication. The `cluster.arrive` instruction marks the warps' arrival at the barrier 526 without causing the executing thread to wait for other participating threads. 527 528 The `aligned` attribute, when provided, generates the .aligned version of the PTX instruction. 529 The .relaxed qualifier on `cluster.arrive` specifies that there are no memory 530 ordering and visibility guarantees provided for the memory accesses performed prior to 531 `cluster.arrive`. 532 533 [For more information, see PTX ISA] 534 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) 535 }]; 536 537 string llvmBuilder = [{ 538 if ($aligned) 539 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_arrive_relaxed_aligned); 540 else 541 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_arrive_relaxed); 542 }]; 543 let assemblyFormat = "attr-dict"; 544} 545 546def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait"> { 547 let arguments = (ins OptionalAttr<UnitAttr>:$aligned); 548 549 let summary = "Cluster Barrier Wait Op"; 550 let description = [{ 551 The `cluster.wait` causes the executing thread to wait for all non-exited threads 552 of the cluster to perform `cluster.arrive`. The `aligned` attribute, when provided, 553 generates the .aligned version of the PTX instruction. 554 555 [For more information, see PTX ISA] 556 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster) 557 }]; 558 559 string llvmBuilder = [{ 560 if ($aligned) 561 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_wait_aligned); 562 else 563 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_wait); 564 }]; 565 let assemblyFormat = "attr-dict"; 566} 567 568def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> { 569 string llvmBuilder = [{ 570 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_fence_sc_cluster); 571 }]; 572 let assemblyFormat = "attr-dict"; 573} 574 575def SharedSpaceCTA : I32EnumAttrCase<"shared_cta", 0, "cta">; 576def SharedSpaceCluster : I32EnumAttrCase<"shared_cluster", 1, "cluster">; 577def SharedSpace : I32EnumAttr<"SharedSpace", "Shared memory space", 578 [SharedSpaceCTA, SharedSpaceCluster]> { 579 let genSpecializedAttr = 0; 580 let cppNamespace = "::mlir::NVVM"; 581} 582def SharedSpaceAttr : EnumAttr<NVVM_Dialect, SharedSpace, "shared_space"> { 583 let assemblyFormat = "`<` $value `>`"; 584} 585 586def ProxyAlias : I32EnumAttrCase<"alias", 0, "alias">; 587def ProxyAsync : I32EnumAttrCase<"async", 1, "async">; 588def ProxyAsyncGlobal : I32EnumAttrCase<"async_global", 2, "async.global">; 589def ProxyAsyncShared : I32EnumAttrCase<"async_shared", 3, "async.shared">; 590def ProxyTensorMap : I32EnumAttrCase<"TENSORMAP", 4, "tensormap">; 591def ProxyGeneric : I32EnumAttrCase<"GENERIC", 5, "generic">; 592def ProxyKind : I32EnumAttr<"ProxyKind", "Proxy kind", 593 [ProxyAlias, ProxyAsync, ProxyAsyncGlobal, ProxyAsyncShared, ProxyTensorMap, ProxyGeneric]> { 594 let genSpecializedAttr = 0; 595 let cppNamespace = "::mlir::NVVM"; 596} 597 598def ProxyKindAttr : EnumAttr<NVVM_Dialect, ProxyKind, "proxy_kind"> { 599 let assemblyFormat = "`<` $value `>`"; 600} 601 602def NVVM_FenceProxyOp : NVVM_PTXBuilder_Op<"fence.proxy">, 603 Arguments<(ins ProxyKindAttr:$kind, 604 OptionalAttr<SharedSpaceAttr>:$space)> { 605 let description = [{ 606 Fence operation with proxy to establish an ordering between memory accesses 607 that may happen through different proxies. 608 [For more information, see PTX ISA] 609 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) 610 }]; 611 612 let assemblyFormat = "attr-dict"; 613 let extraClassDefinition = [{ 614 std::string $cppClass::getPtx() { 615 std::string ptx = "fence.proxy."; 616 ptx += stringifyProxyKind(getKind()); 617 if(getKind() == NVVM::ProxyKind::async_shared) 618 { ptx += "::"; ptx += stringifySharedSpace(getSpace().value()); } 619 ptx += ";"; 620 return ptx; 621 } 622 }]; 623 let hasVerifier = 1; 624} 625 626// Attrs describing the scope of the Memory Operation 627def MemScopeKindCTA : I32EnumAttrCase<"CTA", 0, "cta">; 628def MemScopeKindCluster : I32EnumAttrCase<"CLUSTER", 1, "cluster">; 629def MemScopeKindGPU : I32EnumAttrCase<"GPU", 2, "gpu">; 630def MemScopeKindSYS : I32EnumAttrCase<"SYS", 3, "sys">; 631 632def MemScopeKind : I32EnumAttr<"MemScopeKind", "NVVM Memory Scope kind", 633 [MemScopeKindCTA, MemScopeKindCluster, MemScopeKindGPU, MemScopeKindSYS]> { 634 let genSpecializedAttr = 0; 635 let cppNamespace = "::mlir::NVVM"; 636} 637def MemScopeKindAttr : EnumAttr<NVVM_Dialect, MemScopeKind, "mem_scope"> { 638 let assemblyFormat = "`<` $value `>`"; 639} 640 641def NVVM_FenceProxyAcquireOp : NVVM_Op<"fence.proxy.acquire">, 642 Arguments<(ins MemScopeKindAttr:$scope, LLVM_PointerGeneric:$addr, I32:$size, 643 DefaultValuedAttr<ProxyKindAttr, 644 "ProxyKind::GENERIC">:$fromProxy, 645 DefaultValuedAttr<ProxyKindAttr, 646 "ProxyKind::TENSORMAP">:$toProxy)> { 647 let summary = "Uni-directional proxy fence operation with acquire semantics"; 648 let description = [{ 649 `fence.proxy.acquire` is a uni-directional fence used to establish ordering 650 between a prior memory access performed via the generic proxy and a 651 subsequent memory access performed via the tensormap proxy 652 653 The address operand `addr` and the operand `size` together specify the 654 memory range `[addr, addr+size)` on which the ordering guarantees on the 655 memory accesses across the proxies is to be provided. The only supported 656 value for the `size` operand is 128 and must be an immediate. Generic Addressing 657 is used unconditionally, and the address specified by the operand `addr` must 658 fall within the `.global` state space. Otherwise, the behavior is undefined 659 [For more information, see PTX ISA] 660 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) 661 }]; 662 663 let assemblyFormat = "$scope $addr `,` $size (`from_proxy` `=` $fromProxy^)? (`to_proxy` `=` $toProxy^)? attr-dict"; 664 let llvmBuilder = [{ 665 createIntrinsicCall( 666 builder, 667 getUnidirectionalFenceProxyID($fromProxy, $toProxy, $scope, false), 668 {$addr, $size}); 669 }]; 670 671 let hasVerifier = 1; 672} 673 674def NVVM_FenceProxyReleaseOp : NVVM_Op<"fence.proxy.release">, 675 Arguments<(ins MemScopeKindAttr:$scope, 676 DefaultValuedAttr<ProxyKindAttr, 677 "ProxyKind::GENERIC">:$fromProxy, 678 DefaultValuedAttr<ProxyKindAttr, 679 "ProxyKind::TENSORMAP">:$toProxy)> { 680 let summary = "Uni-directional proxy fence operation with release semantics"; 681 let description = [{ 682 `fence.proxy.release` is a uni-directional fence used to establish ordering 683 between a prior memory access performed via the generic proxy and a 684 subsequent memory access performed via the tensormap proxy. `fence.proxy.release` 685 operation can form a release sequence that synchronizes with an acquire 686 sequence that contains the fence.proxy.acquire proxy fence operation 687 [For more information, see PTX ISA] 688 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) 689 }]; 690 691 let assemblyFormat = "$scope (`from_proxy` `=` $fromProxy^)? (`to_proxy` `=` $toProxy^)? attr-dict"; 692 let llvmBuilder = [{ 693 createIntrinsicCall(builder, getUnidirectionalFenceProxyID( 694 $fromProxy, $toProxy, $scope, true)); 695 }]; 696 697 let hasVerifier = 1; 698} 699 700def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>; 701def SetMaxRegisterActionDecrease : I32EnumAttrCase<"decrease", 1>; 702def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action", 703 [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> { 704 let genSpecializedAttr = 0; 705 let cppNamespace = "::mlir::NVVM"; 706} 707def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">; 708 709def NVVM_SetMaxRegisterOp : NVVM_Op<"setmaxregister"> { 710 let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action); 711 let assemblyFormat = "$action $regCount attr-dict"; 712 let hasVerifier = 1; 713 string llvmBuilder = [{ 714 auto intId = (op.getAction() == NVVM::SetMaxRegisterAction::increase) ? 715 llvm::Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32 : 716 llvm::Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32; 717 718 createIntrinsicCall(builder, intId, builder.getInt32($regCount)); 719 }]; 720} 721 722def NVVM_FenceMbarrierInitOp : NVVM_PTXBuilder_Op<"fence.mbarrier.init"> { 723 let arguments = (ins ); 724 let description = [{ 725 Fence operation that applies on the prior nvvm.mbarrier.init 726 [For more information, see PTX ISA] 727 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) 728 }]; 729 730 let assemblyFormat = "attr-dict"; 731 let extraClassDefinition = [{ 732 std::string $cppClass::getPtx() { 733 return std::string("fence.mbarrier_init.release.cluster;"); 734 } 735 }]; 736} 737 738def ShflKindBfly : I32EnumAttrCase<"bfly", 0>; 739def ShflKindUp : I32EnumAttrCase<"up", 1>; 740def ShflKindDown : I32EnumAttrCase<"down", 2>; 741def ShflKindIdx : I32EnumAttrCase<"idx", 3>; 742 743/// Enum attribute of the different shuffle kinds. 744def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind", 745 [ShflKindBfly, ShflKindUp, ShflKindDown, ShflKindIdx]> { 746 let genSpecializedAttr = 0; 747 let cppNamespace = "::mlir::NVVM"; 748} 749def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">; 750 751def NVVM_ShflOp : 752 NVVM_Op<"shfl.sync">, 753 Results<(outs LLVM_Type:$res)>, 754 Arguments<(ins I32:$thread_mask, 755 LLVM_Type:$val, 756 I32:$offset, 757 I32:$mask_and_clamp, 758 ShflKindAttr:$kind, 759 OptionalAttr<UnitAttr>:$return_value_and_is_valid)> { 760 let summary = "NVVM Dialect Op for shfl.sync"; 761 let description = [{ 762 The `shfl.sync` Op implements data shuffle within threads of a warp. 763 The `thread_mask` denotes the threads participating in the Op where 764 the bit position corresponds to a particular thread’s laneid. 765 The `offset` specifies a source lane or source lane offset 766 (depending on `kind`). The `val` is the input value to be copied from 767 the source. The `mask_and_clamp` contains two packed values specifying 768 a mask for logically splitting warps into sub-segments and an upper bound 769 for clamping the source lane index. 770 [For more information, refer PTX ISA] 771 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync) 772 }]; 773 string llvmBuilder = [{ 774 auto intId = getShflIntrinsicId( 775 $_resultType, $kind, static_cast<bool>($return_value_and_is_valid)); 776 $res = createIntrinsicCall(builder, 777 intId, {$thread_mask, $val, $offset, $mask_and_clamp}); 778 }]; 779 let assemblyFormat = [{ 780 $kind $thread_mask `,` $val `,` $offset `,` $mask_and_clamp attr-dict 781 `:` type($val) `->` type($res) 782 }]; 783 let hasVerifier = 1; 784} 785 786def NVVM_VoteBallotOp : 787 NVVM_Op<"vote.ballot.sync">, 788 Results<(outs LLVM_Type:$res)>, 789 Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> { 790 string llvmBuilder = [{ 791 $res = createIntrinsicCall(builder, 792 llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred}); 793 }]; 794 let hasCustomAssemblyFormat = 1; 795} 796 797def NVVM_SyncWarpOp : 798 NVVM_Op<"bar.warp.sync">, 799 Arguments<(ins LLVM_Type:$mask)> { 800 string llvmBuilder = [{ 801 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_bar_warp_sync, {$mask}); 802 }]; 803 let assemblyFormat = "$mask attr-dict `:` type($mask)"; 804} 805 806def NVVM_ElectSyncOp : NVVM_Op<"elect.sync"> 807{ 808 let summary = "Elect one leader thread"; 809 let description = [{ 810 The `elect.sync` instruction elects one predicated active leader 811 thread from among a set of threads specified in membermask. 812 The membermask is set to `0xFFFFFFFF` for the current version 813 of this Op. The predicate result is set to `True` for the 814 leader thread, and `False` for all other threads. 815 816 [For more information, see PTX ISA] 817 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync) 818 }]; 819 820 let results = (outs I1:$pred); 821 let assemblyFormat = "attr-dict `->` type(results)"; 822 string llvmBuilder = [{ 823 auto *resultTuple = createIntrinsicCall(builder, 824 llvm::Intrinsic::nvvm_elect_sync, {builder.getInt32(0xFFFFFFFF)}); 825 // Extract the second value into $pred 826 $pred = builder.CreateExtractValue(resultTuple, 1); 827 }]; 828} 829 830def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">; 831def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">; 832def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">; 833def LoadCacheModifierLU : I32EnumAttrCase<"LU", 3, "lu">; 834def LoadCacheModifierCV : I32EnumAttrCase<"CV", 4, "cv">; 835 836/// Enum attribute of the different kinds. 837def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind", 838 "NVVM load cache modifier kind", 839 [LoadCacheModifierCA, LoadCacheModifierCG, LoadCacheModifierCS, 840 LoadCacheModifierLU, LoadCacheModifierCV]> { 841 let genSpecializedAttr = 0; 842 let cppNamespace = "::mlir::NVVM"; 843 let description = [{ 844 Enum attribute of the different kinds of cache operators for load instructions. 845 846 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#id62) 847 }]; 848} 849 850def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">; 851 852def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">, 853 Arguments<(ins LLVM_PointerShared:$dst, 854 LLVM_PointerGlobal:$src, 855 I32Attr:$size, 856 LoadCacheModifierAttr:$modifier, 857 Optional<LLVM_Type>:$cpSize)> { 858 let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)"; 859 let hasVerifier = 1; 860 let extraClassDeclaration = [{ 861 static llvm::Intrinsic::ID 862 getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, 863 llvm::SmallVector<llvm::Value *> &args); 864 }]; 865 string llvmBuilder = [{ 866 llvm::SmallVector<llvm::Value *> translatedOperands; 867 auto id = NVVM::CpAsyncOp::getIntrinsicIDAndArgs( 868 *op, moduleTranslation, translatedOperands); 869 createIntrinsicCall(builder, id, translatedOperands); 870 }]; 871} 872 873def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> { 874 string llvmBuilder = [{ 875 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_commit_group); 876 }]; 877 let assemblyFormat = "attr-dict"; 878} 879 880def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">, 881 Arguments<(ins I32Attr:$n)> { 882 string llvmBuilder = [{ 883 createIntrinsicCall( 884 builder, 885 llvm::Intrinsic::nvvm_cp_async_wait_group, 886 llvm::ConstantInt::get( 887 llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext()), 888 $n)); 889 }]; 890 let assemblyFormat = "$n attr-dict"; 891} 892 893def NVVM_CpAsyncMBarrierArriveOp : NVVM_Op<"cp.async.mbarrier.arrive"> { 894 let summary = "NVVM Dialect Op for cp.async.mbarrier.arrive"; 895 let description = [{ 896 The `cp.async.mbarrier.arrive` Op makes the mbarrier object track 897 all prior cp.async operations initiated by the executing thread. 898 The `addr` operand specifies the address of the mbarrier object 899 in generic address space. The `noinc` attr impacts how the 900 mbarrier's state is updated. 901 [For more information, refer PTX ISA] 902 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive) 903 }]; 904 let assemblyFormat = "$addr attr-dict `:` type(operands)"; 905 906 let arguments = (ins 907 LLVM_AnyPointer:$addr, DefaultValuedAttr<I1Attr, "0">:$noinc); 908 909 string llvmBuilder = [{ 910 auto intId = $noinc ? 911 llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc : 912 llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive; 913 914 createIntrinsicCall(builder, intId, {$addr}); 915 }]; 916} 917 918def NVVM_CpAsyncMBarrierArriveSharedOp : NVVM_Op<"cp.async.mbarrier.arrive.shared"> { 919 let summary = "NVVM Dialect Op for cp.async.mbarrier.arrive.shared"; 920 let description = [{ 921 The `cp.async.mbarrier.arrive.shared` Op makes the mbarrier object 922 track all prior cp.async operations initiated by the executing thread. 923 The `addr` operand specifies the address of the mbarrier object in 924 shared memory. The `noinc` attr impacts how the mbarrier's state 925 is updated. [For more information, refer PTX ISA] 926 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive) 927 }]; 928 let assemblyFormat = "$addr attr-dict `:` type(operands)"; 929 930 let arguments = (ins 931 LLVM_PointerShared:$addr, DefaultValuedAttr<I1Attr, "0">:$noinc); 932 933 string llvmBuilder = [{ 934 auto intId = $noinc ? 935 llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared : 936 llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared; 937 938 createIntrinsicCall(builder, intId, {$addr}); 939 }]; 940} 941 942//===----------------------------------------------------------------------===// 943// NVVM Conversion Ops (for "cvt.*" family of PTX instructions) 944//===----------------------------------------------------------------------===// 945 946// Attributes for the floating point rounding modes supported by PTX 947def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">; 948def FPRoundingModeRN : I32EnumAttrCase<"RN", 1, "rn">; 949def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">; 950def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">; 951def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">; 952def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">; 953 954def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind", 955 [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM, 956 FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> { 957 let genSpecializedAttr = 0; 958 let cppNamespace = "::mlir::NVVM"; 959} 960def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> { 961 let assemblyFormat = "`<` $value `>`"; 962} 963 964def SaturationModeNone : I32EnumAttrCase<"NONE", 0, "none">; 965def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">; 966 967def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind", 968 [SaturationModeNone, SaturationModeFinite]> { 969 let genSpecializedAttr = 0; 970 let cppNamespace = "::mlir::NVVM"; 971} 972def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> { 973 let assemblyFormat = "`<` $value `>`"; 974} 975 976def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> { 977 let summary = "Convert the given float input to TF32"; 978 let description = [{ 979 This Op converts the given f32 input to tf32. 980 The result `res` is represented as an i32 type. 981 The `relu` attribute, when set, lowers to the '.relu' variant of 982 the cvt instruction. The `rnd` and `sat` attributes specify the 983 the rounding and saturation modes respectively. 984 [For more information, see PTX ISA] 985 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) 986 }]; 987 988 let hasVerifier = 1; 989 let results = (outs I32:$res); 990 let arguments = (ins 991 F32:$src, 992 DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd, 993 DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat, 994 DefaultValuedAttr<BoolAttr, "false">:$relu); 995 996 let assemblyFormat = "$src attr-dict"; 997 998 let extraClassDeclaration = [{ 999 static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode, 1000 NVVM::SaturationMode, 1001 bool hasRelu); 1002 }]; 1003 1004 string llvmBuilder = [{ 1005 auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu); 1006 $res = createIntrinsicCall(builder, intId, {$src}); 1007 }]; 1008} 1009 1010//===----------------------------------------------------------------------===// 1011// NVVM MMA Ops 1012//===----------------------------------------------------------------------===// 1013/// Helpers to instantiate different version of wmma intrinsics. 1014/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the 1015/// combinations of the intrinsics. 1016class GEOM<int M, int N, int K> { 1017 int m = M; 1018 int n = N; 1019 int k = K; 1020} 1021 1022/// Class containing information about valid mma matrix types. 1023class WMMA_REGS<GEOM Geom, string Frag, string PtxEltType> { 1024 int m = Geom.m; 1025 int n = Geom.n; 1026 int k = Geom.k; 1027 string geom = "m"#Geom.m#"n"#Geom.n#"k"#Geom.k; 1028 string frag = Frag; 1029 string ptx_elt_type = PtxEltType; 1030 string gft = geom#":"#Frag#":"#ptx_elt_type; 1031} 1032 1033//// Generate enum value of the mma.load/mma.store intrinsic. 1034class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> { 1035 string id = "llvm::Intrinsic::nvvm_wmma" 1036 # "_" # Frag.geom 1037 # "_" # Op 1038 # "_" # Frag.frag 1039 # "_" # Frag.ptx_elt_type 1040 # "_" # Layout 1041 # !if(WithStride, "_stride", ""); 1042} 1043 1044/// Generate the signature part of the mma intrinsic name. 1045class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> { 1046 list<WMMA_REGS> id_frags = !cond( 1047 // FP16 ops are identified by accumulator & result type. 1048 !eq(A.ptx_elt_type, "f16") : [D, C], 1049 // other ops are identified by input types. 1050 !ne(A.ptx_elt_type, B.ptx_elt_type): [A, B], 1051 true: [A] 1052 ); 1053 string ret = !foldl("", id_frags, a, b, !strconcat(a, "_", b.ptx_elt_type)); 1054} 1055 1056/// Generate enum value of the wmma.mma intrinsic. 1057class WMMA_NAME<string Op, string ALayout, string BLayout, WMMA_REGS A, 1058 WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> { 1059 string signature = MMA_SIGNATURE<A, B, C, D>.ret; 1060 string id = "llvm::Intrinsic::nvvm_wmma" 1061 # "_" # A.geom 1062 # "_" # Op 1063 # "_" # ALayout 1064 # "_" # BLayout 1065 # signature; 1066} 1067 1068// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. 1069// Geom: list of supported geometries. 1070// TypeN: PTX type of the corresponding fragment's element. 1071// TypeB and TypeD may be empty if it must match that of TypeA or TypeC. 1072class MMA_OPS<list<GEOM> Geom, list<string> TypeA, list<string> TypeB, 1073 list<string> TypeC, list<string> TypeD> { 1074 list<list<WMMA_REGS>> ret = 1075 !foldl([]<list<WMMA_REGS>>, Geom, t1, geom, !listconcat(t1, 1076 !foldl([]<list<WMMA_REGS>>, TypeA, t2, type_a, !listconcat(t2, 1077 !foldl([]<list<WMMA_REGS>>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3, 1078 !foldl([]<list<WMMA_REGS>>, TypeC, t4, type_c, !listconcat(t4, 1079 !foldl([]<list<WMMA_REGS>>, !if(!size(TypeD), TypeD, [type_c]), t5, type_d, !listconcat(t5, 1080 [[WMMA_REGS<geom, "a", type_a>, 1081 WMMA_REGS<geom, "b", type_b>, 1082 WMMA_REGS<geom, "c", type_c>, 1083 WMMA_REGS<geom, "d", type_d>]])))))))))); 1084 // Debugging aid for readable representation of the list above. 1085 list<list<string>> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]); 1086} 1087 1088/// Creates a list of combinations of load/store operations supported. 1089class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> { 1090 list<WMMA_REGS> ret = 1091 !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1, 1092 !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2, 1093 !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3, 1094 [WMMA_REGS<geom, frag, type>])))))); 1095 // Debugging aid for readable representation of the list above. 1096 list<string> ops = !foreach(x, ret, x.gft); 1097} 1098 1099// Creates list of valid combinations of fragments. This is a subset of what 1100// llvm supports and can be extended as needed. 1101class NVVM_MMA_OPS { 1102 // "wmma" operations 1103 list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS< 1104 [GEOM<16, 16, 8>], 1105 ["tf32"], [], ["f32"], []>.ret; 1106 list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS< 1107 [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], 1108 ["f16"], [], ["f16", "f32"], []>.ret; 1109 list<list<WMMA_REGS>> i8_wmma_ops = MMA_OPS< 1110 [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], 1111 ["s8","u8"], [], ["s32"], []>.ret; 1112 list<list<WMMA_REGS>> all_wmma_ops = !listconcat( 1113 tf32_wmma_ops, 1114 fp_wmma_ops, 1115 i8_wmma_ops); 1116 1117 list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS< 1118 [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], 1119 ["a", "b"], ["f16","s8","u8"]>.ret; 1120 list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS< 1121 [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], 1122 ["c", "d"], ["f16", "f32","s32"]>.ret; 1123 list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS< 1124 [GEOM<16, 16, 8>], 1125 ["a", "b"], ["tf32"]>.ret; 1126 list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS< 1127 [GEOM<16, 16, 8>], 1128 ["c", "d"], ["f32"]>.ret; 1129 list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops, 1130 ldst_tf32_ab_ops, 1131 ldst_tf32_cd_ops); 1132 // Separate A/B/C fragments (loads) from D (stores). 1133 list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d")); 1134 list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d")); 1135 1136 // "mma_sync" operations 1137 list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS< 1138 [GEOM<16,8,4>, GEOM<16,8,8>], 1139 ["tf32"], [], ["f32"], []>.ret; 1140 list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS< 1141 [GEOM<16,8,16>, GEOM<16,8,8>], 1142 ["bf16"], [], ["f32"], []>.ret; 1143 list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS< 1144 [GEOM<8,8,4>], 1145 ["f64"], [], ["f64"], []>.ret; 1146 list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS< 1147 [GEOM<8,8,4>, GEOM<16,8,8>, GEOM<16,8,16>], 1148 ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; 1149 list<list<WMMA_REGS>> int_mma_ops = MMA_OPS< 1150 [GEOM<8,8,16>, GEOM<16,8,16>, GEOM<16,8,32>], 1151 ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret; 1152 list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS< 1153 [GEOM<8,8,32>, GEOM<16,8,32>, GEOM<16,8,64>], 1154 ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret; 1155 list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS< 1156 [GEOM<8,8,128>, GEOM<16,8,128>, GEOM<16,8,256>], 1157 ["b1"], [], ["s32"], []>.ret; 1158 list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat( 1159 tf32_mma_ops, bf16_mma_ops, f64_mma_ops, 1160 fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops); 1161} 1162 1163def NVVM_MMA_OPS : NVVM_MMA_OPS; 1164 1165/// Helper to create the mapping between the configuration and the store 1166/// intrinsic enum value. 1167class MMA_ST_INTR<string op> { 1168 list<list<string>> cond0 = !foreach(frag, NVVM_MMA_OPS.all_st_ops, 1169 !foreach(layout, ["row", "col"], 1170 "if (layout == \"" # layout # "\" && m == " # frag.m # " &&" 1171 " n == " #frag.n # " && k == " # frag.k # " && \"" # 1172 frag.ptx_elt_type # "\" == eltype)" 1173 " return " #WMMA_NAME_LDST<op, frag, layout, 1>.id #";")); 1174 string id = !foldl("", 1175 !foldl([""], cond0, acc, el, !listconcat(acc, el)), 1176 acc1, el1, acc1 # "\n" # el1); 1177} 1178 1179/// Helper to map a mxk shape to a supported mxnxk matrix type. This will return 1180/// the n value of the supported configuration. 1181class MMA_ST_INFER_N<list<WMMA_REGS> ldst> { 1182 list<string> cond = !foreach(frag, ldst, 1183 "if (m == " # frag.m # " && k == " #frag.k # " && \"" # 1184 frag.ptx_elt_type # "\" == eltype)" 1185 " return "# frag.n #";"); 1186 string id = !foldl("", cond, acc, el, acc # "\n" # el); 1187} 1188 1189/// Helper to map a kxn shape to a supported mxnxk matrix type. This will return 1190/// the m value of the supported configuration. 1191class MMA_ST_INFER_M<list<WMMA_REGS> ldst> { 1192 list<string> cond = !foreach(frag, ldst, 1193 "if (n == " # frag.n # " && k == " #frag.k # " && \"" # 1194 frag.ptx_elt_type # "\" == eltype)" 1195 " return "# frag.m #";"); 1196 string id = !foldl("", cond, acc, el, acc # "\n" # el); 1197} 1198 1199/// Helper to map a mxn shape to a supported mxnxk matrix type. This will return 1200/// the k value of the supported configuration. 1201class MMA_ST_INFER_K<list<WMMA_REGS> ldst> { 1202 list<string> cond = !foreach(frag, ldst, 1203 "if (m == " # frag.m # " && n == " #frag.n # " && \"" # 1204 frag.ptx_elt_type # "\" == eltype)" 1205 " return "# frag.k #";"); 1206 string id = !foldl("", cond, acc, el, acc # "\n" # el); 1207} 1208 1209/// Helper to create the mapping between the configuration and the load 1210/// intrinsic enum value. 1211class MMA_LD_INTR<string op> { 1212 list<list<string>> cond0 = !foreach(frag, NVVM_MMA_OPS.all_ld_ops, 1213 !foreach(layout, ["row", "col"], 1214 "if (layout == \"" # layout # "\" && m == " # frag.m # " &&" 1215 " n == " #frag.n # " && k == " # frag.k # " && \"" # 1216 frag.ptx_elt_type # "\" == eltype && frag == \""#frag.frag#"\")" 1217 " return "# WMMA_NAME_LDST<op, frag, layout, 1>.id #";")); 1218 string id = !foldl("", 1219 !foldl([""], cond0, acc, el, !listconcat(acc, el)), 1220 acc1, el1, acc1 # "\n" # el1); 1221} 1222 1223/// Helper to create the mapping between the configuration and the wmma.mma 1224/// intrinsic enum value. 1225class MMA_MMA_INTR<string opName> { 1226 list<list<list<string>>> cond0 = 1227 !foreach(op, NVVM_MMA_OPS.all_wmma_ops, 1228 !foreach(layoutA, ["row", "col"], 1229 !foreach(layoutB, ["row", "col"], 1230 "if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && " 1231 " m == " # op[0].m # " && n == " #op[0].n # " && k == " # op[0].k # 1232 " && \"" # op[0].ptx_elt_type # "\" == eltypeA && \"" 1233 # op[3].ptx_elt_type # "\" == eltypeB)" 1234 " return " # 1235 WMMA_NAME<opName, layoutA, layoutB, op[0], op[1], op[2], op[3]>.id # ";"))); 1236 list<string> f = !foldl([""], 1237 !foldl([[""]], cond0, acc, el, !listconcat(acc, el)), 1238 acc1, el1, !listconcat(acc1, el1)); 1239 string id = !foldl("", f, acc, el, acc # "\n" # el); 1240} 1241 1242/// Enum attribute for binary (b1) MMA operation type 1243def MMAB1OpNone : I32EnumAttrCase<"none", 0>; 1244def MMAB1OpXorPopc : I32EnumAttrCase<"xor_popc", 1>; 1245def MMAB1OpAndPopc : I32EnumAttrCase<"and_popc", 2>; 1246def MMAB1Op : I32EnumAttr<"MMAB1Op", "MMA binary operations", 1247 [MMAB1OpNone, MMAB1OpXorPopc, MMAB1OpAndPopc]> { 1248 let genSpecializedAttr = 0; 1249 let cppNamespace = "::mlir::NVVM"; 1250} 1251def MMAB1OpAttr : EnumAttr<NVVM_Dialect, MMAB1Op, "mma_b1op"> { 1252 let assemblyFormat = "`<` $value `>`"; 1253} 1254 1255/// Enum attribute type for the overflow behavior of MMA integer operations 1256def MMAIntOverflowWrap : I32EnumAttrCase<"wrapped", 0>; 1257def MMAIntOverflowSat : I32EnumAttrCase<"satfinite", 1>; 1258def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options", 1259 [MMAIntOverflowSat, MMAIntOverflowWrap]> { 1260 let genSpecializedAttr = 0; 1261 let cppNamespace = "::mlir::NVVM"; 1262} 1263def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> { 1264 let assemblyFormat = "`<` $value `>`"; 1265} 1266 1267/// Attribute to hold the MMA shape 1268def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> { 1269 let summary = "Attribute for MMA operation shape."; 1270 let parameters = (ins "int":$m, "int":$n, "int":$k); 1271 let assemblyFormat = "`<` struct(params) `>`"; 1272} 1273 1274// Returns true if this combination of layout/satf for MMA ops is supported; 1275// false otherwise. 1276// E.g. 1277// if NVVM_MMA_SUPPORTED<...>.ret then 1278// def : FOO<>; // The record will only be defined for supported ops. 1279// 1280class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> { 1281 // MMA ops check both layouts. 1282 string layout = layout_a # ":" # layout_b; 1283 string a_type = frags[0].ptx_elt_type; 1284 string b_type = frags[1].ptx_elt_type; 1285 string c_type = frags[2].ptx_elt_type; 1286 string d_type = frags[3].ptx_elt_type; 1287 string geom = frags[0].geom; 1288 1289 // gcd is a shortcut used to identify instructions that depend on 1290 // geom+frag_c+frag_d. 1291 string gcd = geom # ":" # c_type # d_type; 1292 bit ret = !cond( 1293 1294 // Limit satf to valid types 1295 !and(!eq(satf, 1), 1296 !ne(a_type, "s8"), 1297 !ne(a_type, "u8"), 1298 !ne(a_type, "s4"), 1299 !ne(a_type, "u4")): false, 1300 1301 // m8n8k4 has no C=f32 D=f16 variant. 1302 !eq(gcd, "m8n8k4:f32f16"): false, 1303 1304 // only m8n8k4 for f16 does not require row:col layout 1305 !and(!ne(layout, "row:col"), 1306 !or(!ne(geom, "m8n8k4"), 1307 !ne(a_type, "f16"))) : false, 1308 1309 // m16n8k8 requires A and B to be the same type and C and D to be the same 1310 // type. 1311 !and(!eq(geom, "m16n8k8"), 1312 !or(!ne(a_type, b_type), 1313 !ne(c_type, d_type))): false, 1314 1315 // m16n8k8 requires C and D to be the same type. 1316 !and(!eq(geom, "m16n8k8"), 1317 !ne(c_type, d_type)): false, 1318 1319 // All other are OK. 1320 true: true 1321 ); 1322} 1323 1324// Returns a list of operation suffixes corresponding to possible b1 1325// multiply-and-accumulate operations for all fragments which have a 1326// b1 type. For all other fragments, the list returned holds a list 1327// containing the empty string. 1328class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> { 1329 list<string> ret = !cond( 1330 !eq(frags[0].ptx_elt_type, "b1") : ["xor_popc", "and_popc"], 1331 true: [""] 1332 ); 1333} 1334 1335/// Generate enum value of the mma.sync intrinsic. 1336class MMA_SYNC_NAME<string ALayout, string BLayout, string b1op, int Satfinite, 1337 WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> { 1338 string signature = MMA_SIGNATURE<A, B, C, D>.ret; 1339 string id = "llvm::Intrinsic::nvvm_mma" 1340 # !if(!ne(b1op, ""), "_" # b1op, "") 1341 # "_" # A.geom 1342 # "_" # ALayout 1343 # "_" # BLayout 1344 # !if(Satfinite, "_satfinite", "") 1345 # signature; 1346} 1347 1348/// Helper to create the mapping between the configuration and the mma.sync 1349/// intrinsic enum value. 1350class MMA_SYNC_INTR { 1351 list<list<list<list<list<string>>>>> cond0 = 1352 !foreach(op, NVVM_MMA_OPS.all_mma_sync_ops, 1353 !foreach(layoutA, ["row", "col"], 1354 !foreach(layoutB, ["row", "col"], 1355 !foreach (sat, [0, 1], 1356 !foreach (b1op, NVVM_MMA_B1OPS<op>.ret, 1357 !if(NVVM_MMA_SUPPORTED<[op[0], op[1], op[2], op[3]], 1358 layoutA, layoutB, sat>.ret, 1359 "if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && " 1360 " m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k # 1361 " && \"" # op[0].ptx_elt_type # "\" == eltypeA && \"" 1362 # op[1].ptx_elt_type # "\" == eltypeB && " 1363 # " \"" # op[2].ptx_elt_type # "\" == eltypeC && " 1364 # " \"" # op[3].ptx_elt_type # "\" == eltypeD " 1365 # " && (sat.has_value() ? " # sat # " == static_cast<int>(*sat) : true)" 1366 # !if(!ne(b1op, ""), " && (b1Op.has_value() ? MMAB1Op::" # b1op # " == *b1Op : true)", "") # ")\n" 1367 # " return " # 1368 MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";", 1369 "") // if supported 1370 ) // b1op 1371 ) // sat 1372 ) // layoutB 1373 ) // layoutA 1374 ); // all_mma_sync_ops 1375 list<list<list<string>>> f1 = !foldl([[[""]]], 1376 !foldl([[[[""]]]], cond0, acc, el, 1377 !listconcat(acc, el)), 1378 acc1, el1, !listconcat(acc1, el1)); 1379 list<list<string>> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1)); 1380 list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el)); 1381 string id = !foldl("", f3, acc, el, acc # "\n" # el); 1382} 1383 1384def MMALayoutRow : I32EnumAttrCase<"row", 0>; 1385def MMALayoutCol : I32EnumAttrCase<"col", 1>; 1386 1387/// Enum attribute of the different matrix layout. 1388def MMALayout : I32EnumAttr<"MMALayout", "NVVM MMA layout", 1389 [MMALayoutRow, MMALayoutCol]> { 1390 let genSpecializedAttr = 0; 1391 let cppNamespace = "::mlir::NVVM"; 1392} 1393def MMALayoutAttr : EnumAttr<NVVM_Dialect, MMALayout, "mma_layout"> { 1394 let assemblyFormat = "`<` $value `>`"; 1395} 1396 1397/// Enum attribute of the different PTX element types used for MMA operands. 1398def MMATypeF16 : I32EnumAttrCase<"f16", 0>; 1399def MMATypeF32 : I32EnumAttrCase<"f32", 1>; 1400def MMATypeTF32 : I32EnumAttrCase<"tf32", 2>; 1401def MMATypeU8 : I32EnumAttrCase<"u8", 3>; 1402def MMATypeS8 : I32EnumAttrCase<"s8", 4>; 1403def MMATypeS32 : I32EnumAttrCase<"s32", 5>; 1404def MMATypeB1 : I32EnumAttrCase<"b1", 6>; 1405def MMATypeU4 : I32EnumAttrCase<"u4", 7>; 1406def MMATypeS4 : I32EnumAttrCase<"s4", 8>; 1407def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>; 1408def MMATypeF64 : I32EnumAttrCase<"f64", 10>; 1409 1410def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types", 1411 [MMATypeF16, MMATypeF32, MMATypeTF32, 1412 MMATypeBF16, MMATypeS8, MMATypeU8, 1413 MMATypeS32, MMATypeS4, MMATypeU4, 1414 MMATypeB1, MMATypeF64]> { 1415 let genSpecializedAttr = 0; 1416 let cppNamespace = "::mlir::NVVM"; 1417} 1418def MMATypesAttr : EnumAttr<NVVM_Dialect, MMATypes, "mma_type"> { 1419 let assemblyFormat = "`<` $value `>`"; 1420} 1421 1422def MMAFragA : I32EnumAttrCase<"a", 0>; 1423def MMAFragB : I32EnumAttrCase<"b", 1>; 1424def MMAFragC : I32EnumAttrCase<"c", 2>; 1425 1426/// Enum attribute of the different frag types. 1427def MMAFrag: I32EnumAttr<"MMAFrag", "NVVM MMA frag type", 1428 [MMAFragA, MMAFragB, MMAFragC]> { 1429 let genSpecializedAttr = 0; 1430 let cppNamespace = "::mlir::NVVM"; 1431} 1432def MMAFragAttr : EnumAttr<NVVM_Dialect, MMAFrag, "mma_frag"> { 1433 let assemblyFormat = "`<` $value `>`"; 1434} 1435 1436def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">, 1437 Results<(outs LLVM_AnyStruct:$res)>, 1438 Arguments<(ins LLVM_AnyPointer: $ptr, I32: $stride, I32Attr:$m, 1439 I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layout, 1440 MMATypesAttr:$eltype, MMAFragAttr:$frag)> { 1441 1442 let summary = "Warp synchronous matrix load"; 1443 1444 // Since LLVM intrinsic IDs are enum that cannot be dynamically generated in 1445 // C++ we instanciate a function in tablegen to map the valide configuration 1446 // to the corresponsding intrinsic ID. 1447 // Because we want a single source of truth, this mean the source of truth 1448 // about valid combinations needs to be in tablgen, therefore we generate 1449 // extra helpers to query valid configurations based on the shapes of 1450 // load/store operations. 1451 let extraClassDeclaration = 1452 "static llvm::Intrinsic::ID getIntrinsicID(" 1453 "int m, int n, int k, mlir::NVVM::MMALayout layoutEnum," 1454 "mlir::NVVM::MMATypes eltypeEnum,mlir::NVVM::MMAFrag fragEnum) {" 1455 "llvm::StringRef layout = stringifyEnum(layoutEnum);" 1456 "llvm::StringRef eltype = stringifyEnum(eltypeEnum);" 1457 "llvm::StringRef frag = stringifyEnum(fragEnum);" 1458 #MMA_LD_INTR<"load">.id# "\n" 1459 "return 0;" 1460 "}\n" 1461 "/// Helpers to find valid n dimension based on mxk load shape.\n" 1462 "static int inferNDimension(int m, int k, mlir::NVVM::MMATypes eltypeEnum) {" 1463 " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" 1464 #MMA_ST_INFER_N<!filter(op, NVVM_MMA_OPS.all_ld_ops, !eq(op.frag, "a"))>.id# "\n" 1465 "return 0;" 1466 "}\n" 1467 "/// Helpers to find valid m dimension based on kxn load shape.\n" 1468 "static int inferMDimension(int k, int n, mlir::NVVM::MMATypes eltypeEnum) {" 1469 " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" 1470 #MMA_ST_INFER_M<!filter(op, NVVM_MMA_OPS.all_ld_ops, !eq(op.frag, "b"))>.id# "\n" 1471 "return 0;" 1472 "}\n" 1473 "/// Helpers to find valid k dimension based on mxn load shape.\n" 1474 "static int inferKDimension(int m, int n, mlir::NVVM::MMATypes eltypeEnum) {" 1475 " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" 1476 #MMA_ST_INFER_K<!filter(op, NVVM_MMA_OPS.all_ld_ops, !eq(op.frag, "c"))>.id# "\n" 1477 "return 0;" 1478 "}\n"; 1479 1480 1481 string llvmBuilder = [{ 1482 auto operands = moduleTranslation.lookupValues(opInst.getOperands()); 1483 auto intId = mlir::NVVM::WMMALoadOp::getIntrinsicID( 1484 $m, $n, $k, $layout, $eltype, $frag); 1485 $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); 1486 }]; 1487 1488 string baseDescription = [{ 1489 The `nvvm.wmma.load` operation loads a matrix collectively using all the 1490 threads in a warp. 1491 1492 The operation takes two arguments, the address from where the matrix 1493 elements are to be loaded from and a stride. The stride argument 1494 represents the leading dimension of the source matrix. The address and 1495 the stride are required to be the same across all threads in the warp. 1496 Each thread in a warp holds a certain number of elements. The Op returns 1497 a LLVMStruct which holds the elements of the matrix held by this thread. 1498 1499 This op is meant to be used along with `nvvm.wmma.store` and 1500 `nvvm.wmma.mma`. 1501 1502 Example: 1503 1504 ```mlir 1505 %2 = nvvm.wmma.load %0, %1 1506 {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} 1507 : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 1508 ``` 1509 }]; 1510 1511 let assemblyFormat = "$ptr `,` $stride attr-dict `:` functional-type($ptr, $res)"; 1512 let hasVerifier = 1; 1513} 1514 1515def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">, 1516 Arguments<(ins LLVM_AnyPointer: $ptr, 1517 I32Attr:$m, I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layout, 1518 MMATypesAttr:$eltype, Variadic<LLVM_Type>:$args, I32: $stride)>{ 1519 let summary = "Warp synchronous matrix store"; 1520 1521 let extraClassDeclaration = 1522 "static llvm::Intrinsic::ID getIntrinsicID(" 1523 "int m, int n, int k, mlir::NVVM::MMALayout layoutEnum," 1524 "mlir::NVVM::MMATypes eltypeEnum) {" 1525 " llvm::StringRef layout = stringifyEnum(layoutEnum);" 1526 " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" 1527 #MMA_ST_INTR<"store">.id# "\n" 1528 "return 0;" 1529 "}\n" 1530 "/// Helpers to find valid k dimension based on mxn store shape.\n" 1531 "static int inferKDimension(int m, int n, mlir::NVVM::MMATypes eltypeEnum) {" 1532 " llvm::StringRef eltype = stringifyEnum(eltypeEnum);" 1533 #MMA_ST_INFER_K<NVVM_MMA_OPS.all_st_ops>.id# "\n" 1534 "return 0;" 1535 "}"; 1536 1537 string llvmBuilder = [{ 1538 auto operands = moduleTranslation.lookupValues(opInst.getOperands()); 1539 auto intId = 1540 mlir::NVVM::WMMAStoreOp::getIntrinsicID($m, $n, $k, $layout, $eltype); 1541 createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); 1542 }]; 1543 1544 string baseDescription = [{ 1545 The `nvvm.wmma.store` operation stores a matrix collectively using 1546 all the threads in a warp. 1547 1548 The operation takes as arguments the address to where the matrix elements are 1549 to be stored, a stride and the elements to store, held by the current thread. 1550 The stride argument represents the leading dimension of the destination matrix. 1551 The address and the stride are required to be the same across all threads in the 1552 warp. 1553 1554 This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and 1555 `nvvm.wmma.m16n16k16.mma`. 1556 1557 Example: 1558 1559 ```mlir 1560 nvvm.wmma.store %0, %1, %2, %3, %4, %5 1561 {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} 1562 : !llvm.ptr<3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16> 1563 ``` 1564 }]; 1565 1566 let assemblyFormat = [{ 1567 $ptr `,` $stride `,` $args attr-dict `:` qualified(type($ptr)) `,` 1568 type($args) 1569 }]; 1570 let hasVerifier = 1; 1571} 1572 1573// Base class for all the variants of WMMA mmaOps that may be defined. 1574def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, 1575 Results<(outs LLVM_AnyStruct:$res)>, 1576 Arguments<(ins I32Attr:$m, I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layoutA, 1577 MMALayoutAttr:$layoutB, MMATypesAttr:$eltypeA, 1578 MMATypesAttr:$eltypeB, Variadic<LLVM_Type>:$args)>{ 1579 let summary = "Warp synchronous matrix-multiply accumulate using tensor cores."; 1580 1581 let extraClassDeclaration = 1582 "static llvm::Intrinsic::ID getIntrinsicID(" 1583 "int m, int n, int k, mlir::NVVM::MMALayout layoutAEnum," 1584 "mlir::NVVM::MMALayout layoutBEnum, mlir::NVVM::MMATypes eltypeAEnum," 1585 "mlir::NVVM::MMATypes eltypeBEnum) {" 1586 "llvm::StringRef layoutA = stringifyEnum(layoutAEnum);" 1587 "llvm::StringRef layoutB = stringifyEnum(layoutBEnum);" 1588 "llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);" 1589 "llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);" 1590 #MMA_MMA_INTR<"mma">.id# "\n" 1591 "return 0;" 1592 "}"; 1593 1594 string llvmBuilder = [{ 1595 auto operands = moduleTranslation.lookupValues(opInst.getOperands()); 1596 auto intId = mlir::NVVM::WMMAMmaOp::getIntrinsicID( 1597 $m, $n, $k, $layoutA, $layoutB, $eltypeA, $eltypeB); 1598 $res = createIntrinsicCall(builder, intId, operands); 1599 }]; 1600 1601 string baseDescription = [{ 1602 The `nvvm.wmma.mma` operation performs a matrix-multiply accumulate 1603 (mma) operation using all the threads in a warp. 1604 1605 The operation performed is represented as `D = A * B + C`. The operation takes 1606 as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the 1607 current thread. The op returns a LLVM struct which holds a part of the result 1608 held by the current thread. 1609 1610 This op is meant to be used along with `nvvm.wmma.load` and 1611 `nvvm.wmma.store`. 1612 1613 Example: 1614 1615 ```mlir 1616 %16 = nvvm.wmma.mma %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15 1617 {eltypeA = "tf32", eltypeB = "f32", k = 8 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32} 1618 : (i32, i32, i32, i32, i32, i32, i32, i32, f32, f32, f32, f32, f32, f32, f32, f32) 1619 -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 1620 ``` 1621 }]; 1622 1623 let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; 1624 let hasVerifier = 1; 1625} 1626 1627def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 1628 Arguments<(ins LLVM_PointerShared:$ptr, 1629 Variadic<I32>:$sources, 1630 MMALayoutAttr:$layout)> { 1631 let summary = "cooperative matrix store"; 1632 let description = [{ 1633 Collectively store one or more matrices across all threads in a warp to the 1634 location indicated by the address operand $ptr in shared memory. 1635 [For more information, see PTX ISA] 1636 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) 1637 }]; 1638 1639 let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; 1640 let extraClassDefinition = [{ 1641 std::string $cppClass::getPtx() { 1642 int d = getSources().size(); 1643 std::string ptx = "stmatrix.sync.aligned"; 1644 ptx += ".x" + std::to_string(d); 1645 if (getLayout() == NVVM::MMALayout::col) 1646 ptx += ".trans"; 1647 if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};"; 1648 if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};"; 1649 if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};"; 1650 return ptx; 1651 } 1652 }]; 1653 let hasVerifier = 1; 1654} 1655 1656def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, 1657 Results<(outs AnyType:$res)>, 1658 Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> { 1659 1660 let summary = "cooperative matrix load"; 1661 1662 string llvmBuilder = [{ 1663 auto operands = moduleTranslation.lookupValues(opInst.getOperands()); 1664 auto intId = getLdMatrixIntrinsicId($layout, $num); 1665 $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); 1666 }]; 1667 1668 string baseDescription = [{ 1669 The `nvvm.ldmatrix` operation collectively loads one or more matrices across 1670 all threads in a warp from the location indicated by the address operand 1671 `ptr` from shared memory. 1672 1673 The attribute `num` indicates how many 8x8 16-bit matrices are to be loaded. 1674 1675 All the threads in the warp must execute the same ldmatrix operations. 1676 1677 Each row of 8 elements needs to be consecutive in memory. Each lane of the 1678 warp contains the start address of a row of 8 elements laid out as below: 1679 1680 ``` 1681 num | lane 0--7 | Threads 8--15 | Threads 16--31 1682 1 | addr0--addr7 | | 1683 2 | addr0--addr7 | addr8--addr15 | 1684 4 | addr0--addr7 | addr8--addr15 | addr16--addr31 1685 ``` 1686 1687 Example: 1688 ```mlir 1689 %l1 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<row>} : 1690 (!llvm.ptr<3>) -> i32 1691 %l2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout<row>} : 1692 (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> 1693 ``` 1694 }]; 1695 1696 let assemblyFormat = "$ptr attr-dict `:` functional-type($ptr, $res)"; 1697 let hasVerifier = 1; 1698} 1699 1700def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> { 1701 1702 let summary = "cooperative matrix-multiply and accumulate"; 1703 1704 let description = [{ 1705 The `nvvm.mma.sync` operation collectively performs the operation 1706 `D = matmul(A, B) + C` using all threads in a warp. 1707 1708 All the threads in the warp must execute the same `mma.sync` operation. 1709 1710 For each possible multiplicand PTX data type, there are one or more possible 1711 instruction shapes given as "mMnNkK". The below table describes the posssibilities 1712 as well as the types required for the operands. Note that the data type for 1713 C (the accumulator) and D (the result) can vary independently when there are 1714 multiple possibilities in the "C/D Type" column. 1715 1716 When an optional attribute cannot be immediately inferred from the types of 1717 the operands and the result during parsing or validation, an error will be 1718 raised. 1719 1720 `b1Op` is only relevant when the binary (b1) type is given to 1721 `multiplicandDataType`. It specifies how the multiply-and-acumulate is 1722 performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`. 1723 1724 `intOverflowBehavior` is only relevant when the `multiplicandType` attribute 1725 is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled 1726 in the accumulator. When the attribute is `satfinite`, the accumulator values 1727 are clamped in the int32 range on overflow. This is the default behavior. 1728 Alternatively, accumulator behavior `wrapped` can also be specified, in 1729 which case overflow wraps from one end of the range to the other. 1730 1731 `layoutA` and `layoutB` are required and should generally be set to 1732 `#nvvm.mma_layout<row>` and `#nvvm.mma_layout<col>` respectively, but other 1733 combinations are possible for certain layouts according to the table below. 1734 1735 ``` 1736 | A/B Type | Shape | ALayout | BLayout | A Type | B Type | C/D Type | 1737 |----------|-----------|---------|---------|----------|----------|-------------------| 1738 | f64 | .m8n8k4 | row | col | 1x f64 | 1x f64 | 2x f64 | 1739 | f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 | 1740 | | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 | 1741 | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 | 1742 | bf16 | .m16n8k8 | row | col | 2x i32 | 1x i32 | 4x f32 | 1743 | | .m16n8k16 | row | col | 4x i32 | 2x i32 | 4x f32 | 1744 | tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 | 1745 | | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 | 1746 | u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 | 1747 | | .m16n8k16 | row | col | 2x i32 | 1x i32 | 4x i32 | 1748 | | .m16n8k32 | row | col | 4x i32 | 2x i32 | 4x i32 | 1749 | u4/s4 | .m8n8k32 | row | col | 1x i32 | 1x i32 | 2x i32 | 1750 | | m16n8k32 | row | col | 2x i32 | 1x i32 | 4x i32 | 1751 | | m16n8k64 | row | col | 4x i32 | 2x i32 | 4x i32 | 1752 | b1 | m8n8k128 | row | col | 1x i32 | 1x i32 | 2x i32 | 1753 | | m16n8k128 | row | col | 2x i32 | 1x i32 | 4x i32 | 1754 ``` 1755 1756 1757 Example: 1758 ```mlir 1759 1760 %128 = nvvm.mma.sync A[%120, %121, %122, %123] 1761 B[%124, %125] 1762 C[%126, %127] 1763 {layoutA = #nvvm.mma_layout<row>, 1764 layoutB = #nvvm.mma_layout<col>, 1765 shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} 1766 : (vector<2xf16>, vector<2xf16>, vector<2xf16>) 1767 -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> 1768 ``` 1769 }]; 1770 1771 let results = (outs LLVM_AnyStruct:$res); 1772 let arguments = (ins NVVM_MMAShapeAttr:$shape, 1773 OptionalAttr<MMAB1OpAttr>:$b1Op, 1774 OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior, 1775 MMALayoutAttr:$layoutA, 1776 MMALayoutAttr:$layoutB, 1777 OptionalAttr<MMATypesAttr>:$multiplicandAPtxType, 1778 OptionalAttr<MMATypesAttr>:$multiplicandBPtxType, 1779 Variadic<LLVM_Type>:$operandA, 1780 Variadic<LLVM_Type>:$operandB, 1781 Variadic<LLVM_Type>:$operandC); 1782 1783 let extraClassDeclaration = !strconcat([{ 1784 static llvm::Intrinsic::ID getIntrinsicID( 1785 int64_t m, int64_t n, uint64_t k, 1786 std::optional<MMAB1Op> b1Op, 1787 std::optional<MMAIntOverflow> sat, 1788 mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum, 1789 mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, 1790 mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) { 1791 llvm::StringRef layoutA = stringifyEnum(layoutAEnum); 1792 llvm::StringRef layoutB = stringifyEnum(layoutBEnum); 1793 llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum); 1794 llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum); 1795 llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum); 1796 llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum); 1797 }], 1798 MMA_SYNC_INTR<>.id, [{ 1799 return 0; 1800 } 1801 1802 static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType, 1803 bool isAccumulator); 1804 1805 MMATypes accumPtxType(); 1806 MMATypes resultPtxType(); 1807 }]); 1808 1809 let builders = [ 1810 OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, 1811 "ValueRange":$operandB, "ValueRange":$operandC, 1812 "ArrayRef<int64_t>":$shape, "std::optional<MMAB1Op>":$b1Op, 1813 "std::optional<MMAIntOverflow>":$intOverflow, 1814 "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes, 1815 "std::optional<std::array<MMALayout, 2>>":$multiplicandLayouts)> 1816 ]; 1817 1818 string llvmBuilder = [{ 1819 auto operands = moduleTranslation.lookupValues(opInst.getOperands()); 1820 auto intId = mlir::NVVM::MmaOp::getIntrinsicID( 1821 $shape.getM(), $shape.getN(), $shape.getK(), 1822 $b1Op, $intOverflowBehavior, 1823 $layoutA, $layoutB, 1824 *$multiplicandAPtxType, 1825 *$multiplicandBPtxType, 1826 op.accumPtxType(), 1827 op.resultPtxType()); 1828 1829 $res = createIntrinsicCall( 1830 builder, intId, operands); 1831 }]; 1832 1833 let hasCustomAssemblyFormat = 1; 1834 let hasVerifier = 1; 1835} 1836 1837//===----------------------------------------------------------------------===// 1838// NVVM TMA Ops 1839//===----------------------------------------------------------------------===// 1840 1841def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">, 1842 Arguments<(ins )> { 1843 let assemblyFormat = "attr-dict"; 1844 let description = [{ 1845 This Op commits all prior initiated but uncommitted cp.async.bulk 1846 instructions into a cp.async.bulk-group. 1847 1848 [For more information, see PTX ISA] 1849 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group) 1850 }]; 1851 1852 string llvmBuilder = [{ 1853 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group); 1854 }]; 1855} 1856 1857def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">, 1858 Arguments<(ins 1859 ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group, 1860 OptionalAttr<UnitAttr>:$read)> { 1861 let assemblyFormat = "$group attr-dict"; 1862 let description = [{ 1863 Op waits for completion of the most recent bulk async-groups. 1864 1865 The `$group` operand tells waiting has to be done until for $group or fewer 1866 of the most recent bulk async-groups. If `$group` is 0, the op wait until 1867 all the most recent bulk async-groups have completed. 1868 1869 The `$read` indicates that the waiting has to be done until all the bulk 1870 async operations in the specified bulk async-group have completed reading 1871 from their source locations. 1872 1873 [For more information, see PTX ISA] 1874 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group) 1875 }]; 1876 1877 string llvmBuilder = [{ 1878 auto intId = op.getRead() ? 1879 llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read : 1880 llvm::Intrinsic::nvvm_cp_async_bulk_wait_group; 1881 createIntrinsicCall(builder, intId, builder.getInt32($group)); 1882 }]; 1883} 1884 1885def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 1886 NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 1887 [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 1888 AttrSizedOperandSegments]>, 1889 Arguments<(ins LLVM_PointerShared:$dstMem, 1890 LLVM_AnyPointer:$tmaDescriptor, 1891 Variadic<I32>:$coordinates, 1892 LLVM_PointerShared:$mbar, 1893 Variadic<I16>:$im2colOffsets, 1894 Optional<I16>:$multicastMask, 1895 Optional<I64>:$l2CacheHint, 1896 PtxPredicate:$predicate)> { 1897 let description = [{ 1898 Initiates an asynchronous copy operation on the tensor data from global 1899 memory to shared memory. 1900 1901 The Op operates has two load modes: 1902 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor 1903 layout is preserved at the destination. 1904 1905 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present. 1906 the elements in the Bounding Box of the source tensor are rearranged into 1907 columns at the destination. In this mode, the tensor has to be at least 1908 3-dimensional. 1909 1910 The `multicastMask` operand is optional. When it is present, the Op copies 1911 data from global memory to shared memory of multiple CTAs in the cluster. 1912 Operand `multicastMask` specifies the destination CTAs in the cluster such 1913 that each bit position in the 16-bit `multicastMask` operand corresponds to 1914 the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. 1915 1916 The `l2CacheHint` operand is optional, and it is used to specify cache 1917 eviction policy that may be used during the memory access. 1918 1919 [For more information, see PTX ISA] 1920 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) 1921 }]; 1922 1923 let assemblyFormat = [{ 1924 $dstMem `,` 1925 $tmaDescriptor `,` 1926 $mbar `,` 1927 `box` `[`$coordinates `]` 1928 (`im2col` `[` $im2colOffsets^ `]` )? 1929 (`multicast_mask` `=` $multicastMask^ )? 1930 (`l2_cache_hint` `=` $l2CacheHint^ )? 1931 (`predicate` `=` $predicate^)? 1932 attr-dict `:` type($dstMem) `,` type($tmaDescriptor) 1933 }]; 1934 1935 let extraClassDefinition = [{ 1936 std::string $cppClass::getPtx() { 1937 int im2colDim = getIm2colOffsets().size(); 1938 int dim = getCoordinates().size(); 1939 std::string ptx = "cp.async.bulk.tensor."; 1940 ptx += std::to_string(dim) + "d."; 1941 ptx += "shared::cluster.global.mbarrier::complete_tx::bytes"; 1942 if(im2colDim) ptx += ".im2col"; 1943 if(getMulticastMask()) ptx += ".multicast::cluster"; 1944 if(getL2CacheHint()) ptx += ".L2::cache_hint"; 1945 1946 auto preg = [](int r) { return "%" + std::to_string(r); }; 1947 1948 // Build Registers 1949 ptx += " [%0], [%1, {"; 1950 int r = 2; 1951 for(int i = 0; i < dim; i++) ptx += preg(r+i) + ","; 1952 ptx.pop_back(); r += dim; 1953 ptx += "} ], [%" + std::to_string(r++) + "]"; 1954 if(im2colDim) { 1955 ptx += ",{"; 1956 for(int i = 0; i < im2colDim; i++) ptx += preg(r+i) + ","; 1957 ptx.pop_back(); r += im2colDim; 1958 ptx += "}"; 1959 } 1960 if(getMulticastMask()) ptx += ", " + preg(r++); 1961 if(getL2CacheHint()) ptx += ", " + preg(r++); 1962 ptx += ";"; 1963 return ptx; 1964 } 1965 }]; 1966 let hasVerifier = 1; 1967} 1968 1969def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : 1970 NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", 1971 [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 1972 AttrSizedOperandSegments]>, 1973 Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, 1974 LLVM_PointerShared:$srcMem, 1975 Variadic<I32>:$coordinates, 1976 PtxPredicate:$predicate)> { 1977 let assemblyFormat = [{ 1978 $tmaDescriptor `,` 1979 $srcMem `,` 1980 `box` `[`$coordinates `]` 1981 (`,` `predicate` `=` $predicate^)? 1982 attr-dict `:` type(operands) 1983 }]; 1984 let extraClassDefinition = [{ 1985 std::string $cppClass::getPtx() { 1986 int dim = getCoordinates().size(); 1987 std::string ptx = "cp.async.bulk.tensor."; 1988 ptx += std::to_string(dim) + "d."; 1989 ptx += "global.shared::cta.bulk_group"; 1990 if(dim == 1) ptx += " [%0, {%2} ], [%1];"; 1991 if(dim == 2) ptx += " [%0, {%2, %3} ], [%1];"; 1992 if(dim == 3) ptx += " [%0, {%2, %3, %4} ], [%1];"; 1993 if(dim == 4) ptx += " [%0, {%2, %3, %4, %5} ], [%1];"; 1994 if(dim == 5) ptx += " [%0, {%2, %3, %4, %5, %6} ], [%1];"; 1995 return ptx; 1996 } 1997 }]; 1998 let hasVerifier = 1; 1999} 2000 2001def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap", 2002 [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>, 2003 Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> { 2004 let assemblyFormat = "$tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; 2005 let extraClassDefinition = [{ 2006 std::string $cppClass::getPtx() { 2007 return std::string("prefetch.tensormap [%0];"); 2008 } 2009 }]; 2010} 2011 2012def NVVM_CpAsyncBulkTensorPrefetchOp : 2013 NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> { 2014 let arguments = (ins 2015 LLVM_AnyPointer:$tmaDescriptor, 2016 Variadic<I32>:$coordinates, 2017 Variadic<I16>:$im2colOffsets, 2018 Optional<I64>:$l2CacheHint); 2019 2020 let description = [{ 2021 Initiates an asynchronous prefetch operation on the tensor data from global 2022 memory to L2 cache. 2023 2024 The Op has two modes: 2025 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor 2026 layout is preserved at the destination. 2027 2028 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present. 2029 the elements in the Bounding Box of the source tensor are rearranged into 2030 columns at the destination. In this mode, the tensor has to be at least 2031 3-dimensional. 2032 2033 The `l2CacheHint` operand is optional, and it is used to specify cache 2034 eviction policy that may be used during the memory access. 2035 2036 [For more information, see PTX ISA] 2037 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor) 2038 }]; 2039 2040 let assemblyFormat = [{ 2041 $tmaDescriptor `,` 2042 `box` `[`$coordinates `]` 2043 (`im2col` `[` $im2colOffsets^ `]` )? 2044 (`l2_cache_hint` `=` $l2CacheHint^ )? 2045 attr-dict `:` type($tmaDescriptor) 2046 }]; 2047 2048 let extraClassDeclaration = [{ 2049 static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col); 2050 }]; 2051 2052 let hasVerifier = 1; 2053 2054 string llvmBuilder = [{ 2055 // Arguments to the intrinsic: 2056 // tmaDesc, tensorDims, im2colOffsets 2057 // cache_hint(if applicable) and flag(boolean) 2058 llvm::SmallVector<llvm::Value *> translatedOperands; 2059 translatedOperands.push_back($tmaDescriptor); 2060 2061 for (auto v : op.getCoordinates()) 2062 translatedOperands.push_back(moduleTranslation.lookupValue(v)); 2063 2064 for (auto v : op.getIm2colOffsets()) 2065 translatedOperands.push_back(moduleTranslation.lookupValue(v)); 2066 2067 llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); 2068 auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); 2069 2070 bool isCacheHint = op.getL2CacheHint() ? true : false; 2071 translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused); 2072 translatedOperands.push_back(builder.getInt1(isCacheHint)); 2073 2074 auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID( 2075 op.getCoordinates().size(), op.getIm2colOffsets().size() > 0); 2076 createIntrinsicCall(builder, intId, translatedOperands); 2077 }]; 2078} 2079 2080// List of modes supported for TMA Store and Reduction Ops 2081def TMAStoreModeTile : I32EnumAttrCase<"TILE", 0, "tile">; 2082def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">; 2083 2084def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode", 2085 [TMAStoreModeTile, TMAStoreModeIm2Col]> { 2086 let genSpecializedAttr = 0; 2087 let cppNamespace = "::mlir::NVVM"; 2088} 2089def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> { 2090 let assemblyFormat = "`<` $value `>`"; 2091} 2092 2093// List of Reduction Ops supported with TMA Store 2094def TMAReduxKindAdd : I32EnumAttrCase<"ADD", 0, "add">; 2095def TMAReduxKindMin : I32EnumAttrCase<"MIN", 1, "min">; 2096def TMAReduxKindMax : I32EnumAttrCase<"MAX", 2, "max">; 2097def TMAReduxKindInc : I32EnumAttrCase<"INC", 3, "inc">; 2098def TMAReduxKindDec : I32EnumAttrCase<"DEC", 4, "dec">; 2099def TMAReduxKindAnd : I32EnumAttrCase<"AND", 5, "and">; 2100def TMAReduxKindOr : I32EnumAttrCase<"OR", 6, "or">; 2101def TMAReduxKindXor : I32EnumAttrCase<"XOR", 7, "xor">; 2102 2103def TMAReduxKind : I32EnumAttr<"TMAReduxKind", "NVVM TMA redux kind", 2104 [TMAReduxKindAdd, TMAReduxKindMax, TMAReduxKindMin, 2105 TMAReduxKindInc, TMAReduxKindDec, TMAReduxKindAnd, 2106 TMAReduxKindOr, TMAReduxKindXor]> { 2107 let genSpecializedAttr = 0; 2108 let cppNamespace = "::mlir::NVVM"; 2109} 2110def TMAReduxKindAttr : EnumAttr<NVVM_Dialect, TMAReduxKind, "tma_redux_kind"> { 2111 let assemblyFormat = "`<` $value `>`"; 2112} 2113 2114def NVVM_CpAsyncBulkTensorReduceOp : 2115 NVVM_Op<"cp.async.bulk.tensor.reduce", [AttrSizedOperandSegments]> { 2116 let arguments = (ins 2117 LLVM_AnyPointer:$tmaDescriptor, 2118 LLVM_PointerShared:$srcMem, 2119 TMAReduxKindAttr:$redKind, 2120 DefaultValuedAttr<TMAStoreModeAttr, "TMAStoreMode::TILE">:$mode, 2121 Variadic<I32>:$coordinates, 2122 Optional<I64>:$l2CacheHint); 2123 2124 let description = [{ 2125 Initiates an asynchronous reduction operation of tensor data in 2126 global memory with tensor data in shared memory. 2127 2128 The `mode` attribute indicates whether the copy mode is tile or im2col. 2129 The `redOp` attribute specifies the reduction operations applied. 2130 The supported reduction operations are: 2131 {add, min, max, inc, dec, and, or, xor} 2132 2133 The `l2CacheHint` operand is optional, and it is used to specify cache 2134 eviction policy that may be used during the memory access. 2135 2136 [For more information, see PTX ISA] 2137 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor) 2138 }]; 2139 2140 let assemblyFormat = [{ 2141 $tmaDescriptor `,` 2142 $srcMem `,` 2143 `box` `[`$coordinates `]` 2144 (`l2_cache_hint` `=` $l2CacheHint^ )? 2145 attr-dict `:` type($tmaDescriptor) `,` type($srcMem) 2146 }]; 2147 2148 let extraClassDeclaration = [{ 2149 static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, 2150 NVVM::TMAReduxKind kind, 2151 bool isIm2Col); 2152 }]; 2153 2154 let hasVerifier = 1; 2155 2156 string llvmBuilder = [{ 2157 // Arguments to the intrinsic: 2158 // shared_mem_ptr, tmaDesc, tensorDims 2159 // cache_hint(if applicable) and flag(boolean) 2160 llvm::SmallVector<llvm::Value *> translatedOperands; 2161 translatedOperands.push_back($srcMem); 2162 translatedOperands.push_back($tmaDescriptor); 2163 2164 for (auto v : op.getCoordinates()) 2165 translatedOperands.push_back(moduleTranslation.lookupValue(v)); 2166 2167 llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); 2168 auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64)); 2169 2170 bool isCacheHint = op.getL2CacheHint() ? true : false; 2171 translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef); 2172 translatedOperands.push_back(builder.getInt1(isCacheHint)); 2173 2174 auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID( 2175 op.getCoordinates().size(), $redKind, 2176 (op.getMode() == NVVM::TMAStoreMode::IM2COL)); 2177 createIntrinsicCall(builder, intId, translatedOperands); 2178 }]; 2179} 2180 2181def NVVM_CpAsyncBulkGlobalToSharedClusterOp : 2182 NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> { 2183 let summary = "Async bulk copy from global memory to Shared cluster memory"; 2184 let description = [{ 2185 Initiates an asynchronous copy operation from global memory to cluster's 2186 shared memory. 2187 2188 The `multicastMask` operand is optional. When it is present, the Op copies 2189 data from global memory to shared memory of multiple CTAs in the cluster. 2190 Operand `multicastMask` specifies the destination CTAs in the cluster such 2191 that each bit position in the 16-bit `multicastMask` operand corresponds to 2192 the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. 2193 2194 The `l2CacheHint` operand is optional, and it is used to specify cache 2195 eviction policy that may be used during the memory access. 2196 [For more information, see PTX ISA] 2197 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk) 2198 }]; 2199 2200 let arguments = (ins 2201 LLVM_PointerShared:$dstMem, 2202 LLVM_PointerGlobal:$srcMem, 2203 LLVM_PointerShared:$mbar, 2204 I32:$size, 2205 Optional<I16>:$multicastMask, 2206 Optional<I64>:$l2CacheHint); 2207 2208 let assemblyFormat = [{ 2209 $dstMem `,` $srcMem `,` $mbar `,` $size 2210 (`multicast_mask` `=` $multicastMask^ )? 2211 (`l2_cache_hint` `=` $l2CacheHint^ )? 2212 attr-dict `:` type($dstMem) `,` type($srcMem) 2213 }]; 2214 2215 string llvmBuilder = [{ 2216 // Arguments to the intrinsic: 2217 // dst, mbar, src, size 2218 // multicast_mask, cache_hint, 2219 // flag for multicast_mask, 2220 // flag for cache_hint 2221 llvm::SmallVector<llvm::Value *> translatedOperands; 2222 translatedOperands.push_back($dstMem); 2223 translatedOperands.push_back($mbar); 2224 translatedOperands.push_back($srcMem); 2225 translatedOperands.push_back($size); 2226 2227 // Multicast, if available 2228 llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); 2229 auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0); 2230 bool isMulticast = op.getMulticastMask() ? true : false; 2231 translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused); 2232 2233 // Cachehint, if available 2234 auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); 2235 bool isCacheHint = op.getL2CacheHint() ? true : false; 2236 translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused); 2237 2238 // Flag arguments for multicast and cachehint 2239 translatedOperands.push_back(builder.getInt1(isMulticast)); 2240 translatedOperands.push_back(builder.getInt1(isCacheHint)); 2241 2242 createIntrinsicCall(builder, 2243 llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands); 2244 }]; 2245} 2246 2247def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp : 2248 NVVM_Op<"cp.async.bulk.shared.cluster.shared.cta"> { 2249 let summary = "Async bulk copy from Shared CTA memory to Shared cluster memory"; 2250 let description = [{ 2251 Initiates an asynchronous copy operation from Shared CTA memory to Shared 2252 cluster memory. 2253 2254 [For more information, see PTX ISA] 2255 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk) 2256 }]; 2257 2258 let arguments = (ins 2259 LLVM_PointerShared:$dstMem, 2260 LLVM_PointerShared:$srcMem, 2261 LLVM_PointerShared:$mbar, 2262 I32:$size); 2263 2264 let assemblyFormat = [{ 2265 $dstMem `,` $srcMem `,` $mbar `,` $size 2266 attr-dict `:` type($dstMem) `,` type($srcMem) 2267 }]; 2268 2269 string llvmBuilder = [{ 2270 createIntrinsicCall(builder, 2271 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster, 2272 {$dstMem, $mbar, $srcMem, $size}); 2273 }]; 2274} 2275 2276def NVVM_CpAsyncBulkSharedCTAToGlobalOp : 2277 NVVM_Op<"cp.async.bulk.global.shared.cta"> { 2278 let summary = "Async bulk copy from Shared CTA memory to Global memory"; 2279 let description = [{ 2280 Initiates an asynchronous copy operation from Shared CTA memory to 2281 global memory. 2282 2283 The `l2CacheHint` operand is optional, and it is used to specify cache 2284 eviction policy that may be used during the memory access. 2285 [For more information, see PTX ISA] 2286 (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk) 2287 }]; 2288 2289 let arguments = (ins 2290 LLVM_PointerGlobal:$dstMem, 2291 LLVM_PointerShared:$srcMem, 2292 I32:$size, 2293 Optional<I64>:$l2CacheHint); 2294 2295 let assemblyFormat = [{ 2296 $dstMem `,` $srcMem `,` $size 2297 (`l2_cache_hint` `=` $l2CacheHint^ )? 2298 attr-dict `:` type($dstMem) `,` type($srcMem) 2299 }]; 2300 2301 string llvmBuilder = [{ 2302 // Arguments to the intrinsic: 2303 // dst, src, size, cache_hint, 2304 // Flag for cache_hint 2305 // 2306 llvm::SmallVector<llvm::Value *> translatedOperands; 2307 translatedOperands.push_back($dstMem); 2308 translatedOperands.push_back($srcMem); 2309 translatedOperands.push_back($size); 2310 2311 // Cachehint, if available 2312 llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); 2313 auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); 2314 bool isCacheHint = op.getL2CacheHint() ? true : false; 2315 translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused); 2316 2317 // Flag argument for cachehint 2318 translatedOperands.push_back(builder.getInt1(isCacheHint)); 2319 2320 createIntrinsicCall(builder, 2321 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands); 2322 }]; 2323} 2324 2325//===----------------------------------------------------------------------===// 2326// NVVM Wgmma Ops 2327//===----------------------------------------------------------------------===// 2328 2329def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> { 2330 let arguments = (ins); 2331 let description = [{ 2332 Enforce an ordering of register accesses between warpgroup level matrix 2333 multiplication and other operations. 2334 2335 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence) 2336 }]; 2337 let assemblyFormat = "attr-dict"; 2338 string llvmBuilder = [{ 2339 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_fence_sync_aligned); 2340 }]; 2341} 2342 2343def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">, 2344 Arguments<(ins )> { 2345 let assemblyFormat = "attr-dict"; 2346 let description = [{ 2347 Commits all prior uncommitted warpgroup level matrix multiplication operations. 2348 2349 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group) 2350 }]; 2351 string llvmBuilder = [{ 2352 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_commit_group_sync_aligned); 2353 }]; 2354} 2355 2356def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{ 2357 let arguments = (ins I64Attr:$group); 2358 let assemblyFormat = "attr-dict $group"; 2359 let description = [{ 2360 Signal the completion of a preceding warpgroup operation. 2361 2362 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group) 2363 }]; 2364 string llvmBuilder = [{ 2365 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_wait_group_sync_aligned, builder.getInt64($group)); 2366 }]; 2367} 2368 2369/// Enum attribute type for the negating of input operands 2370def WGMMAScaleInNeg : I32EnumAttrCase<"neg", -1>; 2371def WGMMAScaleInOne : I32EnumAttrCase<"one", 1>; 2372def WGMMAScaleIn : I32EnumAttr<"WGMMAScaleIn", "WGMMA overflow options", 2373 [WGMMAScaleInOne, WGMMAScaleInNeg]> { 2374 let genSpecializedAttr = 0; 2375 let cppNamespace = "::mlir::NVVM"; 2376} 2377def WGMMAScaleInAttr : EnumAttr<NVVM_Dialect, WGMMAScaleIn, "wgmma_scale_in"> { 2378 let assemblyFormat = "`<` $value `>`"; 2379} 2380 2381/// Enum attribute type for the output operand 2382def WGMMAScaleOutZero : I32EnumAttrCase<"zero", 0>; 2383def WGMMAScaleOutOne : I32EnumAttrCase<"one", 1>; 2384def WGMMAScaleOut : I32EnumAttr<"WGMMAScaleOut", "WGMMA input predicate", 2385 [WGMMAScaleOutZero, WGMMAScaleOutOne]> { 2386 let genSpecializedAttr = 0; 2387 let cppNamespace = "::mlir::NVVM"; 2388} 2389def WGMMAScaleOutAttr : EnumAttr<NVVM_Dialect, WGMMAScaleOut, "wgmma_scale_out"> { 2390 let assemblyFormat = "`<` $value `>`"; 2391} 2392 2393/// Enum attribute of the different PTX element types used for WGMMA operands. 2394def WGMMATypeF16 : I32EnumAttrCase<"f16", 0>; 2395def WGMMATypeTF32 : I32EnumAttrCase<"tf32", 1>; 2396def WGMMATypeU8 : I32EnumAttrCase<"u8", 2>; 2397def WGMMATypeS8 : I32EnumAttrCase<"s8", 3>; 2398def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>; 2399def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>; 2400def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>; 2401def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>; 2402def WGMMATypeF32 : I32EnumAttrCase<"f32", 8>; 2403def WGMMATypeS32 : I32EnumAttrCase<"s32", 9>; 2404 2405def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types", 2406 [WGMMATypeF16, WGMMATypeTF32, 2407 WGMMATypeU8, WGMMATypeS8, 2408 WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3, 2409 WGMMATypeF8E5M2, WGMMATypeF32, WGMMATypeS32]> { 2410 let genSpecializedAttr = 0; 2411 let cppNamespace = "::mlir::NVVM"; 2412} 2413def WGMMATypesAttr : EnumAttr<NVVM_Dialect, WGMMATypes, "wgmma_type"> { 2414 let assemblyFormat = "`<` $value `>`"; 2415} 2416 2417 2418def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", 2419 [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 2420 PredOpTrait<"input struct and result struct must be the same type", 2421 TCresIsSameAsOpBase<0, 0>>,]> 2422{ 2423 let results = (outs LLVM_AnyStruct:$results); 2424 let arguments = (ins 2425 LLVM_AnyStruct:$inouts, 2426 I64:$descriptorA, 2427 I64:$descriptorB, 2428 NVVM_MMAShapeAttr:$shape, 2429 WGMMATypesAttr:$typeA, 2430 WGMMATypesAttr:$typeB, 2431 WGMMATypesAttr:$typeD, 2432 WGMMAScaleOutAttr:$scaleD, 2433 WGMMAScaleInAttr:$scaleA, 2434 WGMMAScaleInAttr:$scaleB, 2435 MMALayoutAttr:$layoutA, 2436 MMALayoutAttr:$layoutB, 2437 OptionalAttr<MMAIntOverflowAttr>:$satfinite 2438 ); 2439 2440 let assemblyFormat = [{ 2441 $descriptorA `,` $descriptorB `,` $inouts `,` $shape `,` 2442 `D` `[` $typeD `,` $scaleD (`,` $satfinite^)? `]` `,` 2443 `A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,` 2444 `B` `[` $typeB `,` $scaleB `,` $layoutB `]` 2445 attr-dict `:` 2446 type($inouts) `->` type($results) 2447 }]; 2448 2449 let description = [{ 2450 The warpgroup (128 threads) level matrix multiply and accumulate operation 2451 has either of the following forms, where matrix D is called accumulator: 2452 D = A * B + D 2453 D = A * B, where the input from accumulator D is disabled. 2454 2455 Supported shapes: 2456 ``` 2457 |--------------|--------------|------------|--------------|---------------| 2458 | | | | |f16+=e4m3*e4m3 | 2459 | | | | |f16+=e5m2*e5m2 | 2460 |f32+=tf32*tf32|f16+=f16 *f16 | s32+=s8*s8 |s32 += b1 * b1|f16+=e5m2*e4m3 | 2461 | |f32+=f16 *f16 | s32+=u8*u8 | |f16+=e4m3*e5m2 | 2462 | |f32+=bf16*bf16| s32+=u8*u8 | |f16+=e4m3*e5m2 | 2463 | |f32+=bf16*bf16| s32+=s8*u8 | |f32+=e4m3*e4m3 | 2464 | | | s32+=u8*s8 | |f32+=e5m2*e5m2 | 2465 | | | | |f32+=e4m3*e5m2 | 2466 | | | | |f32+=e4m3*e5m2 | 2467 |--------------|--------------|------------|--------------|---------------| 2468 | .m64n8k8 | .m64n8k16 | .m64n8k32 | .m64n8k256 | .m64n8k32 | 2469 | .m64n16k8 | .m64n16k16 | .m64n16k32 | .m64n16k256 | .m64n16k32 | 2470 | .m64n24k8 | .m64n24k16 | .m64n24k32 | .m64n24k256 | .m64n24k32 | 2471 | .m64n32k8 | .m64n32k16 | .m64n32k32 | .m64n32k256 | .m64n32k32 | 2472 | .m64n40k8 | .m64n40k16 | .m64n48k32 | .m64n48k256 | .m64n40k32 | 2473 | .m64n48k8 | .m64n48k16 | .m64n64k32 | .m64n64k256 | .m64n48k32 | 2474 | .m64n56k8 | .m64n56k16 | .m64n80k32 | .m64n80k256 | .m64n56k32 | 2475 | .m64n64k8 | .m64n64k16 | .m64n96k32 | .m64n96k256 | .m64n64k32 | 2476 | .m64n72k8 | .m64n72k16 | .m64n112k32| .m64n112k256 | .m64n72k32 | 2477 | .m64n80k8 | .m64n80k16 | .m64n128k32| .m64n128k256 | .m64n80k32 | 2478 | .m64n88k8 | .m64n88k16 | .m64n144k32| .m64n144k256 | .m64n88k32 | 2479 | .m64n96k8 | .m64n96k16 | .m64n160k32| .m64n160k256 | .m64n96k32 | 2480 | .m64n104k8 | .m64n104k16 | .m64n176k32| .m64n176k256 | .m64n104k32 | 2481 | .m64n112k8 | .m64n112k16 | .m64n192k32| .m64n192k256 | .m64n112k32 | 2482 | .m64n120k8 | .m64n120k16 | .m64n208k32| .m64n208k256 | .m64n120k32 | 2483 | .m64n128k8 | .m64n128k16 | .m64n224k32| .m64n224k256 | .m64n128k32 | 2484 | .m64n136k8 | .m64n136k16 | .m64n240k32| .m64n240k256 | .m64n136k32 | 2485 | .m64n144k8 | .m64n144k16 | .m64n256k32| .m64n256k256 | .m64n144k32 | 2486 | .m64n152k8 | .m64n152k16 | | | .m64n152k32 | 2487 | .m64n160k8 | .m64n160k16 | | | .m64n160k32 | 2488 | .m64n168k8 | .m64n168k16 | | | .m64n168k32 | 2489 | .m64n176k8 | .m64n176k16 | | | .m64n176k32 | 2490 | .m64n184k8 | .m64n184k16 | | | .m64n184k32 | 2491 | .m64n192k8 | .m64n192k16 | | | .m64n192k32 | 2492 | .m64n200k8 | .m64n200k16 | | | .m64n200k32 | 2493 | .m64n208k8 | .m64n208k16 | | | .m64n208k32 | 2494 | .m64n216k8 | .m64n216k16 | | | .m64n216k32 | 2495 | .m64n224k8 | .m64n224k16 | | | .m64n224k32 | 2496 | .m64n232k8 | .m64n232k16 | | | .m64n232k32 | 2497 | .m64n240k8 | .m64n240k16 | | | .m64n240k32 | 2498 | .m64n248k8 | .m64n248k16 | | | .m64n248k32 | 2499 | .m64n256k8 | .m64n256k16 | | | .m64n256k32 | 2500 |--------------|--------------|------------|--------------|---------------| 2501 ``` 2502 2503 2504 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) 2505 }]; 2506 2507 let hasVerifier = 1; 2508 2509 let extraClassDeclaration = [{ 2510 void getAsmValues(RewriterBase &rewriter, 2511 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues); 2512 }]; 2513} 2514 2515//===----------------------------------------------------------------------===// 2516// NVVM Griddepcontrol Ops 2517//===----------------------------------------------------------------------===// 2518 2519def NVVM_GriddepcontrolWaitOp : NVVM_IntrOp<"griddepcontrol.wait", [], 0> { 2520 let assemblyFormat = "attr-dict"; 2521 2522 let description = [{ 2523 Causes the executing thread to wait until all prerequisite grids in flight 2524 have completed and all the memory operations from the prerequisite grids 2525 are performed and made visible to the current grid. 2526 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol) 2527 }]; 2528} 2529 2530def NVVM_GriddepcontrolLaunchDependentsOp 2531 : NVVM_IntrOp<"griddepcontrol.launch.dependents", [], 0> { 2532 let assemblyFormat = "attr-dict"; 2533 2534 let description = [{ 2535 Signals that specific dependents the runtime system designated to react to 2536 this instruction can be scheduled as soon as all other CTAs in the grid 2537 issue the same instruction or have completed. 2538 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol) 2539 }]; 2540} 2541 2542def NVVM_Exit : NVVM_Op<"exit"> { 2543 let summary = "Exit Op"; 2544 let description = [{ 2545 Ends execution of a thread. 2546 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-exit) 2547 }]; 2548 string llvmBuilder = [{ 2549 createIntrinsicCall(builder, llvm::Intrinsic::nvvm_exit); 2550 }]; 2551 2552 let assemblyFormat = "attr-dict"; 2553} 2554 2555 2556//===----------------------------------------------------------------------===// 2557// NVVM breakpoint Op 2558//===----------------------------------------------------------------------===// 2559 2560def NVVM_Breakpoint : NVVM_Op<"breakpoint"> { 2561 let summary = "Breakpoint Op"; 2562 let description = [{ 2563 Breakpoint suspends execution of the program for debugging. 2564 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-brkpt) 2565 }]; 2566 string llvmBuilder = [{ 2567 createIntrinsicCall(builder, llvm::Intrinsic::debugtrap); 2568 }]; 2569 2570 let assemblyFormat = "attr-dict"; 2571} 2572 2573//===----------------------------------------------------------------------===// 2574// NVVM target attribute. 2575//===----------------------------------------------------------------------===// 2576 2577def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> { 2578 let description = [{ 2579 GPU target attribute for controlling compilation of NVIDIA targets. All 2580 parameters decay into default values if not present. 2581 2582 Examples: 2583 2584 1. Target with default values. 2585 ``` 2586 gpu.module @mymodule [#nvvm.target] attributes {...} { 2587 ... 2588 } 2589 ``` 2590 2591 2. Target with `sm_90` chip and fast math. 2592 ``` 2593 gpu.module @mymodule [#nvvm.target<chip = "sm_90", flags = {fast}>] { 2594 ... 2595 } 2596 ``` 2597 }]; 2598 let parameters = (ins 2599 DefaultValuedParameter<"int", "2", "Optimization level to apply.">:$O, 2600 StringRefParameter<"Target triple.", "\"nvptx64-nvidia-cuda\"">:$triple, 2601 StringRefParameter<"Target chip.", "\"sm_50\"">:$chip, 2602 StringRefParameter<"Target chip features.", "\"+ptx60\"">:$features, 2603 OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags, 2604 OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link 2605 ); 2606 let assemblyFormat = [{ 2607 (`<` struct($O, $triple, $chip, $features, $flags, $link)^ `>`)? 2608 }]; 2609 let builders = [ 2610 AttrBuilder<(ins CArg<"int", "2">:$optLevel, 2611 CArg<"StringRef", "\"nvptx64-nvidia-cuda\"">:$triple, 2612 CArg<"StringRef", "\"sm_50\"">:$chip, 2613 CArg<"StringRef", "\"+ptx60\"">:$features, 2614 CArg<"DictionaryAttr", "nullptr">:$targetFlags, 2615 CArg<"ArrayAttr", "nullptr">:$linkFiles), [{ 2616 return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles); 2617 }]> 2618 ]; 2619 let skipDefaultBuilders = 1; 2620 let genVerifyDecl = 1; 2621 let extraClassDeclaration = [{ 2622 bool hasFlag(StringRef flag) const; 2623 bool hasFastMath() const; 2624 bool hasFtz() const; 2625 }]; 2626 let extraClassDefinition = [{ 2627 bool $cppClass::hasFlag(StringRef flag) const { 2628 if (DictionaryAttr flags = getFlags()) 2629 return flags.get(flag) != nullptr; 2630 return false; 2631 } 2632 bool $cppClass::hasFastMath() const { 2633 return hasFlag("fast"); 2634 } 2635 bool $cppClass::hasFtz() const { 2636 return hasFlag("ftz"); 2637 } 2638 }]; 2639} 2640 2641#endif // NVVMIR_OPS 2642