xref: /llvm-project/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (revision 7a77f14c0abfbecbfb800ea8d974e66d81ee516a)
12ebd633fSKrzysztof Drewniak //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
22ebd633fSKrzysztof Drewniak //
32ebd633fSKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42ebd633fSKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information.
52ebd633fSKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62ebd633fSKrzysztof Drewniak //
72ebd633fSKrzysztof Drewniak //===----------------------------------------------------------------------===//
82ebd633fSKrzysztof Drewniak 
92ebd633fSKrzysztof Drewniak #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
102ebd633fSKrzysztof Drewniak 
112ebd633fSKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
121387ba48SGiuseppe Rossini #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
132ebd633fSKrzysztof Drewniak #include "mlir/Dialect/Arith/IR/Arith.h"
14750e90e4SKrzysztof Drewniak #include "mlir/Dialect/Arith/Utils/Utils.h"
151387ba48SGiuseppe Rossini #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
161387ba48SGiuseppe Rossini #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
172ebd633fSKrzysztof Drewniak #include "mlir/Dialect/Vector/IR/VectorOps.h"
182ebd633fSKrzysztof Drewniak #include "mlir/IR/BuiltinTypes.h"
192ebd633fSKrzysztof Drewniak #include "mlir/IR/PatternMatch.h"
202ebd633fSKrzysztof Drewniak #include "mlir/IR/TypeUtilities.h"
212ebd633fSKrzysztof Drewniak #include "mlir/Pass/Pass.h"
222ebd633fSKrzysztof Drewniak #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
232ebd633fSKrzysztof Drewniak 
242ebd633fSKrzysztof Drewniak namespace mlir {
252ebd633fSKrzysztof Drewniak #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
262ebd633fSKrzysztof Drewniak #include "mlir/Conversion/Passes.h.inc"
272ebd633fSKrzysztof Drewniak } // namespace mlir
282ebd633fSKrzysztof Drewniak 
292ebd633fSKrzysztof Drewniak using namespace mlir;
301387ba48SGiuseppe Rossini using namespace mlir::amdgpu;
312ebd633fSKrzysztof Drewniak 
322ebd633fSKrzysztof Drewniak namespace {
332ebd633fSKrzysztof Drewniak struct ArithToAMDGPUConversionPass final
342ebd633fSKrzysztof Drewniak     : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
352ebd633fSKrzysztof Drewniak   using impl::ArithToAMDGPUConversionPassBase<
362ebd633fSKrzysztof Drewniak       ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
372ebd633fSKrzysztof Drewniak 
382ebd633fSKrzysztof Drewniak   void runOnOperation() override;
392ebd633fSKrzysztof Drewniak };
402ebd633fSKrzysztof Drewniak 
41750e90e4SKrzysztof Drewniak struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
42750e90e4SKrzysztof Drewniak   using OpRewritePattern::OpRewritePattern;
432ebd633fSKrzysztof Drewniak 
442ebd633fSKrzysztof Drewniak   LogicalResult match(arith::ExtFOp op) const override;
452ebd633fSKrzysztof Drewniak   void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
462ebd633fSKrzysztof Drewniak };
472ebd633fSKrzysztof Drewniak 
48750e90e4SKrzysztof Drewniak struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
49750e90e4SKrzysztof Drewniak   bool saturateFP8 = false;
501387ba48SGiuseppe Rossini   TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
511387ba48SGiuseppe Rossini                                Chipset chipset)
521387ba48SGiuseppe Rossini       : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
531387ba48SGiuseppe Rossini         chipset(chipset) {}
541387ba48SGiuseppe Rossini   Chipset chipset;
552ebd633fSKrzysztof Drewniak 
562ebd633fSKrzysztof Drewniak   LogicalResult match(arith::TruncFOp op) const override;
572ebd633fSKrzysztof Drewniak   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
582ebd633fSKrzysztof Drewniak };
591387ba48SGiuseppe Rossini 
601387ba48SGiuseppe Rossini struct TruncfToFloat16RewritePattern final
611387ba48SGiuseppe Rossini     : public OpRewritePattern<arith::TruncFOp> {
621387ba48SGiuseppe Rossini 
631387ba48SGiuseppe Rossini   using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
641387ba48SGiuseppe Rossini 
651387ba48SGiuseppe Rossini   LogicalResult match(arith::TruncFOp op) const override;
661387ba48SGiuseppe Rossini   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
671387ba48SGiuseppe Rossini };
681387ba48SGiuseppe Rossini 
692ebd633fSKrzysztof Drewniak } // end namespace
702ebd633fSKrzysztof Drewniak 
712ebd633fSKrzysztof Drewniak static Value castF32To(Type elementType, Value f32, Location loc,
722ebd633fSKrzysztof Drewniak                        PatternRewriter &rewriter) {
732ebd633fSKrzysztof Drewniak   if (elementType.isF32())
742ebd633fSKrzysztof Drewniak     return f32;
752ebd633fSKrzysztof Drewniak   if (elementType.getIntOrFloatBitWidth() < 32)
762ebd633fSKrzysztof Drewniak     return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
772ebd633fSKrzysztof Drewniak   if (elementType.getIntOrFloatBitWidth() > 32)
782ebd633fSKrzysztof Drewniak     return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
792ebd633fSKrzysztof Drewniak   llvm_unreachable("The only 32-bit float type is f32");
802ebd633fSKrzysztof Drewniak }
812ebd633fSKrzysztof Drewniak 
82750e90e4SKrzysztof Drewniak LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
832ebd633fSKrzysztof Drewniak   Type inType = op.getIn().getType();
84a5757c5bSChristian Sigg   if (auto inVecType = dyn_cast<VectorType>(inType)) {
852ebd633fSKrzysztof Drewniak     if (inVecType.isScalable())
862ebd633fSKrzysztof Drewniak       return failure();
872ebd633fSKrzysztof Drewniak     inType = inVecType.getElementType();
882ebd633fSKrzysztof Drewniak   }
89*7a77f14cSMatthias Springer   return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
902ebd633fSKrzysztof Drewniak }
912ebd633fSKrzysztof Drewniak 
92750e90e4SKrzysztof Drewniak void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
932ebd633fSKrzysztof Drewniak                                          PatternRewriter &rewriter) const {
942ebd633fSKrzysztof Drewniak   Location loc = op.getLoc();
952ebd633fSKrzysztof Drewniak   Value in = op.getIn();
962ebd633fSKrzysztof Drewniak   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
97f35318e8SRob Suderman   auto inType = dyn_cast<VectorType>(in.getType());
98f35318e8SRob Suderman   if (!inType) {
992ebd633fSKrzysztof Drewniak     Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
1002ebd633fSKrzysztof Drewniak         loc, rewriter.getF32Type(), in, 0);
1012ebd633fSKrzysztof Drewniak     Value result = castF32To(outElemType, asFloat, loc, rewriter);
1022ebd633fSKrzysztof Drewniak     return rewriter.replaceOp(op, result);
1032ebd633fSKrzysztof Drewniak   }
1042ebd633fSKrzysztof Drewniak   int64_t numElements = inType.getNumElements();
10565066c02SHugo Trachino   Value zero = rewriter.create<arith::ConstantOp>(
1062ebd633fSKrzysztof Drewniak       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
1072ebd633fSKrzysztof Drewniak   if (inType.getShape().empty()) {
108750e90e4SKrzysztof Drewniak     Value scalarIn =
109750e90e4SKrzysztof Drewniak         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
1102ebd633fSKrzysztof Drewniak     // Recurse to send the 0-D vector case to the 1-D vector case
1112ebd633fSKrzysztof Drewniak     Value scalarExt =
1122ebd633fSKrzysztof Drewniak         rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
113f35318e8SRob Suderman     Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
114750e90e4SKrzysztof Drewniak                                                      ArrayRef<int64_t>{});
1152ebd633fSKrzysztof Drewniak     return rewriter.replaceOp(op, result);
1162ebd633fSKrzysztof Drewniak   }
117f35318e8SRob Suderman 
118f35318e8SRob Suderman   VectorType outType = cast<VectorType>(op.getOut().getType());
119f35318e8SRob Suderman   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
120f35318e8SRob Suderman                                       outType.getElementType());
121f35318e8SRob Suderman   Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
122f35318e8SRob Suderman 
123f35318e8SRob Suderman   if (inType.getRank() > 1) {
124f35318e8SRob Suderman     inType = VectorType::get(SmallVector<int64_t>{numElements},
125f35318e8SRob Suderman                              inType.getElementType());
126f35318e8SRob Suderman     in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
127f35318e8SRob Suderman   }
128f35318e8SRob Suderman 
1292ebd633fSKrzysztof Drewniak   for (int64_t i = 0; i < numElements; i += 4) {
1302ebd633fSKrzysztof Drewniak     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
1312ebd633fSKrzysztof Drewniak     Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
1322ebd633fSKrzysztof Drewniak         loc, in, i, elemsThisOp, 1);
1332ebd633fSKrzysztof Drewniak     for (int64_t j = 0; j < elemsThisOp; ++j) {
1342ebd633fSKrzysztof Drewniak       Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
1352ebd633fSKrzysztof Drewniak           loc, rewriter.getF32Type(), inSlice, j);
1362ebd633fSKrzysztof Drewniak       Value asType = castF32To(outElemType, asFloat, loc, rewriter);
137750e90e4SKrzysztof Drewniak       result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
1382ebd633fSKrzysztof Drewniak     }
1392ebd633fSKrzysztof Drewniak   }
140f35318e8SRob Suderman 
141f35318e8SRob Suderman   if (inType.getRank() != outType.getRank()) {
142f35318e8SRob Suderman     result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
143f35318e8SRob Suderman   }
144f35318e8SRob Suderman 
1452ebd633fSKrzysztof Drewniak   rewriter.replaceOp(op, result);
1462ebd633fSKrzysztof Drewniak }
1472ebd633fSKrzysztof Drewniak 
1482ebd633fSKrzysztof Drewniak static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
1492ebd633fSKrzysztof Drewniak   Type type = value.getType();
1502ebd633fSKrzysztof Drewniak   if (type.isF32())
1512ebd633fSKrzysztof Drewniak     return value;
1522ebd633fSKrzysztof Drewniak   if (type.getIntOrFloatBitWidth() < 32)
1532ebd633fSKrzysztof Drewniak     return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
1542ebd633fSKrzysztof Drewniak   if (type.getIntOrFloatBitWidth() > 32)
1552ebd633fSKrzysztof Drewniak     return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
1562ebd633fSKrzysztof Drewniak   llvm_unreachable("The only 32-bit float type is f32");
1572ebd633fSKrzysztof Drewniak }
1582ebd633fSKrzysztof Drewniak 
159750e90e4SKrzysztof Drewniak // If `in` is a finite value, clamp it between the maximum and minimum values
160750e90e4SKrzysztof Drewniak // of `outElemType` so that subsequent conversion instructions don't
161750e90e4SKrzysztof Drewniak // overflow those out-of-range values to NaN. These semantics are commonly
162750e90e4SKrzysztof Drewniak // used in machine-learning contexts where failure to clamp would lead to
163750e90e4SKrzysztof Drewniak // excessive NaN production.
164750e90e4SKrzysztof Drewniak static Value clampInput(PatternRewriter &rewriter, Location loc,
165750e90e4SKrzysztof Drewniak                         Type outElemType, Value source) {
166750e90e4SKrzysztof Drewniak   Type sourceType = source.getType();
167750e90e4SKrzysztof Drewniak   const llvm::fltSemantics &sourceSem =
168750e90e4SKrzysztof Drewniak       cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
169750e90e4SKrzysztof Drewniak   const llvm::fltSemantics &targetSem =
170750e90e4SKrzysztof Drewniak       cast<FloatType>(outElemType).getFloatSemantics();
171750e90e4SKrzysztof Drewniak 
172750e90e4SKrzysztof Drewniak   APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
173750e90e4SKrzysztof Drewniak   APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
174750e90e4SKrzysztof Drewniak   bool ignoredLosesInfo = false;
175750e90e4SKrzysztof Drewniak   // We can ignore conversion failures here because this conversion promotes
176750e90e4SKrzysztof Drewniak   // from a smaller type to a larger one - ex. there can be no loss of precision
177750e90e4SKrzysztof Drewniak   // when casting fp8 to f16.
178750e90e4SKrzysztof Drewniak   (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
179750e90e4SKrzysztof Drewniak   (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
180750e90e4SKrzysztof Drewniak 
181750e90e4SKrzysztof Drewniak   Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
182750e90e4SKrzysztof Drewniak   Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
183750e90e4SKrzysztof Drewniak 
184750e90e4SKrzysztof Drewniak   Value inf = createScalarOrSplatConstant(
185750e90e4SKrzysztof Drewniak       rewriter, loc, sourceType,
186750e90e4SKrzysztof Drewniak       APFloat::getInf(sourceSem, /*Negative=*/false));
187750e90e4SKrzysztof Drewniak   Value negInf = createScalarOrSplatConstant(
188750e90e4SKrzysztof Drewniak       rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
189750e90e4SKrzysztof Drewniak   Value isInf = rewriter.createOrFold<arith::CmpFOp>(
190750e90e4SKrzysztof Drewniak       loc, arith::CmpFPredicate::OEQ, source, inf);
191750e90e4SKrzysztof Drewniak   Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
192750e90e4SKrzysztof Drewniak       loc, arith::CmpFPredicate::OEQ, source, negInf);
193750e90e4SKrzysztof Drewniak   Value isNan = rewriter.createOrFold<arith::CmpFOp>(
194750e90e4SKrzysztof Drewniak       loc, arith::CmpFPredicate::UNO, source, source);
195750e90e4SKrzysztof Drewniak   Value isNonFinite = rewriter.create<arith::OrIOp>(
196750e90e4SKrzysztof Drewniak       loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
197750e90e4SKrzysztof Drewniak 
198750e90e4SKrzysztof Drewniak   Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
199750e90e4SKrzysztof Drewniak   Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
200750e90e4SKrzysztof Drewniak   Value res =
201750e90e4SKrzysztof Drewniak       rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
202750e90e4SKrzysztof Drewniak   return res;
203750e90e4SKrzysztof Drewniak }
204750e90e4SKrzysztof Drewniak 
205750e90e4SKrzysztof Drewniak LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
2068827ff92SVictor Perez   // Only supporting default rounding mode as of now.
2078827ff92SVictor Perez   if (op.getRoundingmodeAttr())
2088827ff92SVictor Perez     return failure();
2092ebd633fSKrzysztof Drewniak   Type outType = op.getOut().getType();
210a5757c5bSChristian Sigg   if (auto outVecType = dyn_cast<VectorType>(outType)) {
2112ebd633fSKrzysztof Drewniak     if (outVecType.isScalable())
2122ebd633fSKrzysztof Drewniak       return failure();
2132ebd633fSKrzysztof Drewniak     outType = outVecType.getElementType();
2142ebd633fSKrzysztof Drewniak   }
215750e90e4SKrzysztof Drewniak   auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
216750e90e4SKrzysztof Drewniak   if (inType && inType.getWidth() <= 8 && saturateFP8)
217750e90e4SKrzysztof Drewniak     // Conversion between 8-bit floats is not supported with truncation enabled.
218750e90e4SKrzysztof Drewniak     return failure();
219*7a77f14cSMatthias Springer   return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
2202ebd633fSKrzysztof Drewniak }
2212ebd633fSKrzysztof Drewniak 
222750e90e4SKrzysztof Drewniak void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
2232ebd633fSKrzysztof Drewniak                                            PatternRewriter &rewriter) const {
2242ebd633fSKrzysztof Drewniak   Location loc = op.getLoc();
2252ebd633fSKrzysztof Drewniak   Value in = op.getIn();
2262ebd633fSKrzysztof Drewniak   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
227750e90e4SKrzysztof Drewniak   if (saturateFP8)
228750e90e4SKrzysztof Drewniak     in = clampInput(rewriter, loc, outElemType, in);
229f35318e8SRob Suderman   auto inVectorTy = dyn_cast<VectorType>(in.getType());
2302ebd633fSKrzysztof Drewniak   VectorType truncResType = VectorType::get(4, outElemType);
231f35318e8SRob Suderman   if (!inVectorTy) {
2322ebd633fSKrzysztof Drewniak     Value asFloat = castToF32(in, loc, rewriter);
2332ebd633fSKrzysztof Drewniak     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
2342ebd633fSKrzysztof Drewniak         loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
2352ebd633fSKrzysztof Drewniak         /*existing=*/nullptr);
236750e90e4SKrzysztof Drewniak     Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
2372ebd633fSKrzysztof Drewniak     return rewriter.replaceOp(op, result);
2382ebd633fSKrzysztof Drewniak   }
239a5757c5bSChristian Sigg   VectorType outType = cast<VectorType>(op.getOut().getType());
2402ebd633fSKrzysztof Drewniak   int64_t numElements = outType.getNumElements();
24165066c02SHugo Trachino   Value zero = rewriter.create<arith::ConstantOp>(
2422ebd633fSKrzysztof Drewniak       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
2432ebd633fSKrzysztof Drewniak   if (outType.getShape().empty()) {
244750e90e4SKrzysztof Drewniak     Value scalarIn =
245750e90e4SKrzysztof Drewniak         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
2462ebd633fSKrzysztof Drewniak     // Recurse to send the 0-D vector case to the 1-D vector case
2472ebd633fSKrzysztof Drewniak     Value scalarTrunc =
2482ebd633fSKrzysztof Drewniak         rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
249f35318e8SRob Suderman     Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
250750e90e4SKrzysztof Drewniak                                                      ArrayRef<int64_t>{});
2512ebd633fSKrzysztof Drewniak     return rewriter.replaceOp(op, result);
2522ebd633fSKrzysztof Drewniak   }
2532ebd633fSKrzysztof Drewniak 
254f35318e8SRob Suderman   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
255f35318e8SRob Suderman                                       outType.getElementType());
256f35318e8SRob Suderman   Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
257f35318e8SRob Suderman 
258f35318e8SRob Suderman   if (inVectorTy.getRank() > 1) {
259f35318e8SRob Suderman     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
260f35318e8SRob Suderman                                  inVectorTy.getElementType());
261f35318e8SRob Suderman     in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
262f35318e8SRob Suderman   }
263f35318e8SRob Suderman 
2642ebd633fSKrzysztof Drewniak   for (int64_t i = 0; i < numElements; i += 4) {
2652ebd633fSKrzysztof Drewniak     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
2662ebd633fSKrzysztof Drewniak     Value thisResult = nullptr;
2672ebd633fSKrzysztof Drewniak     for (int64_t j = 0; j < elemsThisOp; j += 2) {
268750e90e4SKrzysztof Drewniak       Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
2692ebd633fSKrzysztof Drewniak       Value asFloatA = castToF32(elemA, loc, rewriter);
2702ebd633fSKrzysztof Drewniak       Value asFloatB = nullptr;
2712ebd633fSKrzysztof Drewniak       if (j + 1 < elemsThisOp) {
272750e90e4SKrzysztof Drewniak         Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
2732ebd633fSKrzysztof Drewniak         asFloatB = castToF32(elemB, loc, rewriter);
2742ebd633fSKrzysztof Drewniak       }
2752ebd633fSKrzysztof Drewniak       thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
2762ebd633fSKrzysztof Drewniak           loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
2772ebd633fSKrzysztof Drewniak     }
2782ebd633fSKrzysztof Drewniak     if (elemsThisOp < 4)
2792ebd633fSKrzysztof Drewniak       thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
2802ebd633fSKrzysztof Drewniak           loc, thisResult, 0, elemsThisOp, 1);
2812ebd633fSKrzysztof Drewniak     result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
2822ebd633fSKrzysztof Drewniak                                                            result, i, 1);
2832ebd633fSKrzysztof Drewniak   }
284f35318e8SRob Suderman 
285f35318e8SRob Suderman   if (inVectorTy.getRank() != outType.getRank()) {
286f35318e8SRob Suderman     result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
287f35318e8SRob Suderman   }
288f35318e8SRob Suderman 
2892ebd633fSKrzysztof Drewniak   rewriter.replaceOp(op, result);
2902ebd633fSKrzysztof Drewniak }
2912ebd633fSKrzysztof Drewniak 
2921387ba48SGiuseppe Rossini LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
2931387ba48SGiuseppe Rossini   Type outType = op.getOut().getType();
2941387ba48SGiuseppe Rossini   Type inputType = getElementTypeOrSelf(op.getIn());
2951387ba48SGiuseppe Rossini   if (auto outVecType = dyn_cast<VectorType>(outType)) {
2961387ba48SGiuseppe Rossini     if (outVecType.isScalable())
2971387ba48SGiuseppe Rossini       return failure();
2981387ba48SGiuseppe Rossini     outType = outVecType.getElementType();
2991387ba48SGiuseppe Rossini   }
3001387ba48SGiuseppe Rossini   return success(outType.isF16() && inputType.isF32());
3011387ba48SGiuseppe Rossini }
3021387ba48SGiuseppe Rossini 
3031387ba48SGiuseppe Rossini void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
3041387ba48SGiuseppe Rossini                                             PatternRewriter &rewriter) const {
3051387ba48SGiuseppe Rossini   Location loc = op.getLoc();
3061387ba48SGiuseppe Rossini   Value in = op.getIn();
3071387ba48SGiuseppe Rossini   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
3081387ba48SGiuseppe Rossini   VectorType truncResType = VectorType::get(2, outElemType);
3091387ba48SGiuseppe Rossini   auto inVectorTy = dyn_cast<VectorType>(in.getType());
3101387ba48SGiuseppe Rossini 
3111387ba48SGiuseppe Rossini   // Handle the case where input type is not a vector type
3121387ba48SGiuseppe Rossini   if (!inVectorTy) {
3131387ba48SGiuseppe Rossini     auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
3141387ba48SGiuseppe Rossini     Value asF16s =
3151387ba48SGiuseppe Rossini         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
3168e663039SKunwar Grover     Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
3171387ba48SGiuseppe Rossini     return rewriter.replaceOp(op, result);
3181387ba48SGiuseppe Rossini   }
3191387ba48SGiuseppe Rossini   VectorType outType = cast<VectorType>(op.getOut().getType());
3201387ba48SGiuseppe Rossini   int64_t numElements = outType.getNumElements();
3211387ba48SGiuseppe Rossini   Value zero = rewriter.createOrFold<arith::ConstantOp>(
3221387ba48SGiuseppe Rossini       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
3231387ba48SGiuseppe Rossini   Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
3241387ba48SGiuseppe Rossini 
3251387ba48SGiuseppe Rossini   if (inVectorTy.getRank() > 1) {
3261387ba48SGiuseppe Rossini     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
3271387ba48SGiuseppe Rossini                                  inVectorTy.getElementType());
3281387ba48SGiuseppe Rossini     in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
3291387ba48SGiuseppe Rossini   }
3301387ba48SGiuseppe Rossini 
3311387ba48SGiuseppe Rossini   // Handle the vector case. We also handle the (uncommon) case where the vector
3321387ba48SGiuseppe Rossini   // length is odd
3331387ba48SGiuseppe Rossini   for (int64_t i = 0; i < numElements; i += 2) {
3341387ba48SGiuseppe Rossini     int64_t elemsThisOp = std::min(numElements, i + 2) - i;
3351387ba48SGiuseppe Rossini     Value thisResult = nullptr;
3368e663039SKunwar Grover     Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
3371387ba48SGiuseppe Rossini     Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
3381387ba48SGiuseppe Rossini 
3391387ba48SGiuseppe Rossini     if (elemsThisOp == 2) {
3408e663039SKunwar Grover       elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
3411387ba48SGiuseppe Rossini     }
3421387ba48SGiuseppe Rossini 
3431387ba48SGiuseppe Rossini     thisResult =
3441387ba48SGiuseppe Rossini         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
3451387ba48SGiuseppe Rossini     // Place back the truncated result into the possibly larger vector. If we
3461387ba48SGiuseppe Rossini     // are operating on a size 2 vector, these operations should be folded away
3471387ba48SGiuseppe Rossini     thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
3481387ba48SGiuseppe Rossini         loc, thisResult, 0, elemsThisOp, 1);
3491387ba48SGiuseppe Rossini     result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
3501387ba48SGiuseppe Rossini                                                            result, i, 1);
3511387ba48SGiuseppe Rossini   }
3521387ba48SGiuseppe Rossini 
3531387ba48SGiuseppe Rossini   if (inVectorTy.getRank() != outType.getRank()) {
3541387ba48SGiuseppe Rossini     result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
3551387ba48SGiuseppe Rossini   }
3561387ba48SGiuseppe Rossini 
3571387ba48SGiuseppe Rossini   rewriter.replaceOp(op, result);
3581387ba48SGiuseppe Rossini }
3591387ba48SGiuseppe Rossini 
3602ebd633fSKrzysztof Drewniak void mlir::arith::populateArithToAMDGPUConversionPatterns(
3611387ba48SGiuseppe Rossini     RewritePatternSet &patterns, bool convertFP8Arithmetic,
3621387ba48SGiuseppe Rossini     bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
3631387ba48SGiuseppe Rossini 
3641387ba48SGiuseppe Rossini   if (convertFP8Arithmetic) {
365750e90e4SKrzysztof Drewniak     patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
366750e90e4SKrzysztof Drewniak     patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
3671387ba48SGiuseppe Rossini                                                saturateFP8Truncf, chipset);
3681387ba48SGiuseppe Rossini   }
3691387ba48SGiuseppe Rossini   if (allowPackedF16Rtz)
3701387ba48SGiuseppe Rossini     patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
3712ebd633fSKrzysztof Drewniak }
3722ebd633fSKrzysztof Drewniak 
3732ebd633fSKrzysztof Drewniak void ArithToAMDGPUConversionPass::runOnOperation() {
3742ebd633fSKrzysztof Drewniak   Operation *op = getOperation();
3751387ba48SGiuseppe Rossini   MLIRContext *ctx = &getContext();
3762ebd633fSKrzysztof Drewniak   RewritePatternSet patterns(op->getContext());
3771387ba48SGiuseppe Rossini   FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
3781387ba48SGiuseppe Rossini   if (failed(maybeChipset)) {
3791387ba48SGiuseppe Rossini     emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
3801387ba48SGiuseppe Rossini     return signalPassFailure();
3811387ba48SGiuseppe Rossini   }
3821387ba48SGiuseppe Rossini 
3831387ba48SGiuseppe Rossini   bool convertFP8Arithmetic =
384763bc924SJakub Kuderski       maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
3851387ba48SGiuseppe Rossini   arith::populateArithToAMDGPUConversionPatterns(
3861387ba48SGiuseppe Rossini       patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
3871387ba48SGiuseppe Rossini       *maybeChipset);
38809dfc571SJacques Pienaar   if (failed(applyPatternsGreedily(op, std::move(patterns))))
3892ebd633fSKrzysztof Drewniak     return signalPassFailure();
3902ebd633fSKrzysztof Drewniak }
391