1 //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to 10 // NVVM Dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" 15 #include "mlir/Conversion/LLVMCommon/Pattern.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 19 #include "mlir/IR/TypeUtilities.h" 20 21 using namespace mlir; 22 23 namespace { 24 25 /// Checks if all the operands of the op being lowered are of LLVM Types. The 26 /// types are expected to be converted by the `LLVMTypeConverter` before the op 27 /// is actually lowered. If the type of an operands is not already converted it 28 /// hints a missing typeConversion and failure is returned in that case. 29 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, 30 ConversionPatternRewriter &rewriter) { 31 if (!llvm::all_of(operands, [](Value value) { 32 return LLVM::isCompatibleType(value.getType()); 33 })) { 34 return rewriter.notifyMatchFailure( 35 op, "cannot convert if operands aren't of LLVM type."); 36 } 37 38 return success(); 39 } 40 41 /// Error string to emit when an unimplemented WMMA variant is encountered. 42 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant."; 43 44 static NVVM::MMAFrag convertOperand(StringRef operandName) { 45 if (operandName == "AOp") 46 return NVVM::MMAFrag::a; 47 if (operandName == "BOp") 48 return NVVM::MMAFrag::b; 49 if (operandName == "COp") 50 return NVVM::MMAFrag::c; 51 llvm_unreachable("Unknown operand name"); 52 } 53 54 static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { 55 if (type.getElementType().isF16()) 56 return NVVM::MMATypes::f16; 57 if (type.getElementType().isF32()) 58 return type.getOperand() == "COp" ? NVVM::MMATypes::f32 59 : NVVM::MMATypes::tf32; 60 61 if (type.getElementType().isSignedInteger(8)) 62 return NVVM::MMATypes::s8; 63 if (type.getElementType().isUnsignedInteger(8)) 64 return NVVM::MMATypes::u8; 65 // Accumulator type is signless and implies signed. 66 if (type.getElementType().isInteger(32)) 67 return NVVM::MMATypes::s32; 68 llvm_unreachable("Unsupported type"); 69 } 70 71 /// This class implements the conversion of GPU MMA loadOp to wmma.load op 72 /// in the NVVM dialect. The conversion not only emits the NVVM op but also 73 /// emits code that is necessary to store the data in the destination memref 74 /// after it has been loaded. 75 struct WmmaLoadOpToNVVMLowering 76 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> { 77 using ConvertOpToLLVMPattern< 78 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern; 79 80 LogicalResult 81 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, 82 OpAdaptor adaptor, 83 ConversionPatternRewriter &rewriter) const override { 84 Operation *op = subgroupMmaLoadMatrixOp.getOperation(); 85 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 86 return failure(); 87 88 // Get the shape of the MMAMatrix type being returned. The shape will 89 // choose which intrinsic this op will be lowered to. 90 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose() 91 ? NVVM::MMALayout::col 92 : NVVM::MMALayout::row; 93 gpu::MMAMatrixType retType = 94 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType()); 95 ArrayRef<int64_t> retTypeShape = retType.getShape(); 96 int64_t m = 0; 97 int64_t n = 0; 98 int64_t k = 0; 99 NVVM::MMATypes eltype = getElementType(retType); 100 // NVVM intrinsics require to give mxnxk dimensions, infer the missing 101 // dimension based on the valid intrinsics available. 102 if (retType.getOperand() == "AOp") { 103 m = retTypeShape[0]; 104 k = retTypeShape[1]; 105 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype); 106 } else if (retType.getOperand() == "BOp") { 107 k = retTypeShape[0]; 108 n = retTypeShape[1]; 109 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype); 110 } else if (retType.getOperand() == "COp") { 111 m = retTypeShape[0]; 112 n = retTypeShape[1]; 113 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype); 114 } 115 NVVM::MMAFrag frag = convertOperand(retType.getOperand()); 116 // Check that there is an exisiting instruction for the combination we need. 117 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) 118 return rewriter.notifyMatchFailure(op, kInvalidCaseStr); 119 120 Type resType = convertMMAToLLVMType(retType); 121 Location loc = op->getLoc(); 122 123 // Create nvvm.mma_load op according to the operand types. 124 Value dataPtr = getStridedElementPtr( 125 loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), 126 adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); 127 128 Value leadingDim = rewriter.create<LLVM::ConstantOp>( 129 loc, rewriter.getI32Type(), 130 subgroupMmaLoadMatrixOp.getLeadDimensionAttr()); 131 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>( 132 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag); 133 return success(); 134 } 135 }; 136 137 /// This class implements the conversion of GPU MMA storeOp to wmma.store op 138 /// in the NVVM dialect. The conversion not only emits the NVVM op but also 139 /// emits code that is necessary to unpack the data in the source and 140 /// convert the data in the format that is needed by the NVVM op. 141 struct WmmaStoreOpToNVVMLowering 142 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> { 143 using ConvertOpToLLVMPattern< 144 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern; 145 146 LogicalResult 147 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, 148 OpAdaptor adaptor, 149 ConversionPatternRewriter &rewriter) const override { 150 Operation *op = subgroupMmaStoreMatrixOp.getOperation(); 151 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 152 return failure(); 153 154 Location loc = op->getLoc(); 155 156 SmallVector<Value, 4> storeOpOperands; 157 // Get the shape of the MMAMatrix type being stored. The shape will 158 // choose which intrinsic this op will be lowered to. 159 gpu::MMAMatrixType srcType = 160 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType()); 161 ArrayRef<int64_t> srcTypeShape = srcType.getShape(); 162 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose() 163 ? NVVM::MMALayout::col 164 : NVVM::MMALayout::row; 165 NVVM::MMATypes eltype = getElementType(srcType); 166 int64_t m = srcTypeShape[0]; 167 int64_t n = srcTypeShape[1]; 168 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype); 169 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0) 170 return rewriter.notifyMatchFailure(op, kInvalidCaseStr); 171 172 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType()); 173 for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { 174 Value toUse = 175 rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i); 176 storeOpOperands.push_back(toUse); 177 } 178 179 Value dataPtr = getStridedElementPtr( 180 loc, 181 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()), 182 adaptor.getDstMemref(), adaptor.getIndices(), rewriter); 183 Value leadingDim = rewriter.create<LLVM::ConstantOp>( 184 loc, rewriter.getI32Type(), 185 subgroupMmaStoreMatrixOp.getLeadDimensionAttr()); 186 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>( 187 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim); 188 return success(); 189 } 190 }; 191 192 /// This class implements the conversion of GPU MMA computeOp to wmma.mma op 193 /// in the NVVM dialect. 194 struct WmmaMmaOpToNVVMLowering 195 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> { 196 using ConvertOpToLLVMPattern< 197 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern; 198 199 LogicalResult 200 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, 201 OpAdaptor adaptor, 202 ConversionPatternRewriter &rewriter) const override { 203 Operation *op = subgroupMmaComputeOp.getOperation(); 204 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 205 return failure(); 206 207 Location loc = op->getLoc(); 208 209 // The wmma.mma intrinsic in llvm requires the operands as individual 210 // values. So individual elements from the memrefs need to be extracted and 211 // then passed on to the intrinsic call. Emit llvm ops to extract individual 212 // values form lowered memrefs. 213 SmallVector<Value> unpackedOps; 214 215 auto unpackOp = [&](Value operand) { 216 auto structType = cast<LLVM::LLVMStructType>(operand.getType()); 217 for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { 218 Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i); 219 unpackedOps.push_back(toUse); 220 } 221 }; 222 223 // Get the shapes of the MMAMatrix type being used. The shapes will 224 // choose which intrinsic this op will be lowered to. 225 gpu::MMAMatrixType aType = 226 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType()); 227 ArrayRef<int64_t> aTypeShape = aType.getShape(); 228 gpu::MMAMatrixType cType = 229 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType()); 230 ArrayRef<int64_t> cTypeShape = cType.getShape(); 231 int64_t m = cTypeShape[0]; 232 int64_t n = cTypeShape[1]; 233 int64_t k = aTypeShape[1]; 234 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose() 235 ? NVVM::MMALayout::col 236 : NVVM::MMALayout::row; 237 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose() 238 ? NVVM::MMALayout::col 239 : NVVM::MMALayout::row; 240 NVVM::MMATypes sourceType = getElementType(aType); 241 NVVM::MMATypes destType = getElementType(cType); 242 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType, 243 destType) == 0) 244 return rewriter.notifyMatchFailure(op, kInvalidCaseStr); 245 246 NVVM::MMATypes bElementType = getElementType( 247 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType())); 248 if (bElementType != sourceType) 249 return rewriter.notifyMatchFailure( 250 op, "WMMA compute op input matrix element types must match."); 251 252 unpackOp(adaptor.getOpA()); 253 unpackOp(adaptor.getOpB()); 254 unpackOp(adaptor.getOpC()); 255 256 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>( 257 op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType, 258 destType, unpackedOps); 259 return success(); 260 } 261 }; 262 263 /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp. 264 struct WmmaConstantOpToNVVMLowering 265 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> { 266 using ConvertOpToLLVMPattern< 267 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern; 268 269 LogicalResult 270 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp, 271 OpAdaptor adaptor, 272 ConversionPatternRewriter &rewriter) const override { 273 if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), 274 adaptor.getOperands(), rewriter))) 275 return failure(); 276 Location loc = subgroupMmaConstantOp.getLoc(); 277 Value cst = adaptor.getOperands()[0]; 278 LLVM::LLVMStructType type = convertMMAToLLVMType( 279 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType())); 280 // If the element type is a vector create a vector from the operand. 281 if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) { 282 Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType); 283 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { 284 Value idx = rewriter.create<LLVM::ConstantOp>( 285 loc, rewriter.getI32Type(), vecEl); 286 vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst, 287 cst, idx); 288 } 289 cst = vecCst; 290 } 291 Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type); 292 for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { 293 matrixStruct = 294 rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i); 295 } 296 rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); 297 return success(); 298 } 299 }; 300 301 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 302 Value rhs, bool isMin) { 303 auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType())); 304 Type i1Type = builder.getI1Type(); 305 if (auto vecType = dyn_cast<VectorType>(lhs.getType())) 306 i1Type = VectorType::get(vecType.getShape(), i1Type); 307 Value cmp = builder.create<LLVM::FCmpOp>( 308 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 309 lhs, rhs); 310 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 311 Value isNan = builder.create<LLVM::FCmpOp>( 312 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 313 Value nan = builder.create<LLVM::ConstantOp>( 314 loc, lhs.getType(), 315 builder.getFloatAttr(floatType, 316 APFloat::getQNaN(floatType.getFloatSemantics()))); 317 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 318 } 319 320 static Value createScalarOp(OpBuilder &builder, Location loc, 321 gpu::MMAElementwiseOp op, 322 ArrayRef<Value> operands) { 323 switch (op) { 324 case gpu::MMAElementwiseOp::ADDF: 325 return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands); 326 case gpu::MMAElementwiseOp::MULF: 327 return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands); 328 case gpu::MMAElementwiseOp::DIVF: 329 return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands); 330 case gpu::MMAElementwiseOp::MAXF: 331 return createMinMaxF(builder, loc, operands[0], operands[1], 332 /*isMin=*/false); 333 case gpu::MMAElementwiseOp::MINF: 334 return createMinMaxF(builder, loc, operands[0], operands[1], 335 /*isMin=*/true); 336 default: 337 llvm_unreachable("unknown op"); 338 } 339 } 340 341 /// Convert GPU MMA elementwise ops to extract + op + insert. 342 struct WmmaElementwiseOpToNVVMLowering 343 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> { 344 using ConvertOpToLLVMPattern< 345 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern; 346 347 LogicalResult 348 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, 349 OpAdaptor adaptor, 350 ConversionPatternRewriter &rewriter) const override { 351 if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(), 352 adaptor.getOperands(), rewriter))) 353 return failure(); 354 Location loc = subgroupMmaElementwiseOp.getLoc(); 355 size_t numOperands = adaptor.getOperands().size(); 356 LLVM::LLVMStructType destType = convertMMAToLLVMType( 357 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType())); 358 Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType); 359 for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { 360 SmallVector<Value> extractedOperands; 361 for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { 362 extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( 363 loc, adaptor.getOperands()[opIdx], i)); 364 } 365 Value element = 366 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(), 367 extractedOperands); 368 matrixStruct = 369 rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i); 370 } 371 rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); 372 return success(); 373 } 374 }; 375 376 } // namespace 377 378 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. 379 LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { 380 NVVM::MMAFrag frag = convertOperand(type.getOperand()); 381 NVVM::MMATypes eltType = getElementType(type); 382 auto nRow = type.getShape()[0]; 383 auto nCol = type.getShape()[1]; 384 std::pair<Type, unsigned> typeInfo = 385 NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext()); 386 return LLVM::LLVMStructType::getLiteral( 387 type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); 388 } 389 390 void mlir::populateGpuWMMAToNVVMConversionPatterns( 391 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 392 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering, 393 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering, 394 WmmaElementwiseOpToNVVMLowering>(converter); 395 } 396