1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" 10 11 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" 12 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 namespace mlir { 25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS 26 #include "mlir/Conversion/Passes.h.inc" 27 } // namespace mlir 28 29 using namespace mlir; 30 using namespace mlir::amdgpu; 31 32 namespace { 33 struct ArithToAMDGPUConversionPass final 34 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> { 35 using impl::ArithToAMDGPUConversionPassBase< 36 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; 37 38 void runOnOperation() override; 39 }; 40 41 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { 42 using OpRewritePattern::OpRewritePattern; 43 44 LogicalResult match(arith::ExtFOp op) const override; 45 void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; 46 }; 47 48 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { 49 bool saturateFP8 = false; 50 TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, 51 Chipset chipset) 52 : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), 53 chipset(chipset) {} 54 Chipset chipset; 55 56 LogicalResult match(arith::TruncFOp op) const override; 57 void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; 58 }; 59 60 struct TruncfToFloat16RewritePattern final 61 : public OpRewritePattern<arith::TruncFOp> { 62 63 using OpRewritePattern<arith::TruncFOp>::OpRewritePattern; 64 65 LogicalResult match(arith::TruncFOp op) const override; 66 void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; 67 }; 68 69 } // end namespace 70 71 static Value castF32To(Type elementType, Value f32, Location loc, 72 PatternRewriter &rewriter) { 73 if (elementType.isF32()) 74 return f32; 75 if (elementType.getIntOrFloatBitWidth() < 32) 76 return rewriter.create<arith::TruncFOp>(loc, elementType, f32); 77 if (elementType.getIntOrFloatBitWidth() > 32) 78 return rewriter.create<arith::ExtFOp>(loc, elementType, f32); 79 llvm_unreachable("The only 32-bit float type is f32"); 80 } 81 82 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { 83 Type inType = op.getIn().getType(); 84 if (auto inVecType = dyn_cast<VectorType>(inType)) { 85 if (inVecType.isScalable()) 86 return failure(); 87 inType = inVecType.getElementType(); 88 } 89 return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType)); 90 } 91 92 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, 93 PatternRewriter &rewriter) const { 94 Location loc = op.getLoc(); 95 Value in = op.getIn(); 96 Type outElemType = getElementTypeOrSelf(op.getOut().getType()); 97 auto inType = dyn_cast<VectorType>(in.getType()); 98 if (!inType) { 99 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( 100 loc, rewriter.getF32Type(), in, 0); 101 Value result = castF32To(outElemType, asFloat, loc, rewriter); 102 return rewriter.replaceOp(op, result); 103 } 104 int64_t numElements = inType.getNumElements(); 105 Value zero = rewriter.create<arith::ConstantOp>( 106 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); 107 if (inType.getShape().empty()) { 108 Value scalarIn = 109 rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); 110 // Recurse to send the 0-D vector case to the 1-D vector case 111 Value scalarExt = 112 rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn); 113 Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero, 114 ArrayRef<int64_t>{}); 115 return rewriter.replaceOp(op, result); 116 } 117 118 VectorType outType = cast<VectorType>(op.getOut().getType()); 119 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, 120 outType.getElementType()); 121 Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); 122 123 if (inType.getRank() > 1) { 124 inType = VectorType::get(SmallVector<int64_t>{numElements}, 125 inType.getElementType()); 126 in = rewriter.create<vector::ShapeCastOp>(loc, inType, in); 127 } 128 129 for (int64_t i = 0; i < numElements; i += 4) { 130 int64_t elemsThisOp = std::min(numElements, i + 4) - i; 131 Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( 132 loc, in, i, elemsThisOp, 1); 133 for (int64_t j = 0; j < elemsThisOp; ++j) { 134 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( 135 loc, rewriter.getF32Type(), inSlice, j); 136 Value asType = castF32To(outElemType, asFloat, loc, rewriter); 137 result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j); 138 } 139 } 140 141 if (inType.getRank() != outType.getRank()) { 142 result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); 143 } 144 145 rewriter.replaceOp(op, result); 146 } 147 148 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { 149 Type type = value.getType(); 150 if (type.isF32()) 151 return value; 152 if (type.getIntOrFloatBitWidth() < 32) 153 return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value); 154 if (type.getIntOrFloatBitWidth() > 32) 155 return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value); 156 llvm_unreachable("The only 32-bit float type is f32"); 157 } 158 159 // If `in` is a finite value, clamp it between the maximum and minimum values 160 // of `outElemType` so that subsequent conversion instructions don't 161 // overflow those out-of-range values to NaN. These semantics are commonly 162 // used in machine-learning contexts where failure to clamp would lead to 163 // excessive NaN production. 164 static Value clampInput(PatternRewriter &rewriter, Location loc, 165 Type outElemType, Value source) { 166 Type sourceType = source.getType(); 167 const llvm::fltSemantics &sourceSem = 168 cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics(); 169 const llvm::fltSemantics &targetSem = 170 cast<FloatType>(outElemType).getFloatSemantics(); 171 172 APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true); 173 APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false); 174 bool ignoredLosesInfo = false; 175 // We can ignore conversion failures here because this conversion promotes 176 // from a smaller type to a larger one - ex. there can be no loss of precision 177 // when casting fp8 to f16. 178 (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo); 179 (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo); 180 181 Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min); 182 Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max); 183 184 Value inf = createScalarOrSplatConstant( 185 rewriter, loc, sourceType, 186 APFloat::getInf(sourceSem, /*Negative=*/false)); 187 Value negInf = createScalarOrSplatConstant( 188 rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true)); 189 Value isInf = rewriter.createOrFold<arith::CmpFOp>( 190 loc, arith::CmpFPredicate::OEQ, source, inf); 191 Value isNegInf = rewriter.createOrFold<arith::CmpFOp>( 192 loc, arith::CmpFPredicate::OEQ, source, negInf); 193 Value isNan = rewriter.createOrFold<arith::CmpFOp>( 194 loc, arith::CmpFPredicate::UNO, source, source); 195 Value isNonFinite = rewriter.create<arith::OrIOp>( 196 loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan); 197 198 Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst); 199 Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst); 200 Value res = 201 rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped); 202 return res; 203 } 204 205 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { 206 // Only supporting default rounding mode as of now. 207 if (op.getRoundingmodeAttr()) 208 return failure(); 209 Type outType = op.getOut().getType(); 210 if (auto outVecType = dyn_cast<VectorType>(outType)) { 211 if (outVecType.isScalable()) 212 return failure(); 213 outType = outVecType.getElementType(); 214 } 215 auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType())); 216 if (inType && inType.getWidth() <= 8 && saturateFP8) 217 // Conversion between 8-bit floats is not supported with truncation enabled. 218 return failure(); 219 return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType)); 220 } 221 222 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, 223 PatternRewriter &rewriter) const { 224 Location loc = op.getLoc(); 225 Value in = op.getIn(); 226 Type outElemType = getElementTypeOrSelf(op.getOut().getType()); 227 if (saturateFP8) 228 in = clampInput(rewriter, loc, outElemType, in); 229 auto inVectorTy = dyn_cast<VectorType>(in.getType()); 230 VectorType truncResType = VectorType::get(4, outElemType); 231 if (!inVectorTy) { 232 Value asFloat = castToF32(in, loc, rewriter); 233 Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( 234 loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, 235 /*existing=*/nullptr); 236 Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0); 237 return rewriter.replaceOp(op, result); 238 } 239 VectorType outType = cast<VectorType>(op.getOut().getType()); 240 int64_t numElements = outType.getNumElements(); 241 Value zero = rewriter.create<arith::ConstantOp>( 242 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); 243 if (outType.getShape().empty()) { 244 Value scalarIn = 245 rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); 246 // Recurse to send the 0-D vector case to the 1-D vector case 247 Value scalarTrunc = 248 rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn); 249 Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero, 250 ArrayRef<int64_t>{}); 251 return rewriter.replaceOp(op, result); 252 } 253 254 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, 255 outType.getElementType()); 256 Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); 257 258 if (inVectorTy.getRank() > 1) { 259 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, 260 inVectorTy.getElementType()); 261 in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); 262 } 263 264 for (int64_t i = 0; i < numElements; i += 4) { 265 int64_t elemsThisOp = std::min(numElements, i + 4) - i; 266 Value thisResult = nullptr; 267 for (int64_t j = 0; j < elemsThisOp; j += 2) { 268 Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j); 269 Value asFloatA = castToF32(elemA, loc, rewriter); 270 Value asFloatB = nullptr; 271 if (j + 1 < elemsThisOp) { 272 Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1); 273 asFloatB = castToF32(elemB, loc, rewriter); 274 } 275 thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( 276 loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); 277 } 278 if (elemsThisOp < 4) 279 thisResult = rewriter.create<vector::ExtractStridedSliceOp>( 280 loc, thisResult, 0, elemsThisOp, 1); 281 result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, 282 result, i, 1); 283 } 284 285 if (inVectorTy.getRank() != outType.getRank()) { 286 result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); 287 } 288 289 rewriter.replaceOp(op, result); 290 } 291 292 LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const { 293 Type outType = op.getOut().getType(); 294 Type inputType = getElementTypeOrSelf(op.getIn()); 295 if (auto outVecType = dyn_cast<VectorType>(outType)) { 296 if (outVecType.isScalable()) 297 return failure(); 298 outType = outVecType.getElementType(); 299 } 300 return success(outType.isF16() && inputType.isF32()); 301 } 302 303 void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op, 304 PatternRewriter &rewriter) const { 305 Location loc = op.getLoc(); 306 Value in = op.getIn(); 307 Type outElemType = getElementTypeOrSelf(op.getOut().getType()); 308 VectorType truncResType = VectorType::get(2, outElemType); 309 auto inVectorTy = dyn_cast<VectorType>(in.getType()); 310 311 // Handle the case where input type is not a vector type 312 if (!inVectorTy) { 313 auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); 314 Value asF16s = 315 rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB); 316 Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0); 317 return rewriter.replaceOp(op, result); 318 } 319 VectorType outType = cast<VectorType>(op.getOut().getType()); 320 int64_t numElements = outType.getNumElements(); 321 Value zero = rewriter.createOrFold<arith::ConstantOp>( 322 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); 323 Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero); 324 325 if (inVectorTy.getRank() > 1) { 326 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, 327 inVectorTy.getElementType()); 328 in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); 329 } 330 331 // Handle the vector case. We also handle the (uncommon) case where the vector 332 // length is odd 333 for (int64_t i = 0; i < numElements; i += 2) { 334 int64_t elemsThisOp = std::min(numElements, i + 2) - i; 335 Value thisResult = nullptr; 336 Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i); 337 Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); 338 339 if (elemsThisOp == 2) { 340 elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1); 341 } 342 343 thisResult = 344 rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB); 345 // Place back the truncated result into the possibly larger vector. If we 346 // are operating on a size 2 vector, these operations should be folded away 347 thisResult = rewriter.create<vector::ExtractStridedSliceOp>( 348 loc, thisResult, 0, elemsThisOp, 1); 349 result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, 350 result, i, 1); 351 } 352 353 if (inVectorTy.getRank() != outType.getRank()) { 354 result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); 355 } 356 357 rewriter.replaceOp(op, result); 358 } 359 360 void mlir::arith::populateArithToAMDGPUConversionPatterns( 361 RewritePatternSet &patterns, bool convertFP8Arithmetic, 362 bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { 363 364 if (convertFP8Arithmetic) { 365 patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext()); 366 patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(), 367 saturateFP8Truncf, chipset); 368 } 369 if (allowPackedF16Rtz) 370 patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext()); 371 } 372 373 void ArithToAMDGPUConversionPass::runOnOperation() { 374 Operation *op = getOperation(); 375 MLIRContext *ctx = &getContext(); 376 RewritePatternSet patterns(op->getContext()); 377 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); 378 if (failed(maybeChipset)) { 379 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); 380 return signalPassFailure(); 381 } 382 383 bool convertFP8Arithmetic = 384 maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0); 385 arith::populateArithToAMDGPUConversionPatterns( 386 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, 387 *maybeChipset); 388 if (failed(applyPatternsGreedily(op, std::move(patterns)))) 389 return signalPassFailure(); 390 } 391