xref: /llvm-project/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (revision 7a77f14c0abfbecbfb800ea8d974e66d81ee516a)
1894a591cSThomas Raoux //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2894a591cSThomas Raoux //
3894a591cSThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4894a591cSThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
5894a591cSThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6894a591cSThomas Raoux //
7894a591cSThomas Raoux //===----------------------------------------------------------------------===//
8894a591cSThomas Raoux 
9894a591cSThomas Raoux #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
1067d0d7acSMichele Scuttari 
11e56d6745SGuray Ozen #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12894a591cSThomas Raoux #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13894a591cSThomas Raoux #include "mlir/Conversion/LLVMCommon/Pattern.h"
142b23e6c8SObserver007 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15d20fbc90SGuray Ozen #include "mlir/Dialect/Arith/IR/Arith.h"
16d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17708185f0SChristopher Bate #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18e56d6745SGuray Ozen #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
19894a591cSThomas Raoux #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
20affcfccdSGuray Ozen #include "mlir/Dialect/MemRef/IR/MemRef.h"
2151b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
2223882226SGuray Ozen #include "mlir/Dialect/SCF/Transforms/Patterns.h"
2317649a77SGuray Ozen #include "mlir/IR/BuiltinTypes.h"
24ee49cda7SGuray Ozen #include "mlir/IR/ImplicitLocOpBuilder.h"
25e56d6745SGuray Ozen #include "mlir/IR/PatternMatch.h"
26708185f0SChristopher Bate #include "mlir/IR/TypeUtilities.h"
2717649a77SGuray Ozen #include "mlir/IR/Value.h"
2867d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
29b96d0693SGuray Ozen #include "llvm/Support/Debug.h"
3023882226SGuray Ozen #include "llvm/Support/ErrorHandling.h"
31e56d6745SGuray Ozen #include "llvm/Support/raw_ostream.h"
3263389326SGuray Ozen #include <optional>
3367d0d7acSMichele Scuttari 
34b96d0693SGuray Ozen #define DEBUG_TYPE "nvgpu-to-nvvm"
35b96d0693SGuray Ozen #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36b96d0693SGuray Ozen #define DBGSE() (llvm::dbgs())
37b96d0693SGuray Ozen 
3867d0d7acSMichele Scuttari namespace mlir {
3953689fdfSMarkus Böck #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
4067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
4167d0d7acSMichele Scuttari } // namespace mlir
42894a591cSThomas Raoux 
43894a591cSThomas Raoux using namespace mlir;
44894a591cSThomas Raoux 
45b74cfc13SGuray Ozen /// Number of bits that needs to be excluded when building matrix descriptor for
4623882226SGuray Ozen /// wgmma operations.
4723882226SGuray Ozen constexpr int exclude4LSB = 4;
4823882226SGuray Ozen 
49836dbb85SGuray Ozen /// GPU has 32 bit registers, this function truncates values when larger width
50836dbb85SGuray Ozen /// is not needed.
51ee49cda7SGuray Ozen static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
52836dbb85SGuray Ozen   Type type = value.getType();
53836dbb85SGuray Ozen   assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
54836dbb85SGuray Ozen   if (type.getIntOrFloatBitWidth() <= 32)
55836dbb85SGuray Ozen     return value;
56ee49cda7SGuray Ozen   return b.create<LLVM::TruncOp>(b.getI32Type(), value);
57836dbb85SGuray Ozen }
58836dbb85SGuray Ozen 
59894a591cSThomas Raoux /// Returns the type for the intrinsic given the vectorResultType of the
60894a591cSThomas Raoux /// `gpu.mma.sync` operation.
61894a591cSThomas Raoux static Type inferIntrinsicResultType(Type vectorResultType) {
62894a591cSThomas Raoux   MLIRContext *ctx = vectorResultType.getContext();
635550c821STres Popp   auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
64894a591cSThomas Raoux   auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
65894a591cSThomas Raoux   auto i32Ty = IntegerType::get(ctx, 32);
66894a591cSThomas Raoux   auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
67894a591cSThomas Raoux   Type f64Ty = Float64Type::get(ctx);
68894a591cSThomas Raoux   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
6998798073SChristopher Bate   Type f32Ty = Float32Type::get(ctx);
7098798073SChristopher Bate   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
71894a591cSThomas Raoux   if (a.getElementType() == f16x2Ty) {
72894a591cSThomas Raoux     return LLVM::LLVMStructType::getLiteral(
73894a591cSThomas Raoux         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
74894a591cSThomas Raoux   }
75894a591cSThomas Raoux   if (a.getElementType() == i32x2Ty) {
76894a591cSThomas Raoux     return LLVM::LLVMStructType::getLiteral(
77894a591cSThomas Raoux         ctx,
78894a591cSThomas Raoux         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
79894a591cSThomas Raoux   }
80894a591cSThomas Raoux   if (a.getElementType() == f64x2Ty) {
81894a591cSThomas Raoux     return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
82894a591cSThomas Raoux   }
8398798073SChristopher Bate   if (a.getElementType() == f32x2Ty) {
8498798073SChristopher Bate     return LLVM::LLVMStructType::getLiteral(
8598798073SChristopher Bate         ctx,
8698798073SChristopher Bate         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
8798798073SChristopher Bate   }
8898798073SChristopher Bate   if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
8998798073SChristopher Bate     return LLVM::LLVMStructType::getLiteral(
9098798073SChristopher Bate         ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
9198798073SChristopher Bate   }
92894a591cSThomas Raoux   return vectorResultType;
93894a591cSThomas Raoux }
94894a591cSThomas Raoux 
95894a591cSThomas Raoux /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
96894a591cSThomas Raoux /// always an LLVM struct) into a fragment that is compatible with the vector
97894a591cSThomas Raoux /// type of this operation. This involves extracting elements from the struct
98894a591cSThomas Raoux /// and inserting them into an LLVM array. These extra data-movement
99894a591cSThomas Raoux /// operations should be canonicalized away by the LLVM backend.
100894a591cSThomas Raoux static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
101894a591cSThomas Raoux                                     Type resultType, Value intrinsicResult,
102894a591cSThomas Raoux                                     RewriterBase &rewriter) {
103894a591cSThomas Raoux   MLIRContext *ctx = rewriter.getContext();
1045550c821STres Popp   auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
1055550c821STres Popp   auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
106894a591cSThomas Raoux   Type i32Ty = rewriter.getI32Type();
10798798073SChristopher Bate   Type f32Ty = rewriter.getF32Type();
108894a591cSThomas Raoux   Type f64Ty = rewriter.getF64Type();
109894a591cSThomas Raoux   Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
110894a591cSThomas Raoux   Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
111894a591cSThomas Raoux   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
11298798073SChristopher Bate   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
11398798073SChristopher Bate   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
114894a591cSThomas Raoux 
115894a591cSThomas Raoux   auto makeConst = [&](int32_t index) -> Value {
116894a591cSThomas Raoux     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
117894a591cSThomas Raoux                                              rewriter.getI32IntegerAttr(index));
118894a591cSThomas Raoux   };
119894a591cSThomas Raoux 
120894a591cSThomas Raoux   if (arrayType) {
121894a591cSThomas Raoux     SmallVector<Value, 4> elements;
122894a591cSThomas Raoux 
12398798073SChristopher Bate     // The intrinsic returns 32-bit wide elements in a form which can be
12498798073SChristopher Bate     // directly bitcasted and inserted into the result vector.
12598798073SChristopher Bate     if (arrayType.getElementType() == f16x2Ty ||
12698798073SChristopher Bate         arrayType.getElementType() == f32x1Ty) {
127894a591cSThomas Raoux       for (unsigned i = 0; i < structType.getBody().size(); i++) {
1285c5af910SJeff Niu         Value el =
1295c5af910SJeff Niu             rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
13098798073SChristopher Bate         el = rewriter.createOrFold<LLVM::BitcastOp>(
13198798073SChristopher Bate             loc, arrayType.getElementType(), el);
13298798073SChristopher Bate         elements.push_back(el);
133894a591cSThomas Raoux       }
134894a591cSThomas Raoux     }
135894a591cSThomas Raoux 
13698798073SChristopher Bate     // The intrinsic returns i32, f64, and f32 values as individual scalars,
13798798073SChristopher Bate     // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
13898798073SChristopher Bate     // need to extract them from the struct and pack them into the 64-bit wide
13998798073SChristopher Bate     // rows of the vector result.
140894a591cSThomas Raoux     if (arrayType.getElementType() == i32x2Ty ||
14198798073SChristopher Bate         arrayType.getElementType() == f64x2Ty ||
14298798073SChristopher Bate         arrayType.getElementType() == f32x2Ty) {
14398798073SChristopher Bate 
14498798073SChristopher Bate       for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145894a591cSThomas Raoux         Value vec =
146894a591cSThomas Raoux             rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
1475c5af910SJeff Niu         Value x1 =
1485c5af910SJeff Niu             rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
1495c5af910SJeff Niu         Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
1505c5af910SJeff Niu                                                          i * 2 + 1);
151894a591cSThomas Raoux         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
152894a591cSThomas Raoux                                                      x1, makeConst(0));
153894a591cSThomas Raoux         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
154894a591cSThomas Raoux                                                      x2, makeConst(1));
155894a591cSThomas Raoux         elements.push_back(vec);
156894a591cSThomas Raoux       }
15798798073SChristopher Bate     }
158894a591cSThomas Raoux 
159894a591cSThomas Raoux     // Create the final vectorized result.
160894a591cSThomas Raoux     Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
161894a591cSThomas Raoux     for (const auto &el : llvm::enumerate(elements)) {
1625c5af910SJeff Niu       result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
1635c5af910SJeff Niu                                                     el.index());
164894a591cSThomas Raoux     }
165894a591cSThomas Raoux     return result;
166894a591cSThomas Raoux   }
167894a591cSThomas Raoux 
168894a591cSThomas Raoux   return intrinsicResult;
169894a591cSThomas Raoux }
170894a591cSThomas Raoux 
171894a591cSThomas Raoux /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
172894a591cSThomas Raoux /// given as 2D `vectors` where the rows are 32b or 64b wide. The
173894a591cSThomas Raoux /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
174894a591cSThomas Raoux /// scalars of certain types. This function helps unpack the `vector` arguments
175894a591cSThomas Raoux /// and cast them to the types expected by `nvvm.mma.sync`.
176ee49cda7SGuray Ozen static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
177ee49cda7SGuray Ozen                                               Value operand,
17898798073SChristopher Bate                                               NVVM::MMATypes operandPtxType) {
179894a591cSThomas Raoux   SmallVector<Value> result;
180ee49cda7SGuray Ozen   Type i32Ty = b.getI32Type();
181ee49cda7SGuray Ozen   Type f64Ty = b.getF64Type();
182ee49cda7SGuray Ozen   Type f32Ty = b.getF32Type();
183ee49cda7SGuray Ozen   Type i64Ty = b.getI64Type();
184ee49cda7SGuray Ozen   Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
185ee49cda7SGuray Ozen   Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
18698798073SChristopher Bate   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
1875550c821STres Popp   auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
188894a591cSThomas Raoux 
189894a591cSThomas Raoux   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190ee49cda7SGuray Ozen     Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
191894a591cSThomas Raoux 
192894a591cSThomas Raoux     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
193894a591cSThomas Raoux     // scalar types.
19498798073SChristopher Bate     if (arrayTy.getElementType() == i8x4Ty ||
195334f63e7SChristopher Bate         arrayTy.getElementType() == i4x8Ty ||
19698798073SChristopher Bate         (arrayTy.getElementType() == f32x1Ty &&
19798798073SChristopher Bate          operandPtxType == NVVM::MMATypes::tf32)) {
198ee49cda7SGuray Ozen       result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
199894a591cSThomas Raoux       continue;
200894a591cSThomas Raoux     }
201894a591cSThomas Raoux 
20298798073SChristopher Bate     // For some element types (i32, f32, f64), we need to unpack the inner
203894a591cSThomas Raoux     // vector/array type as well because the intrinsic expects individual
204894a591cSThomas Raoux     // scalars to be provided.
2055550c821STres Popp     VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206894a591cSThomas Raoux     if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
20798798073SChristopher Bate                          innerArrayTy.getElementType() == f64Ty ||
20898798073SChristopher Bate                          innerArrayTy.getElementType() == f32Ty)) {
209894a591cSThomas Raoux       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210894a591cSThomas Raoux            idx < innerSize; idx++) {
211ee49cda7SGuray Ozen         result.push_back(b.create<LLVM::ExtractElementOp>(
212ee49cda7SGuray Ozen             toUse,
213ee49cda7SGuray Ozen             b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
214894a591cSThomas Raoux       }
215894a591cSThomas Raoux       continue;
216894a591cSThomas Raoux     }
217894a591cSThomas Raoux     result.push_back(toUse);
218894a591cSThomas Raoux   }
219894a591cSThomas Raoux   return result;
220894a591cSThomas Raoux }
221894a591cSThomas Raoux 
22299475f5bSNicolas Vasilache /// Returns whether mbarrier object has shared memory address space.
22317649a77SGuray Ozen static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
22499475f5bSNicolas Vasilache   return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
22599475f5bSNicolas Vasilache       barrierType.getMemorySpace()));
22699475f5bSNicolas Vasilache }
22799475f5bSNicolas Vasilache 
22899475f5bSNicolas Vasilache /// Returns the memory space attribute of the mbarrier object.
22999475f5bSNicolas Vasilache Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context,
23017649a77SGuray Ozen                                         nvgpu::MBarrierGroupType barrierType) {
23199475f5bSNicolas Vasilache   Attribute memorySpace = {};
23299475f5bSNicolas Vasilache   if (isMbarrierShared(barrierType)) {
23399475f5bSNicolas Vasilache     memorySpace =
23499475f5bSNicolas Vasilache         IntegerAttr::get(IntegerType::get(context, 64),
23599475f5bSNicolas Vasilache                          nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
23699475f5bSNicolas Vasilache   }
23799475f5bSNicolas Vasilache   return memorySpace;
23899475f5bSNicolas Vasilache }
23999475f5bSNicolas Vasilache 
24099475f5bSNicolas Vasilache /// Returns memref type of the mbarrier object. The type is defined in the
24117649a77SGuray Ozen /// MBarrierGroupType.
24299475f5bSNicolas Vasilache MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
24317649a77SGuray Ozen                                         nvgpu::MBarrierGroupType barrierType) {
24499475f5bSNicolas Vasilache   Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
24599475f5bSNicolas Vasilache   MemRefLayoutAttrInterface layout;
24617649a77SGuray Ozen   return MemRefType::get({barrierType.getNumBarriers()},
24717649a77SGuray Ozen                          IntegerType::get(context, 64), layout, memorySpace);
24899475f5bSNicolas Vasilache }
24999475f5bSNicolas Vasilache 
250894a591cSThomas Raoux namespace {
251894a591cSThomas Raoux 
252894a591cSThomas Raoux struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
253894a591cSThomas Raoux   using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
254894a591cSThomas Raoux 
255894a591cSThomas Raoux   LogicalResult
256894a591cSThomas Raoux   matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
257894a591cSThomas Raoux                   ConversionPatternRewriter &rewriter) const override {
258894a591cSThomas Raoux     MLIRContext *ctx = getContext();
259ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
260894a591cSThomas Raoux 
261894a591cSThomas Raoux     // The result type of ldmatrix will always be a struct of 32bit integer
262894a591cSThomas Raoux     // registers if more than one 32bit value is returned. Otherwise, the result
263894a591cSThomas Raoux     // is a single i32. The result type of the GPU operation is always a vector
264894a591cSThomas Raoux     // of shape (NumRegisters, VectorRegister) where VectorRegister is the
265894a591cSThomas Raoux     // vector type of the result and always 32 bits long. We bitcast the result
266894a591cSThomas Raoux     // of the NVVM::LdMatrix to this vector type.
2675550c821STres Popp     auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268894a591cSThomas Raoux     if (!vectorResultType) {
269894a591cSThomas Raoux       return failure();
270894a591cSThomas Raoux     }
271894a591cSThomas Raoux     Type innerVectorType = LLVM::getFixedVectorType(
272894a591cSThomas Raoux         vectorResultType.getElementType(), vectorResultType.getDimSize(1));
273894a591cSThomas Raoux 
274894a591cSThomas Raoux     int64_t num32BitRegs = vectorResultType.getDimSize(0);
275894a591cSThomas Raoux 
276894a591cSThomas Raoux     Type ldMatrixResultType;
277894a591cSThomas Raoux     if (num32BitRegs > 1) {
278894a591cSThomas Raoux       ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
279894a591cSThomas Raoux           ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
280894a591cSThomas Raoux     } else {
281894a591cSThomas Raoux       ldMatrixResultType = rewriter.getI32Type();
282894a591cSThomas Raoux     }
283894a591cSThomas Raoux 
2845550c821STres Popp     auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
2858df54a6aSJacques Pienaar     Value srcPtr =
286ee49cda7SGuray Ozen         getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
2878df54a6aSJacques Pienaar                              adaptor.getIndices(), rewriter);
288ee49cda7SGuray Ozen     Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289ee49cda7SGuray Ozen         ldMatrixResultType, srcPtr,
2908df54a6aSJacques Pienaar         /*num=*/op.getNumTiles(),
2918df54a6aSJacques Pienaar         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
292894a591cSThomas Raoux                                      : NVVM::MMALayout::row);
293894a591cSThomas Raoux 
294894a591cSThomas Raoux     // The ldmatrix operation returns either a single i32 value or a struct of
295894a591cSThomas Raoux     // i32 values. Here we unpack those values and cast them back to their
296894a591cSThomas Raoux     // actual vector type (still of width 32b) and repack them into a result
297894a591cSThomas Raoux     // struct.
298894a591cSThomas Raoux     Type finalResultType = typeConverter->convertType(vectorResultType);
299ee49cda7SGuray Ozen     Value result = b.create<LLVM::UndefOp>(finalResultType);
300894a591cSThomas Raoux     for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
3015c5af910SJeff Niu       Value i32Register =
302ee49cda7SGuray Ozen           num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303894a591cSThomas Raoux                            : ldMatrixResult;
304ee49cda7SGuray Ozen       Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305ee49cda7SGuray Ozen       result = b.create<LLVM::InsertValueOp>(result, casted, i);
306894a591cSThomas Raoux     }
307894a591cSThomas Raoux 
308894a591cSThomas Raoux     rewriter.replaceOp(op, result);
309894a591cSThomas Raoux     return success();
310894a591cSThomas Raoux   }
311894a591cSThomas Raoux };
312894a591cSThomas Raoux 
313708185f0SChristopher Bate /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314708185f0SChristopher Bate /// enum).
315708185f0SChristopher Bate static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316708185f0SChristopher Bate   Type elType = getElementTypeOrSelf(t);
317708185f0SChristopher Bate   if (elType.isInteger(8))
318708185f0SChristopher Bate     return NVVM::MMATypes::s8;
319708185f0SChristopher Bate   if (elType.isInteger(4))
320708185f0SChristopher Bate     return NVVM::MMATypes::s4;
321708185f0SChristopher Bate   if (elType.isF16())
322708185f0SChristopher Bate     return NVVM::MMATypes::f16;
323708185f0SChristopher Bate   if (elType.isF64())
324708185f0SChristopher Bate     return NVVM::MMATypes::f64;
325708185f0SChristopher Bate   if (elType.isF32())
326708185f0SChristopher Bate     return NVVM::MMATypes::tf32;
327708185f0SChristopher Bate   return failure();
328708185f0SChristopher Bate }
329708185f0SChristopher Bate 
330894a591cSThomas Raoux struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331894a591cSThomas Raoux   using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
332894a591cSThomas Raoux 
333894a591cSThomas Raoux   LogicalResult
334894a591cSThomas Raoux   matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335894a591cSThomas Raoux                   ConversionPatternRewriter &rewriter) const override {
336ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337894a591cSThomas Raoux     // Get the shapes of the MMAMatrix type being used. The shapes will
338894a591cSThomas Raoux     // choose which intrinsic this op will be lowered to.
339708185f0SChristopher Bate     VectorType aType = op.getMatrixA().getType();
340708185f0SChristopher Bate     VectorType bType = op.getMatrixA().getType();
341708185f0SChristopher Bate     VectorType cType = op.getMatrixC().getType();
342894a591cSThomas Raoux 
343708185f0SChristopher Bate     std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
34414d79afeSManish Gupta 
34514d79afeSManish Gupta     // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
34614d79afeSManish Gupta     bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
34714d79afeSManish Gupta     if (aType.getElementType().isF32() && !tf32Enabled)
34814d79afeSManish Gupta       return failure();
34998798073SChristopher Bate 
350708185f0SChristopher Bate     FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351708185f0SChristopher Bate     if (failed(ptxTypeA))
352708185f0SChristopher Bate       return op->emitOpError("failed to deduce operand PTX types");
353708185f0SChristopher Bate     FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354708185f0SChristopher Bate     if (failed(ptxTypeB))
355708185f0SChristopher Bate       return op->emitOpError("failed to deduce operand PTX types");
35622426110SRamkumar Ramachandra     std::optional<NVVM::MMATypes> ptxTypeC =
35722426110SRamkumar Ramachandra         NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
35822426110SRamkumar Ramachandra                                          /*isAccumulator=*/true);
359708185f0SChristopher Bate     if (!ptxTypeC)
360708185f0SChristopher Bate       return op->emitError(
361708185f0SChristopher Bate           "could not infer the PTX type for the accumulator/result");
362708185f0SChristopher Bate 
363708185f0SChristopher Bate     // TODO: add an attribute to the op to customize this behavior.
36422426110SRamkumar Ramachandra     std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
3655550c821STres Popp     if (isa<IntegerType>(aType.getElementType()))
366894a591cSThomas Raoux       overflow = NVVM::MMAIntOverflow::satfinite;
367894a591cSThomas Raoux 
36898798073SChristopher Bate     SmallVector<Value> matA =
369ee49cda7SGuray Ozen         unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
37098798073SChristopher Bate     SmallVector<Value> matB =
371ee49cda7SGuray Ozen         unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
37298798073SChristopher Bate     SmallVector<Value> matC =
373ee49cda7SGuray Ozen         unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
37498798073SChristopher Bate 
375894a591cSThomas Raoux     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376894a591cSThomas Raoux     Type intrinsicResTy = inferIntrinsicResultType(
377894a591cSThomas Raoux         typeConverter->convertType(op->getResultTypes()[0]));
378ee49cda7SGuray Ozen     Value intrinsicResult = b.create<NVVM::MmaOp>(
379ee49cda7SGuray Ozen         intrinsicResTy, matA, matB, matC,
380894a591cSThomas Raoux         /*shape=*/gemmShape,
3811a36588eSKazu Hirata         /*b1Op=*/std::nullopt,
382894a591cSThomas Raoux         /*intOverflow=*/overflow,
383894a591cSThomas Raoux         /*multiplicandPtxTypes=*/
384708185f0SChristopher Bate         std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385894a591cSThomas Raoux         /*multiplicandLayouts=*/
386894a591cSThomas Raoux         std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387894a591cSThomas Raoux                                        NVVM::MMALayout::col});
388894a591cSThomas Raoux     rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389894a591cSThomas Raoux                                                   desiredRetTy, intrinsicResult,
390894a591cSThomas Raoux                                                   rewriter));
391894a591cSThomas Raoux     return success();
392894a591cSThomas Raoux   }
393894a591cSThomas Raoux };
394894a591cSThomas Raoux 
395894a591cSThomas Raoux struct ConvertNVGPUToNVVMPass
39653689fdfSMarkus Böck     : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
39753689fdfSMarkus Böck   using Base::Base;
398894a591cSThomas Raoux 
399affcfccdSGuray Ozen   void getDependentDialects(DialectRegistry &registry) const override {
400d20fbc90SGuray Ozen     registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401d20fbc90SGuray Ozen                     arith::ArithDialect>();
402affcfccdSGuray Ozen   }
403affcfccdSGuray Ozen 
404894a591cSThomas Raoux   void runOnOperation() override {
40553689fdfSMarkus Böck     LowerToLLVMOptions options(&getContext());
406894a591cSThomas Raoux     RewritePatternSet patterns(&getContext());
40753689fdfSMarkus Böck     LLVMTypeConverter converter(&getContext(), options);
408affcfccdSGuray Ozen     IRRewriter rewriter(&getContext());
4093a03da37SGuray Ozen     populateGpuMemorySpaceAttributeConversions(
4103a03da37SGuray Ozen         converter, [](gpu::AddressSpace space) -> unsigned {
4113a03da37SGuray Ozen           switch (space) {
4123a03da37SGuray Ozen           case gpu::AddressSpace::Global:
4133a03da37SGuray Ozen             return static_cast<unsigned>(
4143a03da37SGuray Ozen                 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
4153a03da37SGuray Ozen           case gpu::AddressSpace::Workgroup:
4163a03da37SGuray Ozen             return static_cast<unsigned>(
4173a03da37SGuray Ozen                 NVVM::NVVMMemorySpace::kSharedMemorySpace);
4183a03da37SGuray Ozen           case gpu::AddressSpace::Private:
4193a03da37SGuray Ozen             return 0;
4203a03da37SGuray Ozen           }
4213a03da37SGuray Ozen           llvm_unreachable("unknown address space enum value");
4223a03da37SGuray Ozen           return 0;
4233a03da37SGuray Ozen         });
424affcfccdSGuray Ozen     /// device-side async tokens cannot be materialized in nvvm. We just
425affcfccdSGuray Ozen     /// convert them to a dummy i32 type in order to easily drop them during
426affcfccdSGuray Ozen     /// conversion.
42715bcc36eSThomas Raoux     converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
42815bcc36eSThomas Raoux       return converter.convertType(IntegerType::get(type.getContext(), 32));
42915bcc36eSThomas Raoux     });
43023882226SGuray Ozen     converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
43152db7e27SGuray Ozen       Type elemType = type.getFragmented().getElementType();
43252db7e27SGuray Ozen       int64_t sizeM = type.getFragmented().getDimSize(0);
43352db7e27SGuray Ozen       int64_t sizeN = type.getFragmented().getDimSize(1);
43452db7e27SGuray Ozen 
43552db7e27SGuray Ozen       unsigned numMembers;
43652db7e27SGuray Ozen       if (elemType.isF32() || elemType.isInteger(32))
43752db7e27SGuray Ozen         numMembers = sizeN / 2;
43852db7e27SGuray Ozen       else if (elemType.isF16())
43952db7e27SGuray Ozen         numMembers = sizeN / 4;
44052db7e27SGuray Ozen       else
44152db7e27SGuray Ozen         llvm_unreachable("unsupported type for warpgroup accumulator");
44252db7e27SGuray Ozen 
44352db7e27SGuray Ozen       SmallVector<Type> innerStructBody;
44452db7e27SGuray Ozen       for (unsigned i = 0; i < numMembers; i++)
44552db7e27SGuray Ozen         innerStructBody.push_back(elemType);
44652db7e27SGuray Ozen       auto innerStructType =
44752db7e27SGuray Ozen           LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
44852db7e27SGuray Ozen 
44923882226SGuray Ozen       SmallVector<Type> structBody;
45052db7e27SGuray Ozen       for (int i = 0; i < sizeM; i += kWgmmaSizeM)
45152db7e27SGuray Ozen         structBody.push_back(innerStructType);
45252db7e27SGuray Ozen 
45323882226SGuray Ozen       auto convertedType =
45423882226SGuray Ozen           LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
45523882226SGuray Ozen       return converter.convertType(convertedType);
45623882226SGuray Ozen     });
457affcfccdSGuray Ozen     converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
458affcfccdSGuray Ozen       return converter.convertType(IntegerType::get(type.getContext(), 64));
459affcfccdSGuray Ozen     });
46050ab427aSGuray Ozen     converter.addConversion(
46150ab427aSGuray Ozen         [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
46250ab427aSGuray Ozen           return converter.convertType(IntegerType::get(type.getContext(), 64));
46350ab427aSGuray Ozen         });
46417649a77SGuray Ozen     converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
46599475f5bSNicolas Vasilache       return converter.convertType(
46699475f5bSNicolas Vasilache           nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
467affcfccdSGuray Ozen     });
46870c2e061SGuray Ozen     converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
4692f17c9f6SChristian Ulmann       return LLVM::LLVMPointerType::get(type.getContext());
47070c2e061SGuray Ozen     });
471894a591cSThomas Raoux     populateNVGPUToNVVMConversionPatterns(converter, patterns);
472894a591cSThomas Raoux     LLVMConversionTarget target(getContext());
473894a591cSThomas Raoux     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474d20fbc90SGuray Ozen     target.addLegalDialect<::mlir::arith::ArithDialect>();
475affcfccdSGuray Ozen     target.addLegalDialect<::mlir::memref::MemRefDialect>();
476894a591cSThomas Raoux     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
47723882226SGuray Ozen     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
47823882226SGuray Ozen         converter, patterns, target);
479894a591cSThomas Raoux     if (failed(applyPartialConversion(getOperation(), target,
480894a591cSThomas Raoux                                       std::move(patterns))))
481894a591cSThomas Raoux       signalPassFailure();
482894a591cSThomas Raoux   }
483894a591cSThomas Raoux };
484894a591cSThomas Raoux 
485708185f0SChristopher Bate /// Returns the constraints for the sparse MMA inline assembly instruction.
486708185f0SChristopher Bate static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
487708185f0SChristopher Bate                                                      unsigned matBSize,
488708185f0SChristopher Bate                                                      unsigned matCSize) {
489708185f0SChristopher Bate   std::string str;
490708185f0SChristopher Bate   llvm::raw_string_ostream ss(str);
491708185f0SChristopher Bate   for (unsigned i = 0; i < matCSize; i++)
492708185f0SChristopher Bate     ss << "=r,";
493708185f0SChristopher Bate   for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
494708185f0SChristopher Bate     ss << "r,";
4954e4af133SAart Bik   // The final operand is for the sparsity metadata.
4964e4af133SAart Bik   // The sparsity selector appears as direct literal.
4974e4af133SAart Bik   ss << "r";
498708185f0SChristopher Bate   return str;
499708185f0SChristopher Bate }
500708185f0SChristopher Bate 
501708185f0SChristopher Bate /// Returns the string for the `mma.sp.sync` instruction that corresponds to
5024e4af133SAart Bik /// the given parameters. Note that this function doesn't do any validation,
503708185f0SChristopher Bate /// it's expected that the provided parameters correspond to a valid
504708185f0SChristopher Bate /// instruction.
5054e4af133SAart Bik static std::string buildMmaSparseAsmString(
5064e4af133SAart Bik     const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
5074e4af133SAart Bik     unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508708185f0SChristopher Bate     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
5094e4af133SAart Bik     std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
510708185f0SChristopher Bate   auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511708185f0SChristopher Bate     return NVVM::stringifyMMATypes(ptxType);
512708185f0SChristopher Bate   };
513708185f0SChristopher Bate 
514708185f0SChristopher Bate   std::string asmStr;
515708185f0SChristopher Bate   llvm::raw_string_ostream ss(asmStr);
516708185f0SChristopher Bate   ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
517708185f0SChristopher Bate      << shape[2] << ".row.col.";
518708185f0SChristopher Bate 
519708185f0SChristopher Bate   if (overflow)
520708185f0SChristopher Bate     ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
521708185f0SChristopher Bate 
522708185f0SChristopher Bate   ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
523708185f0SChristopher Bate      << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
524708185f0SChristopher Bate   unsigned asmArgIdx = 0;
525708185f0SChristopher Bate 
526708185f0SChristopher Bate   // The operand string is structured into sections `{matC elements...},
527708185f0SChristopher Bate   // {matA elements...}, {matB elements...}, {matC elements}`.
528708185f0SChristopher Bate   for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529708185f0SChristopher Bate     ss << "{";
530708185f0SChristopher Bate     for (unsigned i = 0; i < arrSize; i++)
531708185f0SChristopher Bate       ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
532708185f0SChristopher Bate     ss << "},";
533708185f0SChristopher Bate   }
534b0bbc9b5Srkayaith   ss << "$" << asmArgIdx++ << ",";
5354e4af133SAart Bik   assert(metaDataSelector <= 1);
5364e4af133SAart Bik   ss << "0x" << metaDataSelector << ";";
537708185f0SChristopher Bate   return asmStr;
538708185f0SChristopher Bate }
539708185f0SChristopher Bate 
540708185f0SChristopher Bate /// Builds an inline assembly operation corresponding to the specified MMA
541708185f0SChristopher Bate /// sparse sync operation.
542708185f0SChristopher Bate static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
543ee49cda7SGuray Ozen     ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
544708185f0SChristopher Bate     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
54522426110SRamkumar Ramachandra     std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
546708185f0SChristopher Bate     ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
547708185f0SChristopher Bate     int64_t metadataSelector, const std::array<int64_t, 3> &shape,
548ee49cda7SGuray Ozen     Type intrinsicResultType) {
549ee49cda7SGuray Ozen   auto asmDialectAttr =
550ee49cda7SGuray Ozen       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
551708185f0SChristopher Bate 
5524e4af133SAart Bik   const unsigned matASize = unpackedAData.size();
5534e4af133SAart Bik   const unsigned matBSize = unpackedB.size();
5544e4af133SAart Bik   const unsigned matCSize = unpackedC.size();
555708185f0SChristopher Bate 
5564e4af133SAart Bik   std::string asmStr = buildMmaSparseAsmString(
5574e4af133SAart Bik       shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
5584e4af133SAart Bik       ptxTypeD, overflow, metadataSelector);
5594e4af133SAart Bik   std::string constraintStr =
5604e4af133SAart Bik       buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
561708185f0SChristopher Bate 
562708185f0SChristopher Bate   SmallVector<Value> asmVals;
5634e4af133SAart Bik   asmVals.reserve(matASize + matBSize + matCSize + 1);
564708185f0SChristopher Bate   for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
565708185f0SChristopher Bate     llvm::append_range(asmVals, args);
566708185f0SChristopher Bate   asmVals.push_back(indexData);
567708185f0SChristopher Bate 
568ee49cda7SGuray Ozen   return b.create<LLVM::InlineAsmOp>(
569708185f0SChristopher Bate       /*resultTypes=*/intrinsicResultType,
570708185f0SChristopher Bate       /*operands=*/asmVals,
571708185f0SChristopher Bate       /*asm_string=*/asmStr,
572708185f0SChristopher Bate       /*constraints=*/constraintStr,
573708185f0SChristopher Bate       /*has_side_effects=*/true,
574708185f0SChristopher Bate       /*is_align_stack=*/false,
575708185f0SChristopher Bate       /*asm_dialect=*/asmDialectAttr,
576708185f0SChristopher Bate       /*operand_attrs=*/ArrayAttr());
577708185f0SChristopher Bate }
578708185f0SChristopher Bate 
579708185f0SChristopher Bate /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
580708185f0SChristopher Bate struct NVGPUMmaSparseSyncLowering
581708185f0SChristopher Bate     : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
582708185f0SChristopher Bate   using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
583708185f0SChristopher Bate 
584708185f0SChristopher Bate   LogicalResult
585708185f0SChristopher Bate   matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
586708185f0SChristopher Bate                   ConversionPatternRewriter &rewriter) const override {
587ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
588708185f0SChristopher Bate     // Get the shapes of the MMAMatrix type being used. The shapes will
589708185f0SChristopher Bate     // choose which intrinsic this op will be lowered to.
590708185f0SChristopher Bate     VectorType aType = op.getMatrixA().getType();
591708185f0SChristopher Bate     VectorType bType = op.getMatrixB().getType();
592708185f0SChristopher Bate     VectorType cType = op.getMatrixC().getType();
593708185f0SChristopher Bate 
594708185f0SChristopher Bate     FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595708185f0SChristopher Bate     if (failed(ptxTypeA))
596708185f0SChristopher Bate       return op->emitOpError("failed to deduce operand PTX types");
597708185f0SChristopher Bate     FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598708185f0SChristopher Bate     if (failed(ptxTypeB))
599708185f0SChristopher Bate       return op->emitOpError("failed to deduce operand PTX types");
60022426110SRamkumar Ramachandra     std::optional<NVVM::MMATypes> ptxTypeC =
60122426110SRamkumar Ramachandra         NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
60222426110SRamkumar Ramachandra                                          /*isAccumulator=*/true);
603708185f0SChristopher Bate     if (!ptxTypeC)
604708185f0SChristopher Bate       return op->emitError(
605708185f0SChristopher Bate           "could not infer the PTX type for the accumulator/result");
606708185f0SChristopher Bate 
607708185f0SChristopher Bate     // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
608708185f0SChristopher Bate     bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609708185f0SChristopher Bate     if (aType.getElementType().isF32() && !tf32Enabled)
610708185f0SChristopher Bate       return failure();
611708185f0SChristopher Bate 
612708185f0SChristopher Bate     // TODO: add an attribute to the op to customize this behavior.
61322426110SRamkumar Ramachandra     std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
6145550c821STres Popp     if (isa<IntegerType>(aType.getElementType()))
615708185f0SChristopher Bate       overflow = NVVM::MMAIntOverflow::satfinite;
616708185f0SChristopher Bate 
617708185f0SChristopher Bate     SmallVector<Value> matA =
618ee49cda7SGuray Ozen         unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
619708185f0SChristopher Bate     SmallVector<Value> matB =
620ee49cda7SGuray Ozen         unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
621708185f0SChristopher Bate     SmallVector<Value> matC =
622ee49cda7SGuray Ozen         unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
623708185f0SChristopher Bate 
624708185f0SChristopher Bate     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
625708185f0SChristopher Bate     Type intrinsicResTy = inferIntrinsicResultType(
626708185f0SChristopher Bate         typeConverter->convertType(op->getResultTypes()[0]));
627708185f0SChristopher Bate 
628708185f0SChristopher Bate     // Bitcast the sparse metadata from vector<2xf16> to an i32.
629708185f0SChristopher Bate     Value sparseMetadata = adaptor.getSparseMetadata();
630708185f0SChristopher Bate     if (sparseMetadata.getType() !=
631708185f0SChristopher Bate         LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
632708185f0SChristopher Bate       return op->emitOpError() << "Expected metadata type to be LLVM "
633708185f0SChristopher Bate                                   "VectorType of 2 i16 elements";
634ee49cda7SGuray Ozen     sparseMetadata =
635ee49cda7SGuray Ozen         b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
636708185f0SChristopher Bate 
637708185f0SChristopher Bate     FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
638ee49cda7SGuray Ozen         b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
639708185f0SChristopher Bate         matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
640ee49cda7SGuray Ozen         intrinsicResTy);
641708185f0SChristopher Bate     if (failed(intrinsicResult))
642708185f0SChristopher Bate       return failure();
643708185f0SChristopher Bate 
644708185f0SChristopher Bate     assert((*intrinsicResult).getNumResults() == 1 &&
645708185f0SChristopher Bate            "expected inline asm op returns a single LLVM struct type");
646708185f0SChristopher Bate     rewriter.replaceOp(
647708185f0SChristopher Bate         op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
648708185f0SChristopher Bate                                    (*intrinsicResult)->getResult(0), rewriter));
649708185f0SChristopher Bate     return success();
650708185f0SChristopher Bate   }
651708185f0SChristopher Bate };
652708185f0SChristopher Bate 
65315bcc36eSThomas Raoux struct NVGPUAsyncCopyLowering
65415bcc36eSThomas Raoux     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
65515bcc36eSThomas Raoux   using ConvertOpToLLVMPattern<
65615bcc36eSThomas Raoux       nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
65715bcc36eSThomas Raoux 
65815bcc36eSThomas Raoux   LogicalResult
65915bcc36eSThomas Raoux   matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
66015bcc36eSThomas Raoux                   ConversionPatternRewriter &rewriter) const override {
661ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
662ee49cda7SGuray Ozen     Location loc = op.getLoc();
6635550c821STres Popp     auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
664ee49cda7SGuray Ozen     Value dstPtr =
665ee49cda7SGuray Ozen         getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
6668df54a6aSJacques Pienaar                              adaptor.getDstIndices(), rewriter);
667499abb24SKrzysztof Drewniak     FailureOr<unsigned> dstAddressSpace =
668499abb24SKrzysztof Drewniak         getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
669499abb24SKrzysztof Drewniak     if (failed(dstAddressSpace))
670499abb24SKrzysztof Drewniak       return rewriter.notifyMatchFailure(
671499abb24SKrzysztof Drewniak           loc, "destination memref address space not convertible to integer");
67215bcc36eSThomas Raoux 
6735550c821STres Popp     auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
674499abb24SKrzysztof Drewniak     FailureOr<unsigned> srcAddressSpace =
675499abb24SKrzysztof Drewniak         getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
676499abb24SKrzysztof Drewniak     if (failed(srcAddressSpace))
677499abb24SKrzysztof Drewniak       return rewriter.notifyMatchFailure(
678499abb24SKrzysztof Drewniak           loc, "source memref address space not convertible to integer");
67915bcc36eSThomas Raoux 
6808df54a6aSJacques Pienaar     Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
6818df54a6aSJacques Pienaar                                         adaptor.getSrcIndices(), rewriter);
68215bcc36eSThomas Raoux     // Intrinsics takes a global pointer so we need an address space cast.
6832f17c9f6SChristian Ulmann     auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
6842f17c9f6SChristian Ulmann         op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
685ee49cda7SGuray Ozen     scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
686fbf69f95SManish Gupta     int64_t dstElements = adaptor.getDstElements().getZExtValue();
68715bcc36eSThomas Raoux     int64_t sizeInBytes =
688fbf69f95SManish Gupta         (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
689fbf69f95SManish Gupta     // When the optional SrcElements argument is *not* present, the regular
690fbf69f95SManish Gupta     // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
6912c573967SGuray Ozen     // memory) to fill DstElements number of elements in the destination
6922c573967SGuray Ozen     // (shared memory).
6932c573967SGuray Ozen     Value srcBytes = adaptor.getSrcElements();
6942c573967SGuray Ozen     if (srcBytes) {
6952c573967SGuray Ozen       // When the optional SrcElements argument is present, the source (global
6962c573967SGuray Ozen       // memory) of CpAsyncOp is read only for SrcElements number of elements.
6972c573967SGuray Ozen       // The rest of the DstElements in the destination (shared memory) are
6982c573967SGuray Ozen       // filled with zeros.
699ee49cda7SGuray Ozen       Value c3I32 =
700ee49cda7SGuray Ozen           b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
701ee49cda7SGuray Ozen       Value bitwidth = b.create<LLVM::ConstantOp>(
702ee49cda7SGuray Ozen           b.getI32Type(),
703ee49cda7SGuray Ozen           b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
704ee49cda7SGuray Ozen       Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
705ee49cda7SGuray Ozen       srcBytes = b.create<LLVM::LShrOp>(
706ee49cda7SGuray Ozen           b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
7072c573967SGuray Ozen     }
7082c573967SGuray Ozen     // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
7092c573967SGuray Ozen     // 16 dst bytes.
7102c573967SGuray Ozen     NVVM::LoadCacheModifierKind cacheModifier =
7112c573967SGuray Ozen         (op.getBypassL1().value_or(false) && sizeInBytes == 16)
7122c573967SGuray Ozen             ? NVVM::LoadCacheModifierKind::CG
7132c573967SGuray Ozen             : NVVM::LoadCacheModifierKind::CA;
7142c573967SGuray Ozen 
715ee49cda7SGuray Ozen     b.create<NVVM::CpAsyncOp>(
716ee49cda7SGuray Ozen         dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
7172c573967SGuray Ozen         NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
7182c573967SGuray Ozen         srcBytes);
71915bcc36eSThomas Raoux 
72015bcc36eSThomas Raoux     // Drop the result token.
721ee49cda7SGuray Ozen     Value zero = b.create<LLVM::ConstantOp>(
722ee49cda7SGuray Ozen         IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
72315bcc36eSThomas Raoux     rewriter.replaceOp(op, zero);
72415bcc36eSThomas Raoux     return success();
72515bcc36eSThomas Raoux   }
72615bcc36eSThomas Raoux };
72715bcc36eSThomas Raoux 
72815bcc36eSThomas Raoux struct NVGPUAsyncCreateGroupLowering
72915bcc36eSThomas Raoux     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
73015bcc36eSThomas Raoux   using ConvertOpToLLVMPattern<
73115bcc36eSThomas Raoux       nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
73215bcc36eSThomas Raoux 
73315bcc36eSThomas Raoux   LogicalResult
73415bcc36eSThomas Raoux   matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
73515bcc36eSThomas Raoux                   ConversionPatternRewriter &rewriter) const override {
73615bcc36eSThomas Raoux     rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
73715bcc36eSThomas Raoux     // Drop the result token.
73815bcc36eSThomas Raoux     Value zero = rewriter.create<LLVM::ConstantOp>(
73915bcc36eSThomas Raoux         op->getLoc(), IntegerType::get(op.getContext(), 32),
74015bcc36eSThomas Raoux         rewriter.getI32IntegerAttr(0));
74115bcc36eSThomas Raoux     rewriter.replaceOp(op, zero);
74215bcc36eSThomas Raoux     return success();
74315bcc36eSThomas Raoux   }
74415bcc36eSThomas Raoux };
74515bcc36eSThomas Raoux 
74615bcc36eSThomas Raoux struct NVGPUAsyncWaitLowering
74715bcc36eSThomas Raoux     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
74815bcc36eSThomas Raoux   using ConvertOpToLLVMPattern<
74915bcc36eSThomas Raoux       nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
75015bcc36eSThomas Raoux 
75115bcc36eSThomas Raoux   LogicalResult
75215bcc36eSThomas Raoux   matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
75315bcc36eSThomas Raoux                   ConversionPatternRewriter &rewriter) const override {
75415bcc36eSThomas Raoux     // If numGroup is not present pick 0 as a conservative correct value.
7552789c4f5SKazu Hirata     int32_t numGroups = adaptor.getNumGroups().value_or(0);
75615bcc36eSThomas Raoux     rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
75715bcc36eSThomas Raoux     rewriter.eraseOp(op);
75815bcc36eSThomas Raoux     return success();
75915bcc36eSThomas Raoux   }
76015bcc36eSThomas Raoux };
76115bcc36eSThomas Raoux 
762affcfccdSGuray Ozen /// Creates mbarrier object in shared memory
763affcfccdSGuray Ozen struct NVGPUMBarrierCreateLowering
764affcfccdSGuray Ozen     : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
765affcfccdSGuray Ozen   using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
766affcfccdSGuray Ozen 
767affcfccdSGuray Ozen   template <typename moduleT>
768affcfccdSGuray Ozen   memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
769affcfccdSGuray Ozen                                          Operation *funcOp, moduleT moduleOp,
770affcfccdSGuray Ozen                                          MemRefType barrierType) const {
771affcfccdSGuray Ozen     SymbolTable symbolTable(moduleOp);
772affcfccdSGuray Ozen     OpBuilder::InsertionGuard guard(rewriter);
773affcfccdSGuray Ozen     rewriter.setInsertionPoint(&moduleOp.front());
774affcfccdSGuray Ozen     auto global = rewriter.create<memref::GlobalOp>(
775affcfccdSGuray Ozen         funcOp->getLoc(), "__mbarrier",
776affcfccdSGuray Ozen         /*sym_visibility=*/rewriter.getStringAttr("private"),
777affcfccdSGuray Ozen         /*type=*/barrierType,
778affcfccdSGuray Ozen         /*initial_value=*/ElementsAttr(),
779affcfccdSGuray Ozen         /*constant=*/false,
780affcfccdSGuray Ozen         /*alignment=*/rewriter.getI64IntegerAttr(8));
781affcfccdSGuray Ozen     symbolTable.insert(global);
782affcfccdSGuray Ozen     return global;
783affcfccdSGuray Ozen   }
784affcfccdSGuray Ozen 
785affcfccdSGuray Ozen   LogicalResult
786affcfccdSGuray Ozen   matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
787affcfccdSGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
788affcfccdSGuray Ozen     Operation *funcOp = op->getParentOp();
78999475f5bSNicolas Vasilache     MemRefType barrierType = nvgpu::getMBarrierMemrefType(
79017649a77SGuray Ozen         rewriter.getContext(), op.getBarriers().getType());
791affcfccdSGuray Ozen 
792affcfccdSGuray Ozen     memref::GlobalOp global;
7939dad32cbSGuray Ozen     if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
794affcfccdSGuray Ozen       global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
7959dad32cbSGuray Ozen     else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
796affcfccdSGuray Ozen       global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
797affcfccdSGuray Ozen 
798affcfccdSGuray Ozen     rewriter.setInsertionPoint(op);
799affcfccdSGuray Ozen     rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
800affcfccdSGuray Ozen                                                      global.getName());
801affcfccdSGuray Ozen     return success();
802affcfccdSGuray Ozen   }
803affcfccdSGuray Ozen };
804affcfccdSGuray Ozen 
80517649a77SGuray Ozen /// Base class for lowering mbarrier operations to nvvm intrinsics.
80617649a77SGuray Ozen template <typename SourceOp>
80717649a77SGuray Ozen struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
80817649a77SGuray Ozen public:
80917649a77SGuray Ozen   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
81017649a77SGuray Ozen   /// Returns the base pointer of the mbarrier object.
811ee49cda7SGuray Ozen   Value getMbarrierPtr(ImplicitLocOpBuilder &b,
812ee49cda7SGuray Ozen                        nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
813ee49cda7SGuray Ozen                        Value mbarId,
81417649a77SGuray Ozen                        ConversionPatternRewriter &rewriter) const {
81517649a77SGuray Ozen     MemRefType mbarrierMemrefType =
81617649a77SGuray Ozen         nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
81717649a77SGuray Ozen     return ConvertToLLVMPattern::getStridedElementPtr(
818ee49cda7SGuray Ozen         b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
81917649a77SGuray Ozen   }
82017649a77SGuray Ozen };
82117649a77SGuray Ozen 
822affcfccdSGuray Ozen /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
823affcfccdSGuray Ozen struct NVGPUMBarrierInitLowering
82417649a77SGuray Ozen     : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
82517649a77SGuray Ozen   using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
826affcfccdSGuray Ozen 
827affcfccdSGuray Ozen   LogicalResult
828affcfccdSGuray Ozen   matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
829affcfccdSGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
830ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
83117649a77SGuray Ozen     nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
832affcfccdSGuray Ozen     rewriter.setInsertionPoint(op);
833ee49cda7SGuray Ozen     Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
83417649a77SGuray Ozen                                    adaptor.getMbarId(), rewriter);
835ee49cda7SGuray Ozen     Value count = truncToI32(b, adaptor.getCount());
83617649a77SGuray Ozen     if (isMbarrierShared(mbarrierType)) {
837192d3320SGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
838192d3320SGuray Ozen           op, barrier, count, adaptor.getPredicate());
839affcfccdSGuray Ozen     } else {
84063389326SGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
841192d3320SGuray Ozen                                                         adaptor.getPredicate());
842affcfccdSGuray Ozen     }
843affcfccdSGuray Ozen     return success();
844affcfccdSGuray Ozen   }
845affcfccdSGuray Ozen };
846affcfccdSGuray Ozen 
847affcfccdSGuray Ozen /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
848affcfccdSGuray Ozen struct NVGPUMBarrierArriveLowering
84917649a77SGuray Ozen     : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
85017649a77SGuray Ozen   using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
851affcfccdSGuray Ozen   LogicalResult
852affcfccdSGuray Ozen   matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
853affcfccdSGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
854ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
85517649a77SGuray Ozen     Value barrier =
856ee49cda7SGuray Ozen         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
85717649a77SGuray Ozen                        adaptor.getMbarId(), rewriter);
858affcfccdSGuray Ozen     Type tokenType = getTypeConverter()->convertType(
859affcfccdSGuray Ozen         nvgpu::MBarrierTokenType::get(op->getContext()));
86017649a77SGuray Ozen     if (isMbarrierShared(op.getBarriers().getType())) {
861affcfccdSGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
862affcfccdSGuray Ozen                                                                 barrier);
863affcfccdSGuray Ozen     } else {
864affcfccdSGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
865affcfccdSGuray Ozen                                                           barrier);
866affcfccdSGuray Ozen     }
867affcfccdSGuray Ozen     return success();
868affcfccdSGuray Ozen   }
869affcfccdSGuray Ozen };
870affcfccdSGuray Ozen 
871affcfccdSGuray Ozen /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
872affcfccdSGuray Ozen /// `nvvm.mbarrier.arrive.nocomplete`
873affcfccdSGuray Ozen struct NVGPUMBarrierArriveNoCompleteLowering
87417649a77SGuray Ozen     : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
87517649a77SGuray Ozen   using MBarrierBasePattern<
87617649a77SGuray Ozen       nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
877affcfccdSGuray Ozen   LogicalResult
878affcfccdSGuray Ozen   matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
879affcfccdSGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
880ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
88117649a77SGuray Ozen     Value barrier =
882ee49cda7SGuray Ozen         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
88317649a77SGuray Ozen                        adaptor.getMbarId(), rewriter);
884affcfccdSGuray Ozen     Type tokenType = getTypeConverter()->convertType(
885affcfccdSGuray Ozen         nvgpu::MBarrierTokenType::get(op->getContext()));
886ee49cda7SGuray Ozen     Value count = truncToI32(b, adaptor.getCount());
88717649a77SGuray Ozen     if (isMbarrierShared(op.getBarriers().getType())) {
888affcfccdSGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
889affcfccdSGuray Ozen           op, tokenType, barrier, count);
890affcfccdSGuray Ozen     } else {
891affcfccdSGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
892affcfccdSGuray Ozen           op, tokenType, barrier, count);
893affcfccdSGuray Ozen     }
894affcfccdSGuray Ozen     return success();
895affcfccdSGuray Ozen   }
896affcfccdSGuray Ozen };
897affcfccdSGuray Ozen 
898affcfccdSGuray Ozen /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
899affcfccdSGuray Ozen struct NVGPUMBarrierTestWaitLowering
90017649a77SGuray Ozen     : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
90117649a77SGuray Ozen   using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
902affcfccdSGuray Ozen   LogicalResult
903affcfccdSGuray Ozen   matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
904affcfccdSGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
905ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
90617649a77SGuray Ozen     Value barrier =
907ee49cda7SGuray Ozen         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
90817649a77SGuray Ozen                        adaptor.getMbarId(), rewriter);
909affcfccdSGuray Ozen     Type retType = rewriter.getI1Type();
91017649a77SGuray Ozen     if (isMbarrierShared(op.getBarriers().getType())) {
911affcfccdSGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
912affcfccdSGuray Ozen           op, retType, barrier, adaptor.getToken());
913affcfccdSGuray Ozen     } else {
914affcfccdSGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
915affcfccdSGuray Ozen           op, retType, barrier, adaptor.getToken());
916affcfccdSGuray Ozen     }
917affcfccdSGuray Ozen     return success();
918affcfccdSGuray Ozen   }
919affcfccdSGuray Ozen };
920affcfccdSGuray Ozen 
921836dbb85SGuray Ozen struct NVGPUMBarrierArriveExpectTxLowering
92217649a77SGuray Ozen     : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
92317649a77SGuray Ozen   using MBarrierBasePattern<
92417649a77SGuray Ozen       nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
925836dbb85SGuray Ozen   LogicalResult
926836dbb85SGuray Ozen   matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
927836dbb85SGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
928ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
92917649a77SGuray Ozen     Value barrier =
930ee49cda7SGuray Ozen         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
93117649a77SGuray Ozen                        adaptor.getMbarId(), rewriter);
932ee49cda7SGuray Ozen     Value txcount = truncToI32(b, adaptor.getTxcount());
933836dbb85SGuray Ozen 
93417649a77SGuray Ozen     if (isMbarrierShared(op.getBarriers().getType())) {
935836dbb85SGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
936192d3320SGuray Ozen           op, barrier, txcount, adaptor.getPredicate());
937836dbb85SGuray Ozen       return success();
938836dbb85SGuray Ozen     }
939836dbb85SGuray Ozen 
94063389326SGuray Ozen     rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
941192d3320SGuray Ozen         op, barrier, txcount, adaptor.getPredicate());
942836dbb85SGuray Ozen     return success();
943836dbb85SGuray Ozen   }
944836dbb85SGuray Ozen };
945836dbb85SGuray Ozen 
946836dbb85SGuray Ozen struct NVGPUMBarrierTryWaitParityLowering
94717649a77SGuray Ozen     : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
94817649a77SGuray Ozen   using MBarrierBasePattern<
94917649a77SGuray Ozen       nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
950836dbb85SGuray Ozen   LogicalResult
951836dbb85SGuray Ozen   matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
952836dbb85SGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
953ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
95417649a77SGuray Ozen     Value barrier =
955ee49cda7SGuray Ozen         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
95617649a77SGuray Ozen                        adaptor.getMbarId(), rewriter);
957ee49cda7SGuray Ozen     Value ticks = truncToI32(b, adaptor.getTicks());
9580a600c34SGuray Ozen     Value phase =
9590a600c34SGuray Ozen         b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
960836dbb85SGuray Ozen 
96117649a77SGuray Ozen     if (isMbarrierShared(op.getBarriers().getType())) {
962836dbb85SGuray Ozen       rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
963836dbb85SGuray Ozen           op, barrier, phase, ticks);
964836dbb85SGuray Ozen       return success();
965836dbb85SGuray Ozen     }
966836dbb85SGuray Ozen 
967836dbb85SGuray Ozen     rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
968836dbb85SGuray Ozen                                                                phase, ticks);
969836dbb85SGuray Ozen     return success();
970836dbb85SGuray Ozen   }
971836dbb85SGuray Ozen };
972836dbb85SGuray Ozen 
97370c2e061SGuray Ozen struct NVGPUTmaAsyncLoadOpLowering
97417649a77SGuray Ozen     : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
97517649a77SGuray Ozen   using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
97670c2e061SGuray Ozen   LogicalResult
97770c2e061SGuray Ozen   matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
97870c2e061SGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
979ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
98050a76a7dSGuray Ozen     auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
98150a76a7dSGuray Ozen     Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
98250a76a7dSGuray Ozen                                       adaptor.getDst(), {}, rewriter);
98317649a77SGuray Ozen     Value barrier =
984ee49cda7SGuray Ozen         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
98517649a77SGuray Ozen                        adaptor.getMbarId(), rewriter);
98670c2e061SGuray Ozen 
98770c2e061SGuray Ozen     SmallVector<Value> coords = adaptor.getCoordinates();
98870c2e061SGuray Ozen     for (auto [index, value] : llvm::enumerate(coords)) {
989ee49cda7SGuray Ozen       coords[index] = truncToI32(b, value);
99070c2e061SGuray Ozen     }
99170c2e061SGuray Ozen     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
9929ceea088SGuray Ozen         op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
9934319e191SGuray Ozen         ValueRange{}, adaptor.getMulticastMask(), Value{},
9944319e191SGuray Ozen         adaptor.getPredicate());
99570c2e061SGuray Ozen     return success();
99670c2e061SGuray Ozen   }
99770c2e061SGuray Ozen };
9988dd0d95cSGuray Ozen 
9998dd0d95cSGuray Ozen struct NVGPUTmaAsyncStoreOpLowering
10008dd0d95cSGuray Ozen     : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
10018dd0d95cSGuray Ozen   using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
10028dd0d95cSGuray Ozen   LogicalResult
10038dd0d95cSGuray Ozen   matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
10048dd0d95cSGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
10058dd0d95cSGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
10068dd0d95cSGuray Ozen     auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
10078dd0d95cSGuray Ozen     Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
10088dd0d95cSGuray Ozen                                       adaptor.getSrc(), {}, rewriter);
10098dd0d95cSGuray Ozen     SmallVector<Value> coords = adaptor.getCoordinates();
10108dd0d95cSGuray Ozen     for (auto [index, value] : llvm::enumerate(coords)) {
10118dd0d95cSGuray Ozen       coords[index] = truncToI32(b, value);
10128dd0d95cSGuray Ozen     }
10138dd0d95cSGuray Ozen 
10148dd0d95cSGuray Ozen     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
10158dd0d95cSGuray Ozen         op, adaptor.getTensorMapDescriptor(), dest, coords,
10168dd0d95cSGuray Ozen         adaptor.getPredicate());
10178dd0d95cSGuray Ozen     return success();
10188dd0d95cSGuray Ozen   }
10198dd0d95cSGuray Ozen };
10208dd0d95cSGuray Ozen 
10216dc7717bSGuray Ozen struct NVGPUGenerateWarpgroupDescriptorLowering
10227eb2b99fSGuray Ozen     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1023cce3e8edSGuray Ozen   using ConvertOpToLLVMPattern<
10247eb2b99fSGuray Ozen       nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1025cce3e8edSGuray Ozen 
1026cce3e8edSGuray Ozen   LogicalResult
10277eb2b99fSGuray Ozen   matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1028cce3e8edSGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
1029cce3e8edSGuray Ozen 
1030ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1031cce3e8edSGuray Ozen 
1032cce3e8edSGuray Ozen     nvgpu::TensorMapSwizzleKind swizzleKind =
1033cce3e8edSGuray Ozen         op.getTensorMap().getType().getSwizzle();
1034cce3e8edSGuray Ozen 
1035cce3e8edSGuray Ozen     unsigned layout =
1036cce3e8edSGuray Ozen         (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B)  ? 128
1037cce3e8edSGuray Ozen         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1038cce3e8edSGuray Ozen         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1039cce3e8edSGuray Ozen                                                                     : 1;
1040cce3e8edSGuray Ozen     unsigned swizzle =
1041cce3e8edSGuray Ozen         (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B)  ? 1
1042cce3e8edSGuray Ozen         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1043cce3e8edSGuray Ozen         : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1044cce3e8edSGuray Ozen                                                                     : 0;
1045cce3e8edSGuray Ozen 
1046ee49cda7SGuray Ozen     auto ti64 = b.getIntegerType(64);
1047cce3e8edSGuray Ozen     auto makeConst = [&](uint64_t index) -> Value {
1048ee49cda7SGuray Ozen       return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1049cce3e8edSGuray Ozen     };
1050cce3e8edSGuray Ozen     auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1051ee49cda7SGuray Ozen       return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1052cce3e8edSGuray Ozen     };
1053cce3e8edSGuray Ozen     auto shiftRight = [&](Value value, unsigned shift) -> Value {
1054ee49cda7SGuray Ozen       return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1055cce3e8edSGuray Ozen     };
1056cce3e8edSGuray Ozen     auto insertBit = [&](Value desc, Value val, int startBit) {
1057ee49cda7SGuray Ozen       return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1058cce3e8edSGuray Ozen     };
1059cce3e8edSGuray Ozen 
1060cce3e8edSGuray Ozen     int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
106123882226SGuray Ozen     uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
106223882226SGuray Ozen     uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1063b96d0693SGuray Ozen     uint64_t offsetVal = 0;
1064b96d0693SGuray Ozen 
1065b96d0693SGuray Ozen     Value strideDim = makeConst(strideDimVal);
1066b96d0693SGuray Ozen     Value leadDim = makeConst(leadDimVal);
1067b96d0693SGuray Ozen 
1068cce3e8edSGuray Ozen     Value baseAddr = getStridedElementPtr(
1069cce3e8edSGuray Ozen         op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1070cce3e8edSGuray Ozen         adaptor.getTensor(), {}, rewriter);
1071ee49cda7SGuray Ozen     Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1072cce3e8edSGuray Ozen     // Just use 14 bits for base address
1073cce3e8edSGuray Ozen     Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1074cce3e8edSGuray Ozen 
1075cce3e8edSGuray Ozen     int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1076cce3e8edSGuray Ozen         startLeadBit = 16, startBaseAddrBit = 0;
1077cce3e8edSGuray Ozen     Value dsc = makeConst(0);
1078cce3e8edSGuray Ozen     // // [62,64)  swizzle type
1079cce3e8edSGuray Ozen     dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1080cce3e8edSGuray Ozen     // // [49,52)  base_offset
1081b96d0693SGuray Ozen     dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1082cce3e8edSGuray Ozen     // // [32,46)  stride
1083cce3e8edSGuray Ozen     dsc = insertBit(dsc, strideDim, startStrideBit);
1084cce3e8edSGuray Ozen     // // [16,30)  leading dimension
1085cce3e8edSGuray Ozen     dsc = insertBit(dsc, leadDim, startLeadBit);
1086cce3e8edSGuray Ozen     // // [0,14)   start_address
1087cce3e8edSGuray Ozen     dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1088cce3e8edSGuray Ozen 
10896dc7717bSGuray Ozen     LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1090b96d0693SGuray Ozen                       << "leading_off:" << leadDimVal << "\t"
1091b96d0693SGuray Ozen                       << "stride_off :" << strideDimVal << "\t"
1092b96d0693SGuray Ozen                       << "base_offset:" << offsetVal << "\t"
1093b96d0693SGuray Ozen                       << "layout_type:" << swizzle << " ("
1094b96d0693SGuray Ozen                       << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1095b96d0693SGuray Ozen                       << ")\n start_addr :  " << baseAddr << "\n");
1096b96d0693SGuray Ozen 
1097cce3e8edSGuray Ozen     rewriter.replaceOp(op, dsc);
1098cce3e8edSGuray Ozen     return success();
1099cce3e8edSGuray Ozen   }
1100cce3e8edSGuray Ozen };
1101e56d6745SGuray Ozen 
1102ee49cda7SGuray Ozen static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1103ee49cda7SGuray Ozen   return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1104ee49cda7SGuray Ozen                                     b.getI32IntegerAttr(index));
1105e56d6745SGuray Ozen }
1106e56d6745SGuray Ozen 
1107e56d6745SGuray Ozen /// Returns a Value that holds data type enum that is expected by CUDA driver.
1108ee49cda7SGuray Ozen static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1109e56d6745SGuray Ozen   // Enum is from CUDA driver API
1110e56d6745SGuray Ozen   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1111e56d6745SGuray Ozen   enum CUtensorMapDataTypeEnum {
1112e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1113e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_UINT16,
1114e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_UINT32,
1115e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_INT32,
1116e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_UINT64,
1117e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_INT64,
1118e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1119e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1120e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1121e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1122e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1123e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1124e56d6745SGuray Ozen     CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1125e56d6745SGuray Ozen   };
1126e56d6745SGuray Ozen 
1127e56d6745SGuray Ozen   if (type.isUnsignedInteger(8))
1128ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1129e56d6745SGuray Ozen   if (type.isUnsignedInteger(16))
1130ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1131e56d6745SGuray Ozen   if (type.isUnsignedInteger(32))
1132ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1133e56d6745SGuray Ozen   if (type.isUnsignedInteger(64))
1134ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1135e56d6745SGuray Ozen   if (type.isSignlessInteger(32))
1136ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1137e56d6745SGuray Ozen   if (type.isSignlessInteger(64))
1138ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1139e56d6745SGuray Ozen   if (type.isF16())
1140ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1141e56d6745SGuray Ozen   if (type.isF32())
1142ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1143e56d6745SGuray Ozen   if (type.isF64())
1144ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1145e56d6745SGuray Ozen   if (type.isBF16())
1146ee49cda7SGuray Ozen     return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1147e56d6745SGuray Ozen 
1148e56d6745SGuray Ozen   llvm_unreachable("Not supported data type");
1149e56d6745SGuray Ozen }
1150e56d6745SGuray Ozen 
1151e56d6745SGuray Ozen struct NVGPUTmaCreateDescriptorOpLowering
1152e56d6745SGuray Ozen     : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1153e56d6745SGuray Ozen   using ConvertOpToLLVMPattern<
1154e56d6745SGuray Ozen       nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1155e56d6745SGuray Ozen   LogicalResult
1156e56d6745SGuray Ozen   matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1157e56d6745SGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
1158ee49cda7SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
11592f17c9f6SChristian Ulmann     auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1160e56d6745SGuray Ozen     Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1161e56d6745SGuray Ozen 
1162ee49cda7SGuray Ozen     Value tensorElementType =
1163ee49cda7SGuray Ozen         elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1164e56d6745SGuray Ozen     auto promotedOperands = getTypeConverter()->promoteOperands(
1165ee49cda7SGuray Ozen         b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1166e56d6745SGuray Ozen 
1167ee49cda7SGuray Ozen     Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1168ee49cda7SGuray Ozen                                                  makeI64Const(b, 5));
1169e56d6745SGuray Ozen     for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1170ee49cda7SGuray Ozen       Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1171ee49cda7SGuray Ozen                                         boxArrayPtr, makeI64Const(b, index));
1172ee49cda7SGuray Ozen       b.create<LLVM::StoreOp>(value, gep);
1173e56d6745SGuray Ozen     }
1174e56d6745SGuray Ozen 
1175e56d6745SGuray Ozen     nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1176e56d6745SGuray Ozen     // Set Arguments for the function call
1177e56d6745SGuray Ozen     SmallVector<Value> arguments;
1178e56d6745SGuray Ozen     arguments.push_back(promotedOperands[0]); // rank
1179e56d6745SGuray Ozen     arguments.push_back(promotedOperands[1]); // descriptor
1180e56d6745SGuray Ozen     arguments.push_back(tensorElementType);   // data type
1181e56d6745SGuray Ozen     arguments.push_back(
1182ee49cda7SGuray Ozen         makeI64Const(b, (int)desc.getInterleave()));              // interleave
1183ee49cda7SGuray Ozen     arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1184ee49cda7SGuray Ozen     arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1185ee49cda7SGuray Ozen     arguments.push_back(makeI64Const(b, (int)desc.getOob()));     // oob
1186e56d6745SGuray Ozen     arguments.push_back(boxArrayPtr); // box dimensions
1187e56d6745SGuray Ozen 
1188e56d6745SGuray Ozen     // Set data types of the arguments
1189e56d6745SGuray Ozen     SmallVector<Type> argTypes = {
1190e56d6745SGuray Ozen         llvmInt64Type,   /* int64_t tensorRank */
1191e56d6745SGuray Ozen         llvmPointerType, /* ptr */
1192e56d6745SGuray Ozen         llvmInt64Type,   /* int64_t */
1193e56d6745SGuray Ozen         llvmInt64Type,   /* int64_t */
1194e56d6745SGuray Ozen         llvmInt64Type,   /* int64_t */
1195e56d6745SGuray Ozen         llvmInt64Type,   /* int64_t */
1196e56d6745SGuray Ozen         llvmInt64Type,   /* int64_t */
1197e56d6745SGuray Ozen         llvmPointerType  /* ptr  */
1198e56d6745SGuray Ozen     };
1199e56d6745SGuray Ozen     FunctionCallBuilder hostRegisterCallBuilder = {
1200e56d6745SGuray Ozen         "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1201e56d6745SGuray Ozen     Value tensorMap =
1202ee49cda7SGuray Ozen         hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1203e56d6745SGuray Ozen 
1204e56d6745SGuray Ozen     rewriter.replaceOp(op, tensorMap);
1205e56d6745SGuray Ozen     return success();
1206e56d6745SGuray Ozen   }
1207e56d6745SGuray Ozen };
1208e56d6745SGuray Ozen 
120923882226SGuray Ozen struct NVGPUWarpgroupMmaOpLowering
121023882226SGuray Ozen     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
121123882226SGuray Ozen   using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
121223882226SGuray Ozen 
1213b74cfc13SGuray Ozen   /// This is a helper class to generate required NVVM Ops for warp-group level
1214b74cfc13SGuray Ozen   /// matrix multiplication.
1215b74cfc13SGuray Ozen   /// When the given GEMM shape is larger than the shape of
1216b74cfc13SGuray Ozen   /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1217b74cfc13SGuray Ozen   /// Op(s), group and execute them asynchronously. The class also handles
1218b74cfc13SGuray Ozen   /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1219b74cfc13SGuray Ozen   /// create descriptors for each instruction.
1220b74cfc13SGuray Ozen   ///
1221b74cfc13SGuray Ozen   /// For example this is the case when the shape of GEMM is 128x128x128
1222b74cfc13SGuray Ozen   ///
1223b74cfc13SGuray Ozen   ///    nvvm.wgmma.fence.aligned
1224b74cfc13SGuray Ozen   ///
1225b74cfc13SGuray Ozen   ///    nvvm.wgmma.mma.async descA, descB
1226b74cfc13SGuray Ozen   ///    iterate(descA, descB)
1227b74cfc13SGuray Ozen   ///    nvvm.wgmma.mma.async descA, descB
1228b74cfc13SGuray Ozen   ///    [6x times more]
1229b74cfc13SGuray Ozen   ///
1230b74cfc13SGuray Ozen   ///    nvvm.wgmma.group.sync.aligned
1231b74cfc13SGuray Ozen   ///    nvvm.wgmma.wait.group.sync [groupId]
1232b74cfc13SGuray Ozen   ///
1233b74cfc13SGuray Ozen   class WarpgroupGemm {
1234b74cfc13SGuray Ozen     nvgpu::WarpgroupMmaOp op;
1235b74cfc13SGuray Ozen     ImplicitLocOpBuilder b;
1236b74cfc13SGuray Ozen     OpAdaptor adaptor;
1237b74cfc13SGuray Ozen 
1238b74cfc13SGuray Ozen     // Entire shape of the given Op
1239b74cfc13SGuray Ozen     int64_t totalM, totalN, totalK;
1240b74cfc13SGuray Ozen 
1241b74cfc13SGuray Ozen     // Shape of one wgmma instruction
1242b74cfc13SGuray Ozen     int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1243b74cfc13SGuray Ozen 
1244b74cfc13SGuray Ozen     // Iteration counts for GEMM
1245b74cfc13SGuray Ozen     int iterationM = 0, iterationN = 0, iterationK = 0;
1246b74cfc13SGuray Ozen 
1247b74cfc13SGuray Ozen     /// The function returns the shape of wgmma instruction that is defined in
1248b74cfc13SGuray Ozen     /// PTX programming guide.
1249b74cfc13SGuray Ozen     /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1250b74cfc13SGuray Ozen     void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1251b74cfc13SGuray Ozen       wgmmaM = 64;
1252b74cfc13SGuray Ozen       wgmmaN = sizeN;
125323882226SGuray Ozen       if (inputElemType.isTF32()) {
1254b74cfc13SGuray Ozen         wgmmaK = 8;
125523882226SGuray Ozen       } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1256b74cfc13SGuray Ozen         wgmmaK = 16;
1257*7a77f14cSMatthias Springer       } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1258*7a77f14cSMatthias Springer                  inputElemType.isInteger(16)) {
1259b74cfc13SGuray Ozen         wgmmaK = 32;
126023882226SGuray Ozen       } else if (inputElemType.isInteger(1)) {
1261b74cfc13SGuray Ozen         wgmmaK = 256;
126223882226SGuray Ozen       } else {
126323882226SGuray Ozen         llvm_unreachable("msg: not supported K shape");
126423882226SGuray Ozen       }
1265b74cfc13SGuray Ozen       LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1266b74cfc13SGuray Ozen                         << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
126723882226SGuray Ozen     }
126823882226SGuray Ozen 
1269b74cfc13SGuray Ozen     /// Generates WGMMATypesAttr from MLIR Type
127012c241b3SGuray Ozen     NVVM::WGMMATypesAttr generateWgmmaType(Type type,
127112c241b3SGuray Ozen                                            bool useF32 = false) const {
127212c241b3SGuray Ozen       auto getWgmmaType = [=](Type elemType) {
1273b74cfc13SGuray Ozen         if (elemType.isF32() || elemType.isTF32())
127412c241b3SGuray Ozen           return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1275b74cfc13SGuray Ozen         if (elemType.isF16())
1276b74cfc13SGuray Ozen           return NVVM::WGMMATypes::f16;
1277b74cfc13SGuray Ozen         if (elemType.isBF16())
1278b74cfc13SGuray Ozen           return NVVM::WGMMATypes::bf16;
1279*7a77f14cSMatthias Springer         if (isa<Float8E4M3FNType>(elemType))
1280b74cfc13SGuray Ozen           return NVVM::WGMMATypes::e4m3;
1281*7a77f14cSMatthias Springer         if (isa<Float8E5M2Type>(elemType))
1282b74cfc13SGuray Ozen           return NVVM::WGMMATypes::e5m2;
1283b74cfc13SGuray Ozen         if (elemType.isInteger(1))
1284b74cfc13SGuray Ozen           return NVVM::WGMMATypes::b1;
1285b74cfc13SGuray Ozen         if (elemType.isInteger(8))
1286b74cfc13SGuray Ozen           return NVVM::WGMMATypes::s8;
1287b74cfc13SGuray Ozen         if (elemType.isUnsignedInteger(8))
1288b74cfc13SGuray Ozen           return NVVM::WGMMATypes::u8;
128912c241b3SGuray Ozen         if (elemType.isInteger(32))
129012c241b3SGuray Ozen           return NVVM::WGMMATypes::s32;
1291b74cfc13SGuray Ozen         llvm_unreachable("unsupported type");
1292b74cfc13SGuray Ozen       };
1293b74cfc13SGuray Ozen       return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
129423882226SGuray Ozen     }
129523882226SGuray Ozen 
1296b74cfc13SGuray Ozen     /// Generates layout attribute for the input matrix for wgmma instruction
1297b74cfc13SGuray Ozen     NVVM::MMALayoutAttr
1298b74cfc13SGuray Ozen     generateWgmmaLayout(std::optional<bool> transpose) const {
1299b74cfc13SGuray Ozen       if (transpose.value_or(false))
1300b74cfc13SGuray Ozen         return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1301b74cfc13SGuray Ozen       return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
130223882226SGuray Ozen     }
130323882226SGuray Ozen 
1304b74cfc13SGuray Ozen     /// Generates shape attribute for wgmma instruction
1305b74cfc13SGuray Ozen     NVVM::MMAShapeAttr generateWgmmaShape() const {
1306b74cfc13SGuray Ozen       return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1307b74cfc13SGuray Ozen     }
130823882226SGuray Ozen 
1309b74cfc13SGuray Ozen     /// Generates scale attributes of output matrix for wgmma instruction
1310b74cfc13SGuray Ozen     NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1311b74cfc13SGuray Ozen       return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1312b74cfc13SGuray Ozen                                           NVVM::WGMMAScaleOut::one);
1313b74cfc13SGuray Ozen     }
1314b74cfc13SGuray Ozen     /// Generates scale attributes of input matrix for wgmma instruction
1315b74cfc13SGuray Ozen     NVVM::WGMMAScaleInAttr generateScaleIn() const {
1316b74cfc13SGuray Ozen       return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1317b74cfc13SGuray Ozen                                          NVVM::WGMMAScaleIn::one);
1318b74cfc13SGuray Ozen     }
131923882226SGuray Ozen 
1320b74cfc13SGuray Ozen     /// Basic function to generate Add
1321b74cfc13SGuray Ozen     Value makeAdd(Value lhs, Value rhs) {
1322ee49cda7SGuray Ozen       return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
132323882226SGuray Ozen     };
132423882226SGuray Ozen 
1325b74cfc13SGuray Ozen     /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1326b74cfc13SGuray Ozen     /// Currently, it only handles row-major.
1327b74cfc13SGuray Ozen     ///
1328b74cfc13SGuray Ozen     /// It moves the pointer like below for [128][64] size:
1329b74cfc13SGuray Ozen     ///                 +2 +4 +6
1330b74cfc13SGuray Ozen     ///                  ↓  ↓  ↓
1331b74cfc13SGuray Ozen     /// descA    ---> +--+--+--+--+
1332b74cfc13SGuray Ozen     ///               |->|->|->|->|
1333b74cfc13SGuray Ozen     ///               |  |  |  |  |
1334b74cfc13SGuray Ozen     ///               |  |  |  |  |
1335b74cfc13SGuray Ozen     ///               |  |  |  |  |
1336b74cfc13SGuray Ozen     /// descA+512---> +-----------+
1337b74cfc13SGuray Ozen     ///               |  |  |  |  |
1338b74cfc13SGuray Ozen     ///               |  |  |  |  |
1339b74cfc13SGuray Ozen     ///               |  |  |  |  |
1340b74cfc13SGuray Ozen     ///               |  |  |  |  |
1341b74cfc13SGuray Ozen     ///               +-----------+
1342b74cfc13SGuray Ozen     ///
1343b74cfc13SGuray Ozen     Value iterateDescriptorA(Value desc, int i, int j, int k) {
1344b74cfc13SGuray Ozen       MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1345b74cfc13SGuray Ozen       Type elemA = matrixTypeA.getElementType();
1346b74cfc13SGuray Ozen       int byte = elemA.getIntOrFloatBitWidth() / 8;
1347b74cfc13SGuray Ozen       int tileShapeA = matrixTypeA.getDimSize(1);
1348b74cfc13SGuray Ozen       int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
134923882226SGuray Ozen       incrementVal = incrementVal >> exclude4LSB;
1350b74cfc13SGuray Ozen       LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1351b74cfc13SGuray Ozen                         << "] [wgmma descriptors] Descriptor A + "
135223882226SGuray Ozen                         << incrementVal << " | \t ");
135323882226SGuray Ozen       if (!incrementVal)
135423882226SGuray Ozen         return desc;
1355ee49cda7SGuray Ozen       return makeAdd(desc, makeI64Const(b, incrementVal));
1356b74cfc13SGuray Ozen     }
135723882226SGuray Ozen 
1358b74cfc13SGuray Ozen     /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1359b74cfc13SGuray Ozen     /// Currently, it only handles column-major.
1360b74cfc13SGuray Ozen     ///
1361b74cfc13SGuray Ozen     /// It moves the pointer like below for [128][64] size:
1362b74cfc13SGuray Ozen     /// descB     ---> +--+--+--+--+--+--+--+--+
1363b74cfc13SGuray Ozen     ///                |↓ |  |  |  |  |  |  |  |
1364b74cfc13SGuray Ozen     ///                |↓ |  |  |  |  |  |  |  |
1365b74cfc13SGuray Ozen     ///                |↓ |  |  |  |  |  |  |  |
1366b74cfc13SGuray Ozen     ///                |↓ |  |  |  |  |  |  |  |
1367b74cfc13SGuray Ozen     ///                +--+--+--+--+--+--+--+--+
1368b74cfc13SGuray Ozen     ///
1369b74cfc13SGuray Ozen     Value iterateDescriptorB(Value desc, int i, int j, int k) {
1370b74cfc13SGuray Ozen       MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1371b74cfc13SGuray Ozen       Type elemB = matrixTypeB.getElementType();
1372b74cfc13SGuray Ozen       int byte = elemB.getIntOrFloatBitWidth() / 8;
1373b74cfc13SGuray Ozen       int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
137423882226SGuray Ozen       incrementVal = incrementVal >> exclude4LSB;
137523882226SGuray Ozen       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
137623882226SGuray Ozen       if (!incrementVal)
137723882226SGuray Ozen         return desc;
1378ee49cda7SGuray Ozen       return makeAdd(desc, makeI64Const(b, incrementVal));
137923882226SGuray Ozen     }
1380b74cfc13SGuray Ozen 
1381b74cfc13SGuray Ozen     /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1382b74cfc13SGuray Ozen     /// descriptors and arranges them based on induction variables: i, j, and k.
138352db7e27SGuray Ozen     Value generateWgmma(int i, int j, int k, Value matrixC) {
1384b74cfc13SGuray Ozen       LLVM_DEBUG(DBGS() << "\t wgmma."
1385b74cfc13SGuray Ozen                         << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1386b74cfc13SGuray Ozen                         << "(A[" << (iterationM * wgmmaM) << ":"
1387b74cfc13SGuray Ozen                         << (iterationM * wgmmaM) + wgmmaM << "]["
1388b74cfc13SGuray Ozen                         << (iterationK * wgmmaK) << ":"
1389b74cfc13SGuray Ozen                         << (iterationK * wgmmaK + wgmmaK) << "] * "
1390b74cfc13SGuray Ozen                         << " B[" << (iterationK * wgmmaK) << ":"
1391b74cfc13SGuray Ozen                         << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1392b74cfc13SGuray Ozen                         << wgmmaN << "])\n");
1393b74cfc13SGuray Ozen 
1394b74cfc13SGuray Ozen       Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1395b74cfc13SGuray Ozen       Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1396b74cfc13SGuray Ozen 
1397b74cfc13SGuray Ozen       Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1398b74cfc13SGuray Ozen       NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1399b74cfc13SGuray Ozen 
1400b74cfc13SGuray Ozen       Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1401b74cfc13SGuray Ozen       NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1402b74cfc13SGuray Ozen 
140312c241b3SGuray Ozen       Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
140412c241b3SGuray Ozen       NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
140512c241b3SGuray Ozen 
1406b74cfc13SGuray Ozen       NVVM::MMAShapeAttr shape = generateWgmmaShape();
1407b74cfc13SGuray Ozen       NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1408b74cfc13SGuray Ozen       NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1409b74cfc13SGuray Ozen       NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1410fa13c3eeSGuray Ozen       NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1411b74cfc13SGuray Ozen 
1412b74cfc13SGuray Ozen       auto overflow = NVVM::MMAIntOverflowAttr::get(
1413b74cfc13SGuray Ozen           op->getContext(), NVVM::MMAIntOverflow::wrapped);
1414b74cfc13SGuray Ozen 
1415b74cfc13SGuray Ozen       return b.create<NVVM::WgmmaMmaAsyncOp>(
141652db7e27SGuray Ozen           matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
141712c241b3SGuray Ozen           itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
141812c241b3SGuray Ozen           overflow);
1419b74cfc13SGuray Ozen     }
1420b74cfc13SGuray Ozen 
1421b74cfc13SGuray Ozen     /// Generates multiple wgmma instructions to complete the given GEMM shape
142252db7e27SGuray Ozen     Value generateWgmmaGroup() {
142352db7e27SGuray Ozen       Value wgmmaResult =
142452db7e27SGuray Ozen           b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1425b74cfc13SGuray Ozen 
1426b74cfc13SGuray Ozen       // Perform GEMM
142752db7e27SGuray Ozen       SmallVector<Value> wgmmaResults;
1428b74cfc13SGuray Ozen       for (int i = 0; i < iterationM; ++i) {
142952db7e27SGuray Ozen         Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1430b74cfc13SGuray Ozen         for (int j = 0; j < iterationN; ++j)
1431b74cfc13SGuray Ozen           for (int k = 0; k < iterationK; ++k)
143252db7e27SGuray Ozen             matrixC = generateWgmma(i, j, k, matrixC);
143323882226SGuray Ozen         wgmmaResults.push_back(matrixC);
143423882226SGuray Ozen       }
143552db7e27SGuray Ozen       for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
143652db7e27SGuray Ozen         wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
143752db7e27SGuray Ozen                                                     wgmmaResult, matrix, idx);
143852db7e27SGuray Ozen       }
143952db7e27SGuray Ozen       return wgmmaResult;
1440b74cfc13SGuray Ozen     }
1441b74cfc13SGuray Ozen 
1442b74cfc13SGuray Ozen   public:
1443b74cfc13SGuray Ozen     WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
144452db7e27SGuray Ozen                   OpAdaptor adaptor)
144552db7e27SGuray Ozen         : op(op), b(b), adaptor(adaptor) {
1446b74cfc13SGuray Ozen       // Find the entire GEMM Shape
1447b74cfc13SGuray Ozen       totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1448b74cfc13SGuray Ozen       totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1449b74cfc13SGuray Ozen       totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1450b74cfc13SGuray Ozen       LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1451b74cfc13SGuray Ozen                         << "] += A[" << totalM << "][" << totalK << "] * B["
1452b74cfc13SGuray Ozen                         << totalK << "][" << totalN << "] ---===\n");
1453b74cfc13SGuray Ozen 
1454b74cfc13SGuray Ozen       // Find the shape for one wgmma instruction
1455b74cfc13SGuray Ozen       findWgmmaShape(
1456b74cfc13SGuray Ozen           totalM, totalN,
1457b74cfc13SGuray Ozen           op.getDescriptorA().getType().getTensor().getElementType());
1458b74cfc13SGuray Ozen 
1459b74cfc13SGuray Ozen       // Iterations counts to complete the given shape with wgmma shape
1460b74cfc13SGuray Ozen       iterationM = totalM / wgmmaM;
1461b74cfc13SGuray Ozen       iterationN = totalN / wgmmaN;
1462b74cfc13SGuray Ozen       iterationK = totalK / wgmmaK;
1463b74cfc13SGuray Ozen     }
1464b74cfc13SGuray Ozen 
1465b74cfc13SGuray Ozen     /// Generates WgmmaMmaAsync Ops to complete the specified GEMM  shape. It
1466b74cfc13SGuray Ozen     /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1467b74cfc13SGuray Ozen     /// instructions and group synchronization, as well as waiting
1468b74cfc13SGuray Ozen     /// (WgmmaGroupSyncAlignedOp) for group synchronization
1469b74cfc13SGuray Ozen     /// (WgmmaWaitGroupSyncOp) after the instructions.
147052db7e27SGuray Ozen     Value generateWarpgroupMma() {
1471b74cfc13SGuray Ozen       b.create<NVVM::WgmmaFenceAlignedOp>();
147252db7e27SGuray Ozen       Value wgmmaResult = generateWgmmaGroup();
1473ee49cda7SGuray Ozen       b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1474ee49cda7SGuray Ozen       b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
147552db7e27SGuray Ozen       return wgmmaResult;
1476b74cfc13SGuray Ozen     }
1477b74cfc13SGuray Ozen   };
1478b74cfc13SGuray Ozen   LogicalResult
1479b74cfc13SGuray Ozen   matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1480b74cfc13SGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
1481b74cfc13SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
148252db7e27SGuray Ozen 
1483b74cfc13SGuray Ozen     // Step 1. Build a helper class
148452db7e27SGuray Ozen     WarpgroupGemm warpgroupGemm(op, b, adaptor);
1485b74cfc13SGuray Ozen 
1486b74cfc13SGuray Ozen     // Step 2. Get the entire GEMM Shape
148752db7e27SGuray Ozen     Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1488b74cfc13SGuray Ozen 
1489b74cfc13SGuray Ozen     // Step 3. Replace fragmented result struct with the op results
149052db7e27SGuray Ozen     rewriter.replaceOp(op, wgmmaResult);
149123882226SGuray Ozen     return success();
149223882226SGuray Ozen   }
149323882226SGuray Ozen };
149423882226SGuray Ozen 
1495d20fbc90SGuray Ozen struct NVGPUWarpgroupMmaStoreOpLowering
1496d20fbc90SGuray Ozen     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1497d20fbc90SGuray Ozen   using ConvertOpToLLVMPattern<
1498d20fbc90SGuray Ozen       nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1499d20fbc90SGuray Ozen 
1500d20fbc90SGuray Ozen   /// This function stores a fragmented register matrix owned by a warp group
1501d20fbc90SGuray Ozen   /// (128 threads) into a memref. Each thread has 64 registers, each the size
1502d20fbc90SGuray Ozen   /// of a struct.
1503d20fbc90SGuray Ozen   /// Here is what each threads (T) holds, each `d` is struct value with a
1504d20fbc90SGuray Ozen   /// number.
1505d20fbc90SGuray Ozen   ///
1506d20fbc90SGuray Ozen   /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1507d20fbc90SGuray Ozen   /// 0-31 	  Warp-0  -> MatrixD[0:15 ][0:N]
1508d20fbc90SGuray Ozen   /// 32-63 	Warp-1  -> MatrixD[16:31][0:N]
1509d20fbc90SGuray Ozen   /// 64-95 	Warp-2  -> MatrixD[32:47][0:N]
1510d20fbc90SGuray Ozen   /// 96-127 	Warp-3  -> MatrixD[48:64][0:N]
1511d20fbc90SGuray Ozen   ///
1512d20fbc90SGuray Ozen   /// Matrix-D:
1513d20fbc90SGuray Ozen   ///   +______________________________________________________________________+
1514d20fbc90SGuray Ozen   ///   |     0-1  |    2-3  |    4-5  |    6-7  |   8-9  |   10-11|..|N-8,N-7 |
1515d20fbc90SGuray Ozen   /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1516d20fbc90SGuray Ozen   /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1517d20fbc90SGuray Ozen   /// ..| .........|.........|.........|.........|........|...........|........|
1518d20fbc90SGuray Ozen   /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1519d20fbc90SGuray Ozen   /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1520d20fbc90SGuray Ozen   /// ..| .........|.........|.........|.........|........|...........|........|
1521d20fbc90SGuray Ozen   /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1522d20fbc90SGuray Ozen   /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1523d20fbc90SGuray Ozen   /// ..| .........|.........|.........|.........|........|...........|........|
1524d20fbc90SGuray Ozen   /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1525d20fbc90SGuray Ozen   /// ..| .........|.........|.........|.........|........|...........|........|
1526d20fbc90SGuray Ozen   /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1527d20fbc90SGuray Ozen   /// ..| .........|.........|.........|.........|........|...........|........|
1528d20fbc90SGuray Ozen   ///   +______________________________________________________________________+
1529d20fbc90SGuray Ozen   ///
1530d20fbc90SGuray Ozen   /// \param rewriter: The pattern rewriter.
1531d20fbc90SGuray Ozen   /// \param matrixD: Result of the warp-group MMA operation (fragmented
1532d20fbc90SGuray Ozen   /// matrix). It is holded by a thread and a struct with 64 elements.
1533d20fbc90SGuray Ozen   /// \param dstMemref: The memref where the registers will be stored.
1534d20fbc90SGuray Ozen   /// \param offset: the offset within the memref where the registers will be
1535d20fbc90SGuray Ozen   /// stored.
1536d20fbc90SGuray Ozen   void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1537d20fbc90SGuray Ozen                              TypedValue<MemRefType> dstMemref,
1538d20fbc90SGuray Ozen                              int offset) const {
1539d20fbc90SGuray Ozen     Type i32 = b.getI32Type();
1540d20fbc90SGuray Ozen 
1541d20fbc90SGuray Ozen     auto makeConst = [&](int32_t index) -> Value {
1542d20fbc90SGuray Ozen       return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1543d20fbc90SGuray Ozen     };
1544d20fbc90SGuray Ozen     Value c1 = makeConst(1);
1545d20fbc90SGuray Ozen     Value c2 = makeConst(2);
1546d20fbc90SGuray Ozen     Value c4 = makeConst(4);
1547d20fbc90SGuray Ozen     Value c8 = makeConst(8);
1548d20fbc90SGuray Ozen     Value c16 = makeConst(16);
1549d20fbc90SGuray Ozen     Value warpSize = makeConst(kWarpSize);
1550d20fbc90SGuray Ozen 
1551d20fbc90SGuray Ozen     auto makeMul = [&](Value lhs, Value rhs) -> Value {
1552d20fbc90SGuray Ozen       return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1553d20fbc90SGuray Ozen     };
1554d20fbc90SGuray Ozen     auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1555d20fbc90SGuray Ozen       return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1556d20fbc90SGuray Ozen     };
1557d20fbc90SGuray Ozen 
1558d20fbc90SGuray Ozen     auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1559d20fbc90SGuray Ozen                                    TypedValue<::mlir::MemRefType> memref) {
1560d20fbc90SGuray Ozen       Type it = b.getIndexType();
1561d20fbc90SGuray Ozen       Value idx = b.create<arith::IndexCastOp>(it, x);
1562d20fbc90SGuray Ozen       Value idy0 = b.create<arith::IndexCastOp>(it, y);
1563d20fbc90SGuray Ozen       Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1564d20fbc90SGuray Ozen       Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1565d20fbc90SGuray Ozen       Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1566d20fbc90SGuray Ozen       b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1567d20fbc90SGuray Ozen       b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1568d20fbc90SGuray Ozen     };
1569d20fbc90SGuray Ozen 
157021830c91SGuray Ozen     Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
157121830c91SGuray Ozen     Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
157221830c91SGuray Ozen     Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
157321830c91SGuray Ozen     Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
157421830c91SGuray Ozen     Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
157521830c91SGuray Ozen 
1576d20fbc90SGuray Ozen     Value tj = makeMul(lane4modId, c2);
1577d20fbc90SGuray Ozen     Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1578d20fbc90SGuray Ozen     if (offset)
1579d20fbc90SGuray Ozen       ti = makeAdd(ti, makeConst(offset));
158021830c91SGuray Ozen 
1581a5757c5bSChristian Sigg     auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
158221830c91SGuray Ozen 
158321830c91SGuray Ozen     // Number of 32-bit registers owns per thread
158421830c91SGuray Ozen     constexpr unsigned numAdjacentRegisters = 2;
158521830c91SGuray Ozen     // Number of 8x8 matrices one below another per warp
158621830c91SGuray Ozen     constexpr unsigned numStackedMatrices = 2;
158721830c91SGuray Ozen 
158821830c91SGuray Ozen     size_t storeCount = (structType.getBody().size() /
158921830c91SGuray Ozen                          (numStackedMatrices * numAdjacentRegisters));
159021830c91SGuray Ozen 
159121830c91SGuray Ozen     for (size_t i = 0; i < numStackedMatrices; ++i) {
1592d20fbc90SGuray Ozen       Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
159321830c91SGuray Ozen       for (size_t j = 0; j < storeCount; ++j) {
1594d20fbc90SGuray Ozen         Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
159521830c91SGuray Ozen         size_t structIndex = (i * numAdjacentRegisters) +
159621830c91SGuray Ozen                              (j * (numStackedMatrices * numAdjacentRegisters));
159721830c91SGuray Ozen         makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1598d20fbc90SGuray Ozen       }
1599d20fbc90SGuray Ozen     }
1600d20fbc90SGuray Ozen   }
1601d20fbc90SGuray Ozen 
1602d20fbc90SGuray Ozen   LogicalResult
1603d20fbc90SGuray Ozen   matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1604d20fbc90SGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
1605d20fbc90SGuray Ozen     int offset = 0;
160652db7e27SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
160752db7e27SGuray Ozen     Value matriDValue = adaptor.getMatrixD();
1608a5757c5bSChristian Sigg     auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
160952db7e27SGuray Ozen     for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1610a5757c5bSChristian Sigg       auto structType = cast<LLVM::LLVMStructType>(matrixD);
161152db7e27SGuray Ozen       Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
161252db7e27SGuray Ozen       storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1613d20fbc90SGuray Ozen       offset += structType.getBody().size();
1614d20fbc90SGuray Ozen     }
1615d20fbc90SGuray Ozen     rewriter.eraseOp(op);
1616d20fbc90SGuray Ozen     return success();
1617d20fbc90SGuray Ozen   }
1618d20fbc90SGuray Ozen };
1619d20fbc90SGuray Ozen 
1620315ab3c4SGuray Ozen struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1621315ab3c4SGuray Ozen     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1622315ab3c4SGuray Ozen   using ConvertOpToLLVMPattern<
1623315ab3c4SGuray Ozen       nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1624315ab3c4SGuray Ozen   LogicalResult
1625315ab3c4SGuray Ozen   matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1626315ab3c4SGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
1627315ab3c4SGuray Ozen     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1628a5757c5bSChristian Sigg     LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1629a5757c5bSChristian Sigg         getTypeConverter()->convertType(op.getMatrixC().getType()));
1630a5757c5bSChristian Sigg     Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
163152db7e27SGuray Ozen                         .getBody()
163252db7e27SGuray Ozen                         .front();
163352db7e27SGuray Ozen     Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1634c4ba84d6SGuray Ozen     Value packStruct = b.create<LLVM::UndefOp>(packStructType);
1635c4ba84d6SGuray Ozen     SmallVector<Value> innerStructs;
1636c4ba84d6SGuray Ozen     // Unpack the structs and set all values to zero
1637c4ba84d6SGuray Ozen     for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1638a5757c5bSChristian Sigg       auto structType = cast<LLVM::LLVMStructType>(s);
1639c4ba84d6SGuray Ozen       Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1640c4ba84d6SGuray Ozen       for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1641c4ba84d6SGuray Ozen         structValue = b.create<LLVM::InsertValueOp>(
1642c4ba84d6SGuray Ozen             structType, structValue, zero, ArrayRef<int64_t>({i}));
1643315ab3c4SGuray Ozen       }
1644c4ba84d6SGuray Ozen       innerStructs.push_back(structValue);
1645315ab3c4SGuray Ozen     }
1646c4ba84d6SGuray Ozen     // Pack the inner structs into a single struct
1647c4ba84d6SGuray Ozen     for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1648c4ba84d6SGuray Ozen       packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1649c4ba84d6SGuray Ozen                                                  packStruct, matrix, idx);
1650c4ba84d6SGuray Ozen     }
1651c4ba84d6SGuray Ozen     rewriter.replaceOp(op, packStruct);
1652315ab3c4SGuray Ozen     return success();
1653315ab3c4SGuray Ozen   }
1654315ab3c4SGuray Ozen };
1655315ab3c4SGuray Ozen 
165639cdefb5SGuray Ozen struct NVGPUTmaPrefetchOpLowering
165739cdefb5SGuray Ozen     : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
165839cdefb5SGuray Ozen   using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
165939cdefb5SGuray Ozen   LogicalResult
166039cdefb5SGuray Ozen   matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
166139cdefb5SGuray Ozen                   ConversionPatternRewriter &rewriter) const override {
166239cdefb5SGuray Ozen     rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
166339cdefb5SGuray Ozen         op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
166439cdefb5SGuray Ozen     return success();
166539cdefb5SGuray Ozen   }
166639cdefb5SGuray Ozen };
166739cdefb5SGuray Ozen 
16682b23e6c8SObserver007 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
16692b23e6c8SObserver007   using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
16702b23e6c8SObserver007   LogicalResult
16712b23e6c8SObserver007   matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
16722b23e6c8SObserver007                   ConversionPatternRewriter &rewriter) const override {
16732b23e6c8SObserver007     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
16742b23e6c8SObserver007     auto i64Ty = b.getI64Type();
16752b23e6c8SObserver007     auto f32Ty = b.getF32Type();
16762b23e6c8SObserver007     VectorType inTy = op.getIn().getType();
16772b23e6c8SObserver007     // apply rcp.approx.ftz.f on each element in vector.
16782b23e6c8SObserver007     auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
16792b23e6c8SObserver007       Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy);
16802b23e6c8SObserver007       int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
16812b23e6c8SObserver007       for (int i = 0; i < numElems; i++) {
16822b23e6c8SObserver007         Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
16832b23e6c8SObserver007         Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
16842b23e6c8SObserver007         Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
16852b23e6c8SObserver007         ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
16862b23e6c8SObserver007       }
16872b23e6c8SObserver007       return ret1DVec;
16882b23e6c8SObserver007     };
16892b23e6c8SObserver007     if (inTy.getRank() == 1) {
16902b23e6c8SObserver007       rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
16912b23e6c8SObserver007       return success();
16922b23e6c8SObserver007     }
16932b23e6c8SObserver007     return LLVM::detail::handleMultidimensionalVectors(
16942b23e6c8SObserver007         op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
16952b23e6c8SObserver007         [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
16962b23e6c8SObserver007           OpAdaptor adaptor(operands);
16972b23e6c8SObserver007           return convert1DVec(llvm1DVectorTy, adaptor.getIn());
16982b23e6c8SObserver007         },
16992b23e6c8SObserver007         rewriter);
17002b23e6c8SObserver007   }
17012b23e6c8SObserver007 };
1702894a591cSThomas Raoux } // namespace
170315bcc36eSThomas Raoux 
1704206fad0eSMatthias Springer void mlir::populateNVGPUToNVVMConversionPatterns(
1705206fad0eSMatthias Springer     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1706affcfccdSGuray Ozen   patterns.add<
1707affcfccdSGuray Ozen       NVGPUMBarrierCreateLowering,           // nvgpu.mbarrier.create
1708affcfccdSGuray Ozen       NVGPUMBarrierInitLowering,             // nvgpu.mbarrier.init
1709affcfccdSGuray Ozen       NVGPUMBarrierArriveLowering,           // nvgpu.mbarrier.arrive
1710affcfccdSGuray Ozen       NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1711836dbb85SGuray Ozen       NVGPUMBarrierTestWaitLowering,         // nvgpu.mbarrier.test_wait_parity
1712836dbb85SGuray Ozen       NVGPUMBarrierTryWaitParityLowering,    // nvgpu.mbarrier.try_wait_parity
1713e56d6745SGuray Ozen       NVGPUTmaAsyncLoadOpLowering,           // nvgpu.tma.async.load
17148dd0d95cSGuray Ozen       NVGPUTmaAsyncStoreOpLowering,          // nvgpu.tma.async.store
1715e56d6745SGuray Ozen       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
171639cdefb5SGuray Ozen       NVGPUTmaPrefetchOpLowering,            // nvgpu.tma.prefetch.descriptor
1717836dbb85SGuray Ozen       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
17186dc7717bSGuray Ozen       NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
171923882226SGuray Ozen       NVGPUWarpgroupMmaOpLowering,              // nvgpu.warpgroup.mma
1720d20fbc90SGuray Ozen       NVGPUWarpgroupMmaStoreOpLowering,         // nvgpu.warpgroup.mma.store
1721315ab3c4SGuray Ozen       NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1722affcfccdSGuray Ozen       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1723708185f0SChristopher Bate       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
17242b23e6c8SObserver007       NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1725894a591cSThomas Raoux }
1726