xref: /llvm-project/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (revision d4159e2a1d1d640077b2e5cde66b0a284049955f)
1//===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the NVVM IR operation definition file.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef NVVMIR_OPS
14#define NVVMIR_OPS
15
16include "mlir/IR/EnumAttr.td"
17include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
18include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19include "mlir/Interfaces/SideEffectInterfaces.td"
20include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
21include "mlir/Interfaces/InferIntRangeInterface.td"
22
23def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
24def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
25def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
26
27//===----------------------------------------------------------------------===//
28// NVVM dialect definitions
29//===----------------------------------------------------------------------===//
30
31def NVVM_Dialect : Dialect {
32  let name = "nvvm";
33  let cppNamespace = "::mlir::NVVM";
34  let dependentDialects = ["LLVM::LLVMDialect"];
35  let hasOperationAttrVerify = 1;
36
37  let extraClassDeclaration = [{
38    /// Get the name of the attribute used to annotate external kernel
39    /// functions.
40    static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; }
41    /// Get the name of the attribute used to annotate max threads required
42    /// per CTA for kernel functions.
43    static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; }
44    /// Get the name of the metadata names for each dimension
45    static StringRef getMaxntidXName() { return "maxntidx"; }
46    static StringRef getMaxntidYName() { return "maxntidy"; }
47    static StringRef getMaxntidZName() { return "maxntidz"; }
48
49    /// Get the name of the attribute used to annotate exact threads required
50    /// per CTA for kernel functions.
51    static StringRef getReqntidAttrName() { return "nvvm.reqntid"; }
52    /// Get the name of the metadata names for each dimension
53    static StringRef getReqntidXName() { return "reqntidx"; }
54    static StringRef getReqntidYName() { return "reqntidy"; }
55    static StringRef getReqntidZName() { return "reqntidz"; }
56
57    /// Get the name of the attribute used to annotate exact CTAs required
58    /// per cluster for kernel functions.
59    static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; }
60    /// Get the name of the metadata names for each dimension
61    static StringRef getClusterDimXName() { return "cluster_dim_x"; }
62    static StringRef getClusterDimYName() { return "cluster_dim_y"; }
63    static StringRef getClusterDimZName() { return "cluster_dim_z"; }
64
65    /// Get the name of the attribute used to annotate maximum number of
66    /// CTAs per cluster for kernel functions.
67    static StringRef getClusterMaxBlocksAttrName() {  return "nvvm.cluster_max_blocks"; }
68
69    /// Get the name of the attribute used to annotate min CTA required
70    /// per SM for kernel functions.
71    static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }
72
73    /// Get the name of the attribute used to annotate max number of
74    /// registers that can be allocated per thread.
75    static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
76
77    /// Get the name of the attribute used to annotate kernel arguments that
78    /// are grid constants.
79    static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
80
81    /// Verify an attribute from this dialect on the argument at 'argIndex' for
82    /// the region at 'regionIndex' on the given operation. Returns failure if
83    /// the verification failed, success otherwise. This hook may optionally be
84    /// invoked from any operation containing a region.
85    LogicalResult verifyRegionArgAttribute(Operation *op,
86                                           unsigned regionIndex,
87                                           unsigned argIndex,
88                                           NamedAttribute argAttr) override;
89  }];
90
91  let useDefaultAttributePrinterParser = 1;
92}
93
94//===----------------------------------------------------------------------===//
95// NVVM op definitions
96//===----------------------------------------------------------------------===//
97
98class NVVM_Op<string mnemonic, list<Trait> traits = []> :
99  LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
100}
101
102/// Base class that defines BasicPtxBuilderOpInterface.
103class NVVM_PTXBuilder_Op<string mnemonic,
104  list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
105  LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
106}
107
108//===----------------------------------------------------------------------===//
109// NVVM attribute definitions
110//===----------------------------------------------------------------------===//
111
112class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
113    : AttrDef<NVVM_Dialect, attrName, traits> {
114  let mnemonic = attrMnemonic;
115}
116
117//===----------------------------------------------------------------------===//
118// NVVM intrinsic operations
119//===----------------------------------------------------------------------===//
120
121class NVVM_IntrOp<string mnem, list<Trait> traits = [],
122                  int numResults = 0>
123  : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem),
124                    /*list<int> overloadedResults=*/[],
125                    /*list<int> overloadedOperands=*/[],
126                    traits, numResults>;
127
128//===----------------------------------------------------------------------===//
129// NVVM special register op definitions
130//===----------------------------------------------------------------------===//
131
132class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
133  NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
134  let arguments = (ins);
135  let assemblyFormat = "attr-dict `:` type($res)";
136}
137
138class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
139  NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
140  let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
141  let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
142  let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
143  let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
144
145  // Backwards-compatibility builder for an unspecified range.
146  let builders = [
147    OpBuilder<(ins "Type":$resultType), [{
148      build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
149    }]>
150  ];
151
152  // Define this method for the InferIntRangeInterface.
153  let extraClassDefinition = [{
154    // Infer the result ranges based on the range attribute.
155    void $cppClass::inferResultRanges(
156        ArrayRef<::mlir::ConstantIntRanges> argRanges,
157        SetIntRangeFn setResultRanges) {
158        nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
159    }
160  }];
161
162}
163
164//===----------------------------------------------------------------------===//
165// Lane, Warp, SM, Grid index and range
166def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
167def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
168def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
169def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
170def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
171def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
172def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
173
174//===----------------------------------------------------------------------===//
175// Lane Mask Comparison Ops
176def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
177def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
178def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
179def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
180def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
181
182//===----------------------------------------------------------------------===//
183// Thread index and range
184def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
185def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
186def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
187def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
188def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
189def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
190
191//===----------------------------------------------------------------------===//
192// Block index and range
193def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
194def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
195def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
196def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
197def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
198def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
199
200//===----------------------------------------------------------------------===//
201// CTA Cluster index and range
202def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
203def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
204def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
205def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
206def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
207def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
208
209
210//===----------------------------------------------------------------------===//
211// CTA index and range within Cluster
212def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
213def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
214def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
215def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
216def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
217def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
218
219//===----------------------------------------------------------------------===//
220// CTA index and across Cluster dimensions
221def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
222def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
223
224//===----------------------------------------------------------------------===//
225// Clock registers
226def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;
227def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">;
228def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
229
230//===----------------------------------------------------------------------===//
231// envreg registers
232foreach index = !range(0, 32) in {
233  def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
234}
235
236//===----------------------------------------------------------------------===//
237// NVVM approximate op definitions
238//===----------------------------------------------------------------------===//
239
240def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [Pure], 1> {
241  let arguments = (ins F32:$arg);
242  let results = (outs F32:$res);
243  let assemblyFormat = "$arg attr-dict `:` type($res)";
244}
245
246//===----------------------------------------------------------------------===//
247// NVVM redux op definitions
248//===----------------------------------------------------------------------===//
249
250def ReduxKindNone : I32EnumAttrCase<"NONE", 0, "none">;
251def ReduxKindAdd  : I32EnumAttrCase<"ADD", 1, "add">;
252def ReduxKindAnd  : I32EnumAttrCase<"AND", 2, "and">;
253def ReduxKindMax  : I32EnumAttrCase<"MAX", 3, "max">;
254def ReduxKindMin  : I32EnumAttrCase<"MIN", 4, "min">;
255def ReduxKindOr   : I32EnumAttrCase<"OR", 5, "or">;
256def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
257def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
258def ReduxKindXor  : I32EnumAttrCase<"XOR", 8, "xor">;
259
260/// Enum attribute of the different kinds.
261def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
262  [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr,
263    ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
264  let genSpecializedAttr = 0;
265  let cppNamespace = "::mlir::NVVM";
266}
267
268def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
269
270def NVVM_ReduxOp :
271  NVVM_Op<"redux.sync">,
272  Results<(outs LLVM_Type:$res)>,
273  Arguments<(ins LLVM_Type:$val,
274                 ReduxKindAttr:$kind,
275                 I32:$mask_and_clamp)> {
276  string llvmBuilder = [{
277      auto intId = getReduxIntrinsicId($_resultType, $kind);
278      $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp});
279  }];
280  let assemblyFormat = [{
281    $kind $val `,` $mask_and_clamp  attr-dict `:` type($val) `->` type($res)
282   }];
283}
284
285//===----------------------------------------------------------------------===//
286// NVVM Split arrive/wait barrier
287//===----------------------------------------------------------------------===//
288
289/// mbarrier.init instruction with generic pointer type
290def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
291  Arguments<(ins LLVM_AnyPointer:$addr, I32:$count, PtxPredicate:$predicate)> {
292  string llvmBuilder = [{
293      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
294  }];
295  let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
296  let extraClassDeclaration = [{
297    bool hasIntrinsic() { if(getPredicate()) return false; return true; }
298  }];
299  let extraClassDefinition = [{
300    std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
301  }];
302}
303
304/// mbarrier.init instruction with shared pointer type
305def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
306  Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> {
307  string llvmBuilder = [{
308      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
309  }];
310  let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
311  let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
312  let extraClassDefinition = [{
313    std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
314  }];
315}
316
317def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
318  Arguments<(ins LLVM_AnyPointer:$addr)> {
319  string llvmBuilder = [{
320      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_inval, {$addr});
321  }];
322  let assemblyFormat = "$addr attr-dict `:` type(operands)";
323}
324
325def NVVM_MBarrierInvalSharedOp : NVVM_Op<"mbarrier.inval.shared">,
326  Arguments<(ins LLVM_PointerShared:$addr)> {
327  string llvmBuilder = [{
328      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_inval_shared, {$addr});
329  }];
330  let assemblyFormat = "$addr attr-dict `:` type(operands)";
331}
332
333def NVVM_MBarrierArriveOp : NVVM_Op<"mbarrier.arrive">,
334  Results<(outs LLVM_Type:$res)>,
335  Arguments<(ins LLVM_AnyPointer:$addr)> {
336  string llvmBuilder = [{
337      $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_arrive, {$addr});
338  }];
339  let assemblyFormat = "$addr attr-dict `:` type($addr) `->` type($res)";
340}
341
342def NVVM_MBarrierArriveSharedOp : NVVM_Op<"mbarrier.arrive.shared">,
343  Results<(outs LLVM_Type:$res)>,
344  Arguments<(ins LLVM_PointerShared:$addr)> {
345  string llvmBuilder = [{
346      $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {$addr});
347  }];
348  let assemblyFormat = "$addr attr-dict `:` qualified(type($addr)) `->` type($res)";
349}
350
351def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">,
352  Results<(outs LLVM_Type:$res)>,
353  Arguments<(ins LLVM_AnyPointer:$addr, I32:$count)> {
354  string llvmBuilder = [{
355      $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete, {$addr, $count});
356  }];
357  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
358}
359
360def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.shared">,
361  Results<(outs LLVM_Type:$res)>,
362  Arguments<(ins LLVM_PointerShared:$addr, I32:$count)> {
363  string llvmBuilder = [{
364      $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared, {$addr, $count});
365  }];
366  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
367}
368
369def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
370  Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> {
371  let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
372  let extraClassDefinition = [{
373    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
374  }];
375}
376
377def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
378  Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
379  let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
380  let extraClassDefinition = [{
381    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
382  }];
383}
384
385def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
386  Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
387  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
388  let extraClassDefinition = [{
389    std::string $cppClass::getPtx() {
390      return std::string(
391        "{\n\t"
392        ".reg .pred       P1; \n\t"
393        "LAB_WAIT: \n\t"
394        "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
395        "@P1 bra.uni DONE; \n\t"
396        "bra.uni     LAB_WAIT; \n\t"
397        "DONE: \n\t"
398        "}"
399      );
400    }
401  }];
402}
403
404def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
405  Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
406  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
407  let extraClassDefinition = [{
408    std::string $cppClass::getPtx() {
409      return std::string(
410        "{\n\t"
411        ".reg .pred       P1; \n\t"
412        "LAB_WAIT: \n\t"
413        "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
414        "@P1 bra.uni DONE; \n\t"
415        "bra.uni     LAB_WAIT; \n\t"
416        "DONE: \n\t"
417        "}"
418      );
419    }
420  }];
421}
422
423def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
424  Results<(outs LLVM_Type:$res)>,
425  Arguments<(ins LLVM_AnyPointer:$addr, LLVM_Type:$state)> {
426  string llvmBuilder = [{
427      $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_test_wait, {$addr, $state});
428  }];
429  let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)";
430}
431
432def NVVM_MBarrierTestWaitSharedOp : NVVM_Op<"mbarrier.test.wait.shared">,
433  Results<(outs LLVM_Type:$res)>,
434  Arguments<(ins LLVM_PointerShared:$addr, LLVM_Type:$state)> {
435  string llvmBuilder = [{
436      $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_test_wait_shared, {$addr, $state});
437  }];
438  let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)";
439}
440
441//===----------------------------------------------------------------------===//
442// NVVM synchronization op definitions
443//===----------------------------------------------------------------------===//
444
445def NVVM_Barrier0Op : NVVM_IntrOp<"barrier0"> {
446  let assemblyFormat = "attr-dict";
447}
448
449def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
450  let arguments = (ins
451    Optional<I32>:$barrierId,
452    Optional<I32>:$numberOfThreads);
453  string llvmBuilder = [{
454    if ($numberOfThreads && $barrierId) {
455      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier,
456                {$barrierId, $numberOfThreads});
457    } else if($barrierId) {
458      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_n,
459                {$barrierId});
460    } else {
461      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
462    }
463  }];
464  let hasVerifier = 1;
465  let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
466}
467
468def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive">
469{
470  let arguments = (ins Optional<I32>:$barrierId, I32:$numberOfThreads);
471
472  let description = [{
473    Thread that executes this op announces their arrival at the barrier with
474    given id and continue their execution.
475
476    The default barrier id is 0 that is similar to `nvvm.barrier` Op. When
477    `barrierId` is not present, the default barrier id is used.
478
479    [For more information, see PTX ISA]
480    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
481  }];
482
483  let assemblyFormat = "(`id` `=` $barrierId^)? `number_of_threads` `=` $numberOfThreads attr-dict";
484
485  let extraClassDefinition = [{
486    std::string $cppClass::getPtx() {
487      std::string ptx = "bar.arrive ";
488      if (getBarrierId()) { ptx += "%0, %1;"; }
489      else { ptx += "0, %0;"; }
490      return ptx;
491    }
492  }];
493}
494
495def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
496  let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
497
498  let summary = "Cluster Barrier Arrive Op";
499  let description = [{
500    The `cluster.arrive` can be used by the threads within the cluster for synchronization and
501    communication. The `cluster.arrive` instruction marks the warps' arrival at the barrier
502    without causing the executing thread to wait for other participating threads.
503
504    The `aligned` attribute, when provided, generates the .aligned version of the PTX instruction.
505
506    [For more information, see PTX ISA]
507    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster)
508  }];
509
510  string llvmBuilder = [{
511      if ($aligned)
512        createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_arrive_aligned);
513      else
514        createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_arrive);
515  }];
516  let assemblyFormat = "attr-dict";
517}
518
519def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
520  let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
521
522  let summary = "Cluster Barrier Relaxed Arrive Op";
523  let description = [{
524    The `cluster.arrive` can be used by the threads within the cluster for synchronization and
525    communication. The `cluster.arrive` instruction marks the warps' arrival at the barrier
526    without causing the executing thread to wait for other participating threads.
527
528    The `aligned` attribute, when provided, generates the .aligned version of the PTX instruction.
529    The .relaxed qualifier on `cluster.arrive` specifies that there are no memory
530    ordering and visibility guarantees provided for the memory accesses performed prior to
531    `cluster.arrive`.
532
533    [For more information, see PTX ISA]
534    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster)
535  }];
536
537  string llvmBuilder = [{
538      if ($aligned)
539        createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_arrive_relaxed_aligned);
540      else
541        createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_arrive_relaxed);
542  }];
543  let assemblyFormat = "attr-dict";
544}
545
546def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait"> {
547  let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
548
549  let summary = "Cluster Barrier Wait Op";
550  let description = [{
551    The `cluster.wait` causes the executing thread to wait for all non-exited threads
552    of the cluster to perform `cluster.arrive`. The `aligned` attribute, when provided,
553    generates the .aligned version of the PTX instruction.
554
555    [For more information, see PTX ISA]
556    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-barrier-cluster)
557  }];
558
559  string llvmBuilder = [{
560      if ($aligned)
561        createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_wait_aligned);
562      else
563        createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_cluster_wait);
564  }];
565  let assemblyFormat = "attr-dict";
566}
567
568def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> {
569  string llvmBuilder = [{
570      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_fence_sc_cluster);
571  }];
572  let assemblyFormat = "attr-dict";
573}
574
575def SharedSpaceCTA : I32EnumAttrCase<"shared_cta", 0, "cta">;
576def SharedSpaceCluster   : I32EnumAttrCase<"shared_cluster", 1, "cluster">;
577def SharedSpace : I32EnumAttr<"SharedSpace", "Shared memory space",
578  [SharedSpaceCTA, SharedSpaceCluster]> {
579  let genSpecializedAttr = 0;
580  let cppNamespace = "::mlir::NVVM";
581}
582def SharedSpaceAttr : EnumAttr<NVVM_Dialect, SharedSpace, "shared_space"> {
583  let assemblyFormat = "`<` $value `>`";
584}
585
586def ProxyAlias : I32EnumAttrCase<"alias", 0, "alias">;
587def ProxyAsync   : I32EnumAttrCase<"async", 1, "async">;
588def ProxyAsyncGlobal   : I32EnumAttrCase<"async_global", 2, "async.global">;
589def ProxyAsyncShared   : I32EnumAttrCase<"async_shared", 3, "async.shared">;
590def ProxyTensorMap : I32EnumAttrCase<"TENSORMAP", 4, "tensormap">;
591def ProxyGeneric : I32EnumAttrCase<"GENERIC", 5, "generic">;
592def ProxyKind : I32EnumAttr<"ProxyKind", "Proxy kind",
593  [ProxyAlias, ProxyAsync, ProxyAsyncGlobal, ProxyAsyncShared, ProxyTensorMap, ProxyGeneric]> {
594  let genSpecializedAttr = 0;
595  let cppNamespace = "::mlir::NVVM";
596}
597
598def ProxyKindAttr : EnumAttr<NVVM_Dialect, ProxyKind, "proxy_kind"> {
599  let assemblyFormat = "`<` $value `>`";
600}
601
602def NVVM_FenceProxyOp : NVVM_PTXBuilder_Op<"fence.proxy">,
603  Arguments<(ins ProxyKindAttr:$kind,
604                 OptionalAttr<SharedSpaceAttr>:$space)> {
605  let description = [{
606    Fence operation with proxy to establish an ordering between memory accesses
607    that may happen through different proxies.
608    [For more information, see PTX ISA]
609    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar)
610  }];
611
612  let assemblyFormat = "attr-dict";
613  let extraClassDefinition = [{
614    std::string $cppClass::getPtx() {
615      std::string ptx = "fence.proxy.";
616      ptx += stringifyProxyKind(getKind());
617      if(getKind() == NVVM::ProxyKind::async_shared)
618        { ptx += "::"; ptx += stringifySharedSpace(getSpace().value()); }
619      ptx += ";";
620      return ptx;
621    }
622  }];
623  let hasVerifier = 1;
624}
625
626// Attrs describing the scope of the Memory Operation
627def MemScopeKindCTA      : I32EnumAttrCase<"CTA", 0, "cta">;
628def MemScopeKindCluster  : I32EnumAttrCase<"CLUSTER", 1, "cluster">;
629def MemScopeKindGPU      : I32EnumAttrCase<"GPU", 2, "gpu">;
630def MemScopeKindSYS      : I32EnumAttrCase<"SYS", 3, "sys">;
631
632def MemScopeKind : I32EnumAttr<"MemScopeKind", "NVVM Memory Scope kind",
633  [MemScopeKindCTA, MemScopeKindCluster, MemScopeKindGPU, MemScopeKindSYS]> {
634  let genSpecializedAttr = 0;
635  let cppNamespace = "::mlir::NVVM";
636}
637def MemScopeKindAttr : EnumAttr<NVVM_Dialect, MemScopeKind, "mem_scope"> {
638  let assemblyFormat = "`<` $value `>`";
639}
640
641def NVVM_FenceProxyAcquireOp : NVVM_Op<"fence.proxy.acquire">,
642      Arguments<(ins MemScopeKindAttr:$scope, LLVM_PointerGeneric:$addr, I32:$size,
643                     DefaultValuedAttr<ProxyKindAttr,
644                                       "ProxyKind::GENERIC">:$fromProxy,
645                     DefaultValuedAttr<ProxyKindAttr,
646                                       "ProxyKind::TENSORMAP">:$toProxy)> {
647  let summary = "Uni-directional proxy fence operation with acquire semantics";
648  let description = [{
649    `fence.proxy.acquire` is a uni-directional fence used to establish ordering
650    between a prior memory access performed via the generic proxy and a
651    subsequent memory access performed via the tensormap proxy
652
653    The address operand `addr` and the operand `size` together specify the
654    memory range `[addr, addr+size)` on which the ordering guarantees on the
655    memory accesses across the proxies is to be provided. The only supported
656    value for the `size` operand is 128 and must be an immediate. Generic Addressing
657    is used unconditionally, and the address specified by the operand `addr` must
658    fall within the `.global` state space. Otherwise, the behavior is undefined
659    [For more information, see PTX ISA]
660    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar)
661  }];
662
663  let assemblyFormat = "$scope $addr `,` $size (`from_proxy` `=` $fromProxy^)? (`to_proxy` `=` $toProxy^)? attr-dict";
664  let llvmBuilder = [{
665    createIntrinsicCall(
666        builder,
667        getUnidirectionalFenceProxyID($fromProxy, $toProxy, $scope, false),
668        {$addr, $size});
669  }];
670
671  let hasVerifier = 1;
672}
673
674def NVVM_FenceProxyReleaseOp : NVVM_Op<"fence.proxy.release">,
675      Arguments<(ins MemScopeKindAttr:$scope,
676                     DefaultValuedAttr<ProxyKindAttr,
677                                       "ProxyKind::GENERIC">:$fromProxy,
678                     DefaultValuedAttr<ProxyKindAttr,
679                                       "ProxyKind::TENSORMAP">:$toProxy)> {
680  let summary = "Uni-directional proxy fence operation with release semantics";
681  let description = [{
682    `fence.proxy.release` is a uni-directional fence used to establish ordering
683    between a prior memory access performed via the generic proxy and a
684    subsequent memory access performed via the tensormap proxy. `fence.proxy.release`
685    operation can form a release sequence that synchronizes with an acquire
686    sequence that contains the fence.proxy.acquire proxy fence operation
687    [For more information, see PTX ISA]
688    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar)
689  }];
690
691  let assemblyFormat = "$scope (`from_proxy` `=` $fromProxy^)? (`to_proxy` `=` $toProxy^)? attr-dict";
692  let llvmBuilder = [{
693    createIntrinsicCall(builder, getUnidirectionalFenceProxyID(
694                                     $fromProxy, $toProxy, $scope, true));
695  }];
696
697  let hasVerifier = 1;
698}
699
700def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>;
701def SetMaxRegisterActionDecrease   : I32EnumAttrCase<"decrease", 1>;
702def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action",
703  [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> {
704  let genSpecializedAttr = 0;
705  let cppNamespace = "::mlir::NVVM";
706}
707def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
708
709def NVVM_SetMaxRegisterOp : NVVM_Op<"setmaxregister"> {
710  let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action);
711  let assemblyFormat = "$action $regCount attr-dict";
712  let hasVerifier = 1;
713  string llvmBuilder = [{
714    auto intId = (op.getAction() == NVVM::SetMaxRegisterAction::increase) ?
715      llvm::Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32 :
716      llvm::Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32;
717
718    createIntrinsicCall(builder, intId, builder.getInt32($regCount));
719  }];
720}
721
722def NVVM_FenceMbarrierInitOp : NVVM_PTXBuilder_Op<"fence.mbarrier.init"> {
723  let arguments = (ins );
724    let description = [{
725    Fence operation that applies on the prior nvvm.mbarrier.init
726    [For more information, see PTX ISA]
727    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar)
728  }];
729
730  let assemblyFormat = "attr-dict";
731  let extraClassDefinition = [{
732    std::string $cppClass::getPtx() {
733      return std::string("fence.mbarrier_init.release.cluster;");
734    }
735  }];
736}
737
738def ShflKindBfly : I32EnumAttrCase<"bfly", 0>;
739def ShflKindUp   : I32EnumAttrCase<"up", 1>;
740def ShflKindDown : I32EnumAttrCase<"down", 2>;
741def ShflKindIdx  : I32EnumAttrCase<"idx", 3>;
742
743/// Enum attribute of the different shuffle kinds.
744def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
745  [ShflKindBfly, ShflKindUp, ShflKindDown, ShflKindIdx]> {
746  let genSpecializedAttr = 0;
747  let cppNamespace = "::mlir::NVVM";
748}
749def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
750
751def NVVM_ShflOp :
752  NVVM_Op<"shfl.sync">,
753  Results<(outs LLVM_Type:$res)>,
754  Arguments<(ins I32:$thread_mask,
755                 LLVM_Type:$val,
756                 I32:$offset,
757                 I32:$mask_and_clamp,
758                 ShflKindAttr:$kind,
759                 OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
760  let summary = "NVVM Dialect Op for shfl.sync";
761  let description = [{
762    The `shfl.sync` Op implements data shuffle within threads of a warp.
763    The `thread_mask` denotes the threads participating in the Op where
764    the bit position corresponds to a particular thread’s laneid.
765    The `offset` specifies a source lane or source lane offset
766    (depending on `kind`). The `val` is the input value to be copied from
767    the source. The `mask_and_clamp` contains two packed values specifying
768    a mask for logically splitting warps into sub-segments and an upper bound
769    for clamping the source lane index.
770    [For more information, refer PTX ISA]
771    (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync)
772  }];
773  string llvmBuilder = [{
774      auto intId = getShflIntrinsicId(
775          $_resultType, $kind, static_cast<bool>($return_value_and_is_valid));
776      $res = createIntrinsicCall(builder,
777          intId, {$thread_mask, $val, $offset, $mask_and_clamp});
778  }];
779  let assemblyFormat = [{
780    $kind $thread_mask `,` $val `,` $offset `,` $mask_and_clamp  attr-dict
781     `:` type($val) `->` type($res)
782   }];
783   let hasVerifier = 1;
784}
785
786def NVVM_VoteBallotOp :
787  NVVM_Op<"vote.ballot.sync">,
788  Results<(outs LLVM_Type:$res)>,
789  Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
790  string llvmBuilder = [{
791      $res = createIntrinsicCall(builder,
792            llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
793  }];
794  let hasCustomAssemblyFormat = 1;
795}
796
797def NVVM_SyncWarpOp :
798  NVVM_Op<"bar.warp.sync">,
799  Arguments<(ins LLVM_Type:$mask)> {
800  string llvmBuilder = [{
801      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_bar_warp_sync, {$mask});
802  }];
803  let assemblyFormat = "$mask attr-dict `:` type($mask)";
804}
805
806def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
807{
808  let summary = "Elect one leader thread";
809  let description = [{
810    The `elect.sync` instruction elects one predicated active leader
811    thread from among a set of threads specified in membermask.
812    The membermask is set to `0xFFFFFFFF` for the current version
813    of this Op. The predicate result is set to `True` for the
814    leader thread, and `False` for all other threads.
815
816    [For more information, see PTX ISA]
817    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync)
818  }];
819
820  let results = (outs I1:$pred);
821  let assemblyFormat = "attr-dict `->` type(results)";
822  string llvmBuilder = [{
823    auto *resultTuple = createIntrinsicCall(builder,
824        llvm::Intrinsic::nvvm_elect_sync, {builder.getInt32(0xFFFFFFFF)});
825    // Extract the second value into $pred
826    $pred = builder.CreateExtractValue(resultTuple, 1);
827  }];
828}
829
830def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">;
831def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">;
832def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">;
833def LoadCacheModifierLU : I32EnumAttrCase<"LU", 3, "lu">;
834def LoadCacheModifierCV : I32EnumAttrCase<"CV", 4, "cv">;
835
836/// Enum attribute of the different kinds.
837def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
838                                "NVVM load cache modifier kind",
839  [LoadCacheModifierCA, LoadCacheModifierCG, LoadCacheModifierCS,
840    LoadCacheModifierLU, LoadCacheModifierCV]> {
841  let genSpecializedAttr = 0;
842  let cppNamespace = "::mlir::NVVM";
843  let description = [{
844    Enum attribute of the different kinds of cache operators for load instructions.
845
846    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#id62)
847  }];
848}
849
850def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
851
852def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
853  Arguments<(ins LLVM_PointerShared:$dst,
854                 LLVM_PointerGlobal:$src,
855                 I32Attr:$size,
856                 LoadCacheModifierAttr:$modifier,
857                 Optional<LLVM_Type>:$cpSize)> {
858  let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)";
859  let hasVerifier = 1;
860  let extraClassDeclaration = [{
861    static llvm::Intrinsic::ID
862      getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
863                            llvm::SmallVector<llvm::Value *> &args);
864  }];
865  string llvmBuilder = [{
866    llvm::SmallVector<llvm::Value *> translatedOperands;
867    auto id = NVVM::CpAsyncOp::getIntrinsicIDAndArgs(
868      *op, moduleTranslation, translatedOperands);
869    createIntrinsicCall(builder, id, translatedOperands);
870  }];
871}
872
873def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
874  string llvmBuilder = [{
875      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_commit_group);
876  }];
877  let assemblyFormat = "attr-dict";
878}
879
880def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
881  Arguments<(ins I32Attr:$n)> {
882  string llvmBuilder = [{
883      createIntrinsicCall(
884        builder,
885        llvm::Intrinsic::nvvm_cp_async_wait_group,
886        llvm::ConstantInt::get(
887          llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext()),
888          $n));
889  }];
890  let assemblyFormat = "$n attr-dict";
891}
892
893def NVVM_CpAsyncMBarrierArriveOp : NVVM_Op<"cp.async.mbarrier.arrive"> {
894  let summary = "NVVM Dialect Op for cp.async.mbarrier.arrive";
895  let description = [{
896    The `cp.async.mbarrier.arrive` Op makes the mbarrier object track
897    all prior cp.async operations initiated by the executing thread.
898    The `addr` operand specifies the address of the mbarrier object
899    in generic address space. The `noinc` attr impacts how the
900    mbarrier's state is updated.
901    [For more information, refer PTX ISA]
902    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive)
903  }];
904  let assemblyFormat = "$addr attr-dict `:` type(operands)";
905
906  let arguments = (ins
907    LLVM_AnyPointer:$addr, DefaultValuedAttr<I1Attr, "0">:$noinc);
908
909  string llvmBuilder = [{
910    auto intId = $noinc ?
911      llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc :
912      llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
913
914    createIntrinsicCall(builder, intId, {$addr});
915  }];
916}
917
918def NVVM_CpAsyncMBarrierArriveSharedOp : NVVM_Op<"cp.async.mbarrier.arrive.shared"> {
919  let summary = "NVVM Dialect Op for cp.async.mbarrier.arrive.shared";
920  let description = [{
921    The `cp.async.mbarrier.arrive.shared` Op makes the mbarrier object
922    track all prior cp.async operations initiated by the executing thread.
923    The `addr` operand specifies the address of the mbarrier object in
924    shared memory. The `noinc` attr impacts how the mbarrier's state
925    is updated. [For more information, refer PTX ISA]
926    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive)
927  }];
928  let assemblyFormat = "$addr attr-dict `:` type(operands)";
929
930  let arguments = (ins
931    LLVM_PointerShared:$addr, DefaultValuedAttr<I1Attr, "0">:$noinc);
932
933  string llvmBuilder = [{
934    auto intId = $noinc ?
935      llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared :
936      llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared;
937
938    createIntrinsicCall(builder, intId, {$addr});
939  }];
940}
941
942//===----------------------------------------------------------------------===//
943// NVVM Conversion Ops (for "cvt.*" family of PTX instructions)
944//===----------------------------------------------------------------------===//
945
946// Attributes for the floating point rounding modes supported by PTX
947def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
948def FPRoundingModeRN   : I32EnumAttrCase<"RN",   1, "rn">;
949def FPRoundingModeRM   : I32EnumAttrCase<"RM",   2, "rm">;
950def FPRoundingModeRP   : I32EnumAttrCase<"RP",   3, "rp">;
951def FPRoundingModeRZ   : I32EnumAttrCase<"RZ",   4, "rz">;
952def FPRoundingModeRNA  : I32EnumAttrCase<"RNA",  5, "rna">;
953
954def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
955  [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
956    FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
957  let genSpecializedAttr = 0;
958  let cppNamespace = "::mlir::NVVM";
959}
960def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
961  let assemblyFormat = "`<` $value `>`";
962}
963
964def SaturationModeNone   : I32EnumAttrCase<"NONE", 0, "none">;
965def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
966
967def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
968  [SaturationModeNone, SaturationModeFinite]> {
969  let genSpecializedAttr = 0;
970  let cppNamespace = "::mlir::NVVM";
971}
972def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
973  let assemblyFormat = "`<` $value `>`";
974}
975
976def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
977  let summary = "Convert the given float input to TF32";
978  let description = [{
979    This Op converts the given f32 input to tf32.
980    The result `res` is represented as an i32 type.
981    The `relu` attribute, when set, lowers to the '.relu' variant of
982    the cvt instruction. The `rnd` and `sat` attributes specify the
983    the rounding and saturation modes respectively.
984    [For more information, see PTX ISA]
985    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
986  }];
987
988  let hasVerifier = 1;
989  let results = (outs I32:$res);
990  let arguments = (ins
991    F32:$src,
992    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
993    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
994    DefaultValuedAttr<BoolAttr, "false">:$relu);
995
996  let assemblyFormat = "$src attr-dict";
997
998  let extraClassDeclaration = [{
999    static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode,
1000                                              NVVM::SaturationMode,
1001                                              bool hasRelu);
1002  }];
1003
1004  string llvmBuilder = [{
1005    auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu);
1006    $res = createIntrinsicCall(builder, intId, {$src});
1007  }];
1008}
1009
1010//===----------------------------------------------------------------------===//
1011// NVVM MMA Ops
1012//===----------------------------------------------------------------------===//
1013/// Helpers to instantiate different version of wmma intrinsics.
1014/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
1015/// combinations of the intrinsics.
1016class GEOM<int M, int N, int K> {
1017  int m = M;
1018  int n = N;
1019  int k = K;
1020}
1021
1022/// Class containing information about valid mma matrix types.
1023class WMMA_REGS<GEOM Geom, string Frag, string PtxEltType> {
1024  int m = Geom.m;
1025  int n = Geom.n;
1026  int k = Geom.k;
1027  string geom = "m"#Geom.m#"n"#Geom.n#"k"#Geom.k;
1028  string frag = Frag;
1029  string ptx_elt_type = PtxEltType;
1030  string gft = geom#":"#Frag#":"#ptx_elt_type;
1031}
1032
1033//// Generate enum value of the mma.load/mma.store intrinsic.
1034class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
1035  string id =   "llvm::Intrinsic::nvvm_wmma"
1036                # "_" # Frag.geom
1037                # "_" # Op
1038                # "_" # Frag.frag
1039                # "_" # Frag.ptx_elt_type
1040                # "_" # Layout
1041                # !if(WithStride, "_stride", "");
1042}
1043
1044/// Generate the signature part of the mma intrinsic name.
1045class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
1046  list<WMMA_REGS> id_frags = !cond(
1047     // FP16 ops are identified by accumulator & result type.
1048     !eq(A.ptx_elt_type, "f16") : [D, C],
1049     // other ops are identified by input types.
1050     !ne(A.ptx_elt_type, B.ptx_elt_type): [A, B],
1051     true: [A]
1052     );
1053   string ret = !foldl("", id_frags, a, b, !strconcat(a, "_", b.ptx_elt_type));
1054}
1055
1056/// Generate enum value of the wmma.mma intrinsic.
1057class WMMA_NAME<string Op, string ALayout, string BLayout, WMMA_REGS A,
1058  WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
1059  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
1060  string id =   "llvm::Intrinsic::nvvm_wmma"
1061                # "_" # A.geom
1062                # "_" # Op
1063                # "_" # ALayout
1064                # "_" # BLayout
1065                # signature;
1066}
1067
1068// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
1069//   Geom: list of supported geometries.
1070//   TypeN: PTX type of the corresponding fragment's element.
1071//   TypeB and TypeD may be empty if it must match that of TypeA or TypeC.
1072class MMA_OPS<list<GEOM> Geom, list<string> TypeA, list<string> TypeB,
1073            list<string> TypeC, list<string> TypeD> {
1074  list<list<WMMA_REGS>> ret =
1075     !foldl([]<list<WMMA_REGS>>, Geom, t1, geom, !listconcat(t1,
1076     !foldl([]<list<WMMA_REGS>>, TypeA, t2, type_a, !listconcat(t2,
1077     !foldl([]<list<WMMA_REGS>>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3,
1078     !foldl([]<list<WMMA_REGS>>, TypeC, t4, type_c, !listconcat(t4,
1079     !foldl([]<list<WMMA_REGS>>, !if(!size(TypeD), TypeD, [type_c]), t5, type_d, !listconcat(t5,
1080            [[WMMA_REGS<geom, "a", type_a>,
1081              WMMA_REGS<geom, "b", type_b>,
1082              WMMA_REGS<geom, "c", type_c>,
1083              WMMA_REGS<geom, "d", type_d>]]))))))))));
1084   // Debugging aid for readable representation of the list above.
1085   list<list<string>> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]);
1086}
1087
1088/// Creates a list of combinations of load/store operations supported.
1089class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> {
1090  list<WMMA_REGS> ret =
1091     !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
1092     !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
1093     !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
1094            [WMMA_REGS<geom, frag, type>]))))));
1095   // Debugging aid for readable representation of the list above.
1096   list<string> ops = !foreach(x, ret, x.gft);
1097}
1098
1099// Creates list of valid combinations of fragments. This is a subset of what
1100// llvm supports and can be extended as needed.
1101class NVVM_MMA_OPS {
1102  // "wmma" operations
1103  list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
1104            [GEOM<16, 16, 8>],
1105            ["tf32"], [], ["f32"], []>.ret;
1106  list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
1107            [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
1108            ["f16"], [], ["f16", "f32"], []>.ret;
1109  list<list<WMMA_REGS>> i8_wmma_ops = MMA_OPS<
1110            [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
1111            ["s8","u8"], [], ["s32"], []>.ret;
1112  list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
1113            tf32_wmma_ops,
1114            fp_wmma_ops,
1115            i8_wmma_ops);
1116
1117  list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
1118            [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
1119            ["a", "b"], ["f16","s8","u8"]>.ret;
1120  list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
1121            [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
1122            ["c", "d"], ["f16", "f32","s32"]>.ret;
1123  list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
1124            [GEOM<16, 16, 8>],
1125            ["a", "b"], ["tf32"]>.ret;
1126  list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS<
1127            [GEOM<16, 16, 8>],
1128            ["c", "d"], ["f32"]>.ret;
1129  list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
1130                                             ldst_tf32_ab_ops,
1131                                             ldst_tf32_cd_ops);
1132  // Separate A/B/C fragments (loads) from D (stores).
1133  list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
1134  list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
1135
1136  // "mma_sync" operations
1137  list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
1138            [GEOM<16,8,4>, GEOM<16,8,8>],
1139            ["tf32"], [], ["f32"], []>.ret;
1140  list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
1141            [GEOM<16,8,16>, GEOM<16,8,8>],
1142            ["bf16"], [], ["f32"], []>.ret;
1143  list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
1144            [GEOM<8,8,4>],
1145            ["f64"], [], ["f64"], []>.ret;
1146  list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
1147            [GEOM<8,8,4>, GEOM<16,8,8>, GEOM<16,8,16>],
1148            ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
1149  list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
1150            [GEOM<8,8,16>, GEOM<16,8,16>, GEOM<16,8,32>],
1151            ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
1152  list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
1153            [GEOM<8,8,32>, GEOM<16,8,32>, GEOM<16,8,64>],
1154            ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
1155  list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
1156            [GEOM<8,8,128>, GEOM<16,8,128>, GEOM<16,8,256>],
1157            ["b1"], [], ["s32"], []>.ret;
1158  list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
1159            tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
1160            fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
1161}
1162
1163def NVVM_MMA_OPS : NVVM_MMA_OPS;
1164
1165/// Helper to create the mapping between the configuration and the store
1166/// intrinsic enum value.
1167class MMA_ST_INTR<string op> {
1168  list<list<string>> cond0 = !foreach(frag, NVVM_MMA_OPS.all_st_ops,
1169                                !foreach(layout, ["row", "col"],
1170  "if (layout == \"" # layout #  "\" && m == " # frag.m # " &&"
1171  "    n == " #frag.n # " && k == " # frag.k # " && \"" #
1172       frag.ptx_elt_type # "\" == eltype)"
1173  "  return " #WMMA_NAME_LDST<op, frag, layout, 1>.id #";"));
1174  string id = !foldl("",
1175                !foldl([""], cond0, acc, el, !listconcat(acc, el)),
1176                acc1, el1, acc1 # "\n" # el1);
1177}
1178
1179/// Helper to map a mxk shape to a supported mxnxk matrix type. This will return
1180/// the n value of the supported configuration.
1181class MMA_ST_INFER_N<list<WMMA_REGS> ldst> {
1182  list<string> cond = !foreach(frag, ldst,
1183  "if (m == " # frag.m # " && k == " #frag.k # " && \"" #
1184       frag.ptx_elt_type # "\" == eltype)"
1185  "  return "# frag.n #";");
1186  string id = !foldl("", cond, acc, el, acc # "\n" # el);
1187}
1188
1189/// Helper to map a kxn shape to a supported mxnxk matrix type. This will return
1190/// the m value of the supported configuration.
1191class MMA_ST_INFER_M<list<WMMA_REGS> ldst> {
1192  list<string> cond = !foreach(frag, ldst,
1193  "if (n == " # frag.n # " && k == " #frag.k # " && \"" #
1194       frag.ptx_elt_type # "\" == eltype)"
1195  "  return "# frag.m #";");
1196  string id = !foldl("", cond, acc, el, acc # "\n" # el);
1197}
1198
1199/// Helper to map a mxn shape to a supported mxnxk matrix type. This will return
1200/// the k value of the supported configuration.
1201class MMA_ST_INFER_K<list<WMMA_REGS> ldst> {
1202  list<string> cond = !foreach(frag, ldst,
1203  "if (m == " # frag.m # " && n == " #frag.n # " && \"" #
1204       frag.ptx_elt_type # "\" == eltype)"
1205  "  return "# frag.k #";");
1206  string id = !foldl("", cond, acc, el, acc # "\n" # el);
1207}
1208
1209/// Helper to create the mapping between the configuration and the load
1210/// intrinsic enum value.
1211class MMA_LD_INTR<string op> {
1212  list<list<string>> cond0 = !foreach(frag, NVVM_MMA_OPS.all_ld_ops,
1213                                !foreach(layout, ["row", "col"],
1214  "if (layout == \"" # layout #  "\" && m == " # frag.m # " &&"
1215  "    n == " #frag.n # " && k == " # frag.k # " && \"" #
1216       frag.ptx_elt_type # "\" == eltype && frag == \""#frag.frag#"\")"
1217  "  return "# WMMA_NAME_LDST<op, frag, layout, 1>.id #";"));
1218  string id = !foldl("",
1219                !foldl([""], cond0, acc, el, !listconcat(acc, el)),
1220                acc1, el1, acc1 # "\n" # el1);
1221}
1222
1223/// Helper to create the mapping between the configuration and the wmma.mma
1224/// intrinsic enum value.
1225class MMA_MMA_INTR<string opName> {
1226  list<list<list<string>>> cond0 =
1227    !foreach(op, NVVM_MMA_OPS.all_wmma_ops,
1228      !foreach(layoutA, ["row", "col"],
1229        !foreach(layoutB, ["row", "col"],
1230  "if (layoutA == \"" # layoutA #  "\" && layoutB == \"" # layoutB #  "\" && "
1231  "    m == " # op[0].m # " && n == " #op[0].n # " && k == " # op[0].k #
1232  "    && \"" # op[0].ptx_elt_type # "\" == eltypeA && \""
1233   # op[3].ptx_elt_type # "\" == eltypeB)"
1234  "  return " #
1235       WMMA_NAME<opName, layoutA, layoutB, op[0], op[1], op[2], op[3]>.id # ";")));
1236  list<string> f = !foldl([""],
1237                     !foldl([[""]], cond0, acc, el, !listconcat(acc, el)),
1238                          acc1, el1, !listconcat(acc1, el1));
1239  string id = !foldl("", f, acc, el, acc # "\n" # el);
1240}
1241
1242/// Enum attribute for binary (b1) MMA operation type
1243def MMAB1OpNone : I32EnumAttrCase<"none", 0>;
1244def MMAB1OpXorPopc : I32EnumAttrCase<"xor_popc", 1>;
1245def MMAB1OpAndPopc : I32EnumAttrCase<"and_popc", 2>;
1246def MMAB1Op : I32EnumAttr<"MMAB1Op", "MMA binary operations",
1247  [MMAB1OpNone, MMAB1OpXorPopc, MMAB1OpAndPopc]> {
1248  let genSpecializedAttr = 0;
1249  let cppNamespace = "::mlir::NVVM";
1250}
1251def MMAB1OpAttr : EnumAttr<NVVM_Dialect, MMAB1Op, "mma_b1op"> {
1252  let assemblyFormat = "`<` $value `>`";
1253}
1254
1255/// Enum attribute type for the overflow behavior of MMA integer operations
1256def MMAIntOverflowWrap : I32EnumAttrCase<"wrapped", 0>;
1257def MMAIntOverflowSat : I32EnumAttrCase<"satfinite", 1>;
1258def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
1259  [MMAIntOverflowSat, MMAIntOverflowWrap]> {
1260  let genSpecializedAttr = 0;
1261  let cppNamespace = "::mlir::NVVM";
1262}
1263def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
1264  let assemblyFormat = "`<` $value `>`";
1265}
1266
1267/// Attribute to hold the MMA shape
1268def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
1269  let summary = "Attribute for MMA operation shape.";
1270  let parameters = (ins "int":$m, "int":$n, "int":$k);
1271  let assemblyFormat = "`<` struct(params) `>`";
1272}
1273
1274// Returns true if this combination of layout/satf for MMA ops is supported;
1275// false otherwise.
1276// E.g.
1277// if NVVM_MMA_SUPPORTED<...>.ret then
1278//   def : FOO<>; // The record will only be defined for supported ops.
1279//
1280class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
1281  // MMA ops check both layouts.
1282  string layout = layout_a # ":" # layout_b;
1283  string a_type = frags[0].ptx_elt_type;
1284  string b_type = frags[1].ptx_elt_type;
1285  string c_type = frags[2].ptx_elt_type;
1286  string d_type = frags[3].ptx_elt_type;
1287  string geom = frags[0].geom;
1288
1289  // gcd is a shortcut used to identify instructions that depend on
1290  // geom+frag_c+frag_d.
1291  string gcd = geom # ":" # c_type # d_type;
1292  bit ret = !cond(
1293
1294    // Limit satf to valid types
1295    !and(!eq(satf, 1),
1296         !ne(a_type, "s8"),
1297         !ne(a_type, "u8"),
1298         !ne(a_type, "s4"),
1299         !ne(a_type, "u4")): false,
1300
1301    // m8n8k4 has no C=f32 D=f16 variant.
1302    !eq(gcd, "m8n8k4:f32f16"): false,
1303
1304    // only m8n8k4 for f16 does not require row:col layout
1305    !and(!ne(layout, "row:col"),
1306         !or(!ne(geom, "m8n8k4"),
1307             !ne(a_type, "f16"))) : false,
1308
1309    // m16n8k8 requires A and B to be the same type and C and D to be the same
1310    // type.
1311    !and(!eq(geom, "m16n8k8"),
1312         !or(!ne(a_type, b_type),
1313             !ne(c_type, d_type))): false,
1314
1315    // m16n8k8 requires C and D to be the same type.
1316    !and(!eq(geom, "m16n8k8"),
1317         !ne(c_type, d_type)): false,
1318
1319    // All other are OK.
1320    true: true
1321  );
1322}
1323
1324// Returns a list of operation suffixes corresponding to possible b1
1325// multiply-and-accumulate operations for all fragments which have a
1326// b1 type. For all other fragments, the list returned holds a list
1327// containing the empty string.
1328class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
1329  list<string> ret = !cond(
1330    !eq(frags[0].ptx_elt_type, "b1") : ["xor_popc", "and_popc"],
1331    true: [""]
1332  );
1333}
1334
1335/// Generate enum value of the mma.sync intrinsic.
1336class MMA_SYNC_NAME<string ALayout, string BLayout, string b1op, int Satfinite,
1337               WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
1338  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
1339  string id = "llvm::Intrinsic::nvvm_mma"
1340                # !if(!ne(b1op, ""), "_" # b1op, "")
1341                # "_" # A.geom
1342                # "_" # ALayout
1343                # "_" # BLayout
1344                # !if(Satfinite, "_satfinite", "")
1345                # signature;
1346}
1347
1348/// Helper to create the mapping between the configuration and the mma.sync
1349/// intrinsic enum value.
1350class MMA_SYNC_INTR {
1351  list<list<list<list<list<string>>>>> cond0 =
1352    !foreach(op, NVVM_MMA_OPS.all_mma_sync_ops,
1353      !foreach(layoutA, ["row", "col"],
1354        !foreach(layoutB, ["row", "col"],
1355          !foreach (sat, [0, 1],
1356            !foreach (b1op, NVVM_MMA_B1OPS<op>.ret,
1357              !if(NVVM_MMA_SUPPORTED<[op[0], op[1], op[2], op[3]],
1358                                     layoutA, layoutB, sat>.ret,
1359      "if (layoutA == \"" # layoutA #  "\" && layoutB == \"" # layoutB #  "\" && "
1360      "    m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k #
1361      "    && \"" # op[0].ptx_elt_type # "\" == eltypeA && \""
1362       # op[1].ptx_elt_type # "\" == eltypeB && "
1363       # " \"" # op[2].ptx_elt_type # "\" == eltypeC && "
1364       # " \"" # op[3].ptx_elt_type # "\" == eltypeD "
1365       # " && (sat.has_value()  ? " # sat # " == static_cast<int>(*sat) : true)"
1366       # !if(!ne(b1op, ""), " && (b1Op.has_value() ? MMAB1Op::" # b1op # " == *b1Op : true)", "") # ")\n"
1367       # "  return " #
1368       MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";",
1369          "") // if supported
1370          ) // b1op
1371        ) // sat
1372      ) // layoutB
1373    ) // layoutA
1374  ); // all_mma_sync_ops
1375  list<list<list<string>>> f1 = !foldl([[[""]]],
1376                                  !foldl([[[[""]]]], cond0, acc, el,
1377                                      !listconcat(acc, el)),
1378                                    acc1, el1, !listconcat(acc1, el1));
1379  list<list<string>> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1));
1380  list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
1381  string id = !foldl("", f3, acc, el, acc # "\n" # el);
1382}
1383
1384def MMALayoutRow : I32EnumAttrCase<"row", 0>;
1385def MMALayoutCol : I32EnumAttrCase<"col", 1>;
1386
1387/// Enum attribute of the different matrix layout.
1388def MMALayout : I32EnumAttr<"MMALayout", "NVVM MMA layout",
1389  [MMALayoutRow, MMALayoutCol]> {
1390  let genSpecializedAttr = 0;
1391  let cppNamespace = "::mlir::NVVM";
1392}
1393def MMALayoutAttr : EnumAttr<NVVM_Dialect, MMALayout, "mma_layout"> {
1394  let assemblyFormat = "`<` $value `>`";
1395}
1396
1397/// Enum attribute of the different PTX element types used for MMA operands.
1398def MMATypeF16  : I32EnumAttrCase<"f16", 0>;
1399def MMATypeF32  : I32EnumAttrCase<"f32", 1>;
1400def MMATypeTF32 : I32EnumAttrCase<"tf32", 2>;
1401def MMATypeU8 : I32EnumAttrCase<"u8", 3>;
1402def MMATypeS8 : I32EnumAttrCase<"s8", 4>;
1403def MMATypeS32 : I32EnumAttrCase<"s32", 5>;
1404def MMATypeB1 : I32EnumAttrCase<"b1", 6>;
1405def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
1406def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
1407def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
1408def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
1409
1410def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
1411  [MMATypeF16, MMATypeF32, MMATypeTF32,
1412  MMATypeBF16, MMATypeS8, MMATypeU8,
1413  MMATypeS32, MMATypeS4, MMATypeU4,
1414  MMATypeB1, MMATypeF64]> {
1415  let genSpecializedAttr = 0;
1416  let cppNamespace = "::mlir::NVVM";
1417}
1418def MMATypesAttr : EnumAttr<NVVM_Dialect, MMATypes, "mma_type"> {
1419  let assemblyFormat = "`<` $value `>`";
1420}
1421
1422def MMAFragA : I32EnumAttrCase<"a", 0>;
1423def MMAFragB : I32EnumAttrCase<"b", 1>;
1424def MMAFragC : I32EnumAttrCase<"c", 2>;
1425
1426/// Enum attribute of the different frag types.
1427def MMAFrag: I32EnumAttr<"MMAFrag", "NVVM MMA frag type",
1428  [MMAFragA, MMAFragB, MMAFragC]> {
1429  let genSpecializedAttr = 0;
1430  let cppNamespace = "::mlir::NVVM";
1431}
1432def MMAFragAttr : EnumAttr<NVVM_Dialect, MMAFrag, "mma_frag"> {
1433  let assemblyFormat = "`<` $value `>`";
1434}
1435
1436def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">,
1437  Results<(outs LLVM_AnyStruct:$res)>,
1438  Arguments<(ins LLVM_AnyPointer: $ptr, I32: $stride, I32Attr:$m,
1439             I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layout,
1440             MMATypesAttr:$eltype, MMAFragAttr:$frag)> {
1441
1442  let summary = "Warp synchronous matrix load";
1443
1444  // Since LLVM intrinsic IDs are enum that cannot be dynamically generated in
1445  // C++ we instanciate a function in tablegen to map the valide configuration
1446  // to the corresponsding intrinsic ID.
1447  // Because we want a single source of truth, this mean the source of truth
1448  // about valid combinations needs to be in tablgen, therefore we generate
1449  // extra helpers to query valid configurations based on the shapes of
1450  // load/store operations.
1451  let extraClassDeclaration =
1452    "static llvm::Intrinsic::ID getIntrinsicID("
1453    "int m, int n, int k, mlir::NVVM::MMALayout layoutEnum,"
1454    "mlir::NVVM::MMATypes eltypeEnum,mlir::NVVM::MMAFrag fragEnum) {"
1455    "llvm::StringRef layout = stringifyEnum(layoutEnum);"
1456    "llvm::StringRef eltype = stringifyEnum(eltypeEnum);"
1457    "llvm::StringRef frag = stringifyEnum(fragEnum);"
1458    #MMA_LD_INTR<"load">.id# "\n"
1459    "return 0;"
1460    "}\n"
1461    "/// Helpers to find valid n dimension based on mxk load shape.\n"
1462    "static int inferNDimension(int m, int k, mlir::NVVM::MMATypes eltypeEnum) {"
1463    "  llvm::StringRef eltype = stringifyEnum(eltypeEnum);"
1464    #MMA_ST_INFER_N<!filter(op, NVVM_MMA_OPS.all_ld_ops, !eq(op.frag, "a"))>.id# "\n"
1465    "return 0;"
1466    "}\n"
1467    "/// Helpers to find valid m dimension based on kxn load shape.\n"
1468    "static int inferMDimension(int k, int n, mlir::NVVM::MMATypes eltypeEnum) {"
1469    "  llvm::StringRef eltype = stringifyEnum(eltypeEnum);"
1470    #MMA_ST_INFER_M<!filter(op, NVVM_MMA_OPS.all_ld_ops, !eq(op.frag, "b"))>.id# "\n"
1471    "return 0;"
1472    "}\n"
1473    "/// Helpers to find valid k dimension based on mxn load shape.\n"
1474    "static int inferKDimension(int m, int n, mlir::NVVM::MMATypes eltypeEnum) {"
1475    "  llvm::StringRef eltype = stringifyEnum(eltypeEnum);"
1476    #MMA_ST_INFER_K<!filter(op, NVVM_MMA_OPS.all_ld_ops, !eq(op.frag, "c"))>.id# "\n"
1477    "return 0;"
1478    "}\n";
1479
1480
1481  string llvmBuilder = [{
1482      auto operands = moduleTranslation.lookupValues(opInst.getOperands());
1483      auto intId = mlir::NVVM::WMMALoadOp::getIntrinsicID(
1484        $m, $n, $k, $layout, $eltype, $frag);
1485      $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
1486  }];
1487
1488  string baseDescription = [{
1489    The `nvvm.wmma.load` operation loads a matrix collectively using all the
1490    threads in a warp.
1491
1492    The operation takes two arguments, the address from where the matrix
1493    elements are to be loaded from and a stride. The stride argument
1494    represents the leading dimension of the source matrix. The address and
1495    the stride are required to be the same across all threads in the warp.
1496    Each thread in a warp holds a certain number of elements. The Op returns
1497    a LLVMStruct which holds the elements of the matrix held by this thread.
1498
1499    This op is meant to be used along with `nvvm.wmma.store` and
1500    `nvvm.wmma.mma`.
1501
1502    Example:
1503
1504    ```mlir
1505    %2 = nvvm.wmma.load %0, %1
1506      {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32}
1507      : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
1508    ```
1509    }];
1510
1511  let assemblyFormat = "$ptr `,` $stride attr-dict `:` functional-type($ptr, $res)";
1512  let hasVerifier = 1;
1513}
1514
1515def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
1516  Arguments<(ins LLVM_AnyPointer: $ptr,
1517             I32Attr:$m, I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layout,
1518             MMATypesAttr:$eltype, Variadic<LLVM_Type>:$args, I32: $stride)>{
1519  let summary = "Warp synchronous matrix store";
1520
1521  let extraClassDeclaration =
1522    "static llvm::Intrinsic::ID getIntrinsicID("
1523    "int m, int n, int k, mlir::NVVM::MMALayout layoutEnum,"
1524    "mlir::NVVM::MMATypes eltypeEnum) {"
1525    "  llvm::StringRef layout = stringifyEnum(layoutEnum);"
1526    "  llvm::StringRef eltype = stringifyEnum(eltypeEnum);"
1527    #MMA_ST_INTR<"store">.id# "\n"
1528    "return 0;"
1529    "}\n"
1530    "/// Helpers to find valid k dimension based on mxn store shape.\n"
1531    "static int inferKDimension(int m, int n, mlir::NVVM::MMATypes eltypeEnum) {"
1532    "  llvm::StringRef eltype = stringifyEnum(eltypeEnum);"
1533    #MMA_ST_INFER_K<NVVM_MMA_OPS.all_st_ops>.id#  "\n"
1534    "return 0;"
1535    "}";
1536
1537  string llvmBuilder = [{
1538      auto operands = moduleTranslation.lookupValues(opInst.getOperands());
1539      auto intId =
1540        mlir::NVVM::WMMAStoreOp::getIntrinsicID($m, $n, $k, $layout, $eltype);
1541      createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
1542  }];
1543
1544  string baseDescription = [{
1545    The `nvvm.wmma.store` operation stores a matrix collectively using
1546    all the threads in a warp.
1547
1548    The operation takes as arguments the address to where the matrix elements are
1549    to be stored, a stride and the elements to store, held by the current thread.
1550    The stride argument represents the leading dimension of the destination matrix.
1551    The address and the stride are required to be the same across all threads in the
1552    warp.
1553
1554    This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and
1555    `nvvm.wmma.m16n16k16.mma`.
1556
1557    Example:
1558
1559    ```mlir
1560    nvvm.wmma.store %0, %1, %2, %3, %4, %5
1561      {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32}
1562      : !llvm.ptr<3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>
1563    ```
1564  }];
1565
1566  let assemblyFormat = [{
1567    $ptr `,` $stride `,` $args attr-dict `:` qualified(type($ptr)) `,`
1568    type($args)
1569  }];
1570  let hasVerifier = 1;
1571}
1572
1573// Base class for all the variants of WMMA mmaOps that may be defined.
1574def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
1575  Results<(outs LLVM_AnyStruct:$res)>,
1576  Arguments<(ins I32Attr:$m, I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layoutA,
1577             MMALayoutAttr:$layoutB, MMATypesAttr:$eltypeA,
1578             MMATypesAttr:$eltypeB, Variadic<LLVM_Type>:$args)>{
1579  let summary = "Warp synchronous matrix-multiply accumulate using tensor cores.";
1580
1581  let extraClassDeclaration =
1582    "static llvm::Intrinsic::ID getIntrinsicID("
1583    "int m, int n, int k, mlir::NVVM::MMALayout layoutAEnum,"
1584    "mlir::NVVM::MMALayout layoutBEnum, mlir::NVVM::MMATypes eltypeAEnum,"
1585    "mlir::NVVM::MMATypes eltypeBEnum) {"
1586    "llvm::StringRef layoutA = stringifyEnum(layoutAEnum);"
1587    "llvm::StringRef layoutB = stringifyEnum(layoutBEnum);"
1588    "llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);"
1589    "llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);"
1590    #MMA_MMA_INTR<"mma">.id# "\n"
1591    "return 0;"
1592    "}";
1593
1594  string llvmBuilder = [{
1595      auto operands = moduleTranslation.lookupValues(opInst.getOperands());
1596      auto intId = mlir::NVVM::WMMAMmaOp::getIntrinsicID(
1597        $m, $n, $k, $layoutA, $layoutB, $eltypeA, $eltypeB);
1598      $res = createIntrinsicCall(builder, intId, operands);
1599  }];
1600
1601  string baseDescription = [{
1602    The `nvvm.wmma.mma` operation performs a matrix-multiply accumulate
1603    (mma) operation using all the threads in a warp.
1604
1605    The operation performed is represented as `D = A * B + C`. The operation takes
1606    as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the
1607    current thread. The op returns a LLVM struct which holds a part of the result
1608    held by the current thread.
1609
1610    This op is meant to be used along with `nvvm.wmma.load` and
1611    `nvvm.wmma.store`.
1612
1613    Example:
1614
1615    ```mlir
1616    %16 = nvvm.wmma.mma %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15
1617      {eltypeA = "tf32", eltypeB = "f32", k = 8 : i32, layoutA = "row", layoutB = "row", m = 16 : i32, n = 16 : i32}
1618      : (i32, i32, i32, i32, i32, i32, i32, i32, f32, f32, f32, f32, f32, f32, f32, f32)
1619      -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
1620    ```
1621  }];
1622
1623  let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
1624  let hasVerifier = 1;
1625}
1626
1627def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
1628  Arguments<(ins LLVM_PointerShared:$ptr,
1629                 Variadic<I32>:$sources,
1630                 MMALayoutAttr:$layout)> {
1631  let summary = "cooperative matrix store";
1632  let description = [{
1633    Collectively store one or more matrices across all threads in a warp to the
1634    location indicated by the address operand $ptr in shared memory.
1635    [For more information, see PTX ISA]
1636    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
1637  }];
1638
1639  let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
1640  let extraClassDefinition = [{
1641    std::string $cppClass::getPtx() {
1642      int d = getSources().size();
1643      std::string ptx = "stmatrix.sync.aligned";
1644      ptx += ".x" + std::to_string(d);
1645      if (getLayout() == NVVM::MMALayout::col)
1646        ptx += ".trans";
1647      if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};";
1648      if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};";
1649      if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
1650      return ptx;
1651    }
1652  }];
1653  let hasVerifier = 1;
1654}
1655
1656def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
1657  Results<(outs AnyType:$res)>,
1658  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
1659
1660  let summary = "cooperative matrix load";
1661
1662  string llvmBuilder = [{
1663      auto operands = moduleTranslation.lookupValues(opInst.getOperands());
1664      auto intId = getLdMatrixIntrinsicId($layout, $num);
1665      $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
1666  }];
1667
1668  string baseDescription = [{
1669    The `nvvm.ldmatrix` operation collectively loads one or more matrices across
1670    all threads in a warp from the location indicated by the address operand
1671    `ptr` from shared memory.
1672
1673    The attribute `num` indicates how many 8x8 16-bit matrices are to be loaded.
1674
1675    All the threads in the warp must execute the same ldmatrix operations.
1676
1677    Each row of 8 elements needs to be consecutive in memory. Each lane of the
1678    warp contains the start address of a row of 8 elements laid out as below:
1679
1680    ```
1681    num | lane 0--7    | Threads 8--15  | Threads 16--31
1682    1   | addr0--addr7 |                |
1683    2   | addr0--addr7 | addr8--addr15  |
1684    4   | addr0--addr7 | addr8--addr15  | addr16--addr31
1685    ```
1686
1687    Example:
1688    ```mlir
1689    %l1 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<row>} :
1690      (!llvm.ptr<3>) -> i32
1691    %l2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout<row>} :
1692      (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
1693    ```
1694  }];
1695
1696  let assemblyFormat = "$ptr attr-dict `:` functional-type($ptr, $res)";
1697  let hasVerifier = 1;
1698}
1699
1700def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
1701
1702  let summary = "cooperative matrix-multiply and accumulate";
1703
1704  let description = [{
1705    The `nvvm.mma.sync` operation collectively performs the operation
1706    `D = matmul(A, B) + C` using all threads in a warp.
1707
1708    All the threads in the warp must execute the same `mma.sync` operation.
1709
1710    For each possible multiplicand PTX data type, there are one or more possible
1711    instruction shapes given as "mMnNkK". The below table describes the posssibilities
1712    as well as the types required for the operands. Note that the data type for
1713    C (the accumulator) and D (the result) can vary independently when there are
1714    multiple possibilities in the "C/D Type" column.
1715
1716    When an optional attribute cannot be immediately inferred from the types of
1717    the operands and the result during parsing or validation, an error will be
1718    raised.
1719
1720    `b1Op` is only relevant when the binary (b1) type is given to
1721    `multiplicandDataType`. It specifies how the multiply-and-acumulate is
1722    performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`.
1723
1724    `intOverflowBehavior` is only relevant when the `multiplicandType` attribute
1725    is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled
1726    in the accumulator. When the attribute is `satfinite`, the accumulator values
1727    are clamped in the int32 range on overflow. This is the default behavior.
1728    Alternatively, accumulator behavior `wrapped` can also be specified, in
1729    which case overflow wraps from one end of the range to the other.
1730
1731    `layoutA` and `layoutB` are required and should generally be set to
1732    `#nvvm.mma_layout<row>` and `#nvvm.mma_layout<col>` respectively, but other
1733    combinations are possible for certain layouts according to the table below.
1734
1735    ```
1736    | A/B Type | Shape     | ALayout | BLayout | A Type   | B Type   | C/D Type          |
1737    |----------|-----------|---------|---------|----------|----------|-------------------|
1738    | f64      | .m8n8k4   | row     | col     | 1x f64   | 1x f64   | 2x f64            |
1739    | f16      | .m8n8k4   | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
1740    |          | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
1741    |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
1742    | bf16     | .m16n8k8  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
1743    |          | .m16n8k16 | row     | col     | 4x i32   | 2x i32   | 4x f32            |
1744    | tf32     | .m16n8k4  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
1745    |          | .m16n8k8  | row     | col     | 4x i32   | 2x i32   | 2x f16x2 or 4 f32 |
1746    | u8/s8    | .m8n8k16  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
1747    |          | .m16n8k16 | row     | col     | 2x i32   | 1x i32   | 4x i32            |
1748    |          | .m16n8k32 | row     | col     | 4x i32   | 2x i32   | 4x i32            |
1749    | u4/s4    | .m8n8k32  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
1750    |          | m16n8k32  | row     | col     | 2x i32   | 1x i32   | 4x i32            |
1751    |          | m16n8k64  | row     | col     | 4x i32   | 2x i32   | 4x i32            |
1752    | b1       | m8n8k128  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
1753    |          | m16n8k128 | row     | col     | 2x i32   | 1x i32   | 4x i32            |
1754    ```
1755
1756
1757    Example:
1758    ```mlir
1759
1760    %128 = nvvm.mma.sync A[%120, %121, %122, %123]
1761                         B[%124, %125]
1762                         C[%126, %127]
1763                         {layoutA = #nvvm.mma_layout<row>,
1764                          layoutB = #nvvm.mma_layout<col>,
1765                          shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}}
1766        : (vector<2xf16>, vector<2xf16>, vector<2xf16>)
1767           -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
1768    ```
1769  }];
1770
1771  let results = (outs LLVM_AnyStruct:$res);
1772  let arguments = (ins NVVM_MMAShapeAttr:$shape,
1773             OptionalAttr<MMAB1OpAttr>:$b1Op,
1774             OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
1775             MMALayoutAttr:$layoutA,
1776             MMALayoutAttr:$layoutB,
1777             OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
1778             OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
1779             Variadic<LLVM_Type>:$operandA,
1780             Variadic<LLVM_Type>:$operandB,
1781             Variadic<LLVM_Type>:$operandC);
1782
1783  let extraClassDeclaration = !strconcat([{
1784      static llvm::Intrinsic::ID getIntrinsicID(
1785            int64_t m, int64_t n, uint64_t k,
1786            std::optional<MMAB1Op> b1Op,
1787            std::optional<MMAIntOverflow> sat,
1788            mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum,
1789            mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
1790            mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
1791        llvm::StringRef layoutA = stringifyEnum(layoutAEnum);
1792        llvm::StringRef layoutB = stringifyEnum(layoutBEnum);
1793        llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
1794        llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
1795        llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
1796        llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
1797        }],
1798        MMA_SYNC_INTR<>.id, [{
1799        return 0;
1800      }
1801
1802      static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
1803        bool isAccumulator);
1804
1805      MMATypes accumPtxType();
1806      MMATypes resultPtxType();
1807    }]);
1808
1809  let builders = [
1810      OpBuilder<(ins  "Type":$resultType, "ValueRange":$operandA,
1811        "ValueRange":$operandB, "ValueRange":$operandC,
1812        "ArrayRef<int64_t>":$shape, "std::optional<MMAB1Op>":$b1Op,
1813        "std::optional<MMAIntOverflow>":$intOverflow,
1814        "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
1815        "std::optional<std::array<MMALayout, 2>>":$multiplicandLayouts)>
1816    ];
1817
1818  string llvmBuilder = [{
1819    auto operands = moduleTranslation.lookupValues(opInst.getOperands());
1820    auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
1821        $shape.getM(), $shape.getN(), $shape.getK(),
1822        $b1Op, $intOverflowBehavior,
1823        $layoutA, $layoutB,
1824        *$multiplicandAPtxType,
1825        *$multiplicandBPtxType,
1826        op.accumPtxType(),
1827        op.resultPtxType());
1828
1829    $res = createIntrinsicCall(
1830      builder, intId, operands);
1831  }];
1832
1833  let hasCustomAssemblyFormat = 1;
1834  let hasVerifier = 1;
1835}
1836
1837//===----------------------------------------------------------------------===//
1838// NVVM TMA Ops
1839//===----------------------------------------------------------------------===//
1840
1841def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
1842  Arguments<(ins )> {
1843  let assemblyFormat = "attr-dict";
1844  let description = [{
1845    This Op commits all prior initiated but uncommitted cp.async.bulk
1846    instructions into a cp.async.bulk-group.
1847
1848    [For more information, see PTX ISA]
1849    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
1850  }];
1851
1852  string llvmBuilder = [{
1853    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
1854  }];
1855}
1856
1857def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
1858  Arguments<(ins
1859    ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
1860    OptionalAttr<UnitAttr>:$read)> {
1861  let assemblyFormat = "$group attr-dict";
1862  let description = [{
1863    Op waits for completion of the most recent bulk async-groups.
1864
1865    The `$group` operand tells waiting has to be done until for $group or fewer
1866    of the most recent bulk async-groups. If `$group` is 0, the op wait until
1867    all the most recent bulk async-groups have completed.
1868
1869    The `$read` indicates that the waiting has to be done until all the bulk
1870    async operations in the specified bulk async-group have completed reading
1871    from their source locations.
1872
1873    [For more information, see PTX ISA]
1874    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
1875  }];
1876
1877  string llvmBuilder = [{
1878    auto intId = op.getRead() ?
1879      llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
1880      llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
1881    createIntrinsicCall(builder, intId, builder.getInt32($group));
1882  }];
1883}
1884
1885def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
1886  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
1887  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
1888  AttrSizedOperandSegments]>,
1889  Arguments<(ins  LLVM_PointerShared:$dstMem,
1890                  LLVM_AnyPointer:$tmaDescriptor,
1891                  Variadic<I32>:$coordinates,
1892                  LLVM_PointerShared:$mbar,
1893                  Variadic<I16>:$im2colOffsets,
1894                  Optional<I16>:$multicastMask,
1895                  Optional<I64>:$l2CacheHint,
1896                  PtxPredicate:$predicate)> {
1897  let description = [{
1898    Initiates an asynchronous copy operation on the tensor data from global
1899    memory to shared memory.
1900
1901    The Op operates has two load modes:
1902    1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
1903    layout is preserved at the destination.
1904
1905    2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
1906    the elements in the Bounding Box of the source tensor are rearranged into
1907    columns at the destination. In this mode, the tensor has to be at least
1908    3-dimensional.
1909
1910    The `multicastMask` operand is optional. When it is present, the Op copies
1911    data from global memory to shared memory of multiple CTAs in the cluster.
1912    Operand `multicastMask` specifies the destination CTAs in the cluster such
1913    that each bit position in the 16-bit `multicastMask` operand corresponds to
1914    the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
1915
1916    The `l2CacheHint` operand is optional, and it is used to specify cache
1917    eviction policy that may be used during the memory access.
1918
1919    [For more information, see PTX ISA]
1920    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
1921  }];
1922
1923  let assemblyFormat = [{
1924    $dstMem `,`
1925    $tmaDescriptor `,`
1926    $mbar `,`
1927    `box` `[`$coordinates `]`
1928    (`im2col` `[` $im2colOffsets^ `]` )?
1929    (`multicast_mask` `=` $multicastMask^ )?
1930    (`l2_cache_hint` `=` $l2CacheHint^ )?
1931    (`predicate` `=` $predicate^)?
1932    attr-dict  `:` type($dstMem) `,` type($tmaDescriptor)
1933  }];
1934
1935  let extraClassDefinition = [{
1936    std::string $cppClass::getPtx() {
1937      int im2colDim = getIm2colOffsets().size();
1938      int dim = getCoordinates().size();
1939      std::string ptx = "cp.async.bulk.tensor.";
1940      ptx += std::to_string(dim) + "d.";
1941      ptx += "shared::cluster.global.mbarrier::complete_tx::bytes";
1942      if(im2colDim) ptx += ".im2col";
1943      if(getMulticastMask()) ptx += ".multicast::cluster";
1944      if(getL2CacheHint()) ptx += ".L2::cache_hint";
1945
1946      auto preg = [](int r) { return "%" + std::to_string(r); };
1947
1948      // Build Registers
1949      ptx += " [%0], [%1, {";
1950      int r = 2;
1951      for(int i = 0; i < dim; i++) ptx += preg(r+i) + ",";
1952      ptx.pop_back(); r += dim;
1953      ptx += "} ], [%" + std::to_string(r++) + "]";
1954      if(im2colDim) {
1955        ptx += ",{";
1956        for(int i = 0; i < im2colDim; i++) ptx += preg(r+i) + ",";
1957        ptx.pop_back(); r += im2colDim;
1958        ptx += "}";
1959      }
1960      if(getMulticastMask()) ptx += ", " + preg(r++);
1961      if(getL2CacheHint()) ptx += ", " + preg(r++);
1962      ptx += ";";
1963      return ptx;
1964    }
1965  }];
1966  let hasVerifier = 1;
1967}
1968
1969def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
1970  NVVM_Op<"cp.async.bulk.tensor.global.shared.cta",
1971  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
1972  AttrSizedOperandSegments]>,
1973  Arguments<(ins  LLVM_AnyPointer:$tmaDescriptor,
1974                  LLVM_PointerShared:$srcMem,
1975                  Variadic<I32>:$coordinates,
1976                  PtxPredicate:$predicate)> {
1977  let assemblyFormat = [{
1978    $tmaDescriptor `,`
1979    $srcMem `,`
1980    `box` `[`$coordinates `]`
1981    (`,` `predicate` `=` $predicate^)?
1982    attr-dict  `:` type(operands)
1983  }];
1984  let extraClassDefinition = [{
1985    std::string $cppClass::getPtx() {
1986      int dim = getCoordinates().size();
1987      std::string ptx = "cp.async.bulk.tensor.";
1988      ptx += std::to_string(dim) + "d.";
1989      ptx += "global.shared::cta.bulk_group";
1990      if(dim == 1) ptx += " [%0, {%2} ], [%1];";
1991      if(dim == 2) ptx += " [%0, {%2, %3} ], [%1];";
1992      if(dim == 3) ptx += " [%0, {%2, %3, %4} ], [%1];";
1993      if(dim == 4) ptx += " [%0, {%2, %3, %4, %5} ], [%1];";
1994      if(dim == 5) ptx += " [%0, {%2, %3, %4, %5, %6} ], [%1];";
1995      return ptx;
1996    }
1997  }];
1998  let hasVerifier = 1;
1999}
2000
2001def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
2002                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
2003  Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
2004  let assemblyFormat = "$tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
2005  let extraClassDefinition = [{
2006    std::string $cppClass::getPtx() {
2007      return std::string("prefetch.tensormap [%0];");
2008    }
2009  }];
2010}
2011
2012def NVVM_CpAsyncBulkTensorPrefetchOp :
2013  NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
2014  let arguments = (ins
2015    LLVM_AnyPointer:$tmaDescriptor,
2016    Variadic<I32>:$coordinates,
2017    Variadic<I16>:$im2colOffsets,
2018    Optional<I64>:$l2CacheHint);
2019
2020  let description = [{
2021    Initiates an asynchronous prefetch operation on the tensor data from global
2022    memory to L2 cache.
2023
2024    The Op has two modes:
2025    1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
2026    layout is preserved at the destination.
2027
2028    2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
2029    the elements in the Bounding Box of the source tensor are rearranged into
2030    columns at the destination. In this mode, the tensor has to be at least
2031    3-dimensional.
2032
2033    The `l2CacheHint` operand is optional, and it is used to specify cache
2034    eviction policy that may be used during the memory access.
2035
2036    [For more information, see PTX ISA]
2037    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
2038  }];
2039
2040  let assemblyFormat = [{
2041    $tmaDescriptor `,`
2042    `box` `[`$coordinates `]`
2043    (`im2col` `[` $im2colOffsets^ `]` )?
2044    (`l2_cache_hint` `=` $l2CacheHint^ )?
2045    attr-dict  `:` type($tmaDescriptor)
2046  }];
2047
2048  let extraClassDeclaration = [{
2049    static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
2050  }];
2051
2052  let hasVerifier = 1;
2053
2054  string llvmBuilder = [{
2055    // Arguments to the intrinsic:
2056    // tmaDesc, tensorDims, im2colOffsets
2057    // cache_hint(if applicable) and flag(boolean)
2058    llvm::SmallVector<llvm::Value *> translatedOperands;
2059    translatedOperands.push_back($tmaDescriptor);
2060
2061    for (auto v : op.getCoordinates())
2062      translatedOperands.push_back(moduleTranslation.lookupValue(v));
2063
2064    for (auto v : op.getIm2colOffsets())
2065      translatedOperands.push_back(moduleTranslation.lookupValue(v));
2066
2067    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2068    auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2069
2070    bool isCacheHint = op.getL2CacheHint() ? true : false;
2071    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2072    translatedOperands.push_back(builder.getInt1(isCacheHint));
2073
2074    auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
2075        op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
2076    createIntrinsicCall(builder, intId, translatedOperands);
2077  }];
2078}
2079
2080// List of modes supported for TMA Store and Reduction Ops
2081def TMAStoreModeTile   : I32EnumAttrCase<"TILE", 0, "tile">;
2082def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
2083
2084def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
2085    [TMAStoreModeTile, TMAStoreModeIm2Col]> {
2086  let genSpecializedAttr = 0;
2087  let cppNamespace = "::mlir::NVVM";
2088}
2089def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
2090  let assemblyFormat = "`<` $value `>`";
2091}
2092
2093// List of Reduction Ops supported with TMA Store
2094def TMAReduxKindAdd : I32EnumAttrCase<"ADD", 0, "add">;
2095def TMAReduxKindMin : I32EnumAttrCase<"MIN", 1, "min">;
2096def TMAReduxKindMax : I32EnumAttrCase<"MAX", 2, "max">;
2097def TMAReduxKindInc : I32EnumAttrCase<"INC", 3, "inc">;
2098def TMAReduxKindDec : I32EnumAttrCase<"DEC", 4, "dec">;
2099def TMAReduxKindAnd : I32EnumAttrCase<"AND", 5, "and">;
2100def TMAReduxKindOr  : I32EnumAttrCase<"OR",  6, "or">;
2101def TMAReduxKindXor : I32EnumAttrCase<"XOR", 7, "xor">;
2102
2103def TMAReduxKind : I32EnumAttr<"TMAReduxKind", "NVVM TMA redux kind",
2104    [TMAReduxKindAdd, TMAReduxKindMax, TMAReduxKindMin,
2105     TMAReduxKindInc, TMAReduxKindDec, TMAReduxKindAnd,
2106     TMAReduxKindOr,  TMAReduxKindXor]> {
2107  let genSpecializedAttr = 0;
2108  let cppNamespace = "::mlir::NVVM";
2109}
2110def TMAReduxKindAttr : EnumAttr<NVVM_Dialect, TMAReduxKind, "tma_redux_kind"> {
2111  let assemblyFormat = "`<` $value `>`";
2112}
2113
2114def NVVM_CpAsyncBulkTensorReduceOp :
2115  NVVM_Op<"cp.async.bulk.tensor.reduce", [AttrSizedOperandSegments]> {
2116  let arguments = (ins
2117    LLVM_AnyPointer:$tmaDescriptor,
2118    LLVM_PointerShared:$srcMem,
2119    TMAReduxKindAttr:$redKind,
2120    DefaultValuedAttr<TMAStoreModeAttr, "TMAStoreMode::TILE">:$mode,
2121    Variadic<I32>:$coordinates,
2122    Optional<I64>:$l2CacheHint);
2123
2124  let description = [{
2125    Initiates an asynchronous reduction operation of tensor data in
2126    global memory with tensor data in shared memory.
2127
2128    The `mode` attribute indicates whether the copy mode is tile or im2col.
2129    The `redOp` attribute specifies the reduction operations applied.
2130    The supported reduction operations are:
2131    {add, min, max, inc, dec, and, or, xor}
2132
2133    The `l2CacheHint` operand is optional, and it is used to specify cache
2134    eviction policy that may be used during the memory access.
2135
2136    [For more information, see PTX ISA]
2137    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor)
2138  }];
2139
2140  let assemblyFormat = [{
2141    $tmaDescriptor `,`
2142    $srcMem `,`
2143    `box` `[`$coordinates `]`
2144    (`l2_cache_hint` `=` $l2CacheHint^ )?
2145    attr-dict  `:` type($tmaDescriptor) `,` type($srcMem)
2146  }];
2147
2148  let extraClassDeclaration = [{
2149    static llvm::Intrinsic::ID getIntrinsicID(int tensorDims,
2150                                              NVVM::TMAReduxKind kind,
2151                                              bool isIm2Col);
2152  }];
2153
2154  let hasVerifier = 1;
2155
2156  string llvmBuilder = [{
2157    // Arguments to the intrinsic:
2158    // shared_mem_ptr, tmaDesc, tensorDims
2159    // cache_hint(if applicable) and flag(boolean)
2160    llvm::SmallVector<llvm::Value *> translatedOperands;
2161    translatedOperands.push_back($srcMem);
2162    translatedOperands.push_back($tmaDescriptor);
2163
2164    for (auto v : op.getCoordinates())
2165      translatedOperands.push_back(moduleTranslation.lookupValue(v));
2166
2167    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2168    auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
2169
2170    bool isCacheHint = op.getL2CacheHint() ? true : false;
2171    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
2172    translatedOperands.push_back(builder.getInt1(isCacheHint));
2173
2174    auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID(
2175                 op.getCoordinates().size(), $redKind,
2176                 (op.getMode() == NVVM::TMAStoreMode::IM2COL));
2177    createIntrinsicCall(builder, intId, translatedOperands);
2178  }];
2179}
2180
2181def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
2182  NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
2183  let summary = "Async bulk copy from global memory to Shared cluster memory";
2184  let description = [{
2185    Initiates an asynchronous copy operation from global memory to cluster's
2186    shared memory.
2187
2188    The `multicastMask` operand is optional. When it is present, the Op copies
2189    data from global memory to shared memory of multiple CTAs in the cluster.
2190    Operand `multicastMask` specifies the destination CTAs in the cluster such
2191    that each bit position in the 16-bit `multicastMask` operand corresponds to
2192    the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
2193
2194    The `l2CacheHint` operand is optional, and it is used to specify cache
2195    eviction policy that may be used during the memory access.
2196    [For more information, see PTX ISA]
2197    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2198  }];
2199
2200  let arguments = (ins
2201    LLVM_PointerShared:$dstMem,
2202    LLVM_PointerGlobal:$srcMem,
2203    LLVM_PointerShared:$mbar,
2204    I32:$size,
2205    Optional<I16>:$multicastMask,
2206    Optional<I64>:$l2CacheHint);
2207
2208  let assemblyFormat = [{
2209    $dstMem `,` $srcMem `,` $mbar `,` $size
2210    (`multicast_mask` `=` $multicastMask^ )?
2211    (`l2_cache_hint` `=` $l2CacheHint^ )?
2212    attr-dict  `:` type($dstMem) `,` type($srcMem)
2213  }];
2214
2215  string llvmBuilder = [{
2216    // Arguments to the intrinsic:
2217    // dst, mbar, src, size
2218    // multicast_mask, cache_hint,
2219    // flag for multicast_mask,
2220    // flag for cache_hint
2221    llvm::SmallVector<llvm::Value *> translatedOperands;
2222    translatedOperands.push_back($dstMem);
2223    translatedOperands.push_back($mbar);
2224    translatedOperands.push_back($srcMem);
2225    translatedOperands.push_back($size);
2226
2227    // Multicast, if available
2228    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2229    auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0);
2230    bool isMulticast = op.getMulticastMask() ? true : false;
2231    translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused);
2232
2233    // Cachehint, if available
2234    auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2235    bool isCacheHint = op.getL2CacheHint() ? true : false;
2236    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2237
2238    // Flag arguments for multicast and cachehint
2239    translatedOperands.push_back(builder.getInt1(isMulticast));
2240    translatedOperands.push_back(builder.getInt1(isCacheHint));
2241
2242    createIntrinsicCall(builder,
2243      llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands);
2244  }];
2245}
2246
2247def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
2248  NVVM_Op<"cp.async.bulk.shared.cluster.shared.cta"> {
2249  let summary = "Async bulk copy from Shared CTA memory to Shared cluster memory";
2250  let description = [{
2251    Initiates an asynchronous copy operation from Shared CTA memory to Shared
2252    cluster memory.
2253
2254    [For more information, see PTX ISA]
2255    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2256  }];
2257
2258  let arguments = (ins
2259    LLVM_PointerShared:$dstMem,
2260    LLVM_PointerShared:$srcMem,
2261    LLVM_PointerShared:$mbar,
2262    I32:$size);
2263
2264  let assemblyFormat = [{
2265    $dstMem `,` $srcMem `,` $mbar `,` $size
2266    attr-dict  `:` type($dstMem) `,` type($srcMem)
2267  }];
2268
2269  string llvmBuilder = [{
2270    createIntrinsicCall(builder,
2271      llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster,
2272      {$dstMem, $mbar, $srcMem, $size});
2273  }];
2274}
2275
2276def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
2277  NVVM_Op<"cp.async.bulk.global.shared.cta"> {
2278  let summary = "Async bulk copy from Shared CTA memory to Global memory";
2279  let description = [{
2280    Initiates an asynchronous copy operation from Shared CTA memory to
2281    global memory.
2282
2283    The `l2CacheHint` operand is optional, and it is used to specify cache
2284    eviction policy that may be used during the memory access.
2285    [For more information, see PTX ISA]
2286    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
2287  }];
2288
2289  let arguments = (ins
2290    LLVM_PointerGlobal:$dstMem,
2291    LLVM_PointerShared:$srcMem,
2292    I32:$size,
2293    Optional<I64>:$l2CacheHint);
2294
2295  let assemblyFormat = [{
2296    $dstMem `,` $srcMem `,` $size
2297    (`l2_cache_hint` `=` $l2CacheHint^ )?
2298    attr-dict  `:` type($dstMem) `,` type($srcMem)
2299  }];
2300
2301  string llvmBuilder = [{
2302    // Arguments to the intrinsic:
2303    // dst, src, size, cache_hint,
2304    // Flag for cache_hint
2305    //
2306    llvm::SmallVector<llvm::Value *> translatedOperands;
2307    translatedOperands.push_back($dstMem);
2308    translatedOperands.push_back($srcMem);
2309    translatedOperands.push_back($size);
2310
2311    // Cachehint, if available
2312    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2313    auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2314    bool isCacheHint = op.getL2CacheHint() ? true : false;
2315    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2316
2317    // Flag argument for cachehint
2318    translatedOperands.push_back(builder.getInt1(isCacheHint));
2319
2320    createIntrinsicCall(builder,
2321      llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
2322  }];
2323}
2324
2325//===----------------------------------------------------------------------===//
2326// NVVM Wgmma Ops
2327//===----------------------------------------------------------------------===//
2328
2329def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
2330  let arguments = (ins);
2331  let description = [{
2332    Enforce an ordering of register accesses between warpgroup level matrix
2333    multiplication and other operations.
2334
2335    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence)
2336  }];
2337  let assemblyFormat = "attr-dict";
2338  string llvmBuilder = [{
2339    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_fence_sync_aligned);
2340  }];
2341}
2342
2343def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
2344  Arguments<(ins )> {
2345  let assemblyFormat = "attr-dict";
2346  let description = [{
2347    Commits all prior uncommitted warpgroup level matrix multiplication operations.
2348
2349    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group)
2350  }];
2351  string llvmBuilder = [{
2352    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_commit_group_sync_aligned);
2353  }];
2354}
2355
2356def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
2357  let arguments = (ins I64Attr:$group);
2358  let assemblyFormat = "attr-dict $group";
2359  let description = [{
2360    Signal the completion of a preceding warpgroup operation.
2361
2362    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group)
2363  }];
2364  string llvmBuilder = [{
2365    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_wait_group_sync_aligned, builder.getInt64($group));
2366  }];
2367}
2368
2369/// Enum attribute type for the negating of input operands
2370def WGMMAScaleInNeg : I32EnumAttrCase<"neg", -1>;
2371def WGMMAScaleInOne : I32EnumAttrCase<"one", 1>;
2372def WGMMAScaleIn : I32EnumAttr<"WGMMAScaleIn", "WGMMA overflow options",
2373  [WGMMAScaleInOne, WGMMAScaleInNeg]> {
2374  let genSpecializedAttr = 0;
2375  let cppNamespace = "::mlir::NVVM";
2376}
2377def WGMMAScaleInAttr : EnumAttr<NVVM_Dialect, WGMMAScaleIn, "wgmma_scale_in"> {
2378  let assemblyFormat = "`<` $value `>`";
2379}
2380
2381/// Enum attribute type for the output operand
2382def WGMMAScaleOutZero : I32EnumAttrCase<"zero", 0>;
2383def WGMMAScaleOutOne : I32EnumAttrCase<"one", 1>;
2384def WGMMAScaleOut : I32EnumAttr<"WGMMAScaleOut", "WGMMA input predicate",
2385  [WGMMAScaleOutZero, WGMMAScaleOutOne]> {
2386  let genSpecializedAttr = 0;
2387  let cppNamespace = "::mlir::NVVM";
2388}
2389def WGMMAScaleOutAttr : EnumAttr<NVVM_Dialect, WGMMAScaleOut, "wgmma_scale_out"> {
2390  let assemblyFormat = "`<` $value `>`";
2391}
2392
2393/// Enum attribute of the different PTX element types used for WGMMA operands.
2394def WGMMATypeF16  : I32EnumAttrCase<"f16", 0>;
2395def WGMMATypeTF32 : I32EnumAttrCase<"tf32", 1>;
2396def WGMMATypeU8 : I32EnumAttrCase<"u8", 2>;
2397def WGMMATypeS8 : I32EnumAttrCase<"s8", 3>;
2398def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>;
2399def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>;
2400def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>;
2401def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>;
2402def WGMMATypeF32 : I32EnumAttrCase<"f32", 8>;
2403def WGMMATypeS32 : I32EnumAttrCase<"s32", 9>;
2404
2405def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types",
2406  [WGMMATypeF16, WGMMATypeTF32,
2407    WGMMATypeU8, WGMMATypeS8,
2408    WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3,
2409    WGMMATypeF8E5M2, WGMMATypeF32, WGMMATypeS32]> {
2410  let genSpecializedAttr = 0;
2411  let cppNamespace = "::mlir::NVVM";
2412}
2413def WGMMATypesAttr : EnumAttr<NVVM_Dialect, WGMMATypes, "wgmma_type"> {
2414  let assemblyFormat = "`<` $value `>`";
2415}
2416
2417
2418def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
2419              [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
2420                PredOpTrait<"input struct and result struct must be the same type",
2421                  TCresIsSameAsOpBase<0, 0>>,]>
2422{
2423  let results = (outs LLVM_AnyStruct:$results);
2424  let arguments = (ins
2425    LLVM_AnyStruct:$inouts,
2426    I64:$descriptorA,
2427    I64:$descriptorB,
2428    NVVM_MMAShapeAttr:$shape,
2429    WGMMATypesAttr:$typeA,
2430    WGMMATypesAttr:$typeB,
2431    WGMMATypesAttr:$typeD,
2432    WGMMAScaleOutAttr:$scaleD,
2433    WGMMAScaleInAttr:$scaleA,
2434    WGMMAScaleInAttr:$scaleB,
2435    MMALayoutAttr:$layoutA,
2436    MMALayoutAttr:$layoutB,
2437    OptionalAttr<MMAIntOverflowAttr>:$satfinite
2438  );
2439
2440   let assemblyFormat = [{
2441      $descriptorA `,` $descriptorB `,` $inouts `,` $shape `,`
2442      `D` `[` $typeD `,` $scaleD (`,` $satfinite^)? `]` `,`
2443      `A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,`
2444      `B` `[` $typeB `,` $scaleB `,` $layoutB `]`
2445      attr-dict `:`
2446      type($inouts) `->` type($results)
2447    }];
2448
2449  let description = [{
2450    The warpgroup (128 threads) level matrix multiply and accumulate operation
2451    has either of the following forms, where matrix D is called accumulator:
2452      D = A * B + D
2453      D = A * B, where the input from accumulator D is disabled.
2454
2455    Supported shapes:
2456    ```
2457    |--------------|--------------|------------|--------------|---------------|
2458    |              |              |            |              |f16+=e4m3*e4m3 |
2459    |              |              |            |              |f16+=e5m2*e5m2 |
2460    |f32+=tf32*tf32|f16+=f16 *f16 | s32+=s8*s8 |s32 += b1 * b1|f16+=e5m2*e4m3 |
2461    |              |f32+=f16 *f16 | s32+=u8*u8 |              |f16+=e4m3*e5m2 |
2462    |              |f32+=bf16*bf16| s32+=u8*u8 |              |f16+=e4m3*e5m2 |
2463    |              |f32+=bf16*bf16| s32+=s8*u8 |              |f32+=e4m3*e4m3 |
2464    |              |              | s32+=u8*s8 |              |f32+=e5m2*e5m2 |
2465    |              |              |            |              |f32+=e4m3*e5m2 |
2466    |              |              |            |              |f32+=e4m3*e5m2 |
2467    |--------------|--------------|------------|--------------|---------------|
2468    |   .m64n8k8   |  .m64n8k16   | .m64n8k32  | .m64n8k256   | .m64n8k32     |
2469    |   .m64n16k8  |  .m64n16k16  | .m64n16k32 | .m64n16k256  | .m64n16k32    |
2470    |   .m64n24k8  |  .m64n24k16  | .m64n24k32 | .m64n24k256  | .m64n24k32    |
2471    |   .m64n32k8  |  .m64n32k16  | .m64n32k32 | .m64n32k256  | .m64n32k32    |
2472    |   .m64n40k8  |  .m64n40k16  | .m64n48k32 | .m64n48k256  | .m64n40k32    |
2473    |   .m64n48k8  |  .m64n48k16  | .m64n64k32 | .m64n64k256  | .m64n48k32    |
2474    |   .m64n56k8  |  .m64n56k16  | .m64n80k32 | .m64n80k256  | .m64n56k32    |
2475    |   .m64n64k8  |  .m64n64k16  | .m64n96k32 | .m64n96k256  | .m64n64k32    |
2476    |   .m64n72k8  |  .m64n72k16  | .m64n112k32| .m64n112k256 | .m64n72k32    |
2477    |   .m64n80k8  |  .m64n80k16  | .m64n128k32| .m64n128k256 | .m64n80k32    |
2478    |   .m64n88k8  |  .m64n88k16  | .m64n144k32| .m64n144k256 | .m64n88k32    |
2479    |   .m64n96k8  |  .m64n96k16  | .m64n160k32| .m64n160k256 | .m64n96k32    |
2480    |   .m64n104k8 |  .m64n104k16 | .m64n176k32| .m64n176k256 | .m64n104k32   |
2481    |   .m64n112k8 |  .m64n112k16 | .m64n192k32| .m64n192k256 | .m64n112k32   |
2482    |   .m64n120k8 |  .m64n120k16 | .m64n208k32| .m64n208k256 | .m64n120k32   |
2483    |   .m64n128k8 |  .m64n128k16 | .m64n224k32| .m64n224k256 | .m64n128k32   |
2484    |   .m64n136k8 |  .m64n136k16 | .m64n240k32| .m64n240k256 | .m64n136k32   |
2485    |   .m64n144k8 |  .m64n144k16 | .m64n256k32| .m64n256k256 | .m64n144k32   |
2486    |   .m64n152k8 |  .m64n152k16 |            |              | .m64n152k32   |
2487    |   .m64n160k8 |  .m64n160k16 |            |              | .m64n160k32   |
2488    |   .m64n168k8 |  .m64n168k16 |            |              | .m64n168k32   |
2489    |   .m64n176k8 |  .m64n176k16 |            |              | .m64n176k32   |
2490    |   .m64n184k8 |  .m64n184k16 |            |              | .m64n184k32   |
2491    |   .m64n192k8 |  .m64n192k16 |            |              | .m64n192k32   |
2492    |   .m64n200k8 |  .m64n200k16 |            |              | .m64n200k32   |
2493    |   .m64n208k8 |  .m64n208k16 |            |              | .m64n208k32   |
2494    |   .m64n216k8 |  .m64n216k16 |            |              | .m64n216k32   |
2495    |   .m64n224k8 |  .m64n224k16 |            |              | .m64n224k32   |
2496    |   .m64n232k8 |  .m64n232k16 |            |              | .m64n232k32   |
2497    |   .m64n240k8 |  .m64n240k16 |            |              | .m64n240k32   |
2498    |   .m64n248k8 |  .m64n248k16 |            |              | .m64n248k32   |
2499    |   .m64n256k8 |  .m64n256k16 |            |              | .m64n256k32   |
2500    |--------------|--------------|------------|--------------|---------------|
2501    ```
2502
2503
2504    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions)
2505  }];
2506
2507  let hasVerifier = 1;
2508
2509  let extraClassDeclaration = [{
2510    void getAsmValues(RewriterBase &rewriter,
2511        llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
2512  }];
2513}
2514
2515//===----------------------------------------------------------------------===//
2516// NVVM Griddepcontrol Ops
2517//===----------------------------------------------------------------------===//
2518
2519def NVVM_GriddepcontrolWaitOp : NVVM_IntrOp<"griddepcontrol.wait", [], 0> {
2520  let assemblyFormat = "attr-dict";
2521
2522  let description = [{
2523    Causes the executing thread to wait until all prerequisite grids in flight
2524    have completed and all the memory operations from the prerequisite grids
2525    are performed and made visible to the current grid.
2526    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol)
2527  }];
2528}
2529
2530def NVVM_GriddepcontrolLaunchDependentsOp
2531    : NVVM_IntrOp<"griddepcontrol.launch.dependents", [], 0> {
2532  let assemblyFormat = "attr-dict";
2533
2534  let description = [{
2535    Signals that specific dependents the runtime system designated to react to
2536    this instruction can be scheduled as soon as all other CTAs in the grid
2537    issue the same instruction or have completed.
2538    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol)
2539  }];
2540}
2541
2542def NVVM_Exit : NVVM_Op<"exit"> {
2543  let summary = "Exit Op";
2544  let description = [{
2545    Ends execution of a thread.
2546    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-exit)
2547  }];
2548  string llvmBuilder = [{
2549    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_exit);
2550  }];
2551
2552  let assemblyFormat = "attr-dict";
2553}
2554
2555
2556//===----------------------------------------------------------------------===//
2557// NVVM breakpoint Op
2558//===----------------------------------------------------------------------===//
2559
2560def NVVM_Breakpoint : NVVM_Op<"breakpoint"> {
2561  let summary = "Breakpoint Op";
2562  let description = [{
2563    Breakpoint suspends execution of the program for debugging.
2564    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-brkpt)
2565  }];
2566  string llvmBuilder = [{
2567    createIntrinsicCall(builder, llvm::Intrinsic::debugtrap);
2568  }];
2569
2570  let assemblyFormat = "attr-dict";
2571}
2572
2573//===----------------------------------------------------------------------===//
2574// NVVM target attribute.
2575//===----------------------------------------------------------------------===//
2576
2577def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> {
2578  let description = [{
2579    GPU target attribute for controlling compilation of NVIDIA targets. All
2580    parameters decay into default values if not present.
2581
2582    Examples:
2583
2584    1. Target with default values.
2585    ```
2586      gpu.module @mymodule [#nvvm.target] attributes {...} {
2587        ...
2588      }
2589    ```
2590
2591    2. Target with `sm_90` chip and fast math.
2592    ```
2593      gpu.module @mymodule [#nvvm.target<chip = "sm_90", flags = {fast}>] {
2594        ...
2595      }
2596    ```
2597  }];
2598  let parameters = (ins
2599    DefaultValuedParameter<"int", "2", "Optimization level to apply.">:$O,
2600    StringRefParameter<"Target triple.", "\"nvptx64-nvidia-cuda\"">:$triple,
2601    StringRefParameter<"Target chip.", "\"sm_50\"">:$chip,
2602    StringRefParameter<"Target chip features.", "\"+ptx60\"">:$features,
2603    OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags,
2604    OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link
2605  );
2606  let assemblyFormat = [{
2607    (`<` struct($O, $triple, $chip, $features, $flags, $link)^ `>`)?
2608  }];
2609  let builders = [
2610    AttrBuilder<(ins CArg<"int", "2">:$optLevel,
2611                     CArg<"StringRef", "\"nvptx64-nvidia-cuda\"">:$triple,
2612                     CArg<"StringRef", "\"sm_50\"">:$chip,
2613                     CArg<"StringRef", "\"+ptx60\"">:$features,
2614                     CArg<"DictionaryAttr", "nullptr">:$targetFlags,
2615                     CArg<"ArrayAttr", "nullptr">:$linkFiles), [{
2616      return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles);
2617    }]>
2618  ];
2619  let skipDefaultBuilders = 1;
2620  let genVerifyDecl = 1;
2621  let extraClassDeclaration = [{
2622    bool hasFlag(StringRef flag) const;
2623    bool hasFastMath() const;
2624    bool hasFtz() const;
2625  }];
2626  let extraClassDefinition = [{
2627    bool $cppClass::hasFlag(StringRef flag) const {
2628      if (DictionaryAttr flags = getFlags())
2629        return flags.get(flag) != nullptr;
2630      return false;
2631    }
2632    bool $cppClass::hasFastMath() const {
2633      return hasFlag("fast");
2634    }
2635    bool $cppClass::hasFtz() const {
2636      return hasFlag("ftz");
2637    }
2638  }];
2639}
2640
2641#endif // NVVMIR_OPS
2642