1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" 15 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Pass/Pass.h" 21 22 #include "llvm/ADT/STLExtras.h" 23 #include <optional> 24 25 namespace mlir { 26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL 27 #include "mlir/Conversion/Passes.h.inc" 28 } // namespace mlir 29 30 using namespace mlir; 31 using namespace mlir::amdgpu; 32 33 /// Convert an unsigned number `val` to i32. 34 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, 35 Location loc, Value val) { 36 IntegerType i32 = rewriter.getI32Type(); 37 // Force check that `val` is of int type. 38 auto valTy = cast<IntegerType>(val.getType()); 39 if (i32 == valTy) 40 return val; 41 return valTy.getWidth() > 32 42 ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val)) 43 : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val)); 44 } 45 46 static Value createI32Constant(ConversionPatternRewriter &rewriter, 47 Location loc, int32_t value) { 48 Type i32 = rewriter.getI32Type(); 49 return rewriter.create<LLVM::ConstantOp>(loc, i32, value); 50 } 51 52 static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, 53 bool value) { 54 Type llvmI1 = rewriter.getI1Type(); 55 return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value); 56 } 57 58 /// Returns the linear index used to access an element in the memref. 59 static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, 60 Location loc, MemRefDescriptor &memRefDescriptor, 61 ValueRange indices, ArrayRef<int64_t> strides) { 62 IntegerType i32 = rewriter.getI32Type(); 63 Value index; 64 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) { 65 if (stride != 1) { // Skip if stride is 1. 66 Value strideValue = 67 ShapedType::isDynamic(stride) 68 ? convertUnsignedToI32(rewriter, loc, 69 memRefDescriptor.stride(rewriter, loc, i)) 70 : rewriter.create<LLVM::ConstantOp>(loc, i32, stride); 71 increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue); 72 } 73 index = 74 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; 75 } 76 return index ? index : createI32Constant(rewriter, loc, 0); 77 } 78 79 namespace { 80 // Define commonly used chipsets versions for convenience. 81 constexpr Chipset kGfx908 = Chipset(9, 0, 8); 82 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); 83 constexpr Chipset kGfx940 = Chipset(9, 4, 0); 84 85 /// Define lowering patterns for raw buffer ops 86 template <typename GpuOp, typename Intrinsic> 87 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { 88 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) 89 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {} 90 91 Chipset chipset; 92 static constexpr uint32_t maxVectorOpWidth = 128; 93 94 LogicalResult 95 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, 96 ConversionPatternRewriter &rewriter) const override { 97 Location loc = gpuOp.getLoc(); 98 Value memref = adaptor.getMemref(); 99 Value unconvertedMemref = gpuOp.getMemref(); 100 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType()); 101 102 if (chipset.majorVersion < 9) 103 return gpuOp.emitOpError("raw buffer ops require GCN or higher"); 104 105 Value storeData = adaptor.getODSOperands(0)[0]; 106 if (storeData == memref) // no write component to this op 107 storeData = Value(); 108 Type wantedDataType; 109 if (storeData) 110 wantedDataType = storeData.getType(); 111 else 112 wantedDataType = gpuOp.getODSResults(0)[0].getType(); 113 114 Value atomicCmpData = Value(); 115 // Operand index 1 of a load is the indices, trying to read them can crash. 116 if (storeData) { 117 Value maybeCmpData = adaptor.getODSOperands(1)[0]; 118 if (maybeCmpData != memref) 119 atomicCmpData = maybeCmpData; 120 } 121 122 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); 123 124 Type i32 = rewriter.getI32Type(); 125 Type i16 = rewriter.getI16Type(); 126 127 // Get the type size in bytes. 128 DataLayout dataLayout = DataLayout::closest(gpuOp); 129 int64_t elementByteWidth = 130 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; 131 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); 132 133 // If we want to load a vector<NxT> with total size <= 32 134 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 135 // and the total load size is >= 32, use a vector load of N / (bitsize(T) / 136 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, 137 // so bitcast any floats to integers. 138 Type llvmBufferValType = llvmWantedDataType; 139 if (atomicCmpData) { 140 if (auto floatType = dyn_cast<FloatType>(wantedDataType)) 141 llvmBufferValType = this->getTypeConverter()->convertType( 142 rewriter.getIntegerType(floatType.getWidth())); 143 } 144 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) { 145 uint32_t vecLen = dataVector.getNumElements(); 146 uint32_t elemBits = 147 dataLayout.getTypeSizeInBits(dataVector.getElementType()); 148 uint32_t totalBits = elemBits * vecLen; 149 bool usePackedFp16 = 150 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2; 151 if (totalBits > maxVectorOpWidth) 152 return gpuOp.emitOpError( 153 "Total width of loads or stores must be no more than " + 154 Twine(maxVectorOpWidth) + " bits, but we call for " + 155 Twine(totalBits) + 156 " bits. This should've been caught in validation"); 157 if (!usePackedFp16 && elemBits < 32) { 158 if (totalBits > 32) { 159 if (totalBits % 32 != 0) 160 return gpuOp.emitOpError("Load or store of more than 32-bits that " 161 "doesn't fit into words. Can't happen\n"); 162 llvmBufferValType = this->typeConverter->convertType( 163 VectorType::get(totalBits / 32, i32)); 164 } else { 165 llvmBufferValType = this->typeConverter->convertType( 166 rewriter.getIntegerType(totalBits)); 167 } 168 } 169 } 170 171 SmallVector<Value, 6> args; 172 if (storeData) { 173 if (llvmBufferValType != llvmWantedDataType) { 174 Value castForStore = 175 rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); 176 args.push_back(castForStore); 177 } else { 178 args.push_back(storeData); 179 } 180 } 181 182 if (atomicCmpData) { 183 if (llvmBufferValType != llvmWantedDataType) { 184 Value castForCmp = rewriter.create<LLVM::BitcastOp>( 185 loc, llvmBufferValType, atomicCmpData); 186 args.push_back(castForCmp); 187 } else { 188 args.push_back(atomicCmpData); 189 } 190 } 191 192 // Construct buffer descriptor from memref, attributes 193 int64_t offset = 0; 194 SmallVector<int64_t, 5> strides; 195 if (failed(memrefType.getStridesAndOffset(strides, offset))) 196 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); 197 198 MemRefDescriptor memrefDescriptor(memref); 199 200 Value ptr = memrefDescriptor.bufferPtr( 201 rewriter, loc, *this->getTypeConverter(), memrefType); 202 // The stride value is always 0 for raw buffers. This also disables 203 // swizling. 204 Value stride = rewriter.create<LLVM::ConstantOp>( 205 loc, i16, rewriter.getI16IntegerAttr(0)); 206 // Get the number of elements. 207 Value numRecords; 208 if (memrefType.hasStaticShape() && 209 !llvm::any_of(strides, ShapedType::isDynamic)) { 210 int64_t size = memrefType.getRank() == 0 ? 1 : 0; 211 ArrayRef<int64_t> shape = memrefType.getShape(); 212 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) 213 size = std::max(shape[i] * strides[i], size); 214 size = size * elementByteWidth; 215 assert(size < std::numeric_limits<uint32_t>::max() && 216 "the memref buffer is too large"); 217 numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size)); 218 } else { 219 Value maxIndex; 220 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { 221 Value size = memrefDescriptor.size(rewriter, loc, i); 222 Value stride = memrefDescriptor.stride(rewriter, loc, i); 223 Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); 224 maxIndex = 225 maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim) 226 : maxThisDim; 227 } 228 numRecords = rewriter.create<LLVM::MulOp>( 229 loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst); 230 } 231 232 // Flag word: 233 // bits 0-11: dst sel, ignored by these intrinsics 234 // bits 12-14: data format (ignored, must be nonzero, 7=float) 235 // bits 15-18: data format (ignored, must be nonzero, 4=32bit) 236 // bit 19: In nested heap (0 here) 237 // bit 20: Behavior on unmap (0 means "return 0 / ignore") 238 // bits 21-22: Index stride for swizzles (N/A) 239 // bit 23: Add thread ID (0) 240 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) 241 // bits 25-26: Reserved (0) 242 // bit 27: Buffer is non-volatile (CDNA only) 243 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = 244 // none, 3 = either swizzles or testing against offset field) RDNA only 245 // bits 30-31: Type (must be 0) 246 uint32_t flags = (7 << 12) | (4 << 15); 247 if (chipset.majorVersion >= 10) { 248 flags |= (1 << 24); 249 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2; 250 flags |= (oob << 28); 251 } 252 Value flagsConst = createI32Constant(rewriter, loc, flags); 253 Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); 254 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>( 255 loc, rsrcType, ptr, stride, numRecords, flagsConst); 256 args.push_back(resource); 257 258 // Indexing (voffset) 259 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor, 260 adaptor.getIndices(), strides); 261 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset(); 262 indexOffset && *indexOffset > 0) { 263 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); 264 voffset = 265 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) 266 : extraOffsetConst; 267 } 268 voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst); 269 args.push_back(voffset); 270 271 // SGPR offset. 272 Value sgprOffset = adaptor.getSgprOffset(); 273 if (!sgprOffset) 274 sgprOffset = createI32Constant(rewriter, loc, 0); 275 sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst); 276 args.push_back(sgprOffset); 277 278 // bit 0: GLC = 0 (atomics drop value, less coherency) 279 // bits 1-2: SLC, DLC = 0 (similarly) 280 // bit 3: swizzled (0 for raw) 281 args.push_back(createI32Constant(rewriter, loc, 0)); 282 283 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), 284 llvmBufferValType); 285 Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, 286 ArrayRef<NamedAttribute>()); 287 if (lowered->getNumResults() == 1) { 288 Value replacement = lowered->getResult(0); 289 if (llvmBufferValType != llvmWantedDataType) { 290 replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, 291 replacement); 292 } 293 rewriter.replaceOp(gpuOp, replacement); 294 } else { 295 rewriter.eraseOp(gpuOp); 296 } 297 return success(); 298 } 299 }; 300 301 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { 302 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) 303 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {} 304 305 Chipset chipset; 306 307 LogicalResult 308 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, 309 ConversionPatternRewriter &rewriter) const override { 310 bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11; 311 312 if (requiresInlineAsm) { 313 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), 314 LLVM::AsmDialect::AD_ATT); 315 const char *asmStr = 316 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier"; 317 const char *constraints = ""; 318 rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>( 319 op, 320 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), 321 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, 322 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, 323 /*operand_attrs=*/ArrayAttr()); 324 return success(); 325 } 326 if (chipset.majorVersion < 12) { 327 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8); 328 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8); 329 // Left in place in case someone disables the inline ASM path or future 330 // chipsets use the same bit pattern. 331 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4); 332 333 int32_t ldsOnlyBits; 334 if (chipset.majorVersion == 11) 335 ldsOnlyBits = ldsOnlyBitsGfx11; 336 else if (chipset.majorVersion == 10) 337 ldsOnlyBits = ldsOnlyBitsGfx10; 338 else if (chipset.majorVersion <= 9) 339 ldsOnlyBits = ldsOnlyBitsGfx6789; 340 else 341 return op.emitOpError( 342 "don't know how to lower this for chipset major version") 343 << chipset.majorVersion; 344 345 Location loc = op->getLoc(); 346 rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits); 347 rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op); 348 } else { 349 Location loc = op->getLoc(); 350 rewriter.create<ROCDL::WaitDscntOp>(loc, 0); 351 rewriter.create<ROCDL::BarrierSignalOp>(loc, -1); 352 rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1); 353 } 354 355 return success(); 356 } 357 }; 358 359 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> { 360 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) 361 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {} 362 363 Chipset chipset; 364 365 LogicalResult 366 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor, 367 ConversionPatternRewriter &rewriter) const override { 368 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op, 369 (uint32_t)op.getOpts()); 370 return success(); 371 } 372 }; 373 374 } // namespace 375 376 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL 377 /// and LLVM AMDGPU intrinsics convention. 378 /// 379 /// Specifically: 380 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer. 381 /// 2. If the element type is bfloat16, bitcast it to i16. 382 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, 383 Location loc, Value input) { 384 Type inputType = input.getType(); 385 if (auto vectorType = dyn_cast<VectorType>(inputType)) { 386 if (vectorType.getElementType().isBF16()) 387 return rewriter.create<LLVM::BitcastOp>( 388 loc, vectorType.clone(rewriter.getI16Type()), input); 389 if (vectorType.getElementType().isInteger(8)) { 390 return rewriter.create<LLVM::BitcastOp>( 391 loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); 392 } 393 } 394 return input; 395 } 396 397 /// Push an input operand. If it is a float type, nothing to do. If it is 398 /// an integer type, then we need to also push its signdness (1 for signed, 0 399 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 400 /// vector. We also need to convert bfloat inputs to i16 to account for the lack 401 /// of bfloat support in the WMMA intrinsics themselves. 402 static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, 403 Location loc, 404 const TypeConverter *typeConverter, 405 bool isUnsigned, Value llvmInput, 406 Value mlirInput, 407 SmallVector<Value, 4> &operands) { 408 Type inputType = llvmInput.getType(); 409 auto vectorType = dyn_cast<VectorType>(inputType); 410 Type elemType = vectorType.getElementType(); 411 412 if (elemType.isBF16()) 413 llvmInput = rewriter.create<LLVM::BitcastOp>( 414 loc, vectorType.clone(rewriter.getI16Type()), llvmInput); 415 if (!elemType.isInteger(8)) { 416 operands.push_back(llvmInput); 417 return; 418 } 419 420 // We need to check the type of the input before conversion to properly test 421 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the 422 // fp8/int8 information is lost during the conversion process. 423 auto mlirInputType = cast<VectorType>(mlirInput.getType()); 424 bool isInputInt8 = mlirInputType.getElementType().isInteger(8); 425 if (isInputInt8) { 426 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag 427 bool localIsUnsigned = isUnsigned; 428 if (elemType.isUnsignedInteger(8)) { 429 localIsUnsigned = true; 430 } else if (elemType.isSignedInteger(8)) { 431 localIsUnsigned = false; 432 } 433 Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); 434 operands.push_back(sign); 435 } 436 437 int64_t numBytes = vectorType.getNumElements(); 438 Type i32 = rewriter.getI32Type(); 439 VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); 440 auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); 441 Value result = rewriter.createOrFold<LLVM::BitcastOp>( 442 loc, llvmVectorType32bits, llvmInput); 443 operands.push_back(result); 444 } 445 446 /// Push the output operand. For many cases this is only pushing the output in 447 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, 448 /// since the same numbers of VGPRs is used, we need to decide if to store the 449 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the 450 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will 451 /// be stored it in the upper part 452 static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, 453 Location loc, 454 const TypeConverter *typeConverter, 455 Value output, int32_t subwordOffset, 456 bool clamp, SmallVector<Value, 4> &operands) { 457 Type inputType = output.getType(); 458 auto vectorType = dyn_cast<VectorType>(inputType); 459 Type elemType = vectorType.getElementType(); 460 if (elemType.isBF16()) 461 output = rewriter.create<LLVM::BitcastOp>( 462 loc, vectorType.clone(rewriter.getI16Type()), output); 463 operands.push_back(output); 464 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { 465 operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); 466 } else if (elemType.isInteger(32)) { 467 operands.push_back(createI1Constant(rewriter, loc, clamp)); 468 } 469 } 470 471 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` 472 /// if one exists. This includes checking to ensure the intrinsic is supported 473 /// on the architecture you are compiling for. 474 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, 475 Chipset chipset) { 476 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), 477 b = mfma.getBlocks(); 478 Type sourceElem = mfma.getSourceA().getType(); 479 if (auto sourceType = dyn_cast<VectorType>(sourceElem)) 480 sourceElem = sourceType.getElementType(); 481 Type destElem = mfma.getDestC().getType(); 482 if (auto destType = dyn_cast<VectorType>(destElem)) 483 destElem = destType.getElementType(); 484 485 if (sourceElem.isF32() && destElem.isF32()) { 486 if (mfma.getReducePrecision() && chipset >= kGfx940) { 487 if (m == 32 && n == 32 && k == 4 && b == 1) 488 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); 489 if (m == 16 && n == 16 && k == 8 && b == 1) 490 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); 491 } 492 if (m == 32 && n == 32 && k == 1 && b == 2) 493 return ROCDL::mfma_f32_32x32x1f32::getOperationName(); 494 if (m == 16 && n == 16 && k == 1 && b == 4) 495 return ROCDL::mfma_f32_16x16x1f32::getOperationName(); 496 if (m == 4 && n == 4 && k == 1 && b == 16) 497 return ROCDL::mfma_f32_4x4x1f32::getOperationName(); 498 if (m == 32 && n == 32 && k == 2 && b == 1) 499 return ROCDL::mfma_f32_32x32x2f32::getOperationName(); 500 if (m == 16 && n == 16 && k == 4 && b == 1) 501 return ROCDL::mfma_f32_16x16x4f32::getOperationName(); 502 } 503 504 if (sourceElem.isF16() && destElem.isF32()) { 505 if (m == 32 && n == 32 && k == 4 && b == 2) 506 return ROCDL::mfma_f32_32x32x4f16::getOperationName(); 507 if (m == 16 && n == 16 && k == 4 && b == 4) 508 return ROCDL::mfma_f32_16x16x4f16::getOperationName(); 509 if (m == 4 && n == 4 && k == 4 && b == 16) 510 return ROCDL::mfma_f32_4x4x4f16::getOperationName(); 511 if (m == 32 && n == 32 && k == 8 && b == 1) 512 return ROCDL::mfma_f32_32x32x8f16::getOperationName(); 513 if (m == 16 && n == 16 && k == 16 && b == 1) 514 return ROCDL::mfma_f32_16x16x16f16::getOperationName(); 515 } 516 517 if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) { 518 if (m == 32 && n == 32 && k == 4 && b == 2) 519 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); 520 if (m == 16 && n == 16 && k == 4 && b == 4) 521 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); 522 if (m == 4 && n == 4 && k == 4 && b == 16) 523 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); 524 if (m == 32 && n == 32 && k == 8 && b == 1) 525 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); 526 if (m == 16 && n == 16 && k == 16 && b == 1) 527 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); 528 } 529 530 if (sourceElem.isBF16() && destElem.isF32()) { 531 if (m == 32 && n == 32 && k == 2 && b == 2) 532 return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); 533 if (m == 16 && n == 16 && k == 2 && b == 4) 534 return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); 535 if (m == 4 && n == 4 && k == 2 && b == 16) 536 return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); 537 if (m == 32 && n == 32 && k == 4 && b == 1) 538 return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); 539 if (m == 16 && n == 16 && k == 8 && b == 1) 540 return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); 541 } 542 543 if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) { 544 if (m == 32 && n == 32 && k == 4 && b == 2) 545 return ROCDL::mfma_i32_32x32x4i8::getOperationName(); 546 if (m == 16 && n == 16 && k == 4 && b == 4) 547 return ROCDL::mfma_i32_16x16x4i8::getOperationName(); 548 if (m == 4 && n == 4 && k == 4 && b == 16) 549 return ROCDL::mfma_i32_4x4x4i8::getOperationName(); 550 if (m == 32 && n == 32 && k == 8 && b == 1) 551 return ROCDL::mfma_i32_32x32x8i8::getOperationName(); 552 if (m == 16 && n == 16 && k == 16 && b == 1) 553 return ROCDL::mfma_i32_16x16x16i8::getOperationName(); 554 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940) 555 return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); 556 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940) 557 return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); 558 } 559 560 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) { 561 if (m == 16 && n == 16 && k == 4 && b == 1) 562 return ROCDL::mfma_f64_16x16x4f64::getOperationName(); 563 if (m == 4 && n == 4 && k == 4 && b == 4) 564 return ROCDL::mfma_f64_4x4x4f64::getOperationName(); 565 } 566 567 if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() && 568 chipset >= kGfx940) { 569 // Known to be correct because there are no scalar f8 instructions and 570 // because a length mismatch will have been caught by the verifier. 571 Type sourceBElem = 572 cast<VectorType>(mfma.getSourceB().getType()).getElementType(); 573 if (m == 16 && n == 16 && k == 32 && b == 1) { 574 if (isa<Float8E5M2FNUZType>(sourceBElem)) 575 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); 576 if (isa<Float8E4M3FNUZType>(sourceBElem)) 577 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); 578 } 579 if (m == 32 && n == 32 && k == 16 && b == 1) { 580 if (isa<Float8E5M2FNUZType>(sourceBElem)) 581 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); 582 if (isa<Float8E4M3FNUZType>(sourceBElem)) 583 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); 584 } 585 } 586 587 if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() && 588 chipset >= kGfx940) { 589 Type sourceBElem = 590 cast<VectorType>(mfma.getSourceB().getType()).getElementType(); 591 if (m == 16 && n == 16 && k == 32 && b == 1) { 592 if (isa<Float8E5M2FNUZType>(sourceBElem)) 593 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); 594 if (isa<Float8E4M3FNUZType>(sourceBElem)) 595 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); 596 } 597 if (m == 32 && n == 32 && k == 16 && b == 1) { 598 if (isa<Float8E5M2FNUZType>(sourceBElem)) 599 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); 600 if (isa<Float8E4M3FNUZType>(sourceBElem)) 601 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); 602 } 603 } 604 605 return std::nullopt; 606 } 607 608 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` 609 /// if one exists. This includes checking to ensure the intrinsic is supported 610 /// on the architecture you are compiling for. 611 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, 612 Chipset chipset) { 613 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType()); 614 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType()); 615 auto elemSourceType = sourceVectorType.getElementType(); 616 auto elemDestType = destVectorType.getElementType(); 617 618 if (elemSourceType.isF16() && elemDestType.isF32()) 619 return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); 620 if (elemSourceType.isBF16() && elemDestType.isF32()) 621 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); 622 if (elemSourceType.isF16() && elemDestType.isF16()) 623 return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); 624 if (elemSourceType.isBF16() && elemDestType.isBF16()) 625 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); 626 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) 627 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); 628 if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32()) 629 return ROCDL::wmma_f32_16x16x16_fp8::getOperationName(); 630 if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32()) 631 return ROCDL::wmma_f32_16x16x16_bf8::getOperationName(); 632 return std::nullopt; 633 } 634 635 namespace { 636 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { 637 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) 638 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {} 639 640 Chipset chipset; 641 642 LogicalResult 643 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, 644 ConversionPatternRewriter &rewriter) const override { 645 Location loc = op.getLoc(); 646 Type outType = typeConverter->convertType(op.getDestD().getType()); 647 Type intrinsicOutType = outType; 648 if (auto outVecType = dyn_cast<VectorType>(outType)) 649 if (outVecType.getElementType().isBF16()) 650 intrinsicOutType = outVecType.clone(rewriter.getI16Type()); 651 652 if (chipset.majorVersion != 9 || chipset < kGfx908) 653 return op->emitOpError("MFMA only supported on gfx908+"); 654 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp()); 655 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { 656 if (chipset < kGfx940) 657 return op.emitOpError("negation unsupported on older than gfx940"); 658 getBlgpField |= 659 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); 660 } 661 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); 662 if (!maybeIntrinsic.has_value()) 663 return op.emitOpError("no intrinsic matching MFMA size on given chipset"); 664 OperationState loweredOp(loc, *maybeIntrinsic); 665 loweredOp.addTypes(intrinsicOutType); 666 loweredOp.addOperands( 667 {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()), 668 convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()), 669 adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()), 670 createI32Constant(rewriter, loc, op.getAbid()), 671 createI32Constant(rewriter, loc, getBlgpField)}); 672 Value lowered = rewriter.create(loweredOp)->getResult(0); 673 if (outType != intrinsicOutType) 674 lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered); 675 rewriter.replaceOp(op, lowered); 676 return success(); 677 } 678 }; 679 680 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { 681 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) 682 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {} 683 684 Chipset chipset; 685 686 LogicalResult 687 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, 688 ConversionPatternRewriter &rewriter) const override { 689 Location loc = op.getLoc(); 690 auto outType = 691 typeConverter->convertType<VectorType>(op.getDestD().getType()); 692 if (!outType) 693 return rewriter.notifyMatchFailure(op, "type conversion failed"); 694 695 if (chipset.majorVersion != 11 && chipset.majorVersion != 12) 696 return op->emitOpError("WMMA only supported on gfx11 and gfx12"); 697 698 // The WMMA operations represent vectors of bf16s as vectors of i16s, so we 699 // need to bitcast bfloats to i16 and then bitcast them back. 700 VectorType rawOutType = outType; 701 if (outType.getElementType().isBF16()) 702 rawOutType = outType.clone(rewriter.getI16Type()); 703 704 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); 705 706 if (!maybeIntrinsic.has_value()) 707 return op.emitOpError("no intrinsic matching WMMA on the given chipset"); 708 709 OperationState loweredOp(loc, *maybeIntrinsic); 710 loweredOp.addTypes(rawOutType); 711 712 SmallVector<Value, 4> operands; 713 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), 714 adaptor.getSourceA(), op.getSourceA(), operands); 715 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), 716 adaptor.getSourceB(), op.getSourceB(), operands); 717 wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), 718 op.getSubwordOffset(), op.getClamp(), operands); 719 720 loweredOp.addOperands(operands); 721 Operation *lowered = rewriter.create(loweredOp); 722 723 Operation *maybeCastBack = lowered; 724 if (rawOutType != outType) 725 maybeCastBack = 726 rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0)); 727 rewriter.replaceOp(op, maybeCastBack->getResults()); 728 729 return success(); 730 } 731 }; 732 733 namespace { 734 struct ExtPackedFp8OpLowering final 735 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> { 736 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) 737 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter), 738 chipset(chipset) {} 739 Chipset chipset; 740 741 LogicalResult 742 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, 743 ConversionPatternRewriter &rewriter) const override; 744 }; 745 746 struct PackedTrunc2xFp8OpLowering final 747 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { 748 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, 749 Chipset chipset) 750 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter), 751 chipset(chipset) {} 752 Chipset chipset; 753 754 LogicalResult 755 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, 756 ConversionPatternRewriter &rewriter) const override; 757 }; 758 759 struct PackedStochRoundFp8OpLowering final 760 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> { 761 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter, 762 Chipset chipset) 763 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter), 764 chipset(chipset) {} 765 Chipset chipset; 766 767 LogicalResult 768 matchAndRewrite(PackedStochRoundFp8Op op, 769 PackedStochRoundFp8OpAdaptor adaptor, 770 ConversionPatternRewriter &rewriter) const override; 771 }; 772 } // end namespace 773 774 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( 775 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, 776 ConversionPatternRewriter &rewriter) const { 777 Location loc = op.getLoc(); 778 if (chipset.majorVersion != 9 || chipset < kGfx940) 779 return rewriter.notifyMatchFailure( 780 loc, "Fp8 conversion instructions are not available on target " 781 "architecture and their emulation is not implemented"); 782 Type v4i8 = 783 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); 784 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); 785 Type f32 = getTypeConverter()->convertType(op.getResult().getType()); 786 787 Value source = adaptor.getSource(); 788 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType()); 789 Type sourceElemType = getElementTypeOrSelf(op.getSource()); 790 // Extend to a v4i8 791 if (!sourceVecType || sourceVecType.getNumElements() < 4) { 792 Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8); 793 if (!sourceVecType) { 794 longVec = rewriter.create<LLVM::InsertElementOp>( 795 loc, longVec, source, createI32Constant(rewriter, loc, 0)); 796 } else { 797 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { 798 Value idx = createI32Constant(rewriter, loc, i); 799 Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); 800 longVec = 801 rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); 802 } 803 } 804 source = longVec; 805 } 806 Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); 807 Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); 808 if (isa<Float8E5M2FNUZType>(sourceElemType)) { 809 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source, 810 wordSel); 811 } else if (isa<Float8E4M3FNUZType>(sourceElemType)) { 812 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source, 813 wordSel); 814 } 815 return success(); 816 } 817 818 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( 819 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, 820 ConversionPatternRewriter &rewriter) const { 821 Location loc = op.getLoc(); 822 if (chipset.majorVersion != 9 || chipset < kGfx940) 823 return rewriter.notifyMatchFailure( 824 loc, "Fp8 conversion instructions are not available on target " 825 "architecture and their emulation is not implemented"); 826 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); 827 828 Type resultType = op.getResult().getType(); 829 Type resultElemType = getElementTypeOrSelf(resultType); 830 831 Value sourceA = adaptor.getSourceA(); 832 Value sourceB = adaptor.getSourceB(); 833 if (!sourceB) 834 sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType()); 835 Value existing = adaptor.getExisting(); 836 if (existing) 837 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); 838 else 839 existing = rewriter.create<LLVM::UndefOp>(loc, i32); 840 Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); 841 842 Value result; 843 if (isa<Float8E5M2FNUZType>(resultElemType)) 844 result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, 845 existing, wordSel); 846 else if (isa<Float8E4M3FNUZType>(resultElemType)) 847 result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, 848 existing, wordSel); 849 850 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( 851 op, getTypeConverter()->convertType(resultType), result); 852 return success(); 853 } 854 855 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( 856 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, 857 ConversionPatternRewriter &rewriter) const { 858 Location loc = op.getLoc(); 859 if (chipset.majorVersion != 9 || chipset < kGfx940) 860 return rewriter.notifyMatchFailure( 861 loc, "Fp8 conversion instructions are not available on target " 862 "architecture and their emulation is not implemented"); 863 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); 864 865 Type resultType = op.getResult().getType(); 866 Type resultElemType = getElementTypeOrSelf(resultType); 867 868 Value source = adaptor.getSource(); 869 Value stoch = adaptor.getStochiasticParam(); 870 Value existing = adaptor.getExisting(); 871 if (existing) 872 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); 873 else 874 existing = rewriter.create<LLVM::UndefOp>(loc, i32); 875 Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); 876 877 Value result; 878 if (isa<Float8E5M2FNUZType>(resultElemType)) 879 result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch, 880 existing, byteSel); 881 else if (isa<Float8E4M3FNUZType>(resultElemType)) 882 result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch, 883 existing, byteSel); 884 885 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( 886 op, getTypeConverter()->convertType(resultType), result); 887 return success(); 888 } 889 890 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp 891 // operation into the corresponding ROCDL instructions. 892 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { 893 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset) 894 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {} 895 Chipset chipset; 896 897 LogicalResult 898 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor, 899 ConversionPatternRewriter &rewriter) const override { 900 901 // Convert the source operand to the corresponding LLVM type 902 Location loc = DppOp.getLoc(); 903 Value src = adaptor.getSrc(); 904 Value old = adaptor.getOld(); 905 Type srcType = src.getType(); 906 Type oldType = old.getType(); 907 Type llvmType = nullptr; 908 if (srcType.getIntOrFloatBitWidth() < 32) { 909 llvmType = rewriter.getI32Type(); 910 } else if (isa<FloatType>(srcType)) { 911 llvmType = (srcType.getIntOrFloatBitWidth() == 32) 912 ? rewriter.getF32Type() 913 : rewriter.getF64Type(); 914 } else if (isa<IntegerType>(srcType)) { 915 llvmType = (srcType.getIntOrFloatBitWidth() == 32) 916 ? rewriter.getI32Type() 917 : rewriter.getI64Type(); 918 } 919 auto llvmSrcIntType = typeConverter->convertType( 920 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth())); 921 922 // If the source type is less of 32, use bitcast to convert it to i32. 923 auto convertOperand = [&](Value operand, Type operandType) { 924 if (operandType.getIntOrFloatBitWidth() <= 16) { 925 if (llvm::isa<FloatType>(operandType)) { 926 operand = 927 rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand); 928 } 929 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( 930 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); 931 Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType); 932 operand = rewriter.create<LLVM::InsertElementOp>( 933 loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); 934 operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand); 935 } 936 return operand; 937 }; 938 939 src = convertOperand(src, srcType); 940 old = convertOperand(old, oldType); 941 942 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h 943 enum DppCtrl : unsigned { 944 ROW_SHL0 = 0x100, 945 ROW_SHR0 = 0x110, 946 ROW_ROR0 = 0x120, 947 WAVE_SHL1 = 0x130, 948 WAVE_ROL1 = 0x134, 949 WAVE_SHR1 = 0x138, 950 WAVE_ROR1 = 0x13C, 951 ROW_MIRROR = 0x140, 952 ROW_HALF_MIRROR = 0x141, 953 BCAST15 = 0x142, 954 BCAST31 = 0x143, 955 }; 956 957 auto kind = DppOp.getKind(); 958 auto permArgument = DppOp.getPermArgument(); 959 uint32_t DppCtrl = 0; 960 961 switch (kind) { 962 963 case DPPPerm::quad_perm: 964 if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) { 965 int32_t i = 0; 966 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) { 967 uint32_t num = elem.getInt(); 968 DppCtrl |= num << (i * 2); 969 i++; 970 } 971 } 972 break; 973 case DPPPerm::row_shl: 974 if (auto intAttr = cast<IntegerAttr>(*permArgument)) { 975 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0; 976 } 977 break; 978 case DPPPerm::row_shr: 979 if (auto intAttr = cast<IntegerAttr>(*permArgument)) { 980 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0; 981 } 982 break; 983 case DPPPerm::row_ror: 984 if (auto intAttr = cast<IntegerAttr>(*permArgument)) { 985 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0; 986 } 987 break; 988 case DPPPerm::wave_shl: 989 DppCtrl = DppCtrl::WAVE_SHL1; 990 break; 991 case DPPPerm::wave_shr: 992 DppCtrl = DppCtrl::WAVE_SHR1; 993 break; 994 case DPPPerm::wave_rol: 995 DppCtrl = DppCtrl::WAVE_ROL1; 996 break; 997 case DPPPerm::wave_ror: 998 DppCtrl = DppCtrl::WAVE_ROR1; 999 break; 1000 case DPPPerm::row_mirror: 1001 DppCtrl = DppCtrl::ROW_MIRROR; 1002 break; 1003 case DPPPerm::row_half_mirror: 1004 DppCtrl = DppCtrl::ROW_HALF_MIRROR; 1005 break; 1006 case DPPPerm::row_bcast_15: 1007 DppCtrl = DppCtrl::BCAST15; 1008 break; 1009 case DPPPerm::row_bcast_31: 1010 DppCtrl = DppCtrl::BCAST31; 1011 break; 1012 } 1013 1014 // Check for row_mask, bank_mask, bound_ctrl if they exist and create 1015 // constants 1016 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt(); 1017 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt(); 1018 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue(); 1019 1020 // create a ROCDL_DPPMovOp instruction with the appropriate attributes 1021 auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>( 1022 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); 1023 1024 Value result = dppMovOp.getRes(); 1025 if (srcType.getIntOrFloatBitWidth() < 32) { 1026 result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result); 1027 if (!llvm::isa<IntegerType>(srcType)) { 1028 result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result); 1029 } 1030 } 1031 1032 // We are replacing the AMDGPU_DPPOp instruction with the new 1033 // ROCDL_DPPMovOp instruction 1034 rewriter.replaceOp(DppOp, ValueRange(result)); 1035 return success(); 1036 } 1037 }; 1038 1039 struct ConvertAMDGPUToROCDLPass 1040 : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> { 1041 ConvertAMDGPUToROCDLPass() = default; 1042 1043 void runOnOperation() override { 1044 MLIRContext *ctx = &getContext(); 1045 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset); 1046 if (failed(maybeChipset)) { 1047 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); 1048 return signalPassFailure(); 1049 } 1050 1051 RewritePatternSet patterns(ctx); 1052 LLVMTypeConverter converter(ctx); 1053 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); 1054 LLVMConversionTarget target(getContext()); 1055 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); 1056 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 1057 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); 1058 if (failed(applyPartialConversion(getOperation(), target, 1059 std::move(patterns)))) 1060 signalPassFailure(); 1061 } 1062 }; 1063 } // namespace 1064 1065 void mlir::populateAMDGPUToROCDLConversionPatterns( 1066 const LLVMTypeConverter &converter, RewritePatternSet &patterns, 1067 Chipset chipset) { 1068 patterns 1069 .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, 1070 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, 1071 RawBufferOpLowering<RawBufferAtomicFaddOp, 1072 ROCDL::RawPtrBufferAtomicFaddOp>, 1073 RawBufferOpLowering<RawBufferAtomicFmaxOp, 1074 ROCDL::RawPtrBufferAtomicFmaxOp>, 1075 RawBufferOpLowering<RawBufferAtomicSmaxOp, 1076 ROCDL::RawPtrBufferAtomicSmaxOp>, 1077 RawBufferOpLowering<RawBufferAtomicUminOp, 1078 ROCDL::RawPtrBufferAtomicUminOp>, 1079 RawBufferOpLowering<RawBufferAtomicCmpswapOp, 1080 ROCDL::RawPtrBufferAtomicCmpSwap>, 1081 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, 1082 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, 1083 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter, 1084 chipset); 1085 } 1086 1087 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() { 1088 return std::make_unique<ConvertAMDGPUToROCDLPass>(); 1089 } 1090