xref: /llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
10 
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include "llvm/ADT/STLExtras.h"
23 #include <optional>
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::amdgpu;
32 
33 /// Convert an unsigned number `val` to i32.
34 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
35                                   Location loc, Value val) {
36   IntegerType i32 = rewriter.getI32Type();
37   // Force check that `val` is of int type.
38   auto valTy = cast<IntegerType>(val.getType());
39   if (i32 == valTy)
40     return val;
41   return valTy.getWidth() > 32
42              ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
43              : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
44 }
45 
46 static Value createI32Constant(ConversionPatternRewriter &rewriter,
47                                Location loc, int32_t value) {
48   Type i32 = rewriter.getI32Type();
49   return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
50 }
51 
52 static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
53                               bool value) {
54   Type llvmI1 = rewriter.getI1Type();
55   return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
56 }
57 
58 /// Returns the linear index used to access an element in the memref.
59 static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
60                                Location loc, MemRefDescriptor &memRefDescriptor,
61                                ValueRange indices, ArrayRef<int64_t> strides) {
62   IntegerType i32 = rewriter.getI32Type();
63   Value index;
64   for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
65     if (stride != 1) { // Skip if stride is 1.
66       Value strideValue =
67           ShapedType::isDynamic(stride)
68               ? convertUnsignedToI32(rewriter, loc,
69                                      memRefDescriptor.stride(rewriter, loc, i))
70               : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
71       increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
72     }
73     index =
74         index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
75   }
76   return index ? index : createI32Constant(rewriter, loc, 0);
77 }
78 
79 namespace {
80 // Define commonly used chipsets versions for convenience.
81 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
82 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
83 constexpr Chipset kGfx940 = Chipset(9, 4, 0);
84 
85 /// Define lowering patterns for raw buffer ops
86 template <typename GpuOp, typename Intrinsic>
87 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
88   RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
89       : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
90 
91   Chipset chipset;
92   static constexpr uint32_t maxVectorOpWidth = 128;
93 
94   LogicalResult
95   matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
96                   ConversionPatternRewriter &rewriter) const override {
97     Location loc = gpuOp.getLoc();
98     Value memref = adaptor.getMemref();
99     Value unconvertedMemref = gpuOp.getMemref();
100     MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
101 
102     if (chipset.majorVersion < 9)
103       return gpuOp.emitOpError("raw buffer ops require GCN or higher");
104 
105     Value storeData = adaptor.getODSOperands(0)[0];
106     if (storeData == memref) // no write component to this op
107       storeData = Value();
108     Type wantedDataType;
109     if (storeData)
110       wantedDataType = storeData.getType();
111     else
112       wantedDataType = gpuOp.getODSResults(0)[0].getType();
113 
114     Value atomicCmpData = Value();
115     // Operand index 1 of a load is the indices, trying to read them can crash.
116     if (storeData) {
117       Value maybeCmpData = adaptor.getODSOperands(1)[0];
118       if (maybeCmpData != memref)
119         atomicCmpData = maybeCmpData;
120     }
121 
122     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
123 
124     Type i32 = rewriter.getI32Type();
125     Type i16 = rewriter.getI16Type();
126 
127     // Get the type size in bytes.
128     DataLayout dataLayout = DataLayout::closest(gpuOp);
129     int64_t elementByteWidth =
130         dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
131     Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
132 
133     // If we want to load a vector<NxT> with total size <= 32
134     // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
135     // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
136     // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
137     // so bitcast any floats to integers.
138     Type llvmBufferValType = llvmWantedDataType;
139     if (atomicCmpData) {
140       if (auto floatType = dyn_cast<FloatType>(wantedDataType))
141         llvmBufferValType = this->getTypeConverter()->convertType(
142             rewriter.getIntegerType(floatType.getWidth()));
143     }
144     if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
145       uint32_t vecLen = dataVector.getNumElements();
146       uint32_t elemBits =
147           dataLayout.getTypeSizeInBits(dataVector.getElementType());
148       uint32_t totalBits = elemBits * vecLen;
149       bool usePackedFp16 =
150           isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
151       if (totalBits > maxVectorOpWidth)
152         return gpuOp.emitOpError(
153             "Total width of loads or stores must be no more than " +
154             Twine(maxVectorOpWidth) + " bits, but we call for " +
155             Twine(totalBits) +
156             " bits. This should've been caught in validation");
157       if (!usePackedFp16 && elemBits < 32) {
158         if (totalBits > 32) {
159           if (totalBits % 32 != 0)
160             return gpuOp.emitOpError("Load or store of more than 32-bits that "
161                                      "doesn't fit into words. Can't happen\n");
162           llvmBufferValType = this->typeConverter->convertType(
163               VectorType::get(totalBits / 32, i32));
164         } else {
165           llvmBufferValType = this->typeConverter->convertType(
166               rewriter.getIntegerType(totalBits));
167         }
168       }
169     }
170 
171     SmallVector<Value, 6> args;
172     if (storeData) {
173       if (llvmBufferValType != llvmWantedDataType) {
174         Value castForStore =
175             rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
176         args.push_back(castForStore);
177       } else {
178         args.push_back(storeData);
179       }
180     }
181 
182     if (atomicCmpData) {
183       if (llvmBufferValType != llvmWantedDataType) {
184         Value castForCmp = rewriter.create<LLVM::BitcastOp>(
185             loc, llvmBufferValType, atomicCmpData);
186         args.push_back(castForCmp);
187       } else {
188         args.push_back(atomicCmpData);
189       }
190     }
191 
192     // Construct buffer descriptor from memref, attributes
193     int64_t offset = 0;
194     SmallVector<int64_t, 5> strides;
195     if (failed(memrefType.getStridesAndOffset(strides, offset)))
196       return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
197 
198     MemRefDescriptor memrefDescriptor(memref);
199 
200     Value ptr = memrefDescriptor.bufferPtr(
201         rewriter, loc, *this->getTypeConverter(), memrefType);
202     // The stride value is always 0 for raw buffers. This also disables
203     // swizling.
204     Value stride = rewriter.create<LLVM::ConstantOp>(
205         loc, i16, rewriter.getI16IntegerAttr(0));
206     // Get the number of elements.
207     Value numRecords;
208     if (memrefType.hasStaticShape() &&
209         !llvm::any_of(strides, ShapedType::isDynamic)) {
210       int64_t size = memrefType.getRank() == 0 ? 1 : 0;
211       ArrayRef<int64_t> shape = memrefType.getShape();
212       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
213         size = std::max(shape[i] * strides[i], size);
214       size = size * elementByteWidth;
215       assert(size < std::numeric_limits<uint32_t>::max() &&
216              "the memref buffer is too large");
217       numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
218     } else {
219       Value maxIndex;
220       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
221         Value size = memrefDescriptor.size(rewriter, loc, i);
222         Value stride = memrefDescriptor.stride(rewriter, loc, i);
223         Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
224         maxIndex =
225             maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
226                      : maxThisDim;
227       }
228       numRecords = rewriter.create<LLVM::MulOp>(
229           loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
230     }
231 
232     // Flag word:
233     // bits 0-11: dst sel, ignored by these intrinsics
234     // bits 12-14: data format (ignored, must be nonzero, 7=float)
235     // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
236     // bit 19: In nested heap (0 here)
237     // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
238     // bits 21-22: Index stride for swizzles (N/A)
239     // bit 23: Add thread ID (0)
240     // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
241     // bits 25-26: Reserved (0)
242     // bit 27: Buffer is non-volatile (CDNA only)
243     // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
244     //  none, 3 = either swizzles or testing against offset field) RDNA only
245     // bits 30-31: Type (must be 0)
246     uint32_t flags = (7 << 12) | (4 << 15);
247     if (chipset.majorVersion >= 10) {
248       flags |= (1 << 24);
249       uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
250       flags |= (oob << 28);
251     }
252     Value flagsConst = createI32Constant(rewriter, loc, flags);
253     Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
254     Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
255         loc, rsrcType, ptr, stride, numRecords, flagsConst);
256     args.push_back(resource);
257 
258     // Indexing (voffset)
259     Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
260                                       adaptor.getIndices(), strides);
261     if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
262         indexOffset && *indexOffset > 0) {
263       Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
264       voffset =
265           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
266                   : extraOffsetConst;
267     }
268     voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
269     args.push_back(voffset);
270 
271     // SGPR offset.
272     Value sgprOffset = adaptor.getSgprOffset();
273     if (!sgprOffset)
274       sgprOffset = createI32Constant(rewriter, loc, 0);
275     sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
276     args.push_back(sgprOffset);
277 
278     // bit 0: GLC = 0 (atomics drop value, less coherency)
279     // bits 1-2: SLC, DLC = 0 (similarly)
280     // bit 3: swizzled (0 for raw)
281     args.push_back(createI32Constant(rewriter, loc, 0));
282 
283     llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
284                                            llvmBufferValType);
285     Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
286                                                     ArrayRef<NamedAttribute>());
287     if (lowered->getNumResults() == 1) {
288       Value replacement = lowered->getResult(0);
289       if (llvmBufferValType != llvmWantedDataType) {
290         replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
291                                                        replacement);
292       }
293       rewriter.replaceOp(gpuOp, replacement);
294     } else {
295       rewriter.eraseOp(gpuOp);
296     }
297     return success();
298   }
299 };
300 
301 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
302   LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
303       : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
304 
305   Chipset chipset;
306 
307   LogicalResult
308   matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
309                   ConversionPatternRewriter &rewriter) const override {
310     bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
311 
312     if (requiresInlineAsm) {
313       auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
314                                                       LLVM::AsmDialect::AD_ATT);
315       const char *asmStr =
316           ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
317       const char *constraints = "";
318       rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
319           op,
320           /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
321           /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
322           /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
323           /*operand_attrs=*/ArrayAttr());
324       return success();
325     }
326     if (chipset.majorVersion < 12) {
327       constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
328       constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
329       // Left in place in case someone disables the inline ASM path or future
330       // chipsets use the same bit pattern.
331       constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
332 
333       int32_t ldsOnlyBits;
334       if (chipset.majorVersion == 11)
335         ldsOnlyBits = ldsOnlyBitsGfx11;
336       else if (chipset.majorVersion == 10)
337         ldsOnlyBits = ldsOnlyBitsGfx10;
338       else if (chipset.majorVersion <= 9)
339         ldsOnlyBits = ldsOnlyBitsGfx6789;
340       else
341         return op.emitOpError(
342                    "don't know how to lower this for chipset major version")
343                << chipset.majorVersion;
344 
345       Location loc = op->getLoc();
346       rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
347       rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
348     } else {
349       Location loc = op->getLoc();
350       rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
351       rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
352       rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
353     }
354 
355     return success();
356   }
357 };
358 
359 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
360   SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
361       : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
362 
363   Chipset chipset;
364 
365   LogicalResult
366   matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
367                   ConversionPatternRewriter &rewriter) const override {
368     rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
369                                                      (uint32_t)op.getOpts());
370     return success();
371   }
372 };
373 
374 } // namespace
375 
376 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
377 /// and LLVM AMDGPU intrinsics convention.
378 ///
379 /// Specifically:
380 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
381 /// 2. If the element type is bfloat16, bitcast it to i16.
382 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
383                                       Location loc, Value input) {
384   Type inputType = input.getType();
385   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
386     if (vectorType.getElementType().isBF16())
387       return rewriter.create<LLVM::BitcastOp>(
388           loc, vectorType.clone(rewriter.getI16Type()), input);
389     if (vectorType.getElementType().isInteger(8)) {
390       return rewriter.create<LLVM::BitcastOp>(
391           loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
392     }
393   }
394   return input;
395 }
396 
397 /// Push an input operand. If it is a float type, nothing to do. If it is
398 /// an integer type, then we need to also push its signdness (1 for signed, 0
399 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
400 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
401 /// of bfloat support in the WMMA intrinsics themselves.
402 static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
403                                  Location loc,
404                                  const TypeConverter *typeConverter,
405                                  bool isUnsigned, Value llvmInput,
406                                  Value mlirInput,
407                                  SmallVector<Value, 4> &operands) {
408   Type inputType = llvmInput.getType();
409   auto vectorType = dyn_cast<VectorType>(inputType);
410   Type elemType = vectorType.getElementType();
411 
412   if (elemType.isBF16())
413     llvmInput = rewriter.create<LLVM::BitcastOp>(
414         loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
415   if (!elemType.isInteger(8)) {
416     operands.push_back(llvmInput);
417     return;
418   }
419 
420   // We need to check the type of the input before conversion to properly test
421   // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
422   // fp8/int8 information is lost during the conversion process.
423   auto mlirInputType = cast<VectorType>(mlirInput.getType());
424   bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
425   if (isInputInt8) {
426     // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
427     bool localIsUnsigned = isUnsigned;
428     if (elemType.isUnsignedInteger(8)) {
429       localIsUnsigned = true;
430     } else if (elemType.isSignedInteger(8)) {
431       localIsUnsigned = false;
432     }
433     Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
434     operands.push_back(sign);
435   }
436 
437   int64_t numBytes = vectorType.getNumElements();
438   Type i32 = rewriter.getI32Type();
439   VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
440   auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
441   Value result = rewriter.createOrFold<LLVM::BitcastOp>(
442       loc, llvmVectorType32bits, llvmInput);
443   operands.push_back(result);
444 }
445 
446 /// Push the output operand. For many cases this is only pushing the output in
447 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
448 /// since the same numbers of VGPRs is used, we need to decide if to store the
449 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
450 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
451 /// be stored it in the upper part
452 static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
453                                   Location loc,
454                                   const TypeConverter *typeConverter,
455                                   Value output, int32_t subwordOffset,
456                                   bool clamp, SmallVector<Value, 4> &operands) {
457   Type inputType = output.getType();
458   auto vectorType = dyn_cast<VectorType>(inputType);
459   Type elemType = vectorType.getElementType();
460   if (elemType.isBF16())
461     output = rewriter.create<LLVM::BitcastOp>(
462         loc, vectorType.clone(rewriter.getI16Type()), output);
463   operands.push_back(output);
464   if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
465     operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
466   } else if (elemType.isInteger(32)) {
467     operands.push_back(createI1Constant(rewriter, loc, clamp));
468   }
469 }
470 
471 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
472 /// if one exists. This includes checking to ensure the intrinsic is supported
473 /// on the architecture you are compiling for.
474 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
475                                                   Chipset chipset) {
476   uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
477            b = mfma.getBlocks();
478   Type sourceElem = mfma.getSourceA().getType();
479   if (auto sourceType = dyn_cast<VectorType>(sourceElem))
480     sourceElem = sourceType.getElementType();
481   Type destElem = mfma.getDestC().getType();
482   if (auto destType = dyn_cast<VectorType>(destElem))
483     destElem = destType.getElementType();
484 
485   if (sourceElem.isF32() && destElem.isF32()) {
486     if (mfma.getReducePrecision() && chipset >= kGfx940) {
487       if (m == 32 && n == 32 && k == 4 && b == 1)
488         return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
489       if (m == 16 && n == 16 && k == 8 && b == 1)
490         return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
491     }
492     if (m == 32 && n == 32 && k == 1 && b == 2)
493       return ROCDL::mfma_f32_32x32x1f32::getOperationName();
494     if (m == 16 && n == 16 && k == 1 && b == 4)
495       return ROCDL::mfma_f32_16x16x1f32::getOperationName();
496     if (m == 4 && n == 4 && k == 1 && b == 16)
497       return ROCDL::mfma_f32_4x4x1f32::getOperationName();
498     if (m == 32 && n == 32 && k == 2 && b == 1)
499       return ROCDL::mfma_f32_32x32x2f32::getOperationName();
500     if (m == 16 && n == 16 && k == 4 && b == 1)
501       return ROCDL::mfma_f32_16x16x4f32::getOperationName();
502   }
503 
504   if (sourceElem.isF16() && destElem.isF32()) {
505     if (m == 32 && n == 32 && k == 4 && b == 2)
506       return ROCDL::mfma_f32_32x32x4f16::getOperationName();
507     if (m == 16 && n == 16 && k == 4 && b == 4)
508       return ROCDL::mfma_f32_16x16x4f16::getOperationName();
509     if (m == 4 && n == 4 && k == 4 && b == 16)
510       return ROCDL::mfma_f32_4x4x4f16::getOperationName();
511     if (m == 32 && n == 32 && k == 8 && b == 1)
512       return ROCDL::mfma_f32_32x32x8f16::getOperationName();
513     if (m == 16 && n == 16 && k == 16 && b == 1)
514       return ROCDL::mfma_f32_16x16x16f16::getOperationName();
515   }
516 
517   if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
518     if (m == 32 && n == 32 && k == 4 && b == 2)
519       return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
520     if (m == 16 && n == 16 && k == 4 && b == 4)
521       return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
522     if (m == 4 && n == 4 && k == 4 && b == 16)
523       return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
524     if (m == 32 && n == 32 && k == 8 && b == 1)
525       return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
526     if (m == 16 && n == 16 && k == 16 && b == 1)
527       return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
528   }
529 
530   if (sourceElem.isBF16() && destElem.isF32()) {
531     if (m == 32 && n == 32 && k == 2 && b == 2)
532       return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
533     if (m == 16 && n == 16 && k == 2 && b == 4)
534       return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
535     if (m == 4 && n == 4 && k == 2 && b == 16)
536       return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
537     if (m == 32 && n == 32 && k == 4 && b == 1)
538       return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
539     if (m == 16 && n == 16 && k == 8 && b == 1)
540       return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
541   }
542 
543   if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
544     if (m == 32 && n == 32 && k == 4 && b == 2)
545       return ROCDL::mfma_i32_32x32x4i8::getOperationName();
546     if (m == 16 && n == 16 && k == 4 && b == 4)
547       return ROCDL::mfma_i32_16x16x4i8::getOperationName();
548     if (m == 4 && n == 4 && k == 4 && b == 16)
549       return ROCDL::mfma_i32_4x4x4i8::getOperationName();
550     if (m == 32 && n == 32 && k == 8 && b == 1)
551       return ROCDL::mfma_i32_32x32x8i8::getOperationName();
552     if (m == 16 && n == 16 && k == 16 && b == 1)
553       return ROCDL::mfma_i32_16x16x16i8::getOperationName();
554     if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
555       return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
556     if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
557       return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
558   }
559 
560   if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
561     if (m == 16 && n == 16 && k == 4 && b == 1)
562       return ROCDL::mfma_f64_16x16x4f64::getOperationName();
563     if (m == 4 && n == 4 && k == 4 && b == 4)
564       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
565   }
566 
567   if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
568       chipset >= kGfx940) {
569     // Known to be correct because there are no scalar f8 instructions and
570     // because a length mismatch will have been caught by the verifier.
571     Type sourceBElem =
572         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
573     if (m == 16 && n == 16 && k == 32 && b == 1) {
574       if (isa<Float8E5M2FNUZType>(sourceBElem))
575         return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
576       if (isa<Float8E4M3FNUZType>(sourceBElem))
577         return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
578     }
579     if (m == 32 && n == 32 && k == 16 && b == 1) {
580       if (isa<Float8E5M2FNUZType>(sourceBElem))
581         return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
582       if (isa<Float8E4M3FNUZType>(sourceBElem))
583         return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
584     }
585   }
586 
587   if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
588       chipset >= kGfx940) {
589     Type sourceBElem =
590         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
591     if (m == 16 && n == 16 && k == 32 && b == 1) {
592       if (isa<Float8E5M2FNUZType>(sourceBElem))
593         return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
594       if (isa<Float8E4M3FNUZType>(sourceBElem))
595         return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
596     }
597     if (m == 32 && n == 32 && k == 16 && b == 1) {
598       if (isa<Float8E5M2FNUZType>(sourceBElem))
599         return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
600       if (isa<Float8E4M3FNUZType>(sourceBElem))
601         return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
602     }
603   }
604 
605   return std::nullopt;
606 }
607 
608 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
609 /// if one exists. This includes checking to ensure the intrinsic is supported
610 /// on the architecture you are compiling for.
611 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
612                                                   Chipset chipset) {
613   auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
614   auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
615   auto elemSourceType = sourceVectorType.getElementType();
616   auto elemDestType = destVectorType.getElementType();
617 
618   if (elemSourceType.isF16() && elemDestType.isF32())
619     return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
620   if (elemSourceType.isBF16() && elemDestType.isF32())
621     return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
622   if (elemSourceType.isF16() && elemDestType.isF16())
623     return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
624   if (elemSourceType.isBF16() && elemDestType.isBF16())
625     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
626   if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
627     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
628   if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
629     return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
630   if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
631     return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
632   return std::nullopt;
633 }
634 
635 namespace {
636 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
637   MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
638       : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
639 
640   Chipset chipset;
641 
642   LogicalResult
643   matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
644                   ConversionPatternRewriter &rewriter) const override {
645     Location loc = op.getLoc();
646     Type outType = typeConverter->convertType(op.getDestD().getType());
647     Type intrinsicOutType = outType;
648     if (auto outVecType = dyn_cast<VectorType>(outType))
649       if (outVecType.getElementType().isBF16())
650         intrinsicOutType = outVecType.clone(rewriter.getI16Type());
651 
652     if (chipset.majorVersion != 9 || chipset < kGfx908)
653       return op->emitOpError("MFMA only supported on gfx908+");
654     uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
655     if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
656       if (chipset < kGfx940)
657         return op.emitOpError("negation unsupported on older than gfx940");
658       getBlgpField |=
659           op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
660     }
661     std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
662     if (!maybeIntrinsic.has_value())
663       return op.emitOpError("no intrinsic matching MFMA size on given chipset");
664     OperationState loweredOp(loc, *maybeIntrinsic);
665     loweredOp.addTypes(intrinsicOutType);
666     loweredOp.addOperands(
667         {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
668          convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
669          adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
670          createI32Constant(rewriter, loc, op.getAbid()),
671          createI32Constant(rewriter, loc, getBlgpField)});
672     Value lowered = rewriter.create(loweredOp)->getResult(0);
673     if (outType != intrinsicOutType)
674       lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
675     rewriter.replaceOp(op, lowered);
676     return success();
677   }
678 };
679 
680 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
681   WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
682       : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
683 
684   Chipset chipset;
685 
686   LogicalResult
687   matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
688                   ConversionPatternRewriter &rewriter) const override {
689     Location loc = op.getLoc();
690     auto outType =
691         typeConverter->convertType<VectorType>(op.getDestD().getType());
692     if (!outType)
693       return rewriter.notifyMatchFailure(op, "type conversion failed");
694 
695     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
696       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
697 
698     // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
699     // need to bitcast bfloats to i16 and then bitcast them back.
700     VectorType rawOutType = outType;
701     if (outType.getElementType().isBF16())
702       rawOutType = outType.clone(rewriter.getI16Type());
703 
704     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
705 
706     if (!maybeIntrinsic.has_value())
707       return op.emitOpError("no intrinsic matching WMMA on the given chipset");
708 
709     OperationState loweredOp(loc, *maybeIntrinsic);
710     loweredOp.addTypes(rawOutType);
711 
712     SmallVector<Value, 4> operands;
713     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
714                          adaptor.getSourceA(), op.getSourceA(), operands);
715     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
716                          adaptor.getSourceB(), op.getSourceB(), operands);
717     wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
718                           op.getSubwordOffset(), op.getClamp(), operands);
719 
720     loweredOp.addOperands(operands);
721     Operation *lowered = rewriter.create(loweredOp);
722 
723     Operation *maybeCastBack = lowered;
724     if (rawOutType != outType)
725       maybeCastBack =
726           rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
727     rewriter.replaceOp(op, maybeCastBack->getResults());
728 
729     return success();
730   }
731 };
732 
733 namespace {
734 struct ExtPackedFp8OpLowering final
735     : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
736   ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
737       : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
738         chipset(chipset) {}
739   Chipset chipset;
740 
741   LogicalResult
742   matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
743                   ConversionPatternRewriter &rewriter) const override;
744 };
745 
746 struct PackedTrunc2xFp8OpLowering final
747     : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
748   PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
749                              Chipset chipset)
750       : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
751         chipset(chipset) {}
752   Chipset chipset;
753 
754   LogicalResult
755   matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
756                   ConversionPatternRewriter &rewriter) const override;
757 };
758 
759 struct PackedStochRoundFp8OpLowering final
760     : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
761   PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
762                                 Chipset chipset)
763       : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
764         chipset(chipset) {}
765   Chipset chipset;
766 
767   LogicalResult
768   matchAndRewrite(PackedStochRoundFp8Op op,
769                   PackedStochRoundFp8OpAdaptor adaptor,
770                   ConversionPatternRewriter &rewriter) const override;
771 };
772 } // end namespace
773 
774 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
775     ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
776     ConversionPatternRewriter &rewriter) const {
777   Location loc = op.getLoc();
778   if (chipset.majorVersion != 9 || chipset < kGfx940)
779     return rewriter.notifyMatchFailure(
780         loc, "Fp8 conversion instructions are not available on target "
781              "architecture and their emulation is not implemented");
782   Type v4i8 =
783       getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
784   Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
785   Type f32 = getTypeConverter()->convertType(op.getResult().getType());
786 
787   Value source = adaptor.getSource();
788   auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
789   Type sourceElemType = getElementTypeOrSelf(op.getSource());
790   // Extend to a v4i8
791   if (!sourceVecType || sourceVecType.getNumElements() < 4) {
792     Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
793     if (!sourceVecType) {
794       longVec = rewriter.create<LLVM::InsertElementOp>(
795           loc, longVec, source, createI32Constant(rewriter, loc, 0));
796     } else {
797       for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
798         Value idx = createI32Constant(rewriter, loc, i);
799         Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
800         longVec =
801             rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
802       }
803     }
804     source = longVec;
805   }
806   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
807   Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
808   if (isa<Float8E5M2FNUZType>(sourceElemType)) {
809     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
810                                                     wordSel);
811   } else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
812     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
813                                                     wordSel);
814   }
815   return success();
816 }
817 
818 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
819     PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
820     ConversionPatternRewriter &rewriter) const {
821   Location loc = op.getLoc();
822   if (chipset.majorVersion != 9 || chipset < kGfx940)
823     return rewriter.notifyMatchFailure(
824         loc, "Fp8 conversion instructions are not available on target "
825              "architecture and their emulation is not implemented");
826   Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
827 
828   Type resultType = op.getResult().getType();
829   Type resultElemType = getElementTypeOrSelf(resultType);
830 
831   Value sourceA = adaptor.getSourceA();
832   Value sourceB = adaptor.getSourceB();
833   if (!sourceB)
834     sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
835   Value existing = adaptor.getExisting();
836   if (existing)
837     existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
838   else
839     existing = rewriter.create<LLVM::UndefOp>(loc, i32);
840   Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
841 
842   Value result;
843   if (isa<Float8E5M2FNUZType>(resultElemType))
844     result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
845                                                    existing, wordSel);
846   else if (isa<Float8E4M3FNUZType>(resultElemType))
847     result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
848                                                    existing, wordSel);
849 
850   result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
851       op, getTypeConverter()->convertType(resultType), result);
852   return success();
853 }
854 
855 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
856     PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
857     ConversionPatternRewriter &rewriter) const {
858   Location loc = op.getLoc();
859   if (chipset.majorVersion != 9 || chipset < kGfx940)
860     return rewriter.notifyMatchFailure(
861         loc, "Fp8 conversion instructions are not available on target "
862              "architecture and their emulation is not implemented");
863   Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
864 
865   Type resultType = op.getResult().getType();
866   Type resultElemType = getElementTypeOrSelf(resultType);
867 
868   Value source = adaptor.getSource();
869   Value stoch = adaptor.getStochiasticParam();
870   Value existing = adaptor.getExisting();
871   if (existing)
872     existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
873   else
874     existing = rewriter.create<LLVM::UndefOp>(loc, i32);
875   Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
876 
877   Value result;
878   if (isa<Float8E5M2FNUZType>(resultElemType))
879     result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
880                                                    existing, byteSel);
881   else if (isa<Float8E4M3FNUZType>(resultElemType))
882     result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
883                                                    existing, byteSel);
884 
885   result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
886       op, getTypeConverter()->convertType(resultType), result);
887   return success();
888 }
889 
890 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
891 // operation into the corresponding ROCDL instructions.
892 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
893   AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
894       : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
895   Chipset chipset;
896 
897   LogicalResult
898   matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
899                   ConversionPatternRewriter &rewriter) const override {
900 
901     // Convert the source operand to the corresponding LLVM type
902     Location loc = DppOp.getLoc();
903     Value src = adaptor.getSrc();
904     Value old = adaptor.getOld();
905     Type srcType = src.getType();
906     Type oldType = old.getType();
907     Type llvmType = nullptr;
908     if (srcType.getIntOrFloatBitWidth() < 32) {
909       llvmType = rewriter.getI32Type();
910     } else if (isa<FloatType>(srcType)) {
911       llvmType = (srcType.getIntOrFloatBitWidth() == 32)
912                      ? rewriter.getF32Type()
913                      : rewriter.getF64Type();
914     } else if (isa<IntegerType>(srcType)) {
915       llvmType = (srcType.getIntOrFloatBitWidth() == 32)
916                      ? rewriter.getI32Type()
917                      : rewriter.getI64Type();
918     }
919     auto llvmSrcIntType = typeConverter->convertType(
920         rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
921 
922     // If the source type is less of 32, use bitcast to convert it to i32.
923     auto convertOperand = [&](Value operand, Type operandType) {
924       if (operandType.getIntOrFloatBitWidth() <= 16) {
925         if (llvm::isa<FloatType>(operandType)) {
926           operand =
927               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
928         }
929         auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
930             32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
931         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
932         operand = rewriter.create<LLVM::InsertElementOp>(
933             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
934         operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
935       }
936       return operand;
937     };
938 
939     src = convertOperand(src, srcType);
940     old = convertOperand(old, oldType);
941 
942     // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
943     enum DppCtrl : unsigned {
944       ROW_SHL0 = 0x100,
945       ROW_SHR0 = 0x110,
946       ROW_ROR0 = 0x120,
947       WAVE_SHL1 = 0x130,
948       WAVE_ROL1 = 0x134,
949       WAVE_SHR1 = 0x138,
950       WAVE_ROR1 = 0x13C,
951       ROW_MIRROR = 0x140,
952       ROW_HALF_MIRROR = 0x141,
953       BCAST15 = 0x142,
954       BCAST31 = 0x143,
955     };
956 
957     auto kind = DppOp.getKind();
958     auto permArgument = DppOp.getPermArgument();
959     uint32_t DppCtrl = 0;
960 
961     switch (kind) {
962 
963     case DPPPerm::quad_perm:
964       if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
965         int32_t i = 0;
966         for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
967           uint32_t num = elem.getInt();
968           DppCtrl |= num << (i * 2);
969           i++;
970         }
971       }
972       break;
973     case DPPPerm::row_shl:
974       if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
975         DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
976       }
977       break;
978     case DPPPerm::row_shr:
979       if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
980         DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
981       }
982       break;
983     case DPPPerm::row_ror:
984       if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
985         DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
986       }
987       break;
988     case DPPPerm::wave_shl:
989       DppCtrl = DppCtrl::WAVE_SHL1;
990       break;
991     case DPPPerm::wave_shr:
992       DppCtrl = DppCtrl::WAVE_SHR1;
993       break;
994     case DPPPerm::wave_rol:
995       DppCtrl = DppCtrl::WAVE_ROL1;
996       break;
997     case DPPPerm::wave_ror:
998       DppCtrl = DppCtrl::WAVE_ROR1;
999       break;
1000     case DPPPerm::row_mirror:
1001       DppCtrl = DppCtrl::ROW_MIRROR;
1002       break;
1003     case DPPPerm::row_half_mirror:
1004       DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1005       break;
1006     case DPPPerm::row_bcast_15:
1007       DppCtrl = DppCtrl::BCAST15;
1008       break;
1009     case DPPPerm::row_bcast_31:
1010       DppCtrl = DppCtrl::BCAST31;
1011       break;
1012     }
1013 
1014     // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1015     // constants
1016     auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1017     auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1018     bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1019 
1020     // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1021     auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1022         loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1023 
1024     Value result = dppMovOp.getRes();
1025     if (srcType.getIntOrFloatBitWidth() < 32) {
1026       result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1027       if (!llvm::isa<IntegerType>(srcType)) {
1028         result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1029       }
1030     }
1031 
1032     // We are replacing the AMDGPU_DPPOp instruction with the new
1033     // ROCDL_DPPMovOp instruction
1034     rewriter.replaceOp(DppOp, ValueRange(result));
1035     return success();
1036   }
1037 };
1038 
1039 struct ConvertAMDGPUToROCDLPass
1040     : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1041   ConvertAMDGPUToROCDLPass() = default;
1042 
1043   void runOnOperation() override {
1044     MLIRContext *ctx = &getContext();
1045     FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1046     if (failed(maybeChipset)) {
1047       emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1048       return signalPassFailure();
1049     }
1050 
1051     RewritePatternSet patterns(ctx);
1052     LLVMTypeConverter converter(ctx);
1053     populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1054     LLVMConversionTarget target(getContext());
1055     target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1056     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1057     target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1058     if (failed(applyPartialConversion(getOperation(), target,
1059                                       std::move(patterns))))
1060       signalPassFailure();
1061   }
1062 };
1063 } // namespace
1064 
1065 void mlir::populateAMDGPUToROCDLConversionPatterns(
1066     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1067     Chipset chipset) {
1068   patterns
1069       .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1070            RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1071            RawBufferOpLowering<RawBufferAtomicFaddOp,
1072                                ROCDL::RawPtrBufferAtomicFaddOp>,
1073            RawBufferOpLowering<RawBufferAtomicFmaxOp,
1074                                ROCDL::RawPtrBufferAtomicFmaxOp>,
1075            RawBufferOpLowering<RawBufferAtomicSmaxOp,
1076                                ROCDL::RawPtrBufferAtomicSmaxOp>,
1077            RawBufferOpLowering<RawBufferAtomicUminOp,
1078                                ROCDL::RawPtrBufferAtomicUminOp>,
1079            RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1080                                ROCDL::RawPtrBufferAtomicCmpSwap>,
1081            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1082            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1083            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1084                                                                       chipset);
1085 }
1086 
1087 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
1088   return std::make_unique<ConvertAMDGPUToROCDLPass>();
1089 }
1090