1894a591cSThomas Raoux //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// 2894a591cSThomas Raoux // 3894a591cSThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4894a591cSThomas Raoux // See https://llvm.org/LICENSE.txt for license information. 5894a591cSThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6894a591cSThomas Raoux // 7894a591cSThomas Raoux //===----------------------------------------------------------------------===// 8894a591cSThomas Raoux 9894a591cSThomas Raoux #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" 1067d0d7acSMichele Scuttari 11e56d6745SGuray Ozen #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 12894a591cSThomas Raoux #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13894a591cSThomas Raoux #include "mlir/Conversion/LLVMCommon/Pattern.h" 142b23e6c8SObserver007 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 15d20fbc90SGuray Ozen #include "mlir/Dialect/Arith/IR/Arith.h" 16d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17708185f0SChristopher Bate #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18e56d6745SGuray Ozen #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 19894a591cSThomas Raoux #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 20affcfccdSGuray Ozen #include "mlir/Dialect/MemRef/IR/MemRef.h" 2151b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 2223882226SGuray Ozen #include "mlir/Dialect/SCF/Transforms/Patterns.h" 2317649a77SGuray Ozen #include "mlir/IR/BuiltinTypes.h" 24ee49cda7SGuray Ozen #include "mlir/IR/ImplicitLocOpBuilder.h" 25e56d6745SGuray Ozen #include "mlir/IR/PatternMatch.h" 26708185f0SChristopher Bate #include "mlir/IR/TypeUtilities.h" 2717649a77SGuray Ozen #include "mlir/IR/Value.h" 2867d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 29b96d0693SGuray Ozen #include "llvm/Support/Debug.h" 3023882226SGuray Ozen #include "llvm/Support/ErrorHandling.h" 31e56d6745SGuray Ozen #include "llvm/Support/raw_ostream.h" 3263389326SGuray Ozen #include <optional> 3367d0d7acSMichele Scuttari 34b96d0693SGuray Ozen #define DEBUG_TYPE "nvgpu-to-nvvm" 35b96d0693SGuray Ozen #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36b96d0693SGuray Ozen #define DBGSE() (llvm::dbgs()) 37b96d0693SGuray Ozen 3867d0d7acSMichele Scuttari namespace mlir { 3953689fdfSMarkus Böck #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS 4067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 4167d0d7acSMichele Scuttari } // namespace mlir 42894a591cSThomas Raoux 43894a591cSThomas Raoux using namespace mlir; 44894a591cSThomas Raoux 45b74cfc13SGuray Ozen /// Number of bits that needs to be excluded when building matrix descriptor for 4623882226SGuray Ozen /// wgmma operations. 4723882226SGuray Ozen constexpr int exclude4LSB = 4; 4823882226SGuray Ozen 49836dbb85SGuray Ozen /// GPU has 32 bit registers, this function truncates values when larger width 50836dbb85SGuray Ozen /// is not needed. 51ee49cda7SGuray Ozen static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { 52836dbb85SGuray Ozen Type type = value.getType(); 53836dbb85SGuray Ozen assert(llvm::isa<IntegerType>(type) && "expected an integer Value"); 54836dbb85SGuray Ozen if (type.getIntOrFloatBitWidth() <= 32) 55836dbb85SGuray Ozen return value; 56ee49cda7SGuray Ozen return b.create<LLVM::TruncOp>(b.getI32Type(), value); 57836dbb85SGuray Ozen } 58836dbb85SGuray Ozen 59894a591cSThomas Raoux /// Returns the type for the intrinsic given the vectorResultType of the 60894a591cSThomas Raoux /// `gpu.mma.sync` operation. 61894a591cSThomas Raoux static Type inferIntrinsicResultType(Type vectorResultType) { 62894a591cSThomas Raoux MLIRContext *ctx = vectorResultType.getContext(); 635550c821STres Popp auto a = cast<LLVM::LLVMArrayType>(vectorResultType); 64894a591cSThomas Raoux auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); 65894a591cSThomas Raoux auto i32Ty = IntegerType::get(ctx, 32); 66894a591cSThomas Raoux auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 67894a591cSThomas Raoux Type f64Ty = Float64Type::get(ctx); 68894a591cSThomas Raoux Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 6998798073SChristopher Bate Type f32Ty = Float32Type::get(ctx); 7098798073SChristopher Bate Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); 71894a591cSThomas Raoux if (a.getElementType() == f16x2Ty) { 72894a591cSThomas Raoux return LLVM::LLVMStructType::getLiteral( 73894a591cSThomas Raoux ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); 74894a591cSThomas Raoux } 75894a591cSThomas Raoux if (a.getElementType() == i32x2Ty) { 76894a591cSThomas Raoux return LLVM::LLVMStructType::getLiteral( 77894a591cSThomas Raoux ctx, 78894a591cSThomas Raoux SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty)); 79894a591cSThomas Raoux } 80894a591cSThomas Raoux if (a.getElementType() == f64x2Ty) { 81894a591cSThomas Raoux return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); 82894a591cSThomas Raoux } 8398798073SChristopher Bate if (a.getElementType() == f32x2Ty) { 8498798073SChristopher Bate return LLVM::LLVMStructType::getLiteral( 8598798073SChristopher Bate ctx, 8698798073SChristopher Bate SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); 8798798073SChristopher Bate } 8898798073SChristopher Bate if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { 8998798073SChristopher Bate return LLVM::LLVMStructType::getLiteral( 9098798073SChristopher Bate ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); 9198798073SChristopher Bate } 92894a591cSThomas Raoux return vectorResultType; 93894a591cSThomas Raoux } 94894a591cSThomas Raoux 95894a591cSThomas Raoux /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is 96894a591cSThomas Raoux /// always an LLVM struct) into a fragment that is compatible with the vector 97894a591cSThomas Raoux /// type of this operation. This involves extracting elements from the struct 98894a591cSThomas Raoux /// and inserting them into an LLVM array. These extra data-movement 99894a591cSThomas Raoux /// operations should be canonicalized away by the LLVM backend. 100894a591cSThomas Raoux static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, 101894a591cSThomas Raoux Type resultType, Value intrinsicResult, 102894a591cSThomas Raoux RewriterBase &rewriter) { 103894a591cSThomas Raoux MLIRContext *ctx = rewriter.getContext(); 1045550c821STres Popp auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType); 1055550c821STres Popp auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType); 106894a591cSThomas Raoux Type i32Ty = rewriter.getI32Type(); 10798798073SChristopher Bate Type f32Ty = rewriter.getF32Type(); 108894a591cSThomas Raoux Type f64Ty = rewriter.getF64Type(); 109894a591cSThomas Raoux Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); 110894a591cSThomas Raoux Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 111894a591cSThomas Raoux Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 11298798073SChristopher Bate Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); 11398798073SChristopher Bate Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); 114894a591cSThomas Raoux 115894a591cSThomas Raoux auto makeConst = [&](int32_t index) -> Value { 116894a591cSThomas Raoux return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), 117894a591cSThomas Raoux rewriter.getI32IntegerAttr(index)); 118894a591cSThomas Raoux }; 119894a591cSThomas Raoux 120894a591cSThomas Raoux if (arrayType) { 121894a591cSThomas Raoux SmallVector<Value, 4> elements; 122894a591cSThomas Raoux 12398798073SChristopher Bate // The intrinsic returns 32-bit wide elements in a form which can be 12498798073SChristopher Bate // directly bitcasted and inserted into the result vector. 12598798073SChristopher Bate if (arrayType.getElementType() == f16x2Ty || 12698798073SChristopher Bate arrayType.getElementType() == f32x1Ty) { 127894a591cSThomas Raoux for (unsigned i = 0; i < structType.getBody().size(); i++) { 1285c5af910SJeff Niu Value el = 1295c5af910SJeff Niu rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i); 13098798073SChristopher Bate el = rewriter.createOrFold<LLVM::BitcastOp>( 13198798073SChristopher Bate loc, arrayType.getElementType(), el); 13298798073SChristopher Bate elements.push_back(el); 133894a591cSThomas Raoux } 134894a591cSThomas Raoux } 135894a591cSThomas Raoux 13698798073SChristopher Bate // The intrinsic returns i32, f64, and f32 values as individual scalars, 13798798073SChristopher Bate // even when the result is notionally a 64-bit wide element (e.g. f32x2). We 13898798073SChristopher Bate // need to extract them from the struct and pack them into the 64-bit wide 13998798073SChristopher Bate // rows of the vector result. 140894a591cSThomas Raoux if (arrayType.getElementType() == i32x2Ty || 14198798073SChristopher Bate arrayType.getElementType() == f64x2Ty || 14298798073SChristopher Bate arrayType.getElementType() == f32x2Ty) { 14398798073SChristopher Bate 14498798073SChristopher Bate for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { 145894a591cSThomas Raoux Value vec = 146894a591cSThomas Raoux rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType()); 1475c5af910SJeff Niu Value x1 = 1485c5af910SJeff Niu rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2); 1495c5af910SJeff Niu Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, 1505c5af910SJeff Niu i * 2 + 1); 151894a591cSThomas Raoux vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 152894a591cSThomas Raoux x1, makeConst(0)); 153894a591cSThomas Raoux vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 154894a591cSThomas Raoux x2, makeConst(1)); 155894a591cSThomas Raoux elements.push_back(vec); 156894a591cSThomas Raoux } 15798798073SChristopher Bate } 158894a591cSThomas Raoux 159894a591cSThomas Raoux // Create the final vectorized result. 160894a591cSThomas Raoux Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType); 161894a591cSThomas Raoux for (const auto &el : llvm::enumerate(elements)) { 1625c5af910SJeff Niu result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(), 1635c5af910SJeff Niu el.index()); 164894a591cSThomas Raoux } 165894a591cSThomas Raoux return result; 166894a591cSThomas Raoux } 167894a591cSThomas Raoux 168894a591cSThomas Raoux return intrinsicResult; 169894a591cSThomas Raoux } 170894a591cSThomas Raoux 171894a591cSThomas Raoux /// The `gpu.mma.sync` converter below expects matrix fragment operands to be 172894a591cSThomas Raoux /// given as 2D `vectors` where the rows are 32b or 64b wide. The 173894a591cSThomas Raoux /// `nvvm.mma.sync` op expects these argments to be a given in a long list of 174894a591cSThomas Raoux /// scalars of certain types. This function helps unpack the `vector` arguments 175894a591cSThomas Raoux /// and cast them to the types expected by `nvvm.mma.sync`. 176ee49cda7SGuray Ozen static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, 177ee49cda7SGuray Ozen Value operand, 17898798073SChristopher Bate NVVM::MMATypes operandPtxType) { 179894a591cSThomas Raoux SmallVector<Value> result; 180ee49cda7SGuray Ozen Type i32Ty = b.getI32Type(); 181ee49cda7SGuray Ozen Type f64Ty = b.getF64Type(); 182ee49cda7SGuray Ozen Type f32Ty = b.getF32Type(); 183ee49cda7SGuray Ozen Type i64Ty = b.getI64Type(); 184ee49cda7SGuray Ozen Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4); 185ee49cda7SGuray Ozen Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8); 18698798073SChristopher Bate Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); 1875550c821STres Popp auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); 188894a591cSThomas Raoux 189894a591cSThomas Raoux for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { 190ee49cda7SGuray Ozen Value toUse = b.create<LLVM::ExtractValueOp>(operand, i); 191894a591cSThomas Raoux 192894a591cSThomas Raoux // For 4xi8 vectors, the intrinsic expects these to be provided as i32 193894a591cSThomas Raoux // scalar types. 19498798073SChristopher Bate if (arrayTy.getElementType() == i8x4Ty || 195334f63e7SChristopher Bate arrayTy.getElementType() == i4x8Ty || 19698798073SChristopher Bate (arrayTy.getElementType() == f32x1Ty && 19798798073SChristopher Bate operandPtxType == NVVM::MMATypes::tf32)) { 198ee49cda7SGuray Ozen result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse)); 199894a591cSThomas Raoux continue; 200894a591cSThomas Raoux } 201894a591cSThomas Raoux 20298798073SChristopher Bate // For some element types (i32, f32, f64), we need to unpack the inner 203894a591cSThomas Raoux // vector/array type as well because the intrinsic expects individual 204894a591cSThomas Raoux // scalars to be provided. 2055550c821STres Popp VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType()); 206894a591cSThomas Raoux if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || 20798798073SChristopher Bate innerArrayTy.getElementType() == f64Ty || 20898798073SChristopher Bate innerArrayTy.getElementType() == f32Ty)) { 209894a591cSThomas Raoux for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); 210894a591cSThomas Raoux idx < innerSize; idx++) { 211ee49cda7SGuray Ozen result.push_back(b.create<LLVM::ExtractElementOp>( 212ee49cda7SGuray Ozen toUse, 213ee49cda7SGuray Ozen b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)))); 214894a591cSThomas Raoux } 215894a591cSThomas Raoux continue; 216894a591cSThomas Raoux } 217894a591cSThomas Raoux result.push_back(toUse); 218894a591cSThomas Raoux } 219894a591cSThomas Raoux return result; 220894a591cSThomas Raoux } 221894a591cSThomas Raoux 22299475f5bSNicolas Vasilache /// Returns whether mbarrier object has shared memory address space. 22317649a77SGuray Ozen static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) { 22499475f5bSNicolas Vasilache return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace( 22599475f5bSNicolas Vasilache barrierType.getMemorySpace())); 22699475f5bSNicolas Vasilache } 22799475f5bSNicolas Vasilache 22899475f5bSNicolas Vasilache /// Returns the memory space attribute of the mbarrier object. 22999475f5bSNicolas Vasilache Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context, 23017649a77SGuray Ozen nvgpu::MBarrierGroupType barrierType) { 23199475f5bSNicolas Vasilache Attribute memorySpace = {}; 23299475f5bSNicolas Vasilache if (isMbarrierShared(barrierType)) { 23399475f5bSNicolas Vasilache memorySpace = 23499475f5bSNicolas Vasilache IntegerAttr::get(IntegerType::get(context, 64), 23599475f5bSNicolas Vasilache nvgpu::NVGPUDialect::kSharedMemoryAddressSpace); 23699475f5bSNicolas Vasilache } 23799475f5bSNicolas Vasilache return memorySpace; 23899475f5bSNicolas Vasilache } 23999475f5bSNicolas Vasilache 24099475f5bSNicolas Vasilache /// Returns memref type of the mbarrier object. The type is defined in the 24117649a77SGuray Ozen /// MBarrierGroupType. 24299475f5bSNicolas Vasilache MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context, 24317649a77SGuray Ozen nvgpu::MBarrierGroupType barrierType) { 24499475f5bSNicolas Vasilache Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType); 24599475f5bSNicolas Vasilache MemRefLayoutAttrInterface layout; 24617649a77SGuray Ozen return MemRefType::get({barrierType.getNumBarriers()}, 24717649a77SGuray Ozen IntegerType::get(context, 64), layout, memorySpace); 24899475f5bSNicolas Vasilache } 24999475f5bSNicolas Vasilache 250894a591cSThomas Raoux namespace { 251894a591cSThomas Raoux 252894a591cSThomas Raoux struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { 253894a591cSThomas Raoux using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern; 254894a591cSThomas Raoux 255894a591cSThomas Raoux LogicalResult 256894a591cSThomas Raoux matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, 257894a591cSThomas Raoux ConversionPatternRewriter &rewriter) const override { 258894a591cSThomas Raoux MLIRContext *ctx = getContext(); 259ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op.getLoc(), rewriter); 260894a591cSThomas Raoux 261894a591cSThomas Raoux // The result type of ldmatrix will always be a struct of 32bit integer 262894a591cSThomas Raoux // registers if more than one 32bit value is returned. Otherwise, the result 263894a591cSThomas Raoux // is a single i32. The result type of the GPU operation is always a vector 264894a591cSThomas Raoux // of shape (NumRegisters, VectorRegister) where VectorRegister is the 265894a591cSThomas Raoux // vector type of the result and always 32 bits long. We bitcast the result 266894a591cSThomas Raoux // of the NVVM::LdMatrix to this vector type. 2675550c821STres Popp auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]); 268894a591cSThomas Raoux if (!vectorResultType) { 269894a591cSThomas Raoux return failure(); 270894a591cSThomas Raoux } 271894a591cSThomas Raoux Type innerVectorType = LLVM::getFixedVectorType( 272894a591cSThomas Raoux vectorResultType.getElementType(), vectorResultType.getDimSize(1)); 273894a591cSThomas Raoux 274894a591cSThomas Raoux int64_t num32BitRegs = vectorResultType.getDimSize(0); 275894a591cSThomas Raoux 276894a591cSThomas Raoux Type ldMatrixResultType; 277894a591cSThomas Raoux if (num32BitRegs > 1) { 278894a591cSThomas Raoux ldMatrixResultType = LLVM::LLVMStructType::getLiteral( 279894a591cSThomas Raoux ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type())); 280894a591cSThomas Raoux } else { 281894a591cSThomas Raoux ldMatrixResultType = rewriter.getI32Type(); 282894a591cSThomas Raoux } 283894a591cSThomas Raoux 2845550c821STres Popp auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType()); 2858df54a6aSJacques Pienaar Value srcPtr = 286ee49cda7SGuray Ozen getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), 2878df54a6aSJacques Pienaar adaptor.getIndices(), rewriter); 288ee49cda7SGuray Ozen Value ldMatrixResult = b.create<NVVM::LdMatrixOp>( 289ee49cda7SGuray Ozen ldMatrixResultType, srcPtr, 2908df54a6aSJacques Pienaar /*num=*/op.getNumTiles(), 2918df54a6aSJacques Pienaar /*layout=*/op.getTranspose() ? NVVM::MMALayout::col 292894a591cSThomas Raoux : NVVM::MMALayout::row); 293894a591cSThomas Raoux 294894a591cSThomas Raoux // The ldmatrix operation returns either a single i32 value or a struct of 295894a591cSThomas Raoux // i32 values. Here we unpack those values and cast them back to their 296894a591cSThomas Raoux // actual vector type (still of width 32b) and repack them into a result 297894a591cSThomas Raoux // struct. 298894a591cSThomas Raoux Type finalResultType = typeConverter->convertType(vectorResultType); 299ee49cda7SGuray Ozen Value result = b.create<LLVM::UndefOp>(finalResultType); 300894a591cSThomas Raoux for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { 3015c5af910SJeff Niu Value i32Register = 302ee49cda7SGuray Ozen num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i) 303894a591cSThomas Raoux : ldMatrixResult; 304ee49cda7SGuray Ozen Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register); 305ee49cda7SGuray Ozen result = b.create<LLVM::InsertValueOp>(result, casted, i); 306894a591cSThomas Raoux } 307894a591cSThomas Raoux 308894a591cSThomas Raoux rewriter.replaceOp(op, result); 309894a591cSThomas Raoux return success(); 310894a591cSThomas Raoux } 311894a591cSThomas Raoux }; 312894a591cSThomas Raoux 313708185f0SChristopher Bate /// Convert the given type into the corresponding PTX type (NVVM::MMATypes 314708185f0SChristopher Bate /// enum). 315708185f0SChristopher Bate static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) { 316708185f0SChristopher Bate Type elType = getElementTypeOrSelf(t); 317708185f0SChristopher Bate if (elType.isInteger(8)) 318708185f0SChristopher Bate return NVVM::MMATypes::s8; 319708185f0SChristopher Bate if (elType.isInteger(4)) 320708185f0SChristopher Bate return NVVM::MMATypes::s4; 321708185f0SChristopher Bate if (elType.isF16()) 322708185f0SChristopher Bate return NVVM::MMATypes::f16; 323708185f0SChristopher Bate if (elType.isF64()) 324708185f0SChristopher Bate return NVVM::MMATypes::f64; 325708185f0SChristopher Bate if (elType.isF32()) 326708185f0SChristopher Bate return NVVM::MMATypes::tf32; 327708185f0SChristopher Bate return failure(); 328708185f0SChristopher Bate } 329708185f0SChristopher Bate 330894a591cSThomas Raoux struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { 331894a591cSThomas Raoux using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern; 332894a591cSThomas Raoux 333894a591cSThomas Raoux LogicalResult 334894a591cSThomas Raoux matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, 335894a591cSThomas Raoux ConversionPatternRewriter &rewriter) const override { 336ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op.getLoc(), rewriter); 337894a591cSThomas Raoux // Get the shapes of the MMAMatrix type being used. The shapes will 338894a591cSThomas Raoux // choose which intrinsic this op will be lowered to. 339708185f0SChristopher Bate VectorType aType = op.getMatrixA().getType(); 340708185f0SChristopher Bate VectorType bType = op.getMatrixA().getType(); 341708185f0SChristopher Bate VectorType cType = op.getMatrixC().getType(); 342894a591cSThomas Raoux 343708185f0SChristopher Bate std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray(); 34414d79afeSManish Gupta 34514d79afeSManish Gupta // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). 34614d79afeSManish Gupta bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); 34714d79afeSManish Gupta if (aType.getElementType().isF32() && !tf32Enabled) 34814d79afeSManish Gupta return failure(); 34998798073SChristopher Bate 350708185f0SChristopher Bate FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); 351708185f0SChristopher Bate if (failed(ptxTypeA)) 352708185f0SChristopher Bate return op->emitOpError("failed to deduce operand PTX types"); 353708185f0SChristopher Bate FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); 354708185f0SChristopher Bate if (failed(ptxTypeB)) 355708185f0SChristopher Bate return op->emitOpError("failed to deduce operand PTX types"); 35622426110SRamkumar Ramachandra std::optional<NVVM::MMATypes> ptxTypeC = 35722426110SRamkumar Ramachandra NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), 35822426110SRamkumar Ramachandra /*isAccumulator=*/true); 359708185f0SChristopher Bate if (!ptxTypeC) 360708185f0SChristopher Bate return op->emitError( 361708185f0SChristopher Bate "could not infer the PTX type for the accumulator/result"); 362708185f0SChristopher Bate 363708185f0SChristopher Bate // TODO: add an attribute to the op to customize this behavior. 36422426110SRamkumar Ramachandra std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); 3655550c821STres Popp if (isa<IntegerType>(aType.getElementType())) 366894a591cSThomas Raoux overflow = NVVM::MMAIntOverflow::satfinite; 367894a591cSThomas Raoux 36898798073SChristopher Bate SmallVector<Value> matA = 369ee49cda7SGuray Ozen unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); 37098798073SChristopher Bate SmallVector<Value> matB = 371ee49cda7SGuray Ozen unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); 37298798073SChristopher Bate SmallVector<Value> matC = 373ee49cda7SGuray Ozen unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); 37498798073SChristopher Bate 375894a591cSThomas Raoux Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); 376894a591cSThomas Raoux Type intrinsicResTy = inferIntrinsicResultType( 377894a591cSThomas Raoux typeConverter->convertType(op->getResultTypes()[0])); 378ee49cda7SGuray Ozen Value intrinsicResult = b.create<NVVM::MmaOp>( 379ee49cda7SGuray Ozen intrinsicResTy, matA, matB, matC, 380894a591cSThomas Raoux /*shape=*/gemmShape, 3811a36588eSKazu Hirata /*b1Op=*/std::nullopt, 382894a591cSThomas Raoux /*intOverflow=*/overflow, 383894a591cSThomas Raoux /*multiplicandPtxTypes=*/ 384708185f0SChristopher Bate std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, 385894a591cSThomas Raoux /*multiplicandLayouts=*/ 386894a591cSThomas Raoux std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, 387894a591cSThomas Raoux NVVM::MMALayout::col}); 388894a591cSThomas Raoux rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, 389894a591cSThomas Raoux desiredRetTy, intrinsicResult, 390894a591cSThomas Raoux rewriter)); 391894a591cSThomas Raoux return success(); 392894a591cSThomas Raoux } 393894a591cSThomas Raoux }; 394894a591cSThomas Raoux 395894a591cSThomas Raoux struct ConvertNVGPUToNVVMPass 39653689fdfSMarkus Böck : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> { 39753689fdfSMarkus Böck using Base::Base; 398894a591cSThomas Raoux 399affcfccdSGuray Ozen void getDependentDialects(DialectRegistry ®istry) const override { 400d20fbc90SGuray Ozen registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect, 401d20fbc90SGuray Ozen arith::ArithDialect>(); 402affcfccdSGuray Ozen } 403affcfccdSGuray Ozen 404894a591cSThomas Raoux void runOnOperation() override { 40553689fdfSMarkus Böck LowerToLLVMOptions options(&getContext()); 406894a591cSThomas Raoux RewritePatternSet patterns(&getContext()); 40753689fdfSMarkus Böck LLVMTypeConverter converter(&getContext(), options); 408affcfccdSGuray Ozen IRRewriter rewriter(&getContext()); 4093a03da37SGuray Ozen populateGpuMemorySpaceAttributeConversions( 4103a03da37SGuray Ozen converter, [](gpu::AddressSpace space) -> unsigned { 4113a03da37SGuray Ozen switch (space) { 4123a03da37SGuray Ozen case gpu::AddressSpace::Global: 4133a03da37SGuray Ozen return static_cast<unsigned>( 4143a03da37SGuray Ozen NVVM::NVVMMemorySpace::kGlobalMemorySpace); 4153a03da37SGuray Ozen case gpu::AddressSpace::Workgroup: 4163a03da37SGuray Ozen return static_cast<unsigned>( 4173a03da37SGuray Ozen NVVM::NVVMMemorySpace::kSharedMemorySpace); 4183a03da37SGuray Ozen case gpu::AddressSpace::Private: 4193a03da37SGuray Ozen return 0; 4203a03da37SGuray Ozen } 4213a03da37SGuray Ozen llvm_unreachable("unknown address space enum value"); 4223a03da37SGuray Ozen return 0; 4233a03da37SGuray Ozen }); 424affcfccdSGuray Ozen /// device-side async tokens cannot be materialized in nvvm. We just 425affcfccdSGuray Ozen /// convert them to a dummy i32 type in order to easily drop them during 426affcfccdSGuray Ozen /// conversion. 42715bcc36eSThomas Raoux converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { 42815bcc36eSThomas Raoux return converter.convertType(IntegerType::get(type.getContext(), 32)); 42915bcc36eSThomas Raoux }); 43023882226SGuray Ozen converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type { 43152db7e27SGuray Ozen Type elemType = type.getFragmented().getElementType(); 43252db7e27SGuray Ozen int64_t sizeM = type.getFragmented().getDimSize(0); 43352db7e27SGuray Ozen int64_t sizeN = type.getFragmented().getDimSize(1); 43452db7e27SGuray Ozen 43552db7e27SGuray Ozen unsigned numMembers; 43652db7e27SGuray Ozen if (elemType.isF32() || elemType.isInteger(32)) 43752db7e27SGuray Ozen numMembers = sizeN / 2; 43852db7e27SGuray Ozen else if (elemType.isF16()) 43952db7e27SGuray Ozen numMembers = sizeN / 4; 44052db7e27SGuray Ozen else 44152db7e27SGuray Ozen llvm_unreachable("unsupported type for warpgroup accumulator"); 44252db7e27SGuray Ozen 44352db7e27SGuray Ozen SmallVector<Type> innerStructBody; 44452db7e27SGuray Ozen for (unsigned i = 0; i < numMembers; i++) 44552db7e27SGuray Ozen innerStructBody.push_back(elemType); 44652db7e27SGuray Ozen auto innerStructType = 44752db7e27SGuray Ozen LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody); 44852db7e27SGuray Ozen 44923882226SGuray Ozen SmallVector<Type> structBody; 45052db7e27SGuray Ozen for (int i = 0; i < sizeM; i += kWgmmaSizeM) 45152db7e27SGuray Ozen structBody.push_back(innerStructType); 45252db7e27SGuray Ozen 45323882226SGuray Ozen auto convertedType = 45423882226SGuray Ozen LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); 45523882226SGuray Ozen return converter.convertType(convertedType); 45623882226SGuray Ozen }); 457affcfccdSGuray Ozen converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { 458affcfccdSGuray Ozen return converter.convertType(IntegerType::get(type.getContext(), 64)); 459affcfccdSGuray Ozen }); 46050ab427aSGuray Ozen converter.addConversion( 46150ab427aSGuray Ozen [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { 46250ab427aSGuray Ozen return converter.convertType(IntegerType::get(type.getContext(), 64)); 46350ab427aSGuray Ozen }); 46417649a77SGuray Ozen converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { 46599475f5bSNicolas Vasilache return converter.convertType( 46699475f5bSNicolas Vasilache nvgpu::getMBarrierMemrefType(rewriter.getContext(), type)); 467affcfccdSGuray Ozen }); 46870c2e061SGuray Ozen converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type { 4692f17c9f6SChristian Ulmann return LLVM::LLVMPointerType::get(type.getContext()); 47070c2e061SGuray Ozen }); 471894a591cSThomas Raoux populateNVGPUToNVVMConversionPatterns(converter, patterns); 472894a591cSThomas Raoux LLVMConversionTarget target(getContext()); 473894a591cSThomas Raoux target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 474d20fbc90SGuray Ozen target.addLegalDialect<::mlir::arith::ArithDialect>(); 475affcfccdSGuray Ozen target.addLegalDialect<::mlir::memref::MemRefDialect>(); 476894a591cSThomas Raoux target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); 47723882226SGuray Ozen mlir::scf::populateSCFStructuralTypeConversionsAndLegality( 47823882226SGuray Ozen converter, patterns, target); 479894a591cSThomas Raoux if (failed(applyPartialConversion(getOperation(), target, 480894a591cSThomas Raoux std::move(patterns)))) 481894a591cSThomas Raoux signalPassFailure(); 482894a591cSThomas Raoux } 483894a591cSThomas Raoux }; 484894a591cSThomas Raoux 485708185f0SChristopher Bate /// Returns the constraints for the sparse MMA inline assembly instruction. 486708185f0SChristopher Bate static std::string buildMmaSparseAsmConstraintString(unsigned matASize, 487708185f0SChristopher Bate unsigned matBSize, 488708185f0SChristopher Bate unsigned matCSize) { 489708185f0SChristopher Bate std::string str; 490708185f0SChristopher Bate llvm::raw_string_ostream ss(str); 491708185f0SChristopher Bate for (unsigned i = 0; i < matCSize; i++) 492708185f0SChristopher Bate ss << "=r,"; 493708185f0SChristopher Bate for (unsigned i = 0; i < matASize + matBSize + matCSize; i++) 494708185f0SChristopher Bate ss << "r,"; 4954e4af133SAart Bik // The final operand is for the sparsity metadata. 4964e4af133SAart Bik // The sparsity selector appears as direct literal. 4974e4af133SAart Bik ss << "r"; 498708185f0SChristopher Bate return str; 499708185f0SChristopher Bate } 500708185f0SChristopher Bate 501708185f0SChristopher Bate /// Returns the string for the `mma.sp.sync` instruction that corresponds to 5024e4af133SAart Bik /// the given parameters. Note that this function doesn't do any validation, 503708185f0SChristopher Bate /// it's expected that the provided parameters correspond to a valid 504708185f0SChristopher Bate /// instruction. 5054e4af133SAart Bik static std::string buildMmaSparseAsmString( 5064e4af133SAart Bik const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize, 5074e4af133SAart Bik unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, 508708185f0SChristopher Bate NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, 5094e4af133SAart Bik std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) { 510708185f0SChristopher Bate auto ptxTypeStr = [](NVVM::MMATypes ptxType) { 511708185f0SChristopher Bate return NVVM::stringifyMMATypes(ptxType); 512708185f0SChristopher Bate }; 513708185f0SChristopher Bate 514708185f0SChristopher Bate std::string asmStr; 515708185f0SChristopher Bate llvm::raw_string_ostream ss(asmStr); 516708185f0SChristopher Bate ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k" 517708185f0SChristopher Bate << shape[2] << ".row.col."; 518708185f0SChristopher Bate 519708185f0SChristopher Bate if (overflow) 520708185f0SChristopher Bate ss << NVVM::stringifyMMAIntOverflow(*overflow) << "."; 521708185f0SChristopher Bate 522708185f0SChristopher Bate ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "." 523708185f0SChristopher Bate << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " "; 524708185f0SChristopher Bate unsigned asmArgIdx = 0; 525708185f0SChristopher Bate 526708185f0SChristopher Bate // The operand string is structured into sections `{matC elements...}, 527708185f0SChristopher Bate // {matA elements...}, {matB elements...}, {matC elements}`. 528708185f0SChristopher Bate for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) { 529708185f0SChristopher Bate ss << "{"; 530708185f0SChristopher Bate for (unsigned i = 0; i < arrSize; i++) 531708185f0SChristopher Bate ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : ""); 532708185f0SChristopher Bate ss << "},"; 533708185f0SChristopher Bate } 534b0bbc9b5Srkayaith ss << "$" << asmArgIdx++ << ","; 5354e4af133SAart Bik assert(metaDataSelector <= 1); 5364e4af133SAart Bik ss << "0x" << metaDataSelector << ";"; 537708185f0SChristopher Bate return asmStr; 538708185f0SChristopher Bate } 539708185f0SChristopher Bate 540708185f0SChristopher Bate /// Builds an inline assembly operation corresponding to the specified MMA 541708185f0SChristopher Bate /// sparse sync operation. 542708185f0SChristopher Bate static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm( 543ee49cda7SGuray Ozen ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, 544708185f0SChristopher Bate NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, 54522426110SRamkumar Ramachandra std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData, 546708185f0SChristopher Bate ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData, 547708185f0SChristopher Bate int64_t metadataSelector, const std::array<int64_t, 3> &shape, 548ee49cda7SGuray Ozen Type intrinsicResultType) { 549ee49cda7SGuray Ozen auto asmDialectAttr = 550ee49cda7SGuray Ozen LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT); 551708185f0SChristopher Bate 5524e4af133SAart Bik const unsigned matASize = unpackedAData.size(); 5534e4af133SAart Bik const unsigned matBSize = unpackedB.size(); 5544e4af133SAart Bik const unsigned matCSize = unpackedC.size(); 555708185f0SChristopher Bate 5564e4af133SAart Bik std::string asmStr = buildMmaSparseAsmString( 5574e4af133SAart Bik shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC, 5584e4af133SAart Bik ptxTypeD, overflow, metadataSelector); 5594e4af133SAart Bik std::string constraintStr = 5604e4af133SAart Bik buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize); 561708185f0SChristopher Bate 562708185f0SChristopher Bate SmallVector<Value> asmVals; 5634e4af133SAart Bik asmVals.reserve(matASize + matBSize + matCSize + 1); 564708185f0SChristopher Bate for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC}) 565708185f0SChristopher Bate llvm::append_range(asmVals, args); 566708185f0SChristopher Bate asmVals.push_back(indexData); 567708185f0SChristopher Bate 568ee49cda7SGuray Ozen return b.create<LLVM::InlineAsmOp>( 569708185f0SChristopher Bate /*resultTypes=*/intrinsicResultType, 570708185f0SChristopher Bate /*operands=*/asmVals, 571708185f0SChristopher Bate /*asm_string=*/asmStr, 572708185f0SChristopher Bate /*constraints=*/constraintStr, 573708185f0SChristopher Bate /*has_side_effects=*/true, 574708185f0SChristopher Bate /*is_align_stack=*/false, 575708185f0SChristopher Bate /*asm_dialect=*/asmDialectAttr, 576708185f0SChristopher Bate /*operand_attrs=*/ArrayAttr()); 577708185f0SChristopher Bate } 578708185f0SChristopher Bate 579708185f0SChristopher Bate /// Lowers `nvgpu.mma.sp.sync` to inline assembly. 580708185f0SChristopher Bate struct NVGPUMmaSparseSyncLowering 581708185f0SChristopher Bate : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> { 582708185f0SChristopher Bate using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern; 583708185f0SChristopher Bate 584708185f0SChristopher Bate LogicalResult 585708185f0SChristopher Bate matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor, 586708185f0SChristopher Bate ConversionPatternRewriter &rewriter) const override { 587ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op.getLoc(), rewriter); 588708185f0SChristopher Bate // Get the shapes of the MMAMatrix type being used. The shapes will 589708185f0SChristopher Bate // choose which intrinsic this op will be lowered to. 590708185f0SChristopher Bate VectorType aType = op.getMatrixA().getType(); 591708185f0SChristopher Bate VectorType bType = op.getMatrixB().getType(); 592708185f0SChristopher Bate VectorType cType = op.getMatrixC().getType(); 593708185f0SChristopher Bate 594708185f0SChristopher Bate FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); 595708185f0SChristopher Bate if (failed(ptxTypeA)) 596708185f0SChristopher Bate return op->emitOpError("failed to deduce operand PTX types"); 597708185f0SChristopher Bate FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); 598708185f0SChristopher Bate if (failed(ptxTypeB)) 599708185f0SChristopher Bate return op->emitOpError("failed to deduce operand PTX types"); 60022426110SRamkumar Ramachandra std::optional<NVVM::MMATypes> ptxTypeC = 60122426110SRamkumar Ramachandra NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), 60222426110SRamkumar Ramachandra /*isAccumulator=*/true); 603708185f0SChristopher Bate if (!ptxTypeC) 604708185f0SChristopher Bate return op->emitError( 605708185f0SChristopher Bate "could not infer the PTX type for the accumulator/result"); 606708185f0SChristopher Bate 607708185f0SChristopher Bate // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32). 608708185f0SChristopher Bate bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); 609708185f0SChristopher Bate if (aType.getElementType().isF32() && !tf32Enabled) 610708185f0SChristopher Bate return failure(); 611708185f0SChristopher Bate 612708185f0SChristopher Bate // TODO: add an attribute to the op to customize this behavior. 61322426110SRamkumar Ramachandra std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); 6145550c821STres Popp if (isa<IntegerType>(aType.getElementType())) 615708185f0SChristopher Bate overflow = NVVM::MMAIntOverflow::satfinite; 616708185f0SChristopher Bate 617708185f0SChristopher Bate SmallVector<Value> matA = 618ee49cda7SGuray Ozen unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); 619708185f0SChristopher Bate SmallVector<Value> matB = 620ee49cda7SGuray Ozen unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); 621708185f0SChristopher Bate SmallVector<Value> matC = 622ee49cda7SGuray Ozen unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); 623708185f0SChristopher Bate 624708185f0SChristopher Bate Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); 625708185f0SChristopher Bate Type intrinsicResTy = inferIntrinsicResultType( 626708185f0SChristopher Bate typeConverter->convertType(op->getResultTypes()[0])); 627708185f0SChristopher Bate 628708185f0SChristopher Bate // Bitcast the sparse metadata from vector<2xf16> to an i32. 629708185f0SChristopher Bate Value sparseMetadata = adaptor.getSparseMetadata(); 630708185f0SChristopher Bate if (sparseMetadata.getType() != 631708185f0SChristopher Bate LLVM::getFixedVectorType(rewriter.getI16Type(), 2)) 632708185f0SChristopher Bate return op->emitOpError() << "Expected metadata type to be LLVM " 633708185f0SChristopher Bate "VectorType of 2 i16 elements"; 634ee49cda7SGuray Ozen sparseMetadata = 635ee49cda7SGuray Ozen b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata); 636708185f0SChristopher Bate 637708185f0SChristopher Bate FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm( 638ee49cda7SGuray Ozen b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, 639708185f0SChristopher Bate matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(), 640ee49cda7SGuray Ozen intrinsicResTy); 641708185f0SChristopher Bate if (failed(intrinsicResult)) 642708185f0SChristopher Bate return failure(); 643708185f0SChristopher Bate 644708185f0SChristopher Bate assert((*intrinsicResult).getNumResults() == 1 && 645708185f0SChristopher Bate "expected inline asm op returns a single LLVM struct type"); 646708185f0SChristopher Bate rewriter.replaceOp( 647708185f0SChristopher Bate op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, 648708185f0SChristopher Bate (*intrinsicResult)->getResult(0), rewriter)); 649708185f0SChristopher Bate return success(); 650708185f0SChristopher Bate } 651708185f0SChristopher Bate }; 652708185f0SChristopher Bate 65315bcc36eSThomas Raoux struct NVGPUAsyncCopyLowering 65415bcc36eSThomas Raoux : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> { 65515bcc36eSThomas Raoux using ConvertOpToLLVMPattern< 65615bcc36eSThomas Raoux nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; 65715bcc36eSThomas Raoux 65815bcc36eSThomas Raoux LogicalResult 65915bcc36eSThomas Raoux matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, 66015bcc36eSThomas Raoux ConversionPatternRewriter &rewriter) const override { 661ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op.getLoc(), rewriter); 662ee49cda7SGuray Ozen Location loc = op.getLoc(); 6635550c821STres Popp auto dstMemrefType = cast<MemRefType>(op.getDst().getType()); 664ee49cda7SGuray Ozen Value dstPtr = 665ee49cda7SGuray Ozen getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(), 6668df54a6aSJacques Pienaar adaptor.getDstIndices(), rewriter); 667499abb24SKrzysztof Drewniak FailureOr<unsigned> dstAddressSpace = 668499abb24SKrzysztof Drewniak getTypeConverter()->getMemRefAddressSpace(dstMemrefType); 669499abb24SKrzysztof Drewniak if (failed(dstAddressSpace)) 670499abb24SKrzysztof Drewniak return rewriter.notifyMatchFailure( 671499abb24SKrzysztof Drewniak loc, "destination memref address space not convertible to integer"); 67215bcc36eSThomas Raoux 6735550c821STres Popp auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); 674499abb24SKrzysztof Drewniak FailureOr<unsigned> srcAddressSpace = 675499abb24SKrzysztof Drewniak getTypeConverter()->getMemRefAddressSpace(srcMemrefType); 676499abb24SKrzysztof Drewniak if (failed(srcAddressSpace)) 677499abb24SKrzysztof Drewniak return rewriter.notifyMatchFailure( 678499abb24SKrzysztof Drewniak loc, "source memref address space not convertible to integer"); 67915bcc36eSThomas Raoux 6808df54a6aSJacques Pienaar Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(), 6818df54a6aSJacques Pienaar adaptor.getSrcIndices(), rewriter); 68215bcc36eSThomas Raoux // Intrinsics takes a global pointer so we need an address space cast. 6832f17c9f6SChristian Ulmann auto srcPointerGlobalType = LLVM::LLVMPointerType::get( 6842f17c9f6SChristian Ulmann op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); 685ee49cda7SGuray Ozen scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr); 686fbf69f95SManish Gupta int64_t dstElements = adaptor.getDstElements().getZExtValue(); 68715bcc36eSThomas Raoux int64_t sizeInBytes = 688fbf69f95SManish Gupta (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; 689fbf69f95SManish Gupta // When the optional SrcElements argument is *not* present, the regular 690fbf69f95SManish Gupta // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global 6912c573967SGuray Ozen // memory) to fill DstElements number of elements in the destination 6922c573967SGuray Ozen // (shared memory). 6932c573967SGuray Ozen Value srcBytes = adaptor.getSrcElements(); 6942c573967SGuray Ozen if (srcBytes) { 6952c573967SGuray Ozen // When the optional SrcElements argument is present, the source (global 6962c573967SGuray Ozen // memory) of CpAsyncOp is read only for SrcElements number of elements. 6972c573967SGuray Ozen // The rest of the DstElements in the destination (shared memory) are 6982c573967SGuray Ozen // filled with zeros. 699ee49cda7SGuray Ozen Value c3I32 = 700ee49cda7SGuray Ozen b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3)); 701ee49cda7SGuray Ozen Value bitwidth = b.create<LLVM::ConstantOp>( 702ee49cda7SGuray Ozen b.getI32Type(), 703ee49cda7SGuray Ozen b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); 704ee49cda7SGuray Ozen Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes); 705ee49cda7SGuray Ozen srcBytes = b.create<LLVM::LShrOp>( 706ee49cda7SGuray Ozen b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32); 7072c573967SGuray Ozen } 7082c573967SGuray Ozen // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than 7092c573967SGuray Ozen // 16 dst bytes. 7102c573967SGuray Ozen NVVM::LoadCacheModifierKind cacheModifier = 7112c573967SGuray Ozen (op.getBypassL1().value_or(false) && sizeInBytes == 16) 7122c573967SGuray Ozen ? NVVM::LoadCacheModifierKind::CG 7132c573967SGuray Ozen : NVVM::LoadCacheModifierKind::CA; 7142c573967SGuray Ozen 715ee49cda7SGuray Ozen b.create<NVVM::CpAsyncOp>( 716ee49cda7SGuray Ozen dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), 7172c573967SGuray Ozen NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), 7182c573967SGuray Ozen srcBytes); 71915bcc36eSThomas Raoux 72015bcc36eSThomas Raoux // Drop the result token. 721ee49cda7SGuray Ozen Value zero = b.create<LLVM::ConstantOp>( 722ee49cda7SGuray Ozen IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); 72315bcc36eSThomas Raoux rewriter.replaceOp(op, zero); 72415bcc36eSThomas Raoux return success(); 72515bcc36eSThomas Raoux } 72615bcc36eSThomas Raoux }; 72715bcc36eSThomas Raoux 72815bcc36eSThomas Raoux struct NVGPUAsyncCreateGroupLowering 72915bcc36eSThomas Raoux : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> { 73015bcc36eSThomas Raoux using ConvertOpToLLVMPattern< 73115bcc36eSThomas Raoux nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; 73215bcc36eSThomas Raoux 73315bcc36eSThomas Raoux LogicalResult 73415bcc36eSThomas Raoux matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, 73515bcc36eSThomas Raoux ConversionPatternRewriter &rewriter) const override { 73615bcc36eSThomas Raoux rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); 73715bcc36eSThomas Raoux // Drop the result token. 73815bcc36eSThomas Raoux Value zero = rewriter.create<LLVM::ConstantOp>( 73915bcc36eSThomas Raoux op->getLoc(), IntegerType::get(op.getContext(), 32), 74015bcc36eSThomas Raoux rewriter.getI32IntegerAttr(0)); 74115bcc36eSThomas Raoux rewriter.replaceOp(op, zero); 74215bcc36eSThomas Raoux return success(); 74315bcc36eSThomas Raoux } 74415bcc36eSThomas Raoux }; 74515bcc36eSThomas Raoux 74615bcc36eSThomas Raoux struct NVGPUAsyncWaitLowering 74715bcc36eSThomas Raoux : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> { 74815bcc36eSThomas Raoux using ConvertOpToLLVMPattern< 74915bcc36eSThomas Raoux nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; 75015bcc36eSThomas Raoux 75115bcc36eSThomas Raoux LogicalResult 75215bcc36eSThomas Raoux matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, 75315bcc36eSThomas Raoux ConversionPatternRewriter &rewriter) const override { 75415bcc36eSThomas Raoux // If numGroup is not present pick 0 as a conservative correct value. 7552789c4f5SKazu Hirata int32_t numGroups = adaptor.getNumGroups().value_or(0); 75615bcc36eSThomas Raoux rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); 75715bcc36eSThomas Raoux rewriter.eraseOp(op); 75815bcc36eSThomas Raoux return success(); 75915bcc36eSThomas Raoux } 76015bcc36eSThomas Raoux }; 76115bcc36eSThomas Raoux 762affcfccdSGuray Ozen /// Creates mbarrier object in shared memory 763affcfccdSGuray Ozen struct NVGPUMBarrierCreateLowering 764affcfccdSGuray Ozen : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> { 765affcfccdSGuray Ozen using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern; 766affcfccdSGuray Ozen 767affcfccdSGuray Ozen template <typename moduleT> 768affcfccdSGuray Ozen memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter, 769affcfccdSGuray Ozen Operation *funcOp, moduleT moduleOp, 770affcfccdSGuray Ozen MemRefType barrierType) const { 771affcfccdSGuray Ozen SymbolTable symbolTable(moduleOp); 772affcfccdSGuray Ozen OpBuilder::InsertionGuard guard(rewriter); 773affcfccdSGuray Ozen rewriter.setInsertionPoint(&moduleOp.front()); 774affcfccdSGuray Ozen auto global = rewriter.create<memref::GlobalOp>( 775affcfccdSGuray Ozen funcOp->getLoc(), "__mbarrier", 776affcfccdSGuray Ozen /*sym_visibility=*/rewriter.getStringAttr("private"), 777affcfccdSGuray Ozen /*type=*/barrierType, 778affcfccdSGuray Ozen /*initial_value=*/ElementsAttr(), 779affcfccdSGuray Ozen /*constant=*/false, 780affcfccdSGuray Ozen /*alignment=*/rewriter.getI64IntegerAttr(8)); 781affcfccdSGuray Ozen symbolTable.insert(global); 782affcfccdSGuray Ozen return global; 783affcfccdSGuray Ozen } 784affcfccdSGuray Ozen 785affcfccdSGuray Ozen LogicalResult 786affcfccdSGuray Ozen matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor, 787affcfccdSGuray Ozen ConversionPatternRewriter &rewriter) const override { 788affcfccdSGuray Ozen Operation *funcOp = op->getParentOp(); 78999475f5bSNicolas Vasilache MemRefType barrierType = nvgpu::getMBarrierMemrefType( 79017649a77SGuray Ozen rewriter.getContext(), op.getBarriers().getType()); 791affcfccdSGuray Ozen 792affcfccdSGuray Ozen memref::GlobalOp global; 7939dad32cbSGuray Ozen if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>()) 794affcfccdSGuray Ozen global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); 7959dad32cbSGuray Ozen else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>()) 796affcfccdSGuray Ozen global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); 797affcfccdSGuray Ozen 798affcfccdSGuray Ozen rewriter.setInsertionPoint(op); 799affcfccdSGuray Ozen rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType, 800affcfccdSGuray Ozen global.getName()); 801affcfccdSGuray Ozen return success(); 802affcfccdSGuray Ozen } 803affcfccdSGuray Ozen }; 804affcfccdSGuray Ozen 80517649a77SGuray Ozen /// Base class for lowering mbarrier operations to nvvm intrinsics. 80617649a77SGuray Ozen template <typename SourceOp> 80717649a77SGuray Ozen struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> { 80817649a77SGuray Ozen public: 80917649a77SGuray Ozen using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 81017649a77SGuray Ozen /// Returns the base pointer of the mbarrier object. 811ee49cda7SGuray Ozen Value getMbarrierPtr(ImplicitLocOpBuilder &b, 812ee49cda7SGuray Ozen nvgpu::MBarrierGroupType mbarType, Value memrefDesc, 813ee49cda7SGuray Ozen Value mbarId, 81417649a77SGuray Ozen ConversionPatternRewriter &rewriter) const { 81517649a77SGuray Ozen MemRefType mbarrierMemrefType = 81617649a77SGuray Ozen nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType); 81717649a77SGuray Ozen return ConvertToLLVMPattern::getStridedElementPtr( 818ee49cda7SGuray Ozen b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter); 81917649a77SGuray Ozen } 82017649a77SGuray Ozen }; 82117649a77SGuray Ozen 822affcfccdSGuray Ozen /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init` 823affcfccdSGuray Ozen struct NVGPUMBarrierInitLowering 82417649a77SGuray Ozen : public MBarrierBasePattern<nvgpu::MBarrierInitOp> { 82517649a77SGuray Ozen using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern; 826affcfccdSGuray Ozen 827affcfccdSGuray Ozen LogicalResult 828affcfccdSGuray Ozen matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor, 829affcfccdSGuray Ozen ConversionPatternRewriter &rewriter) const override { 830ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 83117649a77SGuray Ozen nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType(); 832affcfccdSGuray Ozen rewriter.setInsertionPoint(op); 833ee49cda7SGuray Ozen Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), 83417649a77SGuray Ozen adaptor.getMbarId(), rewriter); 835ee49cda7SGuray Ozen Value count = truncToI32(b, adaptor.getCount()); 83617649a77SGuray Ozen if (isMbarrierShared(mbarrierType)) { 837192d3320SGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( 838192d3320SGuray Ozen op, barrier, count, adaptor.getPredicate()); 839affcfccdSGuray Ozen } else { 84063389326SGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, 841192d3320SGuray Ozen adaptor.getPredicate()); 842affcfccdSGuray Ozen } 843affcfccdSGuray Ozen return success(); 844affcfccdSGuray Ozen } 845affcfccdSGuray Ozen }; 846affcfccdSGuray Ozen 847affcfccdSGuray Ozen /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive` 848affcfccdSGuray Ozen struct NVGPUMBarrierArriveLowering 84917649a77SGuray Ozen : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> { 85017649a77SGuray Ozen using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern; 851affcfccdSGuray Ozen LogicalResult 852affcfccdSGuray Ozen matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor, 853affcfccdSGuray Ozen ConversionPatternRewriter &rewriter) const override { 854ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 85517649a77SGuray Ozen Value barrier = 856ee49cda7SGuray Ozen getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 85717649a77SGuray Ozen adaptor.getMbarId(), rewriter); 858affcfccdSGuray Ozen Type tokenType = getTypeConverter()->convertType( 859affcfccdSGuray Ozen nvgpu::MBarrierTokenType::get(op->getContext())); 86017649a77SGuray Ozen if (isMbarrierShared(op.getBarriers().getType())) { 861affcfccdSGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType, 862affcfccdSGuray Ozen barrier); 863affcfccdSGuray Ozen } else { 864affcfccdSGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, 865affcfccdSGuray Ozen barrier); 866affcfccdSGuray Ozen } 867affcfccdSGuray Ozen return success(); 868affcfccdSGuray Ozen } 869affcfccdSGuray Ozen }; 870affcfccdSGuray Ozen 871affcfccdSGuray Ozen /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to 872affcfccdSGuray Ozen /// `nvvm.mbarrier.arrive.nocomplete` 873affcfccdSGuray Ozen struct NVGPUMBarrierArriveNoCompleteLowering 87417649a77SGuray Ozen : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> { 87517649a77SGuray Ozen using MBarrierBasePattern< 87617649a77SGuray Ozen nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern; 877affcfccdSGuray Ozen LogicalResult 878affcfccdSGuray Ozen matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor, 879affcfccdSGuray Ozen ConversionPatternRewriter &rewriter) const override { 880ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 88117649a77SGuray Ozen Value barrier = 882ee49cda7SGuray Ozen getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 88317649a77SGuray Ozen adaptor.getMbarId(), rewriter); 884affcfccdSGuray Ozen Type tokenType = getTypeConverter()->convertType( 885affcfccdSGuray Ozen nvgpu::MBarrierTokenType::get(op->getContext())); 886ee49cda7SGuray Ozen Value count = truncToI32(b, adaptor.getCount()); 88717649a77SGuray Ozen if (isMbarrierShared(op.getBarriers().getType())) { 888affcfccdSGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>( 889affcfccdSGuray Ozen op, tokenType, barrier, count); 890affcfccdSGuray Ozen } else { 891affcfccdSGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( 892affcfccdSGuray Ozen op, tokenType, barrier, count); 893affcfccdSGuray Ozen } 894affcfccdSGuray Ozen return success(); 895affcfccdSGuray Ozen } 896affcfccdSGuray Ozen }; 897affcfccdSGuray Ozen 898affcfccdSGuray Ozen /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait` 899affcfccdSGuray Ozen struct NVGPUMBarrierTestWaitLowering 90017649a77SGuray Ozen : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> { 90117649a77SGuray Ozen using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern; 902affcfccdSGuray Ozen LogicalResult 903affcfccdSGuray Ozen matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor, 904affcfccdSGuray Ozen ConversionPatternRewriter &rewriter) const override { 905ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 90617649a77SGuray Ozen Value barrier = 907ee49cda7SGuray Ozen getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 90817649a77SGuray Ozen adaptor.getMbarId(), rewriter); 909affcfccdSGuray Ozen Type retType = rewriter.getI1Type(); 91017649a77SGuray Ozen if (isMbarrierShared(op.getBarriers().getType())) { 911affcfccdSGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>( 912affcfccdSGuray Ozen op, retType, barrier, adaptor.getToken()); 913affcfccdSGuray Ozen } else { 914affcfccdSGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>( 915affcfccdSGuray Ozen op, retType, barrier, adaptor.getToken()); 916affcfccdSGuray Ozen } 917affcfccdSGuray Ozen return success(); 918affcfccdSGuray Ozen } 919affcfccdSGuray Ozen }; 920affcfccdSGuray Ozen 921836dbb85SGuray Ozen struct NVGPUMBarrierArriveExpectTxLowering 92217649a77SGuray Ozen : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> { 92317649a77SGuray Ozen using MBarrierBasePattern< 92417649a77SGuray Ozen nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern; 925836dbb85SGuray Ozen LogicalResult 926836dbb85SGuray Ozen matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor, 927836dbb85SGuray Ozen ConversionPatternRewriter &rewriter) const override { 928ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 92917649a77SGuray Ozen Value barrier = 930ee49cda7SGuray Ozen getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 93117649a77SGuray Ozen adaptor.getMbarId(), rewriter); 932ee49cda7SGuray Ozen Value txcount = truncToI32(b, adaptor.getTxcount()); 933836dbb85SGuray Ozen 93417649a77SGuray Ozen if (isMbarrierShared(op.getBarriers().getType())) { 935836dbb85SGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( 936192d3320SGuray Ozen op, barrier, txcount, adaptor.getPredicate()); 937836dbb85SGuray Ozen return success(); 938836dbb85SGuray Ozen } 939836dbb85SGuray Ozen 94063389326SGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( 941192d3320SGuray Ozen op, barrier, txcount, adaptor.getPredicate()); 942836dbb85SGuray Ozen return success(); 943836dbb85SGuray Ozen } 944836dbb85SGuray Ozen }; 945836dbb85SGuray Ozen 946836dbb85SGuray Ozen struct NVGPUMBarrierTryWaitParityLowering 94717649a77SGuray Ozen : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> { 94817649a77SGuray Ozen using MBarrierBasePattern< 94917649a77SGuray Ozen nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern; 950836dbb85SGuray Ozen LogicalResult 951836dbb85SGuray Ozen matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor, 952836dbb85SGuray Ozen ConversionPatternRewriter &rewriter) const override { 953ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 95417649a77SGuray Ozen Value barrier = 955ee49cda7SGuray Ozen getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 95617649a77SGuray Ozen adaptor.getMbarId(), rewriter); 957ee49cda7SGuray Ozen Value ticks = truncToI32(b, adaptor.getTicks()); 9580a600c34SGuray Ozen Value phase = 9590a600c34SGuray Ozen b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity()); 960836dbb85SGuray Ozen 96117649a77SGuray Ozen if (isMbarrierShared(op.getBarriers().getType())) { 962836dbb85SGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( 963836dbb85SGuray Ozen op, barrier, phase, ticks); 964836dbb85SGuray Ozen return success(); 965836dbb85SGuray Ozen } 966836dbb85SGuray Ozen 967836dbb85SGuray Ozen rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, 968836dbb85SGuray Ozen phase, ticks); 969836dbb85SGuray Ozen return success(); 970836dbb85SGuray Ozen } 971836dbb85SGuray Ozen }; 972836dbb85SGuray Ozen 97370c2e061SGuray Ozen struct NVGPUTmaAsyncLoadOpLowering 97417649a77SGuray Ozen : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> { 97517649a77SGuray Ozen using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern; 97670c2e061SGuray Ozen LogicalResult 97770c2e061SGuray Ozen matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor, 97870c2e061SGuray Ozen ConversionPatternRewriter &rewriter) const override { 979ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 98050a76a7dSGuray Ozen auto srcMemrefType = cast<MemRefType>(op.getDst().getType()); 98150a76a7dSGuray Ozen Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, 98250a76a7dSGuray Ozen adaptor.getDst(), {}, rewriter); 98317649a77SGuray Ozen Value barrier = 984ee49cda7SGuray Ozen getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 98517649a77SGuray Ozen adaptor.getMbarId(), rewriter); 98670c2e061SGuray Ozen 98770c2e061SGuray Ozen SmallVector<Value> coords = adaptor.getCoordinates(); 98870c2e061SGuray Ozen for (auto [index, value] : llvm::enumerate(coords)) { 989ee49cda7SGuray Ozen coords[index] = truncToI32(b, value); 99070c2e061SGuray Ozen } 99170c2e061SGuray Ozen rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>( 9929ceea088SGuray Ozen op, dest, adaptor.getTensorMapDescriptor(), coords, barrier, 9934319e191SGuray Ozen ValueRange{}, adaptor.getMulticastMask(), Value{}, 9944319e191SGuray Ozen adaptor.getPredicate()); 99570c2e061SGuray Ozen return success(); 99670c2e061SGuray Ozen } 99770c2e061SGuray Ozen }; 9988dd0d95cSGuray Ozen 9998dd0d95cSGuray Ozen struct NVGPUTmaAsyncStoreOpLowering 10008dd0d95cSGuray Ozen : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> { 10018dd0d95cSGuray Ozen using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern; 10028dd0d95cSGuray Ozen LogicalResult 10038dd0d95cSGuray Ozen matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor, 10048dd0d95cSGuray Ozen ConversionPatternRewriter &rewriter) const override { 10058dd0d95cSGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 10068dd0d95cSGuray Ozen auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); 10078dd0d95cSGuray Ozen Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, 10088dd0d95cSGuray Ozen adaptor.getSrc(), {}, rewriter); 10098dd0d95cSGuray Ozen SmallVector<Value> coords = adaptor.getCoordinates(); 10108dd0d95cSGuray Ozen for (auto [index, value] : llvm::enumerate(coords)) { 10118dd0d95cSGuray Ozen coords[index] = truncToI32(b, value); 10128dd0d95cSGuray Ozen } 10138dd0d95cSGuray Ozen 10148dd0d95cSGuray Ozen rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>( 10158dd0d95cSGuray Ozen op, adaptor.getTensorMapDescriptor(), dest, coords, 10168dd0d95cSGuray Ozen adaptor.getPredicate()); 10178dd0d95cSGuray Ozen return success(); 10188dd0d95cSGuray Ozen } 10198dd0d95cSGuray Ozen }; 10208dd0d95cSGuray Ozen 10216dc7717bSGuray Ozen struct NVGPUGenerateWarpgroupDescriptorLowering 10227eb2b99fSGuray Ozen : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> { 1023cce3e8edSGuray Ozen using ConvertOpToLLVMPattern< 10247eb2b99fSGuray Ozen nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern; 1025cce3e8edSGuray Ozen 1026cce3e8edSGuray Ozen LogicalResult 10277eb2b99fSGuray Ozen matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor, 1028cce3e8edSGuray Ozen ConversionPatternRewriter &rewriter) const override { 1029cce3e8edSGuray Ozen 1030ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1031cce3e8edSGuray Ozen 1032cce3e8edSGuray Ozen nvgpu::TensorMapSwizzleKind swizzleKind = 1033cce3e8edSGuray Ozen op.getTensorMap().getType().getSwizzle(); 1034cce3e8edSGuray Ozen 1035cce3e8edSGuray Ozen unsigned layout = 1036cce3e8edSGuray Ozen (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 1037cce3e8edSGuray Ozen : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 1038cce3e8edSGuray Ozen : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 1039cce3e8edSGuray Ozen : 1; 1040cce3e8edSGuray Ozen unsigned swizzle = 1041cce3e8edSGuray Ozen (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1 1042cce3e8edSGuray Ozen : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2 1043cce3e8edSGuray Ozen : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3 1044cce3e8edSGuray Ozen : 0; 1045cce3e8edSGuray Ozen 1046ee49cda7SGuray Ozen auto ti64 = b.getIntegerType(64); 1047cce3e8edSGuray Ozen auto makeConst = [&](uint64_t index) -> Value { 1048ee49cda7SGuray Ozen return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index)); 1049cce3e8edSGuray Ozen }; 1050cce3e8edSGuray Ozen auto shiftLeft = [&](Value value, unsigned shift) -> Value { 1051ee49cda7SGuray Ozen return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift)); 1052cce3e8edSGuray Ozen }; 1053cce3e8edSGuray Ozen auto shiftRight = [&](Value value, unsigned shift) -> Value { 1054ee49cda7SGuray Ozen return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift)); 1055cce3e8edSGuray Ozen }; 1056cce3e8edSGuray Ozen auto insertBit = [&](Value desc, Value val, int startBit) { 1057ee49cda7SGuray Ozen return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit)); 1058cce3e8edSGuray Ozen }; 1059cce3e8edSGuray Ozen 1060cce3e8edSGuray Ozen int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); 106123882226SGuray Ozen uint64_t strideDimVal = (layout << 3) >> exclude4LSB; 106223882226SGuray Ozen uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB; 1063b96d0693SGuray Ozen uint64_t offsetVal = 0; 1064b96d0693SGuray Ozen 1065b96d0693SGuray Ozen Value strideDim = makeConst(strideDimVal); 1066b96d0693SGuray Ozen Value leadDim = makeConst(leadDimVal); 1067b96d0693SGuray Ozen 1068cce3e8edSGuray Ozen Value baseAddr = getStridedElementPtr( 1069cce3e8edSGuray Ozen op->getLoc(), cast<MemRefType>(op.getTensor().getType()), 1070cce3e8edSGuray Ozen adaptor.getTensor(), {}, rewriter); 1071ee49cda7SGuray Ozen Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr); 1072cce3e8edSGuray Ozen // Just use 14 bits for base address 1073cce3e8edSGuray Ozen Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); 1074cce3e8edSGuray Ozen 1075cce3e8edSGuray Ozen int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32, 1076cce3e8edSGuray Ozen startLeadBit = 16, startBaseAddrBit = 0; 1077cce3e8edSGuray Ozen Value dsc = makeConst(0); 1078cce3e8edSGuray Ozen // // [62,64) swizzle type 1079cce3e8edSGuray Ozen dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit); 1080cce3e8edSGuray Ozen // // [49,52) base_offset 1081b96d0693SGuray Ozen dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit); 1082cce3e8edSGuray Ozen // // [32,46) stride 1083cce3e8edSGuray Ozen dsc = insertBit(dsc, strideDim, startStrideBit); 1084cce3e8edSGuray Ozen // // [16,30) leading dimension 1085cce3e8edSGuray Ozen dsc = insertBit(dsc, leadDim, startLeadBit); 1086cce3e8edSGuray Ozen // // [0,14) start_address 1087cce3e8edSGuray Ozen dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); 1088cce3e8edSGuray Ozen 10896dc7717bSGuray Ozen LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " 1090b96d0693SGuray Ozen << "leading_off:" << leadDimVal << "\t" 1091b96d0693SGuray Ozen << "stride_off :" << strideDimVal << "\t" 1092b96d0693SGuray Ozen << "base_offset:" << offsetVal << "\t" 1093b96d0693SGuray Ozen << "layout_type:" << swizzle << " (" 1094b96d0693SGuray Ozen << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) 1095b96d0693SGuray Ozen << ")\n start_addr : " << baseAddr << "\n"); 1096b96d0693SGuray Ozen 1097cce3e8edSGuray Ozen rewriter.replaceOp(op, dsc); 1098cce3e8edSGuray Ozen return success(); 1099cce3e8edSGuray Ozen } 1100cce3e8edSGuray Ozen }; 1101e56d6745SGuray Ozen 1102ee49cda7SGuray Ozen static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { 1103ee49cda7SGuray Ozen return b.create<LLVM::ConstantOp>(b.getIntegerType(64), 1104ee49cda7SGuray Ozen b.getI32IntegerAttr(index)); 1105e56d6745SGuray Ozen } 1106e56d6745SGuray Ozen 1107e56d6745SGuray Ozen /// Returns a Value that holds data type enum that is expected by CUDA driver. 1108ee49cda7SGuray Ozen static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) { 1109e56d6745SGuray Ozen // Enum is from CUDA driver API 1110e56d6745SGuray Ozen // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html 1111e56d6745SGuray Ozen enum CUtensorMapDataTypeEnum { 1112e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, 1113e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_UINT16, 1114e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_UINT32, 1115e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_INT32, 1116e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_UINT64, 1117e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_INT64, 1118e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_FLOAT16, 1119e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 1120e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_FLOAT64, 1121e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 1122e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, 1123e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, 1124e56d6745SGuray Ozen CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ 1125e56d6745SGuray Ozen }; 1126e56d6745SGuray Ozen 1127e56d6745SGuray Ozen if (type.isUnsignedInteger(8)) 1128ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8); 1129e56d6745SGuray Ozen if (type.isUnsignedInteger(16)) 1130ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16); 1131e56d6745SGuray Ozen if (type.isUnsignedInteger(32)) 1132ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32); 1133e56d6745SGuray Ozen if (type.isUnsignedInteger(64)) 1134ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64); 1135e56d6745SGuray Ozen if (type.isSignlessInteger(32)) 1136ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32); 1137e56d6745SGuray Ozen if (type.isSignlessInteger(64)) 1138ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64); 1139e56d6745SGuray Ozen if (type.isF16()) 1140ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16); 1141e56d6745SGuray Ozen if (type.isF32()) 1142ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32); 1143e56d6745SGuray Ozen if (type.isF64()) 1144ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64); 1145e56d6745SGuray Ozen if (type.isBF16()) 1146ee49cda7SGuray Ozen return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16); 1147e56d6745SGuray Ozen 1148e56d6745SGuray Ozen llvm_unreachable("Not supported data type"); 1149e56d6745SGuray Ozen } 1150e56d6745SGuray Ozen 1151e56d6745SGuray Ozen struct NVGPUTmaCreateDescriptorOpLowering 1152e56d6745SGuray Ozen : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> { 1153e56d6745SGuray Ozen using ConvertOpToLLVMPattern< 1154e56d6745SGuray Ozen nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern; 1155e56d6745SGuray Ozen LogicalResult 1156e56d6745SGuray Ozen matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor, 1157e56d6745SGuray Ozen ConversionPatternRewriter &rewriter) const override { 1158ee49cda7SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 11592f17c9f6SChristian Ulmann auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext()); 1160e56d6745SGuray Ozen Type llvmInt64Type = IntegerType::get(op->getContext(), 64); 1161e56d6745SGuray Ozen 1162ee49cda7SGuray Ozen Value tensorElementType = 1163ee49cda7SGuray Ozen elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType()); 1164e56d6745SGuray Ozen auto promotedOperands = getTypeConverter()->promoteOperands( 1165ee49cda7SGuray Ozen b.getLoc(), op->getOperands(), adaptor.getOperands(), b); 1166e56d6745SGuray Ozen 1167ee49cda7SGuray Ozen Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type, 1168ee49cda7SGuray Ozen makeI64Const(b, 5)); 1169e56d6745SGuray Ozen for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { 1170ee49cda7SGuray Ozen Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType, 1171ee49cda7SGuray Ozen boxArrayPtr, makeI64Const(b, index)); 1172ee49cda7SGuray Ozen b.create<LLVM::StoreOp>(value, gep); 1173e56d6745SGuray Ozen } 1174e56d6745SGuray Ozen 1175e56d6745SGuray Ozen nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); 1176e56d6745SGuray Ozen // Set Arguments for the function call 1177e56d6745SGuray Ozen SmallVector<Value> arguments; 1178e56d6745SGuray Ozen arguments.push_back(promotedOperands[0]); // rank 1179e56d6745SGuray Ozen arguments.push_back(promotedOperands[1]); // descriptor 1180e56d6745SGuray Ozen arguments.push_back(tensorElementType); // data type 1181e56d6745SGuray Ozen arguments.push_back( 1182ee49cda7SGuray Ozen makeI64Const(b, (int)desc.getInterleave())); // interleave 1183ee49cda7SGuray Ozen arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle 1184ee49cda7SGuray Ozen arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo 1185ee49cda7SGuray Ozen arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob 1186e56d6745SGuray Ozen arguments.push_back(boxArrayPtr); // box dimensions 1187e56d6745SGuray Ozen 1188e56d6745SGuray Ozen // Set data types of the arguments 1189e56d6745SGuray Ozen SmallVector<Type> argTypes = { 1190e56d6745SGuray Ozen llvmInt64Type, /* int64_t tensorRank */ 1191e56d6745SGuray Ozen llvmPointerType, /* ptr */ 1192e56d6745SGuray Ozen llvmInt64Type, /* int64_t */ 1193e56d6745SGuray Ozen llvmInt64Type, /* int64_t */ 1194e56d6745SGuray Ozen llvmInt64Type, /* int64_t */ 1195e56d6745SGuray Ozen llvmInt64Type, /* int64_t */ 1196e56d6745SGuray Ozen llvmInt64Type, /* int64_t */ 1197e56d6745SGuray Ozen llvmPointerType /* ptr */ 1198e56d6745SGuray Ozen }; 1199e56d6745SGuray Ozen FunctionCallBuilder hostRegisterCallBuilder = { 1200e56d6745SGuray Ozen "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes}; 1201e56d6745SGuray Ozen Value tensorMap = 1202ee49cda7SGuray Ozen hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult(); 1203e56d6745SGuray Ozen 1204e56d6745SGuray Ozen rewriter.replaceOp(op, tensorMap); 1205e56d6745SGuray Ozen return success(); 1206e56d6745SGuray Ozen } 1207e56d6745SGuray Ozen }; 1208e56d6745SGuray Ozen 120923882226SGuray Ozen struct NVGPUWarpgroupMmaOpLowering 121023882226SGuray Ozen : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> { 121123882226SGuray Ozen using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern; 121223882226SGuray Ozen 1213b74cfc13SGuray Ozen /// This is a helper class to generate required NVVM Ops for warp-group level 1214b74cfc13SGuray Ozen /// matrix multiplication. 1215b74cfc13SGuray Ozen /// When the given GEMM shape is larger than the shape of 1216b74cfc13SGuray Ozen /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp 1217b74cfc13SGuray Ozen /// Op(s), group and execute them asynchronously. The class also handles 1218b74cfc13SGuray Ozen /// waiting for completion and iterates through WarpgroupMatrixDescriptor to 1219b74cfc13SGuray Ozen /// create descriptors for each instruction. 1220b74cfc13SGuray Ozen /// 1221b74cfc13SGuray Ozen /// For example this is the case when the shape of GEMM is 128x128x128 1222b74cfc13SGuray Ozen /// 1223b74cfc13SGuray Ozen /// nvvm.wgmma.fence.aligned 1224b74cfc13SGuray Ozen /// 1225b74cfc13SGuray Ozen /// nvvm.wgmma.mma.async descA, descB 1226b74cfc13SGuray Ozen /// iterate(descA, descB) 1227b74cfc13SGuray Ozen /// nvvm.wgmma.mma.async descA, descB 1228b74cfc13SGuray Ozen /// [6x times more] 1229b74cfc13SGuray Ozen /// 1230b74cfc13SGuray Ozen /// nvvm.wgmma.group.sync.aligned 1231b74cfc13SGuray Ozen /// nvvm.wgmma.wait.group.sync [groupId] 1232b74cfc13SGuray Ozen /// 1233b74cfc13SGuray Ozen class WarpgroupGemm { 1234b74cfc13SGuray Ozen nvgpu::WarpgroupMmaOp op; 1235b74cfc13SGuray Ozen ImplicitLocOpBuilder b; 1236b74cfc13SGuray Ozen OpAdaptor adaptor; 1237b74cfc13SGuray Ozen 1238b74cfc13SGuray Ozen // Entire shape of the given Op 1239b74cfc13SGuray Ozen int64_t totalM, totalN, totalK; 1240b74cfc13SGuray Ozen 1241b74cfc13SGuray Ozen // Shape of one wgmma instruction 1242b74cfc13SGuray Ozen int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0; 1243b74cfc13SGuray Ozen 1244b74cfc13SGuray Ozen // Iteration counts for GEMM 1245b74cfc13SGuray Ozen int iterationM = 0, iterationN = 0, iterationK = 0; 1246b74cfc13SGuray Ozen 1247b74cfc13SGuray Ozen /// The function returns the shape of wgmma instruction that is defined in 1248b74cfc13SGuray Ozen /// PTX programming guide. 1249b74cfc13SGuray Ozen /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape 1250b74cfc13SGuray Ozen void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) { 1251b74cfc13SGuray Ozen wgmmaM = 64; 1252b74cfc13SGuray Ozen wgmmaN = sizeN; 125323882226SGuray Ozen if (inputElemType.isTF32()) { 1254b74cfc13SGuray Ozen wgmmaK = 8; 125523882226SGuray Ozen } else if (inputElemType.isF16() || inputElemType.isBF16()) { 1256b74cfc13SGuray Ozen wgmmaK = 16; 1257*7a77f14cSMatthias Springer } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) || 1258*7a77f14cSMatthias Springer inputElemType.isInteger(16)) { 1259b74cfc13SGuray Ozen wgmmaK = 32; 126023882226SGuray Ozen } else if (inputElemType.isInteger(1)) { 1261b74cfc13SGuray Ozen wgmmaK = 256; 126223882226SGuray Ozen } else { 126323882226SGuray Ozen llvm_unreachable("msg: not supported K shape"); 126423882226SGuray Ozen } 1265b74cfc13SGuray Ozen LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM 1266b74cfc13SGuray Ozen << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); 126723882226SGuray Ozen } 126823882226SGuray Ozen 1269b74cfc13SGuray Ozen /// Generates WGMMATypesAttr from MLIR Type 127012c241b3SGuray Ozen NVVM::WGMMATypesAttr generateWgmmaType(Type type, 127112c241b3SGuray Ozen bool useF32 = false) const { 127212c241b3SGuray Ozen auto getWgmmaType = [=](Type elemType) { 1273b74cfc13SGuray Ozen if (elemType.isF32() || elemType.isTF32()) 127412c241b3SGuray Ozen return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32; 1275b74cfc13SGuray Ozen if (elemType.isF16()) 1276b74cfc13SGuray Ozen return NVVM::WGMMATypes::f16; 1277b74cfc13SGuray Ozen if (elemType.isBF16()) 1278b74cfc13SGuray Ozen return NVVM::WGMMATypes::bf16; 1279*7a77f14cSMatthias Springer if (isa<Float8E4M3FNType>(elemType)) 1280b74cfc13SGuray Ozen return NVVM::WGMMATypes::e4m3; 1281*7a77f14cSMatthias Springer if (isa<Float8E5M2Type>(elemType)) 1282b74cfc13SGuray Ozen return NVVM::WGMMATypes::e5m2; 1283b74cfc13SGuray Ozen if (elemType.isInteger(1)) 1284b74cfc13SGuray Ozen return NVVM::WGMMATypes::b1; 1285b74cfc13SGuray Ozen if (elemType.isInteger(8)) 1286b74cfc13SGuray Ozen return NVVM::WGMMATypes::s8; 1287b74cfc13SGuray Ozen if (elemType.isUnsignedInteger(8)) 1288b74cfc13SGuray Ozen return NVVM::WGMMATypes::u8; 128912c241b3SGuray Ozen if (elemType.isInteger(32)) 129012c241b3SGuray Ozen return NVVM::WGMMATypes::s32; 1291b74cfc13SGuray Ozen llvm_unreachable("unsupported type"); 1292b74cfc13SGuray Ozen }; 1293b74cfc13SGuray Ozen return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type)); 129423882226SGuray Ozen } 129523882226SGuray Ozen 1296b74cfc13SGuray Ozen /// Generates layout attribute for the input matrix for wgmma instruction 1297b74cfc13SGuray Ozen NVVM::MMALayoutAttr 1298b74cfc13SGuray Ozen generateWgmmaLayout(std::optional<bool> transpose) const { 1299b74cfc13SGuray Ozen if (transpose.value_or(false)) 1300b74cfc13SGuray Ozen return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col); 1301b74cfc13SGuray Ozen return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row); 130223882226SGuray Ozen } 130323882226SGuray Ozen 1304b74cfc13SGuray Ozen /// Generates shape attribute for wgmma instruction 1305b74cfc13SGuray Ozen NVVM::MMAShapeAttr generateWgmmaShape() const { 1306b74cfc13SGuray Ozen return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK); 1307b74cfc13SGuray Ozen } 130823882226SGuray Ozen 1309b74cfc13SGuray Ozen /// Generates scale attributes of output matrix for wgmma instruction 1310b74cfc13SGuray Ozen NVVM::WGMMAScaleOutAttr generateScaleOut() const { 1311b74cfc13SGuray Ozen return NVVM::WGMMAScaleOutAttr::get(op->getContext(), 1312b74cfc13SGuray Ozen NVVM::WGMMAScaleOut::one); 1313b74cfc13SGuray Ozen } 1314b74cfc13SGuray Ozen /// Generates scale attributes of input matrix for wgmma instruction 1315b74cfc13SGuray Ozen NVVM::WGMMAScaleInAttr generateScaleIn() const { 1316b74cfc13SGuray Ozen return NVVM::WGMMAScaleInAttr::get(op->getContext(), 1317b74cfc13SGuray Ozen NVVM::WGMMAScaleIn::one); 1318b74cfc13SGuray Ozen } 131923882226SGuray Ozen 1320b74cfc13SGuray Ozen /// Basic function to generate Add 1321b74cfc13SGuray Ozen Value makeAdd(Value lhs, Value rhs) { 1322ee49cda7SGuray Ozen return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); 132323882226SGuray Ozen }; 132423882226SGuray Ozen 1325b74cfc13SGuray Ozen /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. 1326b74cfc13SGuray Ozen /// Currently, it only handles row-major. 1327b74cfc13SGuray Ozen /// 1328b74cfc13SGuray Ozen /// It moves the pointer like below for [128][64] size: 1329b74cfc13SGuray Ozen /// +2 +4 +6 1330b74cfc13SGuray Ozen /// ↓ ↓ ↓ 1331b74cfc13SGuray Ozen /// descA ---> +--+--+--+--+ 1332b74cfc13SGuray Ozen /// |->|->|->|->| 1333b74cfc13SGuray Ozen /// | | | | | 1334b74cfc13SGuray Ozen /// | | | | | 1335b74cfc13SGuray Ozen /// | | | | | 1336b74cfc13SGuray Ozen /// descA+512---> +-----------+ 1337b74cfc13SGuray Ozen /// | | | | | 1338b74cfc13SGuray Ozen /// | | | | | 1339b74cfc13SGuray Ozen /// | | | | | 1340b74cfc13SGuray Ozen /// | | | | | 1341b74cfc13SGuray Ozen /// +-----------+ 1342b74cfc13SGuray Ozen /// 1343b74cfc13SGuray Ozen Value iterateDescriptorA(Value desc, int i, int j, int k) { 1344b74cfc13SGuray Ozen MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor(); 1345b74cfc13SGuray Ozen Type elemA = matrixTypeA.getElementType(); 1346b74cfc13SGuray Ozen int byte = elemA.getIntOrFloatBitWidth() / 8; 1347b74cfc13SGuray Ozen int tileShapeA = matrixTypeA.getDimSize(1); 1348b74cfc13SGuray Ozen int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; 134923882226SGuray Ozen incrementVal = incrementVal >> exclude4LSB; 1350b74cfc13SGuray Ozen LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k 1351b74cfc13SGuray Ozen << "] [wgmma descriptors] Descriptor A + " 135223882226SGuray Ozen << incrementVal << " | \t "); 135323882226SGuray Ozen if (!incrementVal) 135423882226SGuray Ozen return desc; 1355ee49cda7SGuray Ozen return makeAdd(desc, makeI64Const(b, incrementVal)); 1356b74cfc13SGuray Ozen } 135723882226SGuray Ozen 1358b74cfc13SGuray Ozen /// Moves the descriptor pointer of matrix-B for the next wgmma instruction. 1359b74cfc13SGuray Ozen /// Currently, it only handles column-major. 1360b74cfc13SGuray Ozen /// 1361b74cfc13SGuray Ozen /// It moves the pointer like below for [128][64] size: 1362b74cfc13SGuray Ozen /// descB ---> +--+--+--+--+--+--+--+--+ 1363b74cfc13SGuray Ozen /// |↓ | | | | | | | | 1364b74cfc13SGuray Ozen /// |↓ | | | | | | | | 1365b74cfc13SGuray Ozen /// |↓ | | | | | | | | 1366b74cfc13SGuray Ozen /// |↓ | | | | | | | | 1367b74cfc13SGuray Ozen /// +--+--+--+--+--+--+--+--+ 1368b74cfc13SGuray Ozen /// 1369b74cfc13SGuray Ozen Value iterateDescriptorB(Value desc, int i, int j, int k) { 1370b74cfc13SGuray Ozen MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor(); 1371b74cfc13SGuray Ozen Type elemB = matrixTypeB.getElementType(); 1372b74cfc13SGuray Ozen int byte = elemB.getIntOrFloatBitWidth() / 8; 1373b74cfc13SGuray Ozen int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; 137423882226SGuray Ozen incrementVal = incrementVal >> exclude4LSB; 137523882226SGuray Ozen LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); 137623882226SGuray Ozen if (!incrementVal) 137723882226SGuray Ozen return desc; 1378ee49cda7SGuray Ozen return makeAdd(desc, makeI64Const(b, incrementVal)); 137923882226SGuray Ozen } 1380b74cfc13SGuray Ozen 1381b74cfc13SGuray Ozen /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix 1382b74cfc13SGuray Ozen /// descriptors and arranges them based on induction variables: i, j, and k. 138352db7e27SGuray Ozen Value generateWgmma(int i, int j, int k, Value matrixC) { 1384b74cfc13SGuray Ozen LLVM_DEBUG(DBGS() << "\t wgmma." 1385b74cfc13SGuray Ozen << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK 1386b74cfc13SGuray Ozen << "(A[" << (iterationM * wgmmaM) << ":" 1387b74cfc13SGuray Ozen << (iterationM * wgmmaM) + wgmmaM << "][" 1388b74cfc13SGuray Ozen << (iterationK * wgmmaK) << ":" 1389b74cfc13SGuray Ozen << (iterationK * wgmmaK + wgmmaK) << "] * " 1390b74cfc13SGuray Ozen << " B[" << (iterationK * wgmmaK) << ":" 1391b74cfc13SGuray Ozen << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" 1392b74cfc13SGuray Ozen << wgmmaN << "])\n"); 1393b74cfc13SGuray Ozen 1394b74cfc13SGuray Ozen Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); 1395b74cfc13SGuray Ozen Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); 1396b74cfc13SGuray Ozen 1397b74cfc13SGuray Ozen Type elemA = op.getDescriptorA().getType().getTensor().getElementType(); 1398b74cfc13SGuray Ozen NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA); 1399b74cfc13SGuray Ozen 1400b74cfc13SGuray Ozen Type elemB = op.getDescriptorB().getType().getTensor().getElementType(); 1401b74cfc13SGuray Ozen NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB); 1402b74cfc13SGuray Ozen 140312c241b3SGuray Ozen Type elemD = op.getMatrixC().getType().getFragmented().getElementType(); 140412c241b3SGuray Ozen NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true); 140512c241b3SGuray Ozen 1406b74cfc13SGuray Ozen NVVM::MMAShapeAttr shape = generateWgmmaShape(); 1407b74cfc13SGuray Ozen NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut(); 1408b74cfc13SGuray Ozen NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn(); 1409b74cfc13SGuray Ozen NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA()); 1410fa13c3eeSGuray Ozen NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB()); 1411b74cfc13SGuray Ozen 1412b74cfc13SGuray Ozen auto overflow = NVVM::MMAIntOverflowAttr::get( 1413b74cfc13SGuray Ozen op->getContext(), NVVM::MMAIntOverflow::wrapped); 1414b74cfc13SGuray Ozen 1415b74cfc13SGuray Ozen return b.create<NVVM::WgmmaMmaAsyncOp>( 141652db7e27SGuray Ozen matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, 141712c241b3SGuray Ozen itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, 141812c241b3SGuray Ozen overflow); 1419b74cfc13SGuray Ozen } 1420b74cfc13SGuray Ozen 1421b74cfc13SGuray Ozen /// Generates multiple wgmma instructions to complete the given GEMM shape 142252db7e27SGuray Ozen Value generateWgmmaGroup() { 142352db7e27SGuray Ozen Value wgmmaResult = 142452db7e27SGuray Ozen b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType()); 1425b74cfc13SGuray Ozen 1426b74cfc13SGuray Ozen // Perform GEMM 142752db7e27SGuray Ozen SmallVector<Value> wgmmaResults; 1428b74cfc13SGuray Ozen for (int i = 0; i < iterationM; ++i) { 142952db7e27SGuray Ozen Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i); 1430b74cfc13SGuray Ozen for (int j = 0; j < iterationN; ++j) 1431b74cfc13SGuray Ozen for (int k = 0; k < iterationK; ++k) 143252db7e27SGuray Ozen matrixC = generateWgmma(i, j, k, matrixC); 143323882226SGuray Ozen wgmmaResults.push_back(matrixC); 143423882226SGuray Ozen } 143552db7e27SGuray Ozen for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) { 143652db7e27SGuray Ozen wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(), 143752db7e27SGuray Ozen wgmmaResult, matrix, idx); 143852db7e27SGuray Ozen } 143952db7e27SGuray Ozen return wgmmaResult; 1440b74cfc13SGuray Ozen } 1441b74cfc13SGuray Ozen 1442b74cfc13SGuray Ozen public: 1443b74cfc13SGuray Ozen WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b, 144452db7e27SGuray Ozen OpAdaptor adaptor) 144552db7e27SGuray Ozen : op(op), b(b), adaptor(adaptor) { 1446b74cfc13SGuray Ozen // Find the entire GEMM Shape 1447b74cfc13SGuray Ozen totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); 1448b74cfc13SGuray Ozen totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); 1449b74cfc13SGuray Ozen totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); 1450b74cfc13SGuray Ozen LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN 1451b74cfc13SGuray Ozen << "] += A[" << totalM << "][" << totalK << "] * B[" 1452b74cfc13SGuray Ozen << totalK << "][" << totalN << "] ---===\n"); 1453b74cfc13SGuray Ozen 1454b74cfc13SGuray Ozen // Find the shape for one wgmma instruction 1455b74cfc13SGuray Ozen findWgmmaShape( 1456b74cfc13SGuray Ozen totalM, totalN, 1457b74cfc13SGuray Ozen op.getDescriptorA().getType().getTensor().getElementType()); 1458b74cfc13SGuray Ozen 1459b74cfc13SGuray Ozen // Iterations counts to complete the given shape with wgmma shape 1460b74cfc13SGuray Ozen iterationM = totalM / wgmmaM; 1461b74cfc13SGuray Ozen iterationN = totalN / wgmmaN; 1462b74cfc13SGuray Ozen iterationK = totalK / wgmmaK; 1463b74cfc13SGuray Ozen } 1464b74cfc13SGuray Ozen 1465b74cfc13SGuray Ozen /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It 1466b74cfc13SGuray Ozen /// includes generating a fence Op (WgmmaFenceAlignedOp) before the 1467b74cfc13SGuray Ozen /// instructions and group synchronization, as well as waiting 1468b74cfc13SGuray Ozen /// (WgmmaGroupSyncAlignedOp) for group synchronization 1469b74cfc13SGuray Ozen /// (WgmmaWaitGroupSyncOp) after the instructions. 147052db7e27SGuray Ozen Value generateWarpgroupMma() { 1471b74cfc13SGuray Ozen b.create<NVVM::WgmmaFenceAlignedOp>(); 147252db7e27SGuray Ozen Value wgmmaResult = generateWgmmaGroup(); 1473ee49cda7SGuray Ozen b.create<NVVM::WgmmaGroupSyncAlignedOp>(); 1474ee49cda7SGuray Ozen b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup()); 147552db7e27SGuray Ozen return wgmmaResult; 1476b74cfc13SGuray Ozen } 1477b74cfc13SGuray Ozen }; 1478b74cfc13SGuray Ozen LogicalResult 1479b74cfc13SGuray Ozen matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, 1480b74cfc13SGuray Ozen ConversionPatternRewriter &rewriter) const override { 1481b74cfc13SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 148252db7e27SGuray Ozen 1483b74cfc13SGuray Ozen // Step 1. Build a helper class 148452db7e27SGuray Ozen WarpgroupGemm warpgroupGemm(op, b, adaptor); 1485b74cfc13SGuray Ozen 1486b74cfc13SGuray Ozen // Step 2. Get the entire GEMM Shape 148752db7e27SGuray Ozen Value wgmmaResult = warpgroupGemm.generateWarpgroupMma(); 1488b74cfc13SGuray Ozen 1489b74cfc13SGuray Ozen // Step 3. Replace fragmented result struct with the op results 149052db7e27SGuray Ozen rewriter.replaceOp(op, wgmmaResult); 149123882226SGuray Ozen return success(); 149223882226SGuray Ozen } 149323882226SGuray Ozen }; 149423882226SGuray Ozen 1495d20fbc90SGuray Ozen struct NVGPUWarpgroupMmaStoreOpLowering 1496d20fbc90SGuray Ozen : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> { 1497d20fbc90SGuray Ozen using ConvertOpToLLVMPattern< 1498d20fbc90SGuray Ozen nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; 1499d20fbc90SGuray Ozen 1500d20fbc90SGuray Ozen /// This function stores a fragmented register matrix owned by a warp group 1501d20fbc90SGuray Ozen /// (128 threads) into a memref. Each thread has 64 registers, each the size 1502d20fbc90SGuray Ozen /// of a struct. 1503d20fbc90SGuray Ozen /// Here is what each threads (T) holds, each `d` is struct value with a 1504d20fbc90SGuray Ozen /// number. 1505d20fbc90SGuray Ozen /// 1506d20fbc90SGuray Ozen /// Threads in warp-group (128 threads) and what they owns in the matrixD: 1507d20fbc90SGuray Ozen /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N] 1508d20fbc90SGuray Ozen /// 32-63 Warp-1 -> MatrixD[16:31][0:N] 1509d20fbc90SGuray Ozen /// 64-95 Warp-2 -> MatrixD[32:47][0:N] 1510d20fbc90SGuray Ozen /// 96-127 Warp-3 -> MatrixD[48:64][0:N] 1511d20fbc90SGuray Ozen /// 1512d20fbc90SGuray Ozen /// Matrix-D: 1513d20fbc90SGuray Ozen /// +______________________________________________________________________+ 1514d20fbc90SGuray Ozen /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 | 1515d20fbc90SGuray Ozen /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY| 1516d20fbc90SGuray Ozen /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY| 1517d20fbc90SGuray Ozen /// ..| .........|.........|.........|.........|........|...........|........| 1518d20fbc90SGuray Ozen /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW| 1519d20fbc90SGuray Ozen /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW| 1520d20fbc90SGuray Ozen /// ..| .........|.........|.........|.........|........|...........|........| 1521d20fbc90SGuray Ozen /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........| 1522d20fbc90SGuray Ozen /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........| 1523d20fbc90SGuray Ozen /// ..| .........|.........|.........|.........|........|...........|........| 1524d20fbc90SGuray Ozen /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........| 1525d20fbc90SGuray Ozen /// ..| .........|.........|.........|.........|........|...........|........| 1526d20fbc90SGuray Ozen /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........| 1527d20fbc90SGuray Ozen /// ..| .........|.........|.........|.........|........|...........|........| 1528d20fbc90SGuray Ozen /// +______________________________________________________________________+ 1529d20fbc90SGuray Ozen /// 1530d20fbc90SGuray Ozen /// \param rewriter: The pattern rewriter. 1531d20fbc90SGuray Ozen /// \param matrixD: Result of the warp-group MMA operation (fragmented 1532d20fbc90SGuray Ozen /// matrix). It is holded by a thread and a struct with 64 elements. 1533d20fbc90SGuray Ozen /// \param dstMemref: The memref where the registers will be stored. 1534d20fbc90SGuray Ozen /// \param offset: the offset within the memref where the registers will be 1535d20fbc90SGuray Ozen /// stored. 1536d20fbc90SGuray Ozen void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD, 1537d20fbc90SGuray Ozen TypedValue<MemRefType> dstMemref, 1538d20fbc90SGuray Ozen int offset) const { 1539d20fbc90SGuray Ozen Type i32 = b.getI32Type(); 1540d20fbc90SGuray Ozen 1541d20fbc90SGuray Ozen auto makeConst = [&](int32_t index) -> Value { 1542d20fbc90SGuray Ozen return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index)); 1543d20fbc90SGuray Ozen }; 1544d20fbc90SGuray Ozen Value c1 = makeConst(1); 1545d20fbc90SGuray Ozen Value c2 = makeConst(2); 1546d20fbc90SGuray Ozen Value c4 = makeConst(4); 1547d20fbc90SGuray Ozen Value c8 = makeConst(8); 1548d20fbc90SGuray Ozen Value c16 = makeConst(16); 1549d20fbc90SGuray Ozen Value warpSize = makeConst(kWarpSize); 1550d20fbc90SGuray Ozen 1551d20fbc90SGuray Ozen auto makeMul = [&](Value lhs, Value rhs) -> Value { 1552d20fbc90SGuray Ozen return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs); 1553d20fbc90SGuray Ozen }; 1554d20fbc90SGuray Ozen auto makeAdd = [&](Value lhs, Value rhs) -> Value { 1555d20fbc90SGuray Ozen return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); 1556d20fbc90SGuray Ozen }; 1557d20fbc90SGuray Ozen 1558d20fbc90SGuray Ozen auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, 1559d20fbc90SGuray Ozen TypedValue<::mlir::MemRefType> memref) { 1560d20fbc90SGuray Ozen Type it = b.getIndexType(); 1561d20fbc90SGuray Ozen Value idx = b.create<arith::IndexCastOp>(it, x); 1562d20fbc90SGuray Ozen Value idy0 = b.create<arith::IndexCastOp>(it, y); 1563d20fbc90SGuray Ozen Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1)); 1564d20fbc90SGuray Ozen Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i); 1565d20fbc90SGuray Ozen Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1); 1566d20fbc90SGuray Ozen b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0}); 1567d20fbc90SGuray Ozen b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1}); 1568d20fbc90SGuray Ozen }; 1569d20fbc90SGuray Ozen 157021830c91SGuray Ozen Value tidx = b.create<NVVM::ThreadIdXOp>(i32); 157121830c91SGuray Ozen Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize); 157221830c91SGuray Ozen Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize); 157321830c91SGuray Ozen Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4); 157421830c91SGuray Ozen Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4); 157521830c91SGuray Ozen 1576d20fbc90SGuray Ozen Value tj = makeMul(lane4modId, c2); 1577d20fbc90SGuray Ozen Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); 1578d20fbc90SGuray Ozen if (offset) 1579d20fbc90SGuray Ozen ti = makeAdd(ti, makeConst(offset)); 158021830c91SGuray Ozen 1581a5757c5bSChristian Sigg auto structType = cast<LLVM::LLVMStructType>(matrixD.getType()); 158221830c91SGuray Ozen 158321830c91SGuray Ozen // Number of 32-bit registers owns per thread 158421830c91SGuray Ozen constexpr unsigned numAdjacentRegisters = 2; 158521830c91SGuray Ozen // Number of 8x8 matrices one below another per warp 158621830c91SGuray Ozen constexpr unsigned numStackedMatrices = 2; 158721830c91SGuray Ozen 158821830c91SGuray Ozen size_t storeCount = (structType.getBody().size() / 158921830c91SGuray Ozen (numStackedMatrices * numAdjacentRegisters)); 159021830c91SGuray Ozen 159121830c91SGuray Ozen for (size_t i = 0; i < numStackedMatrices; ++i) { 1592d20fbc90SGuray Ozen Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); 159321830c91SGuray Ozen for (size_t j = 0; j < storeCount; ++j) { 1594d20fbc90SGuray Ozen Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); 159521830c91SGuray Ozen size_t structIndex = (i * numAdjacentRegisters) + 159621830c91SGuray Ozen (j * (numStackedMatrices * numAdjacentRegisters)); 159721830c91SGuray Ozen makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref); 1598d20fbc90SGuray Ozen } 1599d20fbc90SGuray Ozen } 1600d20fbc90SGuray Ozen } 1601d20fbc90SGuray Ozen 1602d20fbc90SGuray Ozen LogicalResult 1603d20fbc90SGuray Ozen matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, 1604d20fbc90SGuray Ozen ConversionPatternRewriter &rewriter) const override { 1605d20fbc90SGuray Ozen int offset = 0; 160652db7e27SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 160752db7e27SGuray Ozen Value matriDValue = adaptor.getMatrixD(); 1608a5757c5bSChristian Sigg auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType()); 160952db7e27SGuray Ozen for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { 1610a5757c5bSChristian Sigg auto structType = cast<LLVM::LLVMStructType>(matrixD); 161152db7e27SGuray Ozen Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx); 161252db7e27SGuray Ozen storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); 1613d20fbc90SGuray Ozen offset += structType.getBody().size(); 1614d20fbc90SGuray Ozen } 1615d20fbc90SGuray Ozen rewriter.eraseOp(op); 1616d20fbc90SGuray Ozen return success(); 1617d20fbc90SGuray Ozen } 1618d20fbc90SGuray Ozen }; 1619d20fbc90SGuray Ozen 1620315ab3c4SGuray Ozen struct NVGPUWarpgroupMmaInitAccumulatorOpLowering 1621315ab3c4SGuray Ozen : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> { 1622315ab3c4SGuray Ozen using ConvertOpToLLVMPattern< 1623315ab3c4SGuray Ozen nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern; 1624315ab3c4SGuray Ozen LogicalResult 1625315ab3c4SGuray Ozen matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor, 1626315ab3c4SGuray Ozen ConversionPatternRewriter &rewriter) const override { 1627315ab3c4SGuray Ozen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1628a5757c5bSChristian Sigg LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>( 1629a5757c5bSChristian Sigg getTypeConverter()->convertType(op.getMatrixC().getType())); 1630a5757c5bSChristian Sigg Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front()) 163152db7e27SGuray Ozen .getBody() 163252db7e27SGuray Ozen .front(); 163352db7e27SGuray Ozen Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType)); 1634c4ba84d6SGuray Ozen Value packStruct = b.create<LLVM::UndefOp>(packStructType); 1635c4ba84d6SGuray Ozen SmallVector<Value> innerStructs; 1636c4ba84d6SGuray Ozen // Unpack the structs and set all values to zero 1637c4ba84d6SGuray Ozen for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { 1638a5757c5bSChristian Sigg auto structType = cast<LLVM::LLVMStructType>(s); 1639c4ba84d6SGuray Ozen Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx); 1640c4ba84d6SGuray Ozen for (unsigned i = 0; i < structType.getBody().size(); ++i) { 1641c4ba84d6SGuray Ozen structValue = b.create<LLVM::InsertValueOp>( 1642c4ba84d6SGuray Ozen structType, structValue, zero, ArrayRef<int64_t>({i})); 1643315ab3c4SGuray Ozen } 1644c4ba84d6SGuray Ozen innerStructs.push_back(structValue); 1645315ab3c4SGuray Ozen } 1646c4ba84d6SGuray Ozen // Pack the inner structs into a single struct 1647c4ba84d6SGuray Ozen for (auto [idx, matrix] : llvm::enumerate(innerStructs)) { 1648c4ba84d6SGuray Ozen packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(), 1649c4ba84d6SGuray Ozen packStruct, matrix, idx); 1650c4ba84d6SGuray Ozen } 1651c4ba84d6SGuray Ozen rewriter.replaceOp(op, packStruct); 1652315ab3c4SGuray Ozen return success(); 1653315ab3c4SGuray Ozen } 1654315ab3c4SGuray Ozen }; 1655315ab3c4SGuray Ozen 165639cdefb5SGuray Ozen struct NVGPUTmaPrefetchOpLowering 165739cdefb5SGuray Ozen : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> { 165839cdefb5SGuray Ozen using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern; 165939cdefb5SGuray Ozen LogicalResult 166039cdefb5SGuray Ozen matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor, 166139cdefb5SGuray Ozen ConversionPatternRewriter &rewriter) const override { 166239cdefb5SGuray Ozen rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>( 166339cdefb5SGuray Ozen op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate()); 166439cdefb5SGuray Ozen return success(); 166539cdefb5SGuray Ozen } 166639cdefb5SGuray Ozen }; 166739cdefb5SGuray Ozen 16682b23e6c8SObserver007 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> { 16692b23e6c8SObserver007 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern; 16702b23e6c8SObserver007 LogicalResult 16712b23e6c8SObserver007 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor, 16722b23e6c8SObserver007 ConversionPatternRewriter &rewriter) const override { 16732b23e6c8SObserver007 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 16742b23e6c8SObserver007 auto i64Ty = b.getI64Type(); 16752b23e6c8SObserver007 auto f32Ty = b.getF32Type(); 16762b23e6c8SObserver007 VectorType inTy = op.getIn().getType(); 16772b23e6c8SObserver007 // apply rcp.approx.ftz.f on each element in vector. 16782b23e6c8SObserver007 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) { 16792b23e6c8SObserver007 Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy); 16802b23e6c8SObserver007 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements(); 16812b23e6c8SObserver007 for (int i = 0; i < numElems; i++) { 16822b23e6c8SObserver007 Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i)); 16832b23e6c8SObserver007 Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx); 16842b23e6c8SObserver007 Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem); 16852b23e6c8SObserver007 ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx); 16862b23e6c8SObserver007 } 16872b23e6c8SObserver007 return ret1DVec; 16882b23e6c8SObserver007 }; 16892b23e6c8SObserver007 if (inTy.getRank() == 1) { 16902b23e6c8SObserver007 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn())); 16912b23e6c8SObserver007 return success(); 16922b23e6c8SObserver007 } 16932b23e6c8SObserver007 return LLVM::detail::handleMultidimensionalVectors( 16942b23e6c8SObserver007 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()), 16952b23e6c8SObserver007 [&](Type llvm1DVectorTy, ValueRange operands) -> Value { 16962b23e6c8SObserver007 OpAdaptor adaptor(operands); 16972b23e6c8SObserver007 return convert1DVec(llvm1DVectorTy, adaptor.getIn()); 16982b23e6c8SObserver007 }, 16992b23e6c8SObserver007 rewriter); 17002b23e6c8SObserver007 } 17012b23e6c8SObserver007 }; 1702894a591cSThomas Raoux } // namespace 170315bcc36eSThomas Raoux 1704206fad0eSMatthias Springer void mlir::populateNVGPUToNVVMConversionPatterns( 1705206fad0eSMatthias Springer const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1706affcfccdSGuray Ozen patterns.add< 1707affcfccdSGuray Ozen NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create 1708affcfccdSGuray Ozen NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init 1709affcfccdSGuray Ozen NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive 1710affcfccdSGuray Ozen NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete 1711836dbb85SGuray Ozen NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity 1712836dbb85SGuray Ozen NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity 1713e56d6745SGuray Ozen NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load 17148dd0d95cSGuray Ozen NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store 1715e56d6745SGuray Ozen NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor 171639cdefb5SGuray Ozen NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor 1717836dbb85SGuray Ozen NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx 17186dc7717bSGuray Ozen NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor 171923882226SGuray Ozen NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma 1720d20fbc90SGuray Ozen NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store 1721315ab3c4SGuray Ozen NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator 1722affcfccdSGuray Ozen MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, 1723708185f0SChristopher Bate NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, 17242b23e6c8SObserver007 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter); 1725894a591cSThomas Raoux } 1726