xref: /llvm-project/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (revision 7a77f14c0abfbecbfb800ea8d974e66d81ee516a)
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
10 
11 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
19 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
22 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
37 
38 namespace mlir {
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 /// Number of bits that needs to be excluded when building matrix descriptor for
46 /// wgmma operations.
47 constexpr int exclude4LSB = 4;
48 
49 /// GPU has 32 bit registers, this function truncates values when larger width
50 /// is not needed.
51 static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
52   Type type = value.getType();
53   assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
54   if (type.getIntOrFloatBitWidth() <= 32)
55     return value;
56   return b.create<LLVM::TruncOp>(b.getI32Type(), value);
57 }
58 
59 /// Returns the type for the intrinsic given the vectorResultType of the
60 /// `gpu.mma.sync` operation.
61 static Type inferIntrinsicResultType(Type vectorResultType) {
62   MLIRContext *ctx = vectorResultType.getContext();
63   auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
64   auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
65   auto i32Ty = IntegerType::get(ctx, 32);
66   auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
67   Type f64Ty = Float64Type::get(ctx);
68   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
69   Type f32Ty = Float32Type::get(ctx);
70   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
71   if (a.getElementType() == f16x2Ty) {
72     return LLVM::LLVMStructType::getLiteral(
73         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
74   }
75   if (a.getElementType() == i32x2Ty) {
76     return LLVM::LLVMStructType::getLiteral(
77         ctx,
78         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
79   }
80   if (a.getElementType() == f64x2Ty) {
81     return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
82   }
83   if (a.getElementType() == f32x2Ty) {
84     return LLVM::LLVMStructType::getLiteral(
85         ctx,
86         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
87   }
88   if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
89     return LLVM::LLVMStructType::getLiteral(
90         ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
91   }
92   return vectorResultType;
93 }
94 
95 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
96 /// always an LLVM struct) into a fragment that is compatible with the vector
97 /// type of this operation. This involves extracting elements from the struct
98 /// and inserting them into an LLVM array. These extra data-movement
99 /// operations should be canonicalized away by the LLVM backend.
100 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
101                                     Type resultType, Value intrinsicResult,
102                                     RewriterBase &rewriter) {
103   MLIRContext *ctx = rewriter.getContext();
104   auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105   auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
106   Type i32Ty = rewriter.getI32Type();
107   Type f32Ty = rewriter.getF32Type();
108   Type f64Ty = rewriter.getF64Type();
109   Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
110   Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
111   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
112   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
113   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
114 
115   auto makeConst = [&](int32_t index) -> Value {
116     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
117                                              rewriter.getI32IntegerAttr(index));
118   };
119 
120   if (arrayType) {
121     SmallVector<Value, 4> elements;
122 
123     // The intrinsic returns 32-bit wide elements in a form which can be
124     // directly bitcasted and inserted into the result vector.
125     if (arrayType.getElementType() == f16x2Ty ||
126         arrayType.getElementType() == f32x1Ty) {
127       for (unsigned i = 0; i < structType.getBody().size(); i++) {
128         Value el =
129             rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
130         el = rewriter.createOrFold<LLVM::BitcastOp>(
131             loc, arrayType.getElementType(), el);
132         elements.push_back(el);
133       }
134     }
135 
136     // The intrinsic returns i32, f64, and f32 values as individual scalars,
137     // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
138     // need to extract them from the struct and pack them into the 64-bit wide
139     // rows of the vector result.
140     if (arrayType.getElementType() == i32x2Ty ||
141         arrayType.getElementType() == f64x2Ty ||
142         arrayType.getElementType() == f32x2Ty) {
143 
144       for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145         Value vec =
146             rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
147         Value x1 =
148             rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149         Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
150                                                          i * 2 + 1);
151         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
152                                                      x1, makeConst(0));
153         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
154                                                      x2, makeConst(1));
155         elements.push_back(vec);
156       }
157     }
158 
159     // Create the final vectorized result.
160     Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
161     for (const auto &el : llvm::enumerate(elements)) {
162       result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
163                                                     el.index());
164     }
165     return result;
166   }
167 
168   return intrinsicResult;
169 }
170 
171 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
172 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
173 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
174 /// scalars of certain types. This function helps unpack the `vector` arguments
175 /// and cast them to the types expected by `nvvm.mma.sync`.
176 static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
177                                               Value operand,
178                                               NVVM::MMATypes operandPtxType) {
179   SmallVector<Value> result;
180   Type i32Ty = b.getI32Type();
181   Type f64Ty = b.getF64Type();
182   Type f32Ty = b.getF32Type();
183   Type i64Ty = b.getI64Type();
184   Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
185   Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
186   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
187   auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
188 
189   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190     Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
191 
192     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
193     // scalar types.
194     if (arrayTy.getElementType() == i8x4Ty ||
195         arrayTy.getElementType() == i4x8Ty ||
196         (arrayTy.getElementType() == f32x1Ty &&
197          operandPtxType == NVVM::MMATypes::tf32)) {
198       result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
199       continue;
200     }
201 
202     // For some element types (i32, f32, f64), we need to unpack the inner
203     // vector/array type as well because the intrinsic expects individual
204     // scalars to be provided.
205     VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206     if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207                          innerArrayTy.getElementType() == f64Ty ||
208                          innerArrayTy.getElementType() == f32Ty)) {
209       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210            idx < innerSize; idx++) {
211         result.push_back(b.create<LLVM::ExtractElementOp>(
212             toUse,
213             b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
214       }
215       continue;
216     }
217     result.push_back(toUse);
218   }
219   return result;
220 }
221 
222 /// Returns whether mbarrier object has shared memory address space.
223 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
224   return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225       barrierType.getMemorySpace()));
226 }
227 
228 /// Returns the memory space attribute of the mbarrier object.
229 Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context,
230                                         nvgpu::MBarrierGroupType barrierType) {
231   Attribute memorySpace = {};
232   if (isMbarrierShared(barrierType)) {
233     memorySpace =
234         IntegerAttr::get(IntegerType::get(context, 64),
235                          nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
236   }
237   return memorySpace;
238 }
239 
240 /// Returns memref type of the mbarrier object. The type is defined in the
241 /// MBarrierGroupType.
242 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
243                                         nvgpu::MBarrierGroupType barrierType) {
244   Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
245   MemRefLayoutAttrInterface layout;
246   return MemRefType::get({barrierType.getNumBarriers()},
247                          IntegerType::get(context, 64), layout, memorySpace);
248 }
249 
250 namespace {
251 
252 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
253   using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
254 
255   LogicalResult
256   matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
257                   ConversionPatternRewriter &rewriter) const override {
258     MLIRContext *ctx = getContext();
259     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
260 
261     // The result type of ldmatrix will always be a struct of 32bit integer
262     // registers if more than one 32bit value is returned. Otherwise, the result
263     // is a single i32. The result type of the GPU operation is always a vector
264     // of shape (NumRegisters, VectorRegister) where VectorRegister is the
265     // vector type of the result and always 32 bits long. We bitcast the result
266     // of the NVVM::LdMatrix to this vector type.
267     auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268     if (!vectorResultType) {
269       return failure();
270     }
271     Type innerVectorType = LLVM::getFixedVectorType(
272         vectorResultType.getElementType(), vectorResultType.getDimSize(1));
273 
274     int64_t num32BitRegs = vectorResultType.getDimSize(0);
275 
276     Type ldMatrixResultType;
277     if (num32BitRegs > 1) {
278       ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
279           ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
280     } else {
281       ldMatrixResultType = rewriter.getI32Type();
282     }
283 
284     auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285     Value srcPtr =
286         getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287                              adaptor.getIndices(), rewriter);
288     Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289         ldMatrixResultType, srcPtr,
290         /*num=*/op.getNumTiles(),
291         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
292                                      : NVVM::MMALayout::row);
293 
294     // The ldmatrix operation returns either a single i32 value or a struct of
295     // i32 values. Here we unpack those values and cast them back to their
296     // actual vector type (still of width 32b) and repack them into a result
297     // struct.
298     Type finalResultType = typeConverter->convertType(vectorResultType);
299     Value result = b.create<LLVM::UndefOp>(finalResultType);
300     for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301       Value i32Register =
302           num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303                            : ldMatrixResult;
304       Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305       result = b.create<LLVM::InsertValueOp>(result, casted, i);
306     }
307 
308     rewriter.replaceOp(op, result);
309     return success();
310   }
311 };
312 
313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314 /// enum).
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316   Type elType = getElementTypeOrSelf(t);
317   if (elType.isInteger(8))
318     return NVVM::MMATypes::s8;
319   if (elType.isInteger(4))
320     return NVVM::MMATypes::s4;
321   if (elType.isF16())
322     return NVVM::MMATypes::f16;
323   if (elType.isF64())
324     return NVVM::MMATypes::f64;
325   if (elType.isF32())
326     return NVVM::MMATypes::tf32;
327   return failure();
328 }
329 
330 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331   using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
332 
333   LogicalResult
334   matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335                   ConversionPatternRewriter &rewriter) const override {
336     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337     // Get the shapes of the MMAMatrix type being used. The shapes will
338     // choose which intrinsic this op will be lowered to.
339     VectorType aType = op.getMatrixA().getType();
340     VectorType bType = op.getMatrixA().getType();
341     VectorType cType = op.getMatrixC().getType();
342 
343     std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344 
345     // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346     bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347     if (aType.getElementType().isF32() && !tf32Enabled)
348       return failure();
349 
350     FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351     if (failed(ptxTypeA))
352       return op->emitOpError("failed to deduce operand PTX types");
353     FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354     if (failed(ptxTypeB))
355       return op->emitOpError("failed to deduce operand PTX types");
356     std::optional<NVVM::MMATypes> ptxTypeC =
357         NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358                                          /*isAccumulator=*/true);
359     if (!ptxTypeC)
360       return op->emitError(
361           "could not infer the PTX type for the accumulator/result");
362 
363     // TODO: add an attribute to the op to customize this behavior.
364     std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365     if (isa<IntegerType>(aType.getElementType()))
366       overflow = NVVM::MMAIntOverflow::satfinite;
367 
368     SmallVector<Value> matA =
369         unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370     SmallVector<Value> matB =
371         unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372     SmallVector<Value> matC =
373         unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374 
375     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376     Type intrinsicResTy = inferIntrinsicResultType(
377         typeConverter->convertType(op->getResultTypes()[0]));
378     Value intrinsicResult = b.create<NVVM::MmaOp>(
379         intrinsicResTy, matA, matB, matC,
380         /*shape=*/gemmShape,
381         /*b1Op=*/std::nullopt,
382         /*intOverflow=*/overflow,
383         /*multiplicandPtxTypes=*/
384         std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385         /*multiplicandLayouts=*/
386         std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387                                        NVVM::MMALayout::col});
388     rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389                                                   desiredRetTy, intrinsicResult,
390                                                   rewriter));
391     return success();
392   }
393 };
394 
395 struct ConvertNVGPUToNVVMPass
396     : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397   using Base::Base;
398 
399   void getDependentDialects(DialectRegistry &registry) const override {
400     registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401                     arith::ArithDialect>();
402   }
403 
404   void runOnOperation() override {
405     LowerToLLVMOptions options(&getContext());
406     RewritePatternSet patterns(&getContext());
407     LLVMTypeConverter converter(&getContext(), options);
408     IRRewriter rewriter(&getContext());
409     populateGpuMemorySpaceAttributeConversions(
410         converter, [](gpu::AddressSpace space) -> unsigned {
411           switch (space) {
412           case gpu::AddressSpace::Global:
413             return static_cast<unsigned>(
414                 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
415           case gpu::AddressSpace::Workgroup:
416             return static_cast<unsigned>(
417                 NVVM::NVVMMemorySpace::kSharedMemorySpace);
418           case gpu::AddressSpace::Private:
419             return 0;
420           }
421           llvm_unreachable("unknown address space enum value");
422           return 0;
423         });
424     /// device-side async tokens cannot be materialized in nvvm. We just
425     /// convert them to a dummy i32 type in order to easily drop them during
426     /// conversion.
427     converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
428       return converter.convertType(IntegerType::get(type.getContext(), 32));
429     });
430     converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
431       Type elemType = type.getFragmented().getElementType();
432       int64_t sizeM = type.getFragmented().getDimSize(0);
433       int64_t sizeN = type.getFragmented().getDimSize(1);
434 
435       unsigned numMembers;
436       if (elemType.isF32() || elemType.isInteger(32))
437         numMembers = sizeN / 2;
438       else if (elemType.isF16())
439         numMembers = sizeN / 4;
440       else
441         llvm_unreachable("unsupported type for warpgroup accumulator");
442 
443       SmallVector<Type> innerStructBody;
444       for (unsigned i = 0; i < numMembers; i++)
445         innerStructBody.push_back(elemType);
446       auto innerStructType =
447           LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
448 
449       SmallVector<Type> structBody;
450       for (int i = 0; i < sizeM; i += kWgmmaSizeM)
451         structBody.push_back(innerStructType);
452 
453       auto convertedType =
454           LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
455       return converter.convertType(convertedType);
456     });
457     converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
458       return converter.convertType(IntegerType::get(type.getContext(), 64));
459     });
460     converter.addConversion(
461         [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
462           return converter.convertType(IntegerType::get(type.getContext(), 64));
463         });
464     converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
465       return converter.convertType(
466           nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
467     });
468     converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
469       return LLVM::LLVMPointerType::get(type.getContext());
470     });
471     populateNVGPUToNVVMConversionPatterns(converter, patterns);
472     LLVMConversionTarget target(getContext());
473     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474     target.addLegalDialect<::mlir::arith::ArithDialect>();
475     target.addLegalDialect<::mlir::memref::MemRefDialect>();
476     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
477     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
478         converter, patterns, target);
479     if (failed(applyPartialConversion(getOperation(), target,
480                                       std::move(patterns))))
481       signalPassFailure();
482   }
483 };
484 
485 /// Returns the constraints for the sparse MMA inline assembly instruction.
486 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
487                                                      unsigned matBSize,
488                                                      unsigned matCSize) {
489   std::string str;
490   llvm::raw_string_ostream ss(str);
491   for (unsigned i = 0; i < matCSize; i++)
492     ss << "=r,";
493   for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
494     ss << "r,";
495   // The final operand is for the sparsity metadata.
496   // The sparsity selector appears as direct literal.
497   ss << "r";
498   return str;
499 }
500 
501 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
502 /// the given parameters. Note that this function doesn't do any validation,
503 /// it's expected that the provided parameters correspond to a valid
504 /// instruction.
505 static std::string buildMmaSparseAsmString(
506     const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
507     unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509     std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
510   auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511     return NVVM::stringifyMMATypes(ptxType);
512   };
513 
514   std::string asmStr;
515   llvm::raw_string_ostream ss(asmStr);
516   ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
517      << shape[2] << ".row.col.";
518 
519   if (overflow)
520     ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
521 
522   ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
523      << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
524   unsigned asmArgIdx = 0;
525 
526   // The operand string is structured into sections `{matC elements...},
527   // {matA elements...}, {matB elements...}, {matC elements}`.
528   for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529     ss << "{";
530     for (unsigned i = 0; i < arrSize; i++)
531       ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
532     ss << "},";
533   }
534   ss << "$" << asmArgIdx++ << ",";
535   assert(metaDataSelector <= 1);
536   ss << "0x" << metaDataSelector << ";";
537   return asmStr;
538 }
539 
540 /// Builds an inline assembly operation corresponding to the specified MMA
541 /// sparse sync operation.
542 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
543     ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
544     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
545     std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
546     ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
547     int64_t metadataSelector, const std::array<int64_t, 3> &shape,
548     Type intrinsicResultType) {
549   auto asmDialectAttr =
550       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
551 
552   const unsigned matASize = unpackedAData.size();
553   const unsigned matBSize = unpackedB.size();
554   const unsigned matCSize = unpackedC.size();
555 
556   std::string asmStr = buildMmaSparseAsmString(
557       shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
558       ptxTypeD, overflow, metadataSelector);
559   std::string constraintStr =
560       buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
561 
562   SmallVector<Value> asmVals;
563   asmVals.reserve(matASize + matBSize + matCSize + 1);
564   for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
565     llvm::append_range(asmVals, args);
566   asmVals.push_back(indexData);
567 
568   return b.create<LLVM::InlineAsmOp>(
569       /*resultTypes=*/intrinsicResultType,
570       /*operands=*/asmVals,
571       /*asm_string=*/asmStr,
572       /*constraints=*/constraintStr,
573       /*has_side_effects=*/true,
574       /*is_align_stack=*/false,
575       /*asm_dialect=*/asmDialectAttr,
576       /*operand_attrs=*/ArrayAttr());
577 }
578 
579 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
580 struct NVGPUMmaSparseSyncLowering
581     : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
582   using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
583 
584   LogicalResult
585   matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
586                   ConversionPatternRewriter &rewriter) const override {
587     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
588     // Get the shapes of the MMAMatrix type being used. The shapes will
589     // choose which intrinsic this op will be lowered to.
590     VectorType aType = op.getMatrixA().getType();
591     VectorType bType = op.getMatrixB().getType();
592     VectorType cType = op.getMatrixC().getType();
593 
594     FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595     if (failed(ptxTypeA))
596       return op->emitOpError("failed to deduce operand PTX types");
597     FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598     if (failed(ptxTypeB))
599       return op->emitOpError("failed to deduce operand PTX types");
600     std::optional<NVVM::MMATypes> ptxTypeC =
601         NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
602                                          /*isAccumulator=*/true);
603     if (!ptxTypeC)
604       return op->emitError(
605           "could not infer the PTX type for the accumulator/result");
606 
607     // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
608     bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609     if (aType.getElementType().isF32() && !tf32Enabled)
610       return failure();
611 
612     // TODO: add an attribute to the op to customize this behavior.
613     std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614     if (isa<IntegerType>(aType.getElementType()))
615       overflow = NVVM::MMAIntOverflow::satfinite;
616 
617     SmallVector<Value> matA =
618         unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
619     SmallVector<Value> matB =
620         unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
621     SmallVector<Value> matC =
622         unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
623 
624     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
625     Type intrinsicResTy = inferIntrinsicResultType(
626         typeConverter->convertType(op->getResultTypes()[0]));
627 
628     // Bitcast the sparse metadata from vector<2xf16> to an i32.
629     Value sparseMetadata = adaptor.getSparseMetadata();
630     if (sparseMetadata.getType() !=
631         LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
632       return op->emitOpError() << "Expected metadata type to be LLVM "
633                                   "VectorType of 2 i16 elements";
634     sparseMetadata =
635         b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
636 
637     FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
638         b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
639         matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
640         intrinsicResTy);
641     if (failed(intrinsicResult))
642       return failure();
643 
644     assert((*intrinsicResult).getNumResults() == 1 &&
645            "expected inline asm op returns a single LLVM struct type");
646     rewriter.replaceOp(
647         op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
648                                    (*intrinsicResult)->getResult(0), rewriter));
649     return success();
650   }
651 };
652 
653 struct NVGPUAsyncCopyLowering
654     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
655   using ConvertOpToLLVMPattern<
656       nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
657 
658   LogicalResult
659   matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
660                   ConversionPatternRewriter &rewriter) const override {
661     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
662     Location loc = op.getLoc();
663     auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
664     Value dstPtr =
665         getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
666                              adaptor.getDstIndices(), rewriter);
667     FailureOr<unsigned> dstAddressSpace =
668         getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
669     if (failed(dstAddressSpace))
670       return rewriter.notifyMatchFailure(
671           loc, "destination memref address space not convertible to integer");
672 
673     auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
674     FailureOr<unsigned> srcAddressSpace =
675         getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
676     if (failed(srcAddressSpace))
677       return rewriter.notifyMatchFailure(
678           loc, "source memref address space not convertible to integer");
679 
680     Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
681                                         adaptor.getSrcIndices(), rewriter);
682     // Intrinsics takes a global pointer so we need an address space cast.
683     auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
684         op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
685     scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
686     int64_t dstElements = adaptor.getDstElements().getZExtValue();
687     int64_t sizeInBytes =
688         (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
689     // When the optional SrcElements argument is *not* present, the regular
690     // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
691     // memory) to fill DstElements number of elements in the destination
692     // (shared memory).
693     Value srcBytes = adaptor.getSrcElements();
694     if (srcBytes) {
695       // When the optional SrcElements argument is present, the source (global
696       // memory) of CpAsyncOp is read only for SrcElements number of elements.
697       // The rest of the DstElements in the destination (shared memory) are
698       // filled with zeros.
699       Value c3I32 =
700           b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
701       Value bitwidth = b.create<LLVM::ConstantOp>(
702           b.getI32Type(),
703           b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
704       Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
705       srcBytes = b.create<LLVM::LShrOp>(
706           b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
707     }
708     // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
709     // 16 dst bytes.
710     NVVM::LoadCacheModifierKind cacheModifier =
711         (op.getBypassL1().value_or(false) && sizeInBytes == 16)
712             ? NVVM::LoadCacheModifierKind::CG
713             : NVVM::LoadCacheModifierKind::CA;
714 
715     b.create<NVVM::CpAsyncOp>(
716         dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
717         NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
718         srcBytes);
719 
720     // Drop the result token.
721     Value zero = b.create<LLVM::ConstantOp>(
722         IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
723     rewriter.replaceOp(op, zero);
724     return success();
725   }
726 };
727 
728 struct NVGPUAsyncCreateGroupLowering
729     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
730   using ConvertOpToLLVMPattern<
731       nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
732 
733   LogicalResult
734   matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
735                   ConversionPatternRewriter &rewriter) const override {
736     rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
737     // Drop the result token.
738     Value zero = rewriter.create<LLVM::ConstantOp>(
739         op->getLoc(), IntegerType::get(op.getContext(), 32),
740         rewriter.getI32IntegerAttr(0));
741     rewriter.replaceOp(op, zero);
742     return success();
743   }
744 };
745 
746 struct NVGPUAsyncWaitLowering
747     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
748   using ConvertOpToLLVMPattern<
749       nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
750 
751   LogicalResult
752   matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
753                   ConversionPatternRewriter &rewriter) const override {
754     // If numGroup is not present pick 0 as a conservative correct value.
755     int32_t numGroups = adaptor.getNumGroups().value_or(0);
756     rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
757     rewriter.eraseOp(op);
758     return success();
759   }
760 };
761 
762 /// Creates mbarrier object in shared memory
763 struct NVGPUMBarrierCreateLowering
764     : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
765   using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
766 
767   template <typename moduleT>
768   memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
769                                          Operation *funcOp, moduleT moduleOp,
770                                          MemRefType barrierType) const {
771     SymbolTable symbolTable(moduleOp);
772     OpBuilder::InsertionGuard guard(rewriter);
773     rewriter.setInsertionPoint(&moduleOp.front());
774     auto global = rewriter.create<memref::GlobalOp>(
775         funcOp->getLoc(), "__mbarrier",
776         /*sym_visibility=*/rewriter.getStringAttr("private"),
777         /*type=*/barrierType,
778         /*initial_value=*/ElementsAttr(),
779         /*constant=*/false,
780         /*alignment=*/rewriter.getI64IntegerAttr(8));
781     symbolTable.insert(global);
782     return global;
783   }
784 
785   LogicalResult
786   matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
787                   ConversionPatternRewriter &rewriter) const override {
788     Operation *funcOp = op->getParentOp();
789     MemRefType barrierType = nvgpu::getMBarrierMemrefType(
790         rewriter.getContext(), op.getBarriers().getType());
791 
792     memref::GlobalOp global;
793     if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
794       global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
795     else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
796       global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
797 
798     rewriter.setInsertionPoint(op);
799     rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
800                                                      global.getName());
801     return success();
802   }
803 };
804 
805 /// Base class for lowering mbarrier operations to nvvm intrinsics.
806 template <typename SourceOp>
807 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
808 public:
809   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
810   /// Returns the base pointer of the mbarrier object.
811   Value getMbarrierPtr(ImplicitLocOpBuilder &b,
812                        nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
813                        Value mbarId,
814                        ConversionPatternRewriter &rewriter) const {
815     MemRefType mbarrierMemrefType =
816         nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
817     return ConvertToLLVMPattern::getStridedElementPtr(
818         b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
819   }
820 };
821 
822 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
823 struct NVGPUMBarrierInitLowering
824     : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
825   using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
826 
827   LogicalResult
828   matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
829                   ConversionPatternRewriter &rewriter) const override {
830     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
831     nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
832     rewriter.setInsertionPoint(op);
833     Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
834                                    adaptor.getMbarId(), rewriter);
835     Value count = truncToI32(b, adaptor.getCount());
836     if (isMbarrierShared(mbarrierType)) {
837       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
838           op, barrier, count, adaptor.getPredicate());
839     } else {
840       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
841                                                         adaptor.getPredicate());
842     }
843     return success();
844   }
845 };
846 
847 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
848 struct NVGPUMBarrierArriveLowering
849     : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
850   using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
851   LogicalResult
852   matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
853                   ConversionPatternRewriter &rewriter) const override {
854     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
855     Value barrier =
856         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
857                        adaptor.getMbarId(), rewriter);
858     Type tokenType = getTypeConverter()->convertType(
859         nvgpu::MBarrierTokenType::get(op->getContext()));
860     if (isMbarrierShared(op.getBarriers().getType())) {
861       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
862                                                                 barrier);
863     } else {
864       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
865                                                           barrier);
866     }
867     return success();
868   }
869 };
870 
871 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
872 /// `nvvm.mbarrier.arrive.nocomplete`
873 struct NVGPUMBarrierArriveNoCompleteLowering
874     : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
875   using MBarrierBasePattern<
876       nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
877   LogicalResult
878   matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
879                   ConversionPatternRewriter &rewriter) const override {
880     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
881     Value barrier =
882         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
883                        adaptor.getMbarId(), rewriter);
884     Type tokenType = getTypeConverter()->convertType(
885         nvgpu::MBarrierTokenType::get(op->getContext()));
886     Value count = truncToI32(b, adaptor.getCount());
887     if (isMbarrierShared(op.getBarriers().getType())) {
888       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
889           op, tokenType, barrier, count);
890     } else {
891       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
892           op, tokenType, barrier, count);
893     }
894     return success();
895   }
896 };
897 
898 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
899 struct NVGPUMBarrierTestWaitLowering
900     : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
901   using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
902   LogicalResult
903   matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
904                   ConversionPatternRewriter &rewriter) const override {
905     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
906     Value barrier =
907         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
908                        adaptor.getMbarId(), rewriter);
909     Type retType = rewriter.getI1Type();
910     if (isMbarrierShared(op.getBarriers().getType())) {
911       rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
912           op, retType, barrier, adaptor.getToken());
913     } else {
914       rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
915           op, retType, barrier, adaptor.getToken());
916     }
917     return success();
918   }
919 };
920 
921 struct NVGPUMBarrierArriveExpectTxLowering
922     : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
923   using MBarrierBasePattern<
924       nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
925   LogicalResult
926   matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
927                   ConversionPatternRewriter &rewriter) const override {
928     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
929     Value barrier =
930         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
931                        adaptor.getMbarId(), rewriter);
932     Value txcount = truncToI32(b, adaptor.getTxcount());
933 
934     if (isMbarrierShared(op.getBarriers().getType())) {
935       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
936           op, barrier, txcount, adaptor.getPredicate());
937       return success();
938     }
939 
940     rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
941         op, barrier, txcount, adaptor.getPredicate());
942     return success();
943   }
944 };
945 
946 struct NVGPUMBarrierTryWaitParityLowering
947     : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
948   using MBarrierBasePattern<
949       nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
950   LogicalResult
951   matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
952                   ConversionPatternRewriter &rewriter) const override {
953     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
954     Value barrier =
955         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
956                        adaptor.getMbarId(), rewriter);
957     Value ticks = truncToI32(b, adaptor.getTicks());
958     Value phase =
959         b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
960 
961     if (isMbarrierShared(op.getBarriers().getType())) {
962       rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
963           op, barrier, phase, ticks);
964       return success();
965     }
966 
967     rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
968                                                                phase, ticks);
969     return success();
970   }
971 };
972 
973 struct NVGPUTmaAsyncLoadOpLowering
974     : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
975   using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
976   LogicalResult
977   matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
978                   ConversionPatternRewriter &rewriter) const override {
979     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
980     auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
981     Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
982                                       adaptor.getDst(), {}, rewriter);
983     Value barrier =
984         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
985                        adaptor.getMbarId(), rewriter);
986 
987     SmallVector<Value> coords = adaptor.getCoordinates();
988     for (auto [index, value] : llvm::enumerate(coords)) {
989       coords[index] = truncToI32(b, value);
990     }
991     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
992         op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
993         ValueRange{}, adaptor.getMulticastMask(), Value{},
994         adaptor.getPredicate());
995     return success();
996   }
997 };
998 
999 struct NVGPUTmaAsyncStoreOpLowering
1000     : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1001   using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1002   LogicalResult
1003   matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1004                   ConversionPatternRewriter &rewriter) const override {
1005     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1006     auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1007     Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1008                                       adaptor.getSrc(), {}, rewriter);
1009     SmallVector<Value> coords = adaptor.getCoordinates();
1010     for (auto [index, value] : llvm::enumerate(coords)) {
1011       coords[index] = truncToI32(b, value);
1012     }
1013 
1014     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1015         op, adaptor.getTensorMapDescriptor(), dest, coords,
1016         adaptor.getPredicate());
1017     return success();
1018   }
1019 };
1020 
1021 struct NVGPUGenerateWarpgroupDescriptorLowering
1022     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1023   using ConvertOpToLLVMPattern<
1024       nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1025 
1026   LogicalResult
1027   matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1028                   ConversionPatternRewriter &rewriter) const override {
1029 
1030     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1031 
1032     nvgpu::TensorMapSwizzleKind swizzleKind =
1033         op.getTensorMap().getType().getSwizzle();
1034 
1035     unsigned layout =
1036         (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B)  ? 128
1037         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1038         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1039                                                                     : 1;
1040     unsigned swizzle =
1041         (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B)  ? 1
1042         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1043         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1044                                                                     : 0;
1045 
1046     auto ti64 = b.getIntegerType(64);
1047     auto makeConst = [&](uint64_t index) -> Value {
1048       return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1049     };
1050     auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1051       return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1052     };
1053     auto shiftRight = [&](Value value, unsigned shift) -> Value {
1054       return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1055     };
1056     auto insertBit = [&](Value desc, Value val, int startBit) {
1057       return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1058     };
1059 
1060     int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1061     uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1062     uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1063     uint64_t offsetVal = 0;
1064 
1065     Value strideDim = makeConst(strideDimVal);
1066     Value leadDim = makeConst(leadDimVal);
1067 
1068     Value baseAddr = getStridedElementPtr(
1069         op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1070         adaptor.getTensor(), {}, rewriter);
1071     Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1072     // Just use 14 bits for base address
1073     Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1074 
1075     int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1076         startLeadBit = 16, startBaseAddrBit = 0;
1077     Value dsc = makeConst(0);
1078     // // [62,64)  swizzle type
1079     dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1080     // // [49,52)  base_offset
1081     dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1082     // // [32,46)  stride
1083     dsc = insertBit(dsc, strideDim, startStrideBit);
1084     // // [16,30)  leading dimension
1085     dsc = insertBit(dsc, leadDim, startLeadBit);
1086     // // [0,14)   start_address
1087     dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1088 
1089     LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1090                       << "leading_off:" << leadDimVal << "\t"
1091                       << "stride_off :" << strideDimVal << "\t"
1092                       << "base_offset:" << offsetVal << "\t"
1093                       << "layout_type:" << swizzle << " ("
1094                       << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1095                       << ")\n start_addr :  " << baseAddr << "\n");
1096 
1097     rewriter.replaceOp(op, dsc);
1098     return success();
1099   }
1100 };
1101 
1102 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1103   return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1104                                     b.getI32IntegerAttr(index));
1105 }
1106 
1107 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1108 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1109   // Enum is from CUDA driver API
1110   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1111   enum CUtensorMapDataTypeEnum {
1112     CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1113     CU_TENSOR_MAP_DATA_TYPE_UINT16,
1114     CU_TENSOR_MAP_DATA_TYPE_UINT32,
1115     CU_TENSOR_MAP_DATA_TYPE_INT32,
1116     CU_TENSOR_MAP_DATA_TYPE_UINT64,
1117     CU_TENSOR_MAP_DATA_TYPE_INT64,
1118     CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1119     CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1120     CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1121     CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1122     CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1123     CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1124     CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1125   };
1126 
1127   if (type.isUnsignedInteger(8))
1128     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1129   if (type.isUnsignedInteger(16))
1130     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1131   if (type.isUnsignedInteger(32))
1132     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1133   if (type.isUnsignedInteger(64))
1134     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1135   if (type.isSignlessInteger(32))
1136     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1137   if (type.isSignlessInteger(64))
1138     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1139   if (type.isF16())
1140     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1141   if (type.isF32())
1142     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1143   if (type.isF64())
1144     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1145   if (type.isBF16())
1146     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1147 
1148   llvm_unreachable("Not supported data type");
1149 }
1150 
1151 struct NVGPUTmaCreateDescriptorOpLowering
1152     : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1153   using ConvertOpToLLVMPattern<
1154       nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1155   LogicalResult
1156   matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1157                   ConversionPatternRewriter &rewriter) const override {
1158     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1159     auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1160     Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1161 
1162     Value tensorElementType =
1163         elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1164     auto promotedOperands = getTypeConverter()->promoteOperands(
1165         b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1166 
1167     Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1168                                                  makeI64Const(b, 5));
1169     for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1170       Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1171                                         boxArrayPtr, makeI64Const(b, index));
1172       b.create<LLVM::StoreOp>(value, gep);
1173     }
1174 
1175     nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1176     // Set Arguments for the function call
1177     SmallVector<Value> arguments;
1178     arguments.push_back(promotedOperands[0]); // rank
1179     arguments.push_back(promotedOperands[1]); // descriptor
1180     arguments.push_back(tensorElementType);   // data type
1181     arguments.push_back(
1182         makeI64Const(b, (int)desc.getInterleave()));              // interleave
1183     arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1184     arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1185     arguments.push_back(makeI64Const(b, (int)desc.getOob()));     // oob
1186     arguments.push_back(boxArrayPtr); // box dimensions
1187 
1188     // Set data types of the arguments
1189     SmallVector<Type> argTypes = {
1190         llvmInt64Type,   /* int64_t tensorRank */
1191         llvmPointerType, /* ptr */
1192         llvmInt64Type,   /* int64_t */
1193         llvmInt64Type,   /* int64_t */
1194         llvmInt64Type,   /* int64_t */
1195         llvmInt64Type,   /* int64_t */
1196         llvmInt64Type,   /* int64_t */
1197         llvmPointerType  /* ptr  */
1198     };
1199     FunctionCallBuilder hostRegisterCallBuilder = {
1200         "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1201     Value tensorMap =
1202         hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1203 
1204     rewriter.replaceOp(op, tensorMap);
1205     return success();
1206   }
1207 };
1208 
1209 struct NVGPUWarpgroupMmaOpLowering
1210     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1211   using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1212 
1213   /// This is a helper class to generate required NVVM Ops for warp-group level
1214   /// matrix multiplication.
1215   /// When the given GEMM shape is larger than the shape of
1216   /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1217   /// Op(s), group and execute them asynchronously. The class also handles
1218   /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1219   /// create descriptors for each instruction.
1220   ///
1221   /// For example this is the case when the shape of GEMM is 128x128x128
1222   ///
1223   ///    nvvm.wgmma.fence.aligned
1224   ///
1225   ///    nvvm.wgmma.mma.async descA, descB
1226   ///    iterate(descA, descB)
1227   ///    nvvm.wgmma.mma.async descA, descB
1228   ///    [6x times more]
1229   ///
1230   ///    nvvm.wgmma.group.sync.aligned
1231   ///    nvvm.wgmma.wait.group.sync [groupId]
1232   ///
1233   class WarpgroupGemm {
1234     nvgpu::WarpgroupMmaOp op;
1235     ImplicitLocOpBuilder b;
1236     OpAdaptor adaptor;
1237 
1238     // Entire shape of the given Op
1239     int64_t totalM, totalN, totalK;
1240 
1241     // Shape of one wgmma instruction
1242     int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1243 
1244     // Iteration counts for GEMM
1245     int iterationM = 0, iterationN = 0, iterationK = 0;
1246 
1247     /// The function returns the shape of wgmma instruction that is defined in
1248     /// PTX programming guide.
1249     /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1250     void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1251       wgmmaM = 64;
1252       wgmmaN = sizeN;
1253       if (inputElemType.isTF32()) {
1254         wgmmaK = 8;
1255       } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1256         wgmmaK = 16;
1257       } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1258                  inputElemType.isInteger(16)) {
1259         wgmmaK = 32;
1260       } else if (inputElemType.isInteger(1)) {
1261         wgmmaK = 256;
1262       } else {
1263         llvm_unreachable("msg: not supported K shape");
1264       }
1265       LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1266                         << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1267     }
1268 
1269     /// Generates WGMMATypesAttr from MLIR Type
1270     NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1271                                            bool useF32 = false) const {
1272       auto getWgmmaType = [=](Type elemType) {
1273         if (elemType.isF32() || elemType.isTF32())
1274           return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1275         if (elemType.isF16())
1276           return NVVM::WGMMATypes::f16;
1277         if (elemType.isBF16())
1278           return NVVM::WGMMATypes::bf16;
1279         if (isa<Float8E4M3FNType>(elemType))
1280           return NVVM::WGMMATypes::e4m3;
1281         if (isa<Float8E5M2Type>(elemType))
1282           return NVVM::WGMMATypes::e5m2;
1283         if (elemType.isInteger(1))
1284           return NVVM::WGMMATypes::b1;
1285         if (elemType.isInteger(8))
1286           return NVVM::WGMMATypes::s8;
1287         if (elemType.isUnsignedInteger(8))
1288           return NVVM::WGMMATypes::u8;
1289         if (elemType.isInteger(32))
1290           return NVVM::WGMMATypes::s32;
1291         llvm_unreachable("unsupported type");
1292       };
1293       return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1294     }
1295 
1296     /// Generates layout attribute for the input matrix for wgmma instruction
1297     NVVM::MMALayoutAttr
1298     generateWgmmaLayout(std::optional<bool> transpose) const {
1299       if (transpose.value_or(false))
1300         return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1301       return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1302     }
1303 
1304     /// Generates shape attribute for wgmma instruction
1305     NVVM::MMAShapeAttr generateWgmmaShape() const {
1306       return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1307     }
1308 
1309     /// Generates scale attributes of output matrix for wgmma instruction
1310     NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1311       return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1312                                           NVVM::WGMMAScaleOut::one);
1313     }
1314     /// Generates scale attributes of input matrix for wgmma instruction
1315     NVVM::WGMMAScaleInAttr generateScaleIn() const {
1316       return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1317                                          NVVM::WGMMAScaleIn::one);
1318     }
1319 
1320     /// Basic function to generate Add
1321     Value makeAdd(Value lhs, Value rhs) {
1322       return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1323     };
1324 
1325     /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1326     /// Currently, it only handles row-major.
1327     ///
1328     /// It moves the pointer like below for [128][64] size:
1329     ///                 +2 +4 +6
1330     ///                  ↓  ↓  ↓
1331     /// descA    ---> +--+--+--+--+
1332     ///               |->|->|->|->|
1333     ///               |  |  |  |  |
1334     ///               |  |  |  |  |
1335     ///               |  |  |  |  |
1336     /// descA+512---> +-----------+
1337     ///               |  |  |  |  |
1338     ///               |  |  |  |  |
1339     ///               |  |  |  |  |
1340     ///               |  |  |  |  |
1341     ///               +-----------+
1342     ///
1343     Value iterateDescriptorA(Value desc, int i, int j, int k) {
1344       MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1345       Type elemA = matrixTypeA.getElementType();
1346       int byte = elemA.getIntOrFloatBitWidth() / 8;
1347       int tileShapeA = matrixTypeA.getDimSize(1);
1348       int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1349       incrementVal = incrementVal >> exclude4LSB;
1350       LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1351                         << "] [wgmma descriptors] Descriptor A + "
1352                         << incrementVal << " | \t ");
1353       if (!incrementVal)
1354         return desc;
1355       return makeAdd(desc, makeI64Const(b, incrementVal));
1356     }
1357 
1358     /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1359     /// Currently, it only handles column-major.
1360     ///
1361     /// It moves the pointer like below for [128][64] size:
1362     /// descB     ---> +--+--+--+--+--+--+--+--+
1363     ///                |↓ |  |  |  |  |  |  |  |
1364     ///                |↓ |  |  |  |  |  |  |  |
1365     ///                |↓ |  |  |  |  |  |  |  |
1366     ///                |↓ |  |  |  |  |  |  |  |
1367     ///                +--+--+--+--+--+--+--+--+
1368     ///
1369     Value iterateDescriptorB(Value desc, int i, int j, int k) {
1370       MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1371       Type elemB = matrixTypeB.getElementType();
1372       int byte = elemB.getIntOrFloatBitWidth() / 8;
1373       int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1374       incrementVal = incrementVal >> exclude4LSB;
1375       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1376       if (!incrementVal)
1377         return desc;
1378       return makeAdd(desc, makeI64Const(b, incrementVal));
1379     }
1380 
1381     /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1382     /// descriptors and arranges them based on induction variables: i, j, and k.
1383     Value generateWgmma(int i, int j, int k, Value matrixC) {
1384       LLVM_DEBUG(DBGS() << "\t wgmma."
1385                         << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1386                         << "(A[" << (iterationM * wgmmaM) << ":"
1387                         << (iterationM * wgmmaM) + wgmmaM << "]["
1388                         << (iterationK * wgmmaK) << ":"
1389                         << (iterationK * wgmmaK + wgmmaK) << "] * "
1390                         << " B[" << (iterationK * wgmmaK) << ":"
1391                         << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1392                         << wgmmaN << "])\n");
1393 
1394       Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1395       Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1396 
1397       Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1398       NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1399 
1400       Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1401       NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1402 
1403       Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1404       NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1405 
1406       NVVM::MMAShapeAttr shape = generateWgmmaShape();
1407       NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1408       NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1409       NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1410       NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1411 
1412       auto overflow = NVVM::MMAIntOverflowAttr::get(
1413           op->getContext(), NVVM::MMAIntOverflow::wrapped);
1414 
1415       return b.create<NVVM::WgmmaMmaAsyncOp>(
1416           matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1417           itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1418           overflow);
1419     }
1420 
1421     /// Generates multiple wgmma instructions to complete the given GEMM shape
1422     Value generateWgmmaGroup() {
1423       Value wgmmaResult =
1424           b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1425 
1426       // Perform GEMM
1427       SmallVector<Value> wgmmaResults;
1428       for (int i = 0; i < iterationM; ++i) {
1429         Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1430         for (int j = 0; j < iterationN; ++j)
1431           for (int k = 0; k < iterationK; ++k)
1432             matrixC = generateWgmma(i, j, k, matrixC);
1433         wgmmaResults.push_back(matrixC);
1434       }
1435       for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1436         wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1437                                                     wgmmaResult, matrix, idx);
1438       }
1439       return wgmmaResult;
1440     }
1441 
1442   public:
1443     WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1444                   OpAdaptor adaptor)
1445         : op(op), b(b), adaptor(adaptor) {
1446       // Find the entire GEMM Shape
1447       totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1448       totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1449       totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1450       LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1451                         << "] += A[" << totalM << "][" << totalK << "] * B["
1452                         << totalK << "][" << totalN << "] ---===\n");
1453 
1454       // Find the shape for one wgmma instruction
1455       findWgmmaShape(
1456           totalM, totalN,
1457           op.getDescriptorA().getType().getTensor().getElementType());
1458 
1459       // Iterations counts to complete the given shape with wgmma shape
1460       iterationM = totalM / wgmmaM;
1461       iterationN = totalN / wgmmaN;
1462       iterationK = totalK / wgmmaK;
1463     }
1464 
1465     /// Generates WgmmaMmaAsync Ops to complete the specified GEMM  shape. It
1466     /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1467     /// instructions and group synchronization, as well as waiting
1468     /// (WgmmaGroupSyncAlignedOp) for group synchronization
1469     /// (WgmmaWaitGroupSyncOp) after the instructions.
1470     Value generateWarpgroupMma() {
1471       b.create<NVVM::WgmmaFenceAlignedOp>();
1472       Value wgmmaResult = generateWgmmaGroup();
1473       b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1474       b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1475       return wgmmaResult;
1476     }
1477   };
1478   LogicalResult
1479   matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1480                   ConversionPatternRewriter &rewriter) const override {
1481     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1482 
1483     // Step 1. Build a helper class
1484     WarpgroupGemm warpgroupGemm(op, b, adaptor);
1485 
1486     // Step 2. Get the entire GEMM Shape
1487     Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1488 
1489     // Step 3. Replace fragmented result struct with the op results
1490     rewriter.replaceOp(op, wgmmaResult);
1491     return success();
1492   }
1493 };
1494 
1495 struct NVGPUWarpgroupMmaStoreOpLowering
1496     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1497   using ConvertOpToLLVMPattern<
1498       nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1499 
1500   /// This function stores a fragmented register matrix owned by a warp group
1501   /// (128 threads) into a memref. Each thread has 64 registers, each the size
1502   /// of a struct.
1503   /// Here is what each threads (T) holds, each `d` is struct value with a
1504   /// number.
1505   ///
1506   /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1507   /// 0-31 	  Warp-0  -> MatrixD[0:15 ][0:N]
1508   /// 32-63 	Warp-1  -> MatrixD[16:31][0:N]
1509   /// 64-95 	Warp-2  -> MatrixD[32:47][0:N]
1510   /// 96-127 	Warp-3  -> MatrixD[48:64][0:N]
1511   ///
1512   /// Matrix-D:
1513   ///   +______________________________________________________________________+
1514   ///   |     0-1  |    2-3  |    4-5  |    6-7  |   8-9  |   10-11|..|N-8,N-7 |
1515   /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1516   /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1517   /// ..| .........|.........|.........|.........|........|...........|........|
1518   /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1519   /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1520   /// ..| .........|.........|.........|.........|........|...........|........|
1521   /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1522   /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1523   /// ..| .........|.........|.........|.........|........|...........|........|
1524   /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1525   /// ..| .........|.........|.........|.........|........|...........|........|
1526   /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1527   /// ..| .........|.........|.........|.........|........|...........|........|
1528   ///   +______________________________________________________________________+
1529   ///
1530   /// \param rewriter: The pattern rewriter.
1531   /// \param matrixD: Result of the warp-group MMA operation (fragmented
1532   /// matrix). It is holded by a thread and a struct with 64 elements.
1533   /// \param dstMemref: The memref where the registers will be stored.
1534   /// \param offset: the offset within the memref where the registers will be
1535   /// stored.
1536   void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1537                              TypedValue<MemRefType> dstMemref,
1538                              int offset) const {
1539     Type i32 = b.getI32Type();
1540 
1541     auto makeConst = [&](int32_t index) -> Value {
1542       return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1543     };
1544     Value c1 = makeConst(1);
1545     Value c2 = makeConst(2);
1546     Value c4 = makeConst(4);
1547     Value c8 = makeConst(8);
1548     Value c16 = makeConst(16);
1549     Value warpSize = makeConst(kWarpSize);
1550 
1551     auto makeMul = [&](Value lhs, Value rhs) -> Value {
1552       return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1553     };
1554     auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1555       return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1556     };
1557 
1558     auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1559                                    TypedValue<::mlir::MemRefType> memref) {
1560       Type it = b.getIndexType();
1561       Value idx = b.create<arith::IndexCastOp>(it, x);
1562       Value idy0 = b.create<arith::IndexCastOp>(it, y);
1563       Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1564       Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1565       Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1566       b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1567       b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1568     };
1569 
1570     Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1571     Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1572     Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1573     Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1574     Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1575 
1576     Value tj = makeMul(lane4modId, c2);
1577     Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1578     if (offset)
1579       ti = makeAdd(ti, makeConst(offset));
1580 
1581     auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1582 
1583     // Number of 32-bit registers owns per thread
1584     constexpr unsigned numAdjacentRegisters = 2;
1585     // Number of 8x8 matrices one below another per warp
1586     constexpr unsigned numStackedMatrices = 2;
1587 
1588     size_t storeCount = (structType.getBody().size() /
1589                          (numStackedMatrices * numAdjacentRegisters));
1590 
1591     for (size_t i = 0; i < numStackedMatrices; ++i) {
1592       Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1593       for (size_t j = 0; j < storeCount; ++j) {
1594         Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1595         size_t structIndex = (i * numAdjacentRegisters) +
1596                              (j * (numStackedMatrices * numAdjacentRegisters));
1597         makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1598       }
1599     }
1600   }
1601 
1602   LogicalResult
1603   matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1604                   ConversionPatternRewriter &rewriter) const override {
1605     int offset = 0;
1606     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1607     Value matriDValue = adaptor.getMatrixD();
1608     auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1609     for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1610       auto structType = cast<LLVM::LLVMStructType>(matrixD);
1611       Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1612       storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1613       offset += structType.getBody().size();
1614     }
1615     rewriter.eraseOp(op);
1616     return success();
1617   }
1618 };
1619 
1620 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1621     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1622   using ConvertOpToLLVMPattern<
1623       nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1624   LogicalResult
1625   matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1626                   ConversionPatternRewriter &rewriter) const override {
1627     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1628     LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1629         getTypeConverter()->convertType(op.getMatrixC().getType()));
1630     Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1631                         .getBody()
1632                         .front();
1633     Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1634     Value packStruct = b.create<LLVM::UndefOp>(packStructType);
1635     SmallVector<Value> innerStructs;
1636     // Unpack the structs and set all values to zero
1637     for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1638       auto structType = cast<LLVM::LLVMStructType>(s);
1639       Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1640       for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1641         structValue = b.create<LLVM::InsertValueOp>(
1642             structType, structValue, zero, ArrayRef<int64_t>({i}));
1643       }
1644       innerStructs.push_back(structValue);
1645     }
1646     // Pack the inner structs into a single struct
1647     for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1648       packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1649                                                  packStruct, matrix, idx);
1650     }
1651     rewriter.replaceOp(op, packStruct);
1652     return success();
1653   }
1654 };
1655 
1656 struct NVGPUTmaPrefetchOpLowering
1657     : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1658   using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1659   LogicalResult
1660   matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1661                   ConversionPatternRewriter &rewriter) const override {
1662     rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1663         op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1664     return success();
1665   }
1666 };
1667 
1668 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1669   using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1670   LogicalResult
1671   matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1672                   ConversionPatternRewriter &rewriter) const override {
1673     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1674     auto i64Ty = b.getI64Type();
1675     auto f32Ty = b.getF32Type();
1676     VectorType inTy = op.getIn().getType();
1677     // apply rcp.approx.ftz.f on each element in vector.
1678     auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1679       Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy);
1680       int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1681       for (int i = 0; i < numElems; i++) {
1682         Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
1683         Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
1684         Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1685         ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1686       }
1687       return ret1DVec;
1688     };
1689     if (inTy.getRank() == 1) {
1690       rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1691       return success();
1692     }
1693     return LLVM::detail::handleMultidimensionalVectors(
1694         op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1695         [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1696           OpAdaptor adaptor(operands);
1697           return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1698         },
1699         rewriter);
1700   }
1701 };
1702 } // namespace
1703 
1704 void mlir::populateNVGPUToNVVMConversionPatterns(
1705     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1706   patterns.add<
1707       NVGPUMBarrierCreateLowering,           // nvgpu.mbarrier.create
1708       NVGPUMBarrierInitLowering,             // nvgpu.mbarrier.init
1709       NVGPUMBarrierArriveLowering,           // nvgpu.mbarrier.arrive
1710       NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1711       NVGPUMBarrierTestWaitLowering,         // nvgpu.mbarrier.test_wait_parity
1712       NVGPUMBarrierTryWaitParityLowering,    // nvgpu.mbarrier.try_wait_parity
1713       NVGPUTmaAsyncLoadOpLowering,           // nvgpu.tma.async.load
1714       NVGPUTmaAsyncStoreOpLowering,          // nvgpu.tma.async.store
1715       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
1716       NVGPUTmaPrefetchOpLowering,            // nvgpu.tma.prefetch.descriptor
1717       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
1718       NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1719       NVGPUWarpgroupMmaOpLowering,              // nvgpu.warpgroup.mma
1720       NVGPUWarpgroupMmaStoreOpLowering,         // nvgpu.warpgroup.mma.store
1721       NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1722       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1723       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1724       NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1725 }
1726