12ebd633fSKrzysztof Drewniak //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// 22ebd633fSKrzysztof Drewniak // 32ebd633fSKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42ebd633fSKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information. 52ebd633fSKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62ebd633fSKrzysztof Drewniak // 72ebd633fSKrzysztof Drewniak //===----------------------------------------------------------------------===// 82ebd633fSKrzysztof Drewniak 92ebd633fSKrzysztof Drewniak #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" 102ebd633fSKrzysztof Drewniak 112ebd633fSKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" 121387ba48SGiuseppe Rossini #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" 132ebd633fSKrzysztof Drewniak #include "mlir/Dialect/Arith/IR/Arith.h" 14750e90e4SKrzysztof Drewniak #include "mlir/Dialect/Arith/Utils/Utils.h" 151387ba48SGiuseppe Rossini #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 161387ba48SGiuseppe Rossini #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 172ebd633fSKrzysztof Drewniak #include "mlir/Dialect/Vector/IR/VectorOps.h" 182ebd633fSKrzysztof Drewniak #include "mlir/IR/BuiltinTypes.h" 192ebd633fSKrzysztof Drewniak #include "mlir/IR/PatternMatch.h" 202ebd633fSKrzysztof Drewniak #include "mlir/IR/TypeUtilities.h" 212ebd633fSKrzysztof Drewniak #include "mlir/Pass/Pass.h" 222ebd633fSKrzysztof Drewniak #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 232ebd633fSKrzysztof Drewniak 242ebd633fSKrzysztof Drewniak namespace mlir { 252ebd633fSKrzysztof Drewniak #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS 262ebd633fSKrzysztof Drewniak #include "mlir/Conversion/Passes.h.inc" 272ebd633fSKrzysztof Drewniak } // namespace mlir 282ebd633fSKrzysztof Drewniak 292ebd633fSKrzysztof Drewniak using namespace mlir; 301387ba48SGiuseppe Rossini using namespace mlir::amdgpu; 312ebd633fSKrzysztof Drewniak 322ebd633fSKrzysztof Drewniak namespace { 332ebd633fSKrzysztof Drewniak struct ArithToAMDGPUConversionPass final 342ebd633fSKrzysztof Drewniak : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> { 352ebd633fSKrzysztof Drewniak using impl::ArithToAMDGPUConversionPassBase< 362ebd633fSKrzysztof Drewniak ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; 372ebd633fSKrzysztof Drewniak 382ebd633fSKrzysztof Drewniak void runOnOperation() override; 392ebd633fSKrzysztof Drewniak }; 402ebd633fSKrzysztof Drewniak 41750e90e4SKrzysztof Drewniak struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { 42750e90e4SKrzysztof Drewniak using OpRewritePattern::OpRewritePattern; 432ebd633fSKrzysztof Drewniak 442ebd633fSKrzysztof Drewniak LogicalResult match(arith::ExtFOp op) const override; 452ebd633fSKrzysztof Drewniak void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; 462ebd633fSKrzysztof Drewniak }; 472ebd633fSKrzysztof Drewniak 48750e90e4SKrzysztof Drewniak struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { 49750e90e4SKrzysztof Drewniak bool saturateFP8 = false; 501387ba48SGiuseppe Rossini TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, 511387ba48SGiuseppe Rossini Chipset chipset) 521387ba48SGiuseppe Rossini : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), 531387ba48SGiuseppe Rossini chipset(chipset) {} 541387ba48SGiuseppe Rossini Chipset chipset; 552ebd633fSKrzysztof Drewniak 562ebd633fSKrzysztof Drewniak LogicalResult match(arith::TruncFOp op) const override; 572ebd633fSKrzysztof Drewniak void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; 582ebd633fSKrzysztof Drewniak }; 591387ba48SGiuseppe Rossini 601387ba48SGiuseppe Rossini struct TruncfToFloat16RewritePattern final 611387ba48SGiuseppe Rossini : public OpRewritePattern<arith::TruncFOp> { 621387ba48SGiuseppe Rossini 631387ba48SGiuseppe Rossini using OpRewritePattern<arith::TruncFOp>::OpRewritePattern; 641387ba48SGiuseppe Rossini 651387ba48SGiuseppe Rossini LogicalResult match(arith::TruncFOp op) const override; 661387ba48SGiuseppe Rossini void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; 671387ba48SGiuseppe Rossini }; 681387ba48SGiuseppe Rossini 692ebd633fSKrzysztof Drewniak } // end namespace 702ebd633fSKrzysztof Drewniak 712ebd633fSKrzysztof Drewniak static Value castF32To(Type elementType, Value f32, Location loc, 722ebd633fSKrzysztof Drewniak PatternRewriter &rewriter) { 732ebd633fSKrzysztof Drewniak if (elementType.isF32()) 742ebd633fSKrzysztof Drewniak return f32; 752ebd633fSKrzysztof Drewniak if (elementType.getIntOrFloatBitWidth() < 32) 762ebd633fSKrzysztof Drewniak return rewriter.create<arith::TruncFOp>(loc, elementType, f32); 772ebd633fSKrzysztof Drewniak if (elementType.getIntOrFloatBitWidth() > 32) 782ebd633fSKrzysztof Drewniak return rewriter.create<arith::ExtFOp>(loc, elementType, f32); 792ebd633fSKrzysztof Drewniak llvm_unreachable("The only 32-bit float type is f32"); 802ebd633fSKrzysztof Drewniak } 812ebd633fSKrzysztof Drewniak 82750e90e4SKrzysztof Drewniak LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { 832ebd633fSKrzysztof Drewniak Type inType = op.getIn().getType(); 84a5757c5bSChristian Sigg if (auto inVecType = dyn_cast<VectorType>(inType)) { 852ebd633fSKrzysztof Drewniak if (inVecType.isScalable()) 862ebd633fSKrzysztof Drewniak return failure(); 872ebd633fSKrzysztof Drewniak inType = inVecType.getElementType(); 882ebd633fSKrzysztof Drewniak } 89*7a77f14cSMatthias Springer return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType)); 902ebd633fSKrzysztof Drewniak } 912ebd633fSKrzysztof Drewniak 92750e90e4SKrzysztof Drewniak void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, 932ebd633fSKrzysztof Drewniak PatternRewriter &rewriter) const { 942ebd633fSKrzysztof Drewniak Location loc = op.getLoc(); 952ebd633fSKrzysztof Drewniak Value in = op.getIn(); 962ebd633fSKrzysztof Drewniak Type outElemType = getElementTypeOrSelf(op.getOut().getType()); 97f35318e8SRob Suderman auto inType = dyn_cast<VectorType>(in.getType()); 98f35318e8SRob Suderman if (!inType) { 992ebd633fSKrzysztof Drewniak Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( 1002ebd633fSKrzysztof Drewniak loc, rewriter.getF32Type(), in, 0); 1012ebd633fSKrzysztof Drewniak Value result = castF32To(outElemType, asFloat, loc, rewriter); 1022ebd633fSKrzysztof Drewniak return rewriter.replaceOp(op, result); 1032ebd633fSKrzysztof Drewniak } 1042ebd633fSKrzysztof Drewniak int64_t numElements = inType.getNumElements(); 10565066c02SHugo Trachino Value zero = rewriter.create<arith::ConstantOp>( 1062ebd633fSKrzysztof Drewniak loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); 1072ebd633fSKrzysztof Drewniak if (inType.getShape().empty()) { 108750e90e4SKrzysztof Drewniak Value scalarIn = 109750e90e4SKrzysztof Drewniak rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); 1102ebd633fSKrzysztof Drewniak // Recurse to send the 0-D vector case to the 1-D vector case 1112ebd633fSKrzysztof Drewniak Value scalarExt = 1122ebd633fSKrzysztof Drewniak rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn); 113f35318e8SRob Suderman Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero, 114750e90e4SKrzysztof Drewniak ArrayRef<int64_t>{}); 1152ebd633fSKrzysztof Drewniak return rewriter.replaceOp(op, result); 1162ebd633fSKrzysztof Drewniak } 117f35318e8SRob Suderman 118f35318e8SRob Suderman VectorType outType = cast<VectorType>(op.getOut().getType()); 119f35318e8SRob Suderman VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, 120f35318e8SRob Suderman outType.getElementType()); 121f35318e8SRob Suderman Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); 122f35318e8SRob Suderman 123f35318e8SRob Suderman if (inType.getRank() > 1) { 124f35318e8SRob Suderman inType = VectorType::get(SmallVector<int64_t>{numElements}, 125f35318e8SRob Suderman inType.getElementType()); 126f35318e8SRob Suderman in = rewriter.create<vector::ShapeCastOp>(loc, inType, in); 127f35318e8SRob Suderman } 128f35318e8SRob Suderman 1292ebd633fSKrzysztof Drewniak for (int64_t i = 0; i < numElements; i += 4) { 1302ebd633fSKrzysztof Drewniak int64_t elemsThisOp = std::min(numElements, i + 4) - i; 1312ebd633fSKrzysztof Drewniak Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( 1322ebd633fSKrzysztof Drewniak loc, in, i, elemsThisOp, 1); 1332ebd633fSKrzysztof Drewniak for (int64_t j = 0; j < elemsThisOp; ++j) { 1342ebd633fSKrzysztof Drewniak Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( 1352ebd633fSKrzysztof Drewniak loc, rewriter.getF32Type(), inSlice, j); 1362ebd633fSKrzysztof Drewniak Value asType = castF32To(outElemType, asFloat, loc, rewriter); 137750e90e4SKrzysztof Drewniak result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j); 1382ebd633fSKrzysztof Drewniak } 1392ebd633fSKrzysztof Drewniak } 140f35318e8SRob Suderman 141f35318e8SRob Suderman if (inType.getRank() != outType.getRank()) { 142f35318e8SRob Suderman result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); 143f35318e8SRob Suderman } 144f35318e8SRob Suderman 1452ebd633fSKrzysztof Drewniak rewriter.replaceOp(op, result); 1462ebd633fSKrzysztof Drewniak } 1472ebd633fSKrzysztof Drewniak 1482ebd633fSKrzysztof Drewniak static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { 1492ebd633fSKrzysztof Drewniak Type type = value.getType(); 1502ebd633fSKrzysztof Drewniak if (type.isF32()) 1512ebd633fSKrzysztof Drewniak return value; 1522ebd633fSKrzysztof Drewniak if (type.getIntOrFloatBitWidth() < 32) 1532ebd633fSKrzysztof Drewniak return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value); 1542ebd633fSKrzysztof Drewniak if (type.getIntOrFloatBitWidth() > 32) 1552ebd633fSKrzysztof Drewniak return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value); 1562ebd633fSKrzysztof Drewniak llvm_unreachable("The only 32-bit float type is f32"); 1572ebd633fSKrzysztof Drewniak } 1582ebd633fSKrzysztof Drewniak 159750e90e4SKrzysztof Drewniak // If `in` is a finite value, clamp it between the maximum and minimum values 160750e90e4SKrzysztof Drewniak // of `outElemType` so that subsequent conversion instructions don't 161750e90e4SKrzysztof Drewniak // overflow those out-of-range values to NaN. These semantics are commonly 162750e90e4SKrzysztof Drewniak // used in machine-learning contexts where failure to clamp would lead to 163750e90e4SKrzysztof Drewniak // excessive NaN production. 164750e90e4SKrzysztof Drewniak static Value clampInput(PatternRewriter &rewriter, Location loc, 165750e90e4SKrzysztof Drewniak Type outElemType, Value source) { 166750e90e4SKrzysztof Drewniak Type sourceType = source.getType(); 167750e90e4SKrzysztof Drewniak const llvm::fltSemantics &sourceSem = 168750e90e4SKrzysztof Drewniak cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics(); 169750e90e4SKrzysztof Drewniak const llvm::fltSemantics &targetSem = 170750e90e4SKrzysztof Drewniak cast<FloatType>(outElemType).getFloatSemantics(); 171750e90e4SKrzysztof Drewniak 172750e90e4SKrzysztof Drewniak APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true); 173750e90e4SKrzysztof Drewniak APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false); 174750e90e4SKrzysztof Drewniak bool ignoredLosesInfo = false; 175750e90e4SKrzysztof Drewniak // We can ignore conversion failures here because this conversion promotes 176750e90e4SKrzysztof Drewniak // from a smaller type to a larger one - ex. there can be no loss of precision 177750e90e4SKrzysztof Drewniak // when casting fp8 to f16. 178750e90e4SKrzysztof Drewniak (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo); 179750e90e4SKrzysztof Drewniak (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo); 180750e90e4SKrzysztof Drewniak 181750e90e4SKrzysztof Drewniak Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min); 182750e90e4SKrzysztof Drewniak Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max); 183750e90e4SKrzysztof Drewniak 184750e90e4SKrzysztof Drewniak Value inf = createScalarOrSplatConstant( 185750e90e4SKrzysztof Drewniak rewriter, loc, sourceType, 186750e90e4SKrzysztof Drewniak APFloat::getInf(sourceSem, /*Negative=*/false)); 187750e90e4SKrzysztof Drewniak Value negInf = createScalarOrSplatConstant( 188750e90e4SKrzysztof Drewniak rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true)); 189750e90e4SKrzysztof Drewniak Value isInf = rewriter.createOrFold<arith::CmpFOp>( 190750e90e4SKrzysztof Drewniak loc, arith::CmpFPredicate::OEQ, source, inf); 191750e90e4SKrzysztof Drewniak Value isNegInf = rewriter.createOrFold<arith::CmpFOp>( 192750e90e4SKrzysztof Drewniak loc, arith::CmpFPredicate::OEQ, source, negInf); 193750e90e4SKrzysztof Drewniak Value isNan = rewriter.createOrFold<arith::CmpFOp>( 194750e90e4SKrzysztof Drewniak loc, arith::CmpFPredicate::UNO, source, source); 195750e90e4SKrzysztof Drewniak Value isNonFinite = rewriter.create<arith::OrIOp>( 196750e90e4SKrzysztof Drewniak loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan); 197750e90e4SKrzysztof Drewniak 198750e90e4SKrzysztof Drewniak Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst); 199750e90e4SKrzysztof Drewniak Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst); 200750e90e4SKrzysztof Drewniak Value res = 201750e90e4SKrzysztof Drewniak rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped); 202750e90e4SKrzysztof Drewniak return res; 203750e90e4SKrzysztof Drewniak } 204750e90e4SKrzysztof Drewniak 205750e90e4SKrzysztof Drewniak LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { 2068827ff92SVictor Perez // Only supporting default rounding mode as of now. 2078827ff92SVictor Perez if (op.getRoundingmodeAttr()) 2088827ff92SVictor Perez return failure(); 2092ebd633fSKrzysztof Drewniak Type outType = op.getOut().getType(); 210a5757c5bSChristian Sigg if (auto outVecType = dyn_cast<VectorType>(outType)) { 2112ebd633fSKrzysztof Drewniak if (outVecType.isScalable()) 2122ebd633fSKrzysztof Drewniak return failure(); 2132ebd633fSKrzysztof Drewniak outType = outVecType.getElementType(); 2142ebd633fSKrzysztof Drewniak } 215750e90e4SKrzysztof Drewniak auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType())); 216750e90e4SKrzysztof Drewniak if (inType && inType.getWidth() <= 8 && saturateFP8) 217750e90e4SKrzysztof Drewniak // Conversion between 8-bit floats is not supported with truncation enabled. 218750e90e4SKrzysztof Drewniak return failure(); 219*7a77f14cSMatthias Springer return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType)); 2202ebd633fSKrzysztof Drewniak } 2212ebd633fSKrzysztof Drewniak 222750e90e4SKrzysztof Drewniak void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, 2232ebd633fSKrzysztof Drewniak PatternRewriter &rewriter) const { 2242ebd633fSKrzysztof Drewniak Location loc = op.getLoc(); 2252ebd633fSKrzysztof Drewniak Value in = op.getIn(); 2262ebd633fSKrzysztof Drewniak Type outElemType = getElementTypeOrSelf(op.getOut().getType()); 227750e90e4SKrzysztof Drewniak if (saturateFP8) 228750e90e4SKrzysztof Drewniak in = clampInput(rewriter, loc, outElemType, in); 229f35318e8SRob Suderman auto inVectorTy = dyn_cast<VectorType>(in.getType()); 2302ebd633fSKrzysztof Drewniak VectorType truncResType = VectorType::get(4, outElemType); 231f35318e8SRob Suderman if (!inVectorTy) { 2322ebd633fSKrzysztof Drewniak Value asFloat = castToF32(in, loc, rewriter); 2332ebd633fSKrzysztof Drewniak Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( 2342ebd633fSKrzysztof Drewniak loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, 2352ebd633fSKrzysztof Drewniak /*existing=*/nullptr); 236750e90e4SKrzysztof Drewniak Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0); 2372ebd633fSKrzysztof Drewniak return rewriter.replaceOp(op, result); 2382ebd633fSKrzysztof Drewniak } 239a5757c5bSChristian Sigg VectorType outType = cast<VectorType>(op.getOut().getType()); 2402ebd633fSKrzysztof Drewniak int64_t numElements = outType.getNumElements(); 24165066c02SHugo Trachino Value zero = rewriter.create<arith::ConstantOp>( 2422ebd633fSKrzysztof Drewniak loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); 2432ebd633fSKrzysztof Drewniak if (outType.getShape().empty()) { 244750e90e4SKrzysztof Drewniak Value scalarIn = 245750e90e4SKrzysztof Drewniak rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); 2462ebd633fSKrzysztof Drewniak // Recurse to send the 0-D vector case to the 1-D vector case 2472ebd633fSKrzysztof Drewniak Value scalarTrunc = 2482ebd633fSKrzysztof Drewniak rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn); 249f35318e8SRob Suderman Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero, 250750e90e4SKrzysztof Drewniak ArrayRef<int64_t>{}); 2512ebd633fSKrzysztof Drewniak return rewriter.replaceOp(op, result); 2522ebd633fSKrzysztof Drewniak } 2532ebd633fSKrzysztof Drewniak 254f35318e8SRob Suderman VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, 255f35318e8SRob Suderman outType.getElementType()); 256f35318e8SRob Suderman Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); 257f35318e8SRob Suderman 258f35318e8SRob Suderman if (inVectorTy.getRank() > 1) { 259f35318e8SRob Suderman inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, 260f35318e8SRob Suderman inVectorTy.getElementType()); 261f35318e8SRob Suderman in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); 262f35318e8SRob Suderman } 263f35318e8SRob Suderman 2642ebd633fSKrzysztof Drewniak for (int64_t i = 0; i < numElements; i += 4) { 2652ebd633fSKrzysztof Drewniak int64_t elemsThisOp = std::min(numElements, i + 4) - i; 2662ebd633fSKrzysztof Drewniak Value thisResult = nullptr; 2672ebd633fSKrzysztof Drewniak for (int64_t j = 0; j < elemsThisOp; j += 2) { 268750e90e4SKrzysztof Drewniak Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j); 2692ebd633fSKrzysztof Drewniak Value asFloatA = castToF32(elemA, loc, rewriter); 2702ebd633fSKrzysztof Drewniak Value asFloatB = nullptr; 2712ebd633fSKrzysztof Drewniak if (j + 1 < elemsThisOp) { 272750e90e4SKrzysztof Drewniak Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1); 2732ebd633fSKrzysztof Drewniak asFloatB = castToF32(elemB, loc, rewriter); 2742ebd633fSKrzysztof Drewniak } 2752ebd633fSKrzysztof Drewniak thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( 2762ebd633fSKrzysztof Drewniak loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); 2772ebd633fSKrzysztof Drewniak } 2782ebd633fSKrzysztof Drewniak if (elemsThisOp < 4) 2792ebd633fSKrzysztof Drewniak thisResult = rewriter.create<vector::ExtractStridedSliceOp>( 2802ebd633fSKrzysztof Drewniak loc, thisResult, 0, elemsThisOp, 1); 2812ebd633fSKrzysztof Drewniak result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, 2822ebd633fSKrzysztof Drewniak result, i, 1); 2832ebd633fSKrzysztof Drewniak } 284f35318e8SRob Suderman 285f35318e8SRob Suderman if (inVectorTy.getRank() != outType.getRank()) { 286f35318e8SRob Suderman result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); 287f35318e8SRob Suderman } 288f35318e8SRob Suderman 2892ebd633fSKrzysztof Drewniak rewriter.replaceOp(op, result); 2902ebd633fSKrzysztof Drewniak } 2912ebd633fSKrzysztof Drewniak 2921387ba48SGiuseppe Rossini LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const { 2931387ba48SGiuseppe Rossini Type outType = op.getOut().getType(); 2941387ba48SGiuseppe Rossini Type inputType = getElementTypeOrSelf(op.getIn()); 2951387ba48SGiuseppe Rossini if (auto outVecType = dyn_cast<VectorType>(outType)) { 2961387ba48SGiuseppe Rossini if (outVecType.isScalable()) 2971387ba48SGiuseppe Rossini return failure(); 2981387ba48SGiuseppe Rossini outType = outVecType.getElementType(); 2991387ba48SGiuseppe Rossini } 3001387ba48SGiuseppe Rossini return success(outType.isF16() && inputType.isF32()); 3011387ba48SGiuseppe Rossini } 3021387ba48SGiuseppe Rossini 3031387ba48SGiuseppe Rossini void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op, 3041387ba48SGiuseppe Rossini PatternRewriter &rewriter) const { 3051387ba48SGiuseppe Rossini Location loc = op.getLoc(); 3061387ba48SGiuseppe Rossini Value in = op.getIn(); 3071387ba48SGiuseppe Rossini Type outElemType = getElementTypeOrSelf(op.getOut().getType()); 3081387ba48SGiuseppe Rossini VectorType truncResType = VectorType::get(2, outElemType); 3091387ba48SGiuseppe Rossini auto inVectorTy = dyn_cast<VectorType>(in.getType()); 3101387ba48SGiuseppe Rossini 3111387ba48SGiuseppe Rossini // Handle the case where input type is not a vector type 3121387ba48SGiuseppe Rossini if (!inVectorTy) { 3131387ba48SGiuseppe Rossini auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); 3141387ba48SGiuseppe Rossini Value asF16s = 3151387ba48SGiuseppe Rossini rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB); 3168e663039SKunwar Grover Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0); 3171387ba48SGiuseppe Rossini return rewriter.replaceOp(op, result); 3181387ba48SGiuseppe Rossini } 3191387ba48SGiuseppe Rossini VectorType outType = cast<VectorType>(op.getOut().getType()); 3201387ba48SGiuseppe Rossini int64_t numElements = outType.getNumElements(); 3211387ba48SGiuseppe Rossini Value zero = rewriter.createOrFold<arith::ConstantOp>( 3221387ba48SGiuseppe Rossini loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); 3231387ba48SGiuseppe Rossini Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero); 3241387ba48SGiuseppe Rossini 3251387ba48SGiuseppe Rossini if (inVectorTy.getRank() > 1) { 3261387ba48SGiuseppe Rossini inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, 3271387ba48SGiuseppe Rossini inVectorTy.getElementType()); 3281387ba48SGiuseppe Rossini in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); 3291387ba48SGiuseppe Rossini } 3301387ba48SGiuseppe Rossini 3311387ba48SGiuseppe Rossini // Handle the vector case. We also handle the (uncommon) case where the vector 3321387ba48SGiuseppe Rossini // length is odd 3331387ba48SGiuseppe Rossini for (int64_t i = 0; i < numElements; i += 2) { 3341387ba48SGiuseppe Rossini int64_t elemsThisOp = std::min(numElements, i + 2) - i; 3351387ba48SGiuseppe Rossini Value thisResult = nullptr; 3368e663039SKunwar Grover Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i); 3371387ba48SGiuseppe Rossini Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); 3381387ba48SGiuseppe Rossini 3391387ba48SGiuseppe Rossini if (elemsThisOp == 2) { 3408e663039SKunwar Grover elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1); 3411387ba48SGiuseppe Rossini } 3421387ba48SGiuseppe Rossini 3431387ba48SGiuseppe Rossini thisResult = 3441387ba48SGiuseppe Rossini rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB); 3451387ba48SGiuseppe Rossini // Place back the truncated result into the possibly larger vector. If we 3461387ba48SGiuseppe Rossini // are operating on a size 2 vector, these operations should be folded away 3471387ba48SGiuseppe Rossini thisResult = rewriter.create<vector::ExtractStridedSliceOp>( 3481387ba48SGiuseppe Rossini loc, thisResult, 0, elemsThisOp, 1); 3491387ba48SGiuseppe Rossini result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, 3501387ba48SGiuseppe Rossini result, i, 1); 3511387ba48SGiuseppe Rossini } 3521387ba48SGiuseppe Rossini 3531387ba48SGiuseppe Rossini if (inVectorTy.getRank() != outType.getRank()) { 3541387ba48SGiuseppe Rossini result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); 3551387ba48SGiuseppe Rossini } 3561387ba48SGiuseppe Rossini 3571387ba48SGiuseppe Rossini rewriter.replaceOp(op, result); 3581387ba48SGiuseppe Rossini } 3591387ba48SGiuseppe Rossini 3602ebd633fSKrzysztof Drewniak void mlir::arith::populateArithToAMDGPUConversionPatterns( 3611387ba48SGiuseppe Rossini RewritePatternSet &patterns, bool convertFP8Arithmetic, 3621387ba48SGiuseppe Rossini bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { 3631387ba48SGiuseppe Rossini 3641387ba48SGiuseppe Rossini if (convertFP8Arithmetic) { 365750e90e4SKrzysztof Drewniak patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext()); 366750e90e4SKrzysztof Drewniak patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(), 3671387ba48SGiuseppe Rossini saturateFP8Truncf, chipset); 3681387ba48SGiuseppe Rossini } 3691387ba48SGiuseppe Rossini if (allowPackedF16Rtz) 3701387ba48SGiuseppe Rossini patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext()); 3712ebd633fSKrzysztof Drewniak } 3722ebd633fSKrzysztof Drewniak 3732ebd633fSKrzysztof Drewniak void ArithToAMDGPUConversionPass::runOnOperation() { 3742ebd633fSKrzysztof Drewniak Operation *op = getOperation(); 3751387ba48SGiuseppe Rossini MLIRContext *ctx = &getContext(); 3762ebd633fSKrzysztof Drewniak RewritePatternSet patterns(op->getContext()); 3771387ba48SGiuseppe Rossini FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); 3781387ba48SGiuseppe Rossini if (failed(maybeChipset)) { 3791387ba48SGiuseppe Rossini emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); 3801387ba48SGiuseppe Rossini return signalPassFailure(); 3811387ba48SGiuseppe Rossini } 3821387ba48SGiuseppe Rossini 3831387ba48SGiuseppe Rossini bool convertFP8Arithmetic = 384763bc924SJakub Kuderski maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0); 3851387ba48SGiuseppe Rossini arith::populateArithToAMDGPUConversionPatterns( 3861387ba48SGiuseppe Rossini patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, 3871387ba48SGiuseppe Rossini *maybeChipset); 38809dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(patterns)))) 3892ebd633fSKrzysztof Drewniak return signalPassFailure(); 3902ebd633fSKrzysztof Drewniak } 391