xref: /llvm-project/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir (revision 1387ba48a312b6e9b174d850f8c9a1322f44c623)
1// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
2
3// CHECK-LABEL: func.func @scalar_ext
4// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
5// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32
6// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
7// CHECK: return [[W]]
8func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
9  %w = arith.extf %v : f8E5M2FNUZ to f16
10  return %w : f16
11}
12
13// No 0-D test because arith.extf hasn't been extended to support it.
14
15// -----
16
17// CHECK-LABEL: func.func @vector_ext_short
18// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
19// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
20// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
21// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
22// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
23// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
24// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
25// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
26// CHECK: return [[W1]] : vector<2xf64>
27
28func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
29  %w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64>
30  return %w : vector<2xf64>
31}
32
33// -----
34
35// CHECK-LABEL: func.func @vector_ext_long
36// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
37// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
38// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
39// CHECK: [[W0:%.+]] = vector.insert [[F0]]
40// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
41// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
42// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
43// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
44// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
45// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
46
47// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
48// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
49// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
50// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
51// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
52// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
53// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
54// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
55// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
56
57// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
58// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
59// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
60// CHECK: return [[W8]]
61func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
62  %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
63  return %w : vector<9xf32>
64}
65
66// -----
67
68// CHECK-LABEL: func.func @scalar_trunc
69// CHECK-SAME: ([[V:%.+]]: f16)
70// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
71// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
72// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ>
73// CHECK: return [[W]] : f8E5M2FNUZ
74func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
75  %w = arith.truncf %v : f16 to f8E5M2FNUZ
76  return %w : f8E5M2FNUZ
77}
78
79// No 0-D test because arith.truncf hasn't been extended to support it.
80
81// -----
82
83// CHECK-LABEL: func.func @vector_trunc_short
84// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
85// CHECK: [[V0:%.+]] = vector.extract [[V]][0]
86// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
87// CHECK: [[V1:%.+]] = vector.extract [[V]][1]
88// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
89// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
90// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ>
91// CHECK: return [[W]] : vector<2xf8E5M2FNUZ>
92func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
93  %w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2FNUZ>
94  return %w : vector<2xf8E5M2FNUZ>
95}
96
97// -----
98
99// CHECK-LABEL: func.func @vector_trunc_long
100// CHECK-SAME: ([[V:%.+]]: vector<9xf32>)
101// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
102// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
103// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
104// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
105
106// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
107// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
108// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
109
110// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
111// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
112// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
113// CHECK: return [[W]]
114func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
115  %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
116  return %w : vector<9xf8E4M3FNUZ>
117}
118
119// -----
120
121// CHECK-LABEL: func.func @vector_trunc_long_2d
122// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
123// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
124// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
125// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
126// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
127
128// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
129// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
130// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
131
132// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
133// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
134// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
135// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FNUZ> to vector<1x9xf8E4M3FNUZ>
136// CHECK: return [[RE]]
137func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
138  %w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FNUZ>
139  return %w : vector<1x9xf8E4M3FNUZ>
140}
141
142// -----
143
144// CHECK-LABEL: func.func @vector_ext_long_2d
145// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
146// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
147// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
148// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
149// CHECK: [[W0:%.+]] = vector.insert [[F0]]
150// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
151// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
152// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
153// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
154// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
155// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
156
157// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
158// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
159// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
160// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
161// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
162// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
163// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
164// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
165// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
166
167// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
168// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
169// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
170// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
171// CHECK: return [[CAST]]
172func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
173  %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
174  return %w : vector<1x9xf32>
175}
176