xref: /llvm-project/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir (revision 7042fcc6389c6c103d501b6f39988eafed0d9b5b)
1// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s
2
3func.func @basic_expansion(%x: bf16) -> bf16 {
4// CHECK-LABEL: @basic_expansion
5// CHECK-SAME: [[X:%.+]]: bf16
6// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
7// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
8// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
9// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
10// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
11// CHECK: return [[Y]]
12  %c = arith.constant 1.0 : bf16
13  %y = arith.addf %x, %c : bf16
14  func.return %y : bf16
15}
16
17// -----
18
19func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
20// CHECK-LABEL: @chained
21// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
22// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
23// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
24// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
25// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
26// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath<contract> : f32 to bf16
27// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath<contract> : bf16 to f32
28// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
29// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath<contract> : f32 to bf16
30// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath<contract> : bf16 to f32
31// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
32// CHECK: return [[RES]]
33  %p = arith.addf %x, %y : bf16
34  %q = arith.mulf %p, %z : bf16
35  %res = arith.cmpf ole, %p, %q : bf16
36  func.return %res : i1
37}
38
39// -----
40
41func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
42// CHECK-LABEL: @memops
43// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
44// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath<contract> : f8E4M3FNUZ to f32
45// CHECK: memref.store [[V]]
46// CHECK: [[W:%.+]] = memref.load
47// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath<contract> : f8E4M3FNUZ to f32
48// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
49// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath<contract> : f32 to f8E4M3FNUZ
50// CHECK: memref.store [[X]]
51  %c0 = arith.constant 0 : index
52  %c1 = arith.constant 1 : index
53  %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ>
54  memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ>
55  %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ>
56  %x = arith.addf %v, %w : f8E4M3FNUZ
57  memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ>
58  func.return
59}
60
61// -----
62
63func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
64// CHECK-LABEL: @vectors
65// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
66// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
67// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
68// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
69// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
70// CHECK: return [[RET]]
71  %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
72  %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
73  func.return %ret : vector<4xf32>
74}
75
76// -----
77
78func.func @no_expansion(%x: f32) -> f32 {
79// CHECK-LABEL: @no_expansion
80// CHECK-SAME: [[X:%.+]]: f32
81// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32
82// CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32
83// CHECK: return [[Y]]
84  %c = arith.constant 1.0 : f32
85  %y = arith.addf %x, %c : f32
86  func.return %y : f32
87}
88