xref: /llvm-project/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (revision 7a77f14c0abfbecbfb800ea8d974e66d81ee516a)
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
10 
11 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
26 #include "mlir/Conversion/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::amdgpu;
31 
32 namespace {
33 struct ArithToAMDGPUConversionPass final
34     : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
35   using impl::ArithToAMDGPUConversionPassBase<
36       ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
37 
38   void runOnOperation() override;
39 };
40 
41 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
42   using OpRewritePattern::OpRewritePattern;
43 
44   LogicalResult match(arith::ExtFOp op) const override;
45   void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
46 };
47 
48 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
49   bool saturateFP8 = false;
50   TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
51                                Chipset chipset)
52       : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
53         chipset(chipset) {}
54   Chipset chipset;
55 
56   LogicalResult match(arith::TruncFOp op) const override;
57   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
58 };
59 
60 struct TruncfToFloat16RewritePattern final
61     : public OpRewritePattern<arith::TruncFOp> {
62 
63   using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
64 
65   LogicalResult match(arith::TruncFOp op) const override;
66   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
67 };
68 
69 } // end namespace
70 
71 static Value castF32To(Type elementType, Value f32, Location loc,
72                        PatternRewriter &rewriter) {
73   if (elementType.isF32())
74     return f32;
75   if (elementType.getIntOrFloatBitWidth() < 32)
76     return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
77   if (elementType.getIntOrFloatBitWidth() > 32)
78     return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
79   llvm_unreachable("The only 32-bit float type is f32");
80 }
81 
82 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
83   Type inType = op.getIn().getType();
84   if (auto inVecType = dyn_cast<VectorType>(inType)) {
85     if (inVecType.isScalable())
86       return failure();
87     inType = inVecType.getElementType();
88   }
89   return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
90 }
91 
92 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
93                                          PatternRewriter &rewriter) const {
94   Location loc = op.getLoc();
95   Value in = op.getIn();
96   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
97   auto inType = dyn_cast<VectorType>(in.getType());
98   if (!inType) {
99     Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
100         loc, rewriter.getF32Type(), in, 0);
101     Value result = castF32To(outElemType, asFloat, loc, rewriter);
102     return rewriter.replaceOp(op, result);
103   }
104   int64_t numElements = inType.getNumElements();
105   Value zero = rewriter.create<arith::ConstantOp>(
106       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
107   if (inType.getShape().empty()) {
108     Value scalarIn =
109         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
110     // Recurse to send the 0-D vector case to the 1-D vector case
111     Value scalarExt =
112         rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
113     Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
114                                                      ArrayRef<int64_t>{});
115     return rewriter.replaceOp(op, result);
116   }
117 
118   VectorType outType = cast<VectorType>(op.getOut().getType());
119   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
120                                       outType.getElementType());
121   Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
122 
123   if (inType.getRank() > 1) {
124     inType = VectorType::get(SmallVector<int64_t>{numElements},
125                              inType.getElementType());
126     in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
127   }
128 
129   for (int64_t i = 0; i < numElements; i += 4) {
130     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
131     Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
132         loc, in, i, elemsThisOp, 1);
133     for (int64_t j = 0; j < elemsThisOp; ++j) {
134       Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
135           loc, rewriter.getF32Type(), inSlice, j);
136       Value asType = castF32To(outElemType, asFloat, loc, rewriter);
137       result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
138     }
139   }
140 
141   if (inType.getRank() != outType.getRank()) {
142     result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
143   }
144 
145   rewriter.replaceOp(op, result);
146 }
147 
148 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
149   Type type = value.getType();
150   if (type.isF32())
151     return value;
152   if (type.getIntOrFloatBitWidth() < 32)
153     return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
154   if (type.getIntOrFloatBitWidth() > 32)
155     return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
156   llvm_unreachable("The only 32-bit float type is f32");
157 }
158 
159 // If `in` is a finite value, clamp it between the maximum and minimum values
160 // of `outElemType` so that subsequent conversion instructions don't
161 // overflow those out-of-range values to NaN. These semantics are commonly
162 // used in machine-learning contexts where failure to clamp would lead to
163 // excessive NaN production.
164 static Value clampInput(PatternRewriter &rewriter, Location loc,
165                         Type outElemType, Value source) {
166   Type sourceType = source.getType();
167   const llvm::fltSemantics &sourceSem =
168       cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
169   const llvm::fltSemantics &targetSem =
170       cast<FloatType>(outElemType).getFloatSemantics();
171 
172   APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
173   APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
174   bool ignoredLosesInfo = false;
175   // We can ignore conversion failures here because this conversion promotes
176   // from a smaller type to a larger one - ex. there can be no loss of precision
177   // when casting fp8 to f16.
178   (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
179   (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
180 
181   Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
182   Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
183 
184   Value inf = createScalarOrSplatConstant(
185       rewriter, loc, sourceType,
186       APFloat::getInf(sourceSem, /*Negative=*/false));
187   Value negInf = createScalarOrSplatConstant(
188       rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
189   Value isInf = rewriter.createOrFold<arith::CmpFOp>(
190       loc, arith::CmpFPredicate::OEQ, source, inf);
191   Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
192       loc, arith::CmpFPredicate::OEQ, source, negInf);
193   Value isNan = rewriter.createOrFold<arith::CmpFOp>(
194       loc, arith::CmpFPredicate::UNO, source, source);
195   Value isNonFinite = rewriter.create<arith::OrIOp>(
196       loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
197 
198   Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
199   Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
200   Value res =
201       rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
202   return res;
203 }
204 
205 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
206   // Only supporting default rounding mode as of now.
207   if (op.getRoundingmodeAttr())
208     return failure();
209   Type outType = op.getOut().getType();
210   if (auto outVecType = dyn_cast<VectorType>(outType)) {
211     if (outVecType.isScalable())
212       return failure();
213     outType = outVecType.getElementType();
214   }
215   auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
216   if (inType && inType.getWidth() <= 8 && saturateFP8)
217     // Conversion between 8-bit floats is not supported with truncation enabled.
218     return failure();
219   return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
220 }
221 
222 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
223                                            PatternRewriter &rewriter) const {
224   Location loc = op.getLoc();
225   Value in = op.getIn();
226   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
227   if (saturateFP8)
228     in = clampInput(rewriter, loc, outElemType, in);
229   auto inVectorTy = dyn_cast<VectorType>(in.getType());
230   VectorType truncResType = VectorType::get(4, outElemType);
231   if (!inVectorTy) {
232     Value asFloat = castToF32(in, loc, rewriter);
233     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
234         loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
235         /*existing=*/nullptr);
236     Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
237     return rewriter.replaceOp(op, result);
238   }
239   VectorType outType = cast<VectorType>(op.getOut().getType());
240   int64_t numElements = outType.getNumElements();
241   Value zero = rewriter.create<arith::ConstantOp>(
242       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
243   if (outType.getShape().empty()) {
244     Value scalarIn =
245         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
246     // Recurse to send the 0-D vector case to the 1-D vector case
247     Value scalarTrunc =
248         rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
249     Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
250                                                      ArrayRef<int64_t>{});
251     return rewriter.replaceOp(op, result);
252   }
253 
254   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
255                                       outType.getElementType());
256   Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
257 
258   if (inVectorTy.getRank() > 1) {
259     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
260                                  inVectorTy.getElementType());
261     in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
262   }
263 
264   for (int64_t i = 0; i < numElements; i += 4) {
265     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
266     Value thisResult = nullptr;
267     for (int64_t j = 0; j < elemsThisOp; j += 2) {
268       Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
269       Value asFloatA = castToF32(elemA, loc, rewriter);
270       Value asFloatB = nullptr;
271       if (j + 1 < elemsThisOp) {
272         Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
273         asFloatB = castToF32(elemB, loc, rewriter);
274       }
275       thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
276           loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
277     }
278     if (elemsThisOp < 4)
279       thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
280           loc, thisResult, 0, elemsThisOp, 1);
281     result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
282                                                            result, i, 1);
283   }
284 
285   if (inVectorTy.getRank() != outType.getRank()) {
286     result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
287   }
288 
289   rewriter.replaceOp(op, result);
290 }
291 
292 LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
293   Type outType = op.getOut().getType();
294   Type inputType = getElementTypeOrSelf(op.getIn());
295   if (auto outVecType = dyn_cast<VectorType>(outType)) {
296     if (outVecType.isScalable())
297       return failure();
298     outType = outVecType.getElementType();
299   }
300   return success(outType.isF16() && inputType.isF32());
301 }
302 
303 void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
304                                             PatternRewriter &rewriter) const {
305   Location loc = op.getLoc();
306   Value in = op.getIn();
307   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
308   VectorType truncResType = VectorType::get(2, outElemType);
309   auto inVectorTy = dyn_cast<VectorType>(in.getType());
310 
311   // Handle the case where input type is not a vector type
312   if (!inVectorTy) {
313     auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
314     Value asF16s =
315         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316     Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
317     return rewriter.replaceOp(op, result);
318   }
319   VectorType outType = cast<VectorType>(op.getOut().getType());
320   int64_t numElements = outType.getNumElements();
321   Value zero = rewriter.createOrFold<arith::ConstantOp>(
322       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
323   Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
324 
325   if (inVectorTy.getRank() > 1) {
326     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
327                                  inVectorTy.getElementType());
328     in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
329   }
330 
331   // Handle the vector case. We also handle the (uncommon) case where the vector
332   // length is odd
333   for (int64_t i = 0; i < numElements; i += 2) {
334     int64_t elemsThisOp = std::min(numElements, i + 2) - i;
335     Value thisResult = nullptr;
336     Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
337     Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
338 
339     if (elemsThisOp == 2) {
340       elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
341     }
342 
343     thisResult =
344         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
345     // Place back the truncated result into the possibly larger vector. If we
346     // are operating on a size 2 vector, these operations should be folded away
347     thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
348         loc, thisResult, 0, elemsThisOp, 1);
349     result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
350                                                            result, i, 1);
351   }
352 
353   if (inVectorTy.getRank() != outType.getRank()) {
354     result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
355   }
356 
357   rewriter.replaceOp(op, result);
358 }
359 
360 void mlir::arith::populateArithToAMDGPUConversionPatterns(
361     RewritePatternSet &patterns, bool convertFP8Arithmetic,
362     bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
363 
364   if (convertFP8Arithmetic) {
365     patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
366     patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
367                                                saturateFP8Truncf, chipset);
368   }
369   if (allowPackedF16Rtz)
370     patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
371 }
372 
373 void ArithToAMDGPUConversionPass::runOnOperation() {
374   Operation *op = getOperation();
375   MLIRContext *ctx = &getContext();
376   RewritePatternSet patterns(op->getContext());
377   FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
378   if (failed(maybeChipset)) {
379     emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
380     return signalPassFailure();
381   }
382 
383   bool convertFP8Arithmetic =
384       maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
385   arith::populateArithToAMDGPUConversionPatterns(
386       patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
387       *maybeChipset);
388   if (failed(applyPatternsGreedily(op, std::move(patterns))))
389     return signalPassFailure();
390 }
391