xref: /llvm-project/llvm/lib/Target/X86/X86InstrFMA3Info.cpp (revision ee2722fc882ed5dbc7609686bd998b023c6645b2)
1 //===-- X86InstrFMA3Info.cpp - X86 FMA3 Instruction Information -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the classes providing information
10 // about existing X86 FMA3 opcodes, classifying and grouping them.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "X86InstrFMA3Info.h"
15 #include "X86InstrInfo.h"
16 #include <atomic>
17 #include <cassert>
18 #include <cstdint>
19 
20 using namespace llvm;
21 
22 #define FMA3GROUP(Name, Suf, Attrs) \
23   { { X86::Name##132##Suf, X86::Name##213##Suf, X86::Name##231##Suf }, Attrs },
24 
25 #define FMA3GROUP_MASKED(Name, Suf, Attrs) \
26   FMA3GROUP(Name, Suf, Attrs) \
27   FMA3GROUP(Name, Suf##k, Attrs | X86InstrFMA3Group::KMergeMasked) \
28   FMA3GROUP(Name, Suf##kz, Attrs | X86InstrFMA3Group::KZeroMasked)
29 
30 #define FMA3GROUP_MASKED_INT(Name, Suf, Attrs) \
31   FMA3GROUP(Name, Suf##_Int, Attrs) \
32   FMA3GROUP(Name, Suf##k_Int, Attrs | X86InstrFMA3Group::KMergeMasked) \
33   FMA3GROUP(Name, Suf##kz_Int, Attrs | X86InstrFMA3Group::KZeroMasked)
34 
35 #define FMA3GROUP_PACKED_WIDTHS_Z(Name, Suf, Attrs) \
36   FMA3GROUP_MASKED(Name, Suf##Z128m, Attrs) \
37   FMA3GROUP_MASKED(Name, Suf##Z128r, Attrs) \
38   FMA3GROUP_MASKED(Name, Suf##Z256m, Attrs) \
39   FMA3GROUP_MASKED(Name, Suf##Z256r, Attrs) \
40   FMA3GROUP_MASKED(Name, Suf##Zm, Attrs) \
41   FMA3GROUP_MASKED(Name, Suf##Zr, Attrs) \
42 
43 #define FMA3GROUP_PACKED_WIDTHS_ALL(Name, Suf, Attrs) \
44   FMA3GROUP(Name, Suf##Ym, Attrs) \
45   FMA3GROUP(Name, Suf##Yr, Attrs) \
46   FMA3GROUP_PACKED_WIDTHS_Z(Name, Suf, Attrs) \
47   FMA3GROUP(Name, Suf##m, Attrs) \
48   FMA3GROUP(Name, Suf##r, Attrs)
49 
50 #define FMA3GROUP_PACKED_DHS(Name, Attrs) \
51   FMA3GROUP_PACKED_WIDTHS_ALL(Name, PD, Attrs) \
52   FMA3GROUP_PACKED_WIDTHS_Z(Name, PH, Attrs) \
53   FMA3GROUP_PACKED_WIDTHS_ALL(Name, PS, Attrs)
54 
55 #define FMA3GROUP_PACKED_BF16(Name, Attrs)                                     \
56   FMA3GROUP_PACKED_WIDTHS_Z(Name, BF16, Attrs)
57 
58 #define FMA3GROUP_SCALAR_WIDTHS_Z(Name, Suf, Attrs) \
59   FMA3GROUP(Name, Suf##Zm, Attrs) \
60   FMA3GROUP_MASKED_INT(Name, Suf##Zm, Attrs | X86InstrFMA3Group::Intrinsic) \
61   FMA3GROUP(Name, Suf##Zr, Attrs) \
62   FMA3GROUP_MASKED_INT(Name, Suf##Zr, Attrs | X86InstrFMA3Group::Intrinsic) \
63 
64 #define FMA3GROUP_SCALAR_WIDTHS_ALL(Name, Suf, Attrs) \
65   FMA3GROUP_SCALAR_WIDTHS_Z(Name, Suf, Attrs) \
66   FMA3GROUP(Name, Suf##m, Attrs) \
67   FMA3GROUP(Name, Suf##m_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
68   FMA3GROUP(Name, Suf##r, Attrs) \
69   FMA3GROUP(Name, Suf##r_Int, Attrs | X86InstrFMA3Group::Intrinsic)
70 
71 #define FMA3GROUP_SCALAR(Name, Attrs) \
72   FMA3GROUP_SCALAR_WIDTHS_ALL(Name, SD, Attrs) \
73   FMA3GROUP_SCALAR_WIDTHS_Z(Name, SH, Attrs) \
74   FMA3GROUP_SCALAR_WIDTHS_ALL(Name, SS, Attrs)
75 
76 #define FMA3GROUP_FULL(Name, Attrs) \
77   FMA3GROUP_PACKED_BF16(Name, Attrs) \
78   FMA3GROUP_PACKED_DHS(Name, Attrs) \
79   FMA3GROUP_SCALAR(Name, Attrs)
80 
81 static const X86InstrFMA3Group Groups[] = {
82   FMA3GROUP_FULL(VFMADD, 0)
83   FMA3GROUP_PACKED_DHS(VFMADDSUB, 0)
84   FMA3GROUP_FULL(VFMSUB, 0)
85   FMA3GROUP_PACKED_DHS(VFMSUBADD, 0)
86   FMA3GROUP_FULL(VFNMADD, 0)
87   FMA3GROUP_FULL(VFNMSUB, 0)
88 };
89 
90 #define FMA3GROUP_PACKED_AVX512_WIDTHS(Name, Type, Suf, Attrs) \
91   FMA3GROUP_MASKED(Name, Type##Z128##Suf, Attrs) \
92   FMA3GROUP_MASKED(Name, Type##Z256##Suf, Attrs) \
93   FMA3GROUP_MASKED(Name, Type##Z##Suf, Attrs)
94 
95 #define FMA3GROUP_PACKED_AVX512_ALL(Name, Suf, Attrs)                          \
96   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, BF16, Suf, Attrs)                       \
97   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PD, Suf, Attrs)                         \
98   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PH, Suf, Attrs)                         \
99   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PS, Suf, Attrs)
100 
101 #define FMA3GROUP_PACKED_AVX512_DHS(Name, Suf, Attrs) \
102   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PD, Suf, Attrs) \
103   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PH, Suf, Attrs) \
104   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PS, Suf, Attrs)
105 
106 #define FMA3GROUP_PACKED_AVX512_ROUND(Name, Suf, Attrs)                        \
107   FMA3GROUP_MASKED(Name, PDZ256##Suf, Attrs)                                   \
108   FMA3GROUP_MASKED(Name, PDZ##Suf, Attrs)                                      \
109   FMA3GROUP_MASKED(Name, PHZ256##Suf, Attrs)                                   \
110   FMA3GROUP_MASKED(Name, PHZ##Suf, Attrs)                                      \
111   FMA3GROUP_MASKED(Name, PSZ256##Suf, Attrs)                                   \
112   FMA3GROUP_MASKED(Name, PSZ##Suf, Attrs)
113 
114 #define FMA3GROUP_SCALAR_AVX512_ROUND(Name, Suf, Attrs) \
115   FMA3GROUP(Name, SDZ##Suf, Attrs) \
116   FMA3GROUP_MASKED_INT(Name, SDZ##Suf, Attrs) \
117   FMA3GROUP(Name, SHZ##Suf, Attrs) \
118   FMA3GROUP_MASKED_INT(Name, SHZ##Suf, Attrs) \
119   FMA3GROUP(Name, SSZ##Suf, Attrs) \
120   FMA3GROUP_MASKED_INT(Name, SSZ##Suf, Attrs)
121 
122 static const X86InstrFMA3Group BroadcastGroups[] = {
123   FMA3GROUP_PACKED_AVX512_ALL(VFMADD, mb, 0)
124   FMA3GROUP_PACKED_AVX512_DHS(VFMADDSUB, mb, 0)
125   FMA3GROUP_PACKED_AVX512_ALL(VFMSUB, mb, 0)
126   FMA3GROUP_PACKED_AVX512_DHS(VFMSUBADD, mb, 0)
127   FMA3GROUP_PACKED_AVX512_ALL(VFNMADD, mb, 0)
128   FMA3GROUP_PACKED_AVX512_ALL(VFNMSUB, mb, 0)
129 };
130 
131 static const X86InstrFMA3Group RoundGroups[] = {
132   FMA3GROUP_PACKED_AVX512_ROUND(VFMADD, rb, 0)
133   FMA3GROUP_SCALAR_AVX512_ROUND(VFMADD, rb, X86InstrFMA3Group::Intrinsic)
134   FMA3GROUP_PACKED_AVX512_ROUND(VFMADDSUB, rb, 0)
135   FMA3GROUP_PACKED_AVX512_ROUND(VFMSUB, rb, 0)
136   FMA3GROUP_SCALAR_AVX512_ROUND(VFMSUB, rb, X86InstrFMA3Group::Intrinsic)
137   FMA3GROUP_PACKED_AVX512_ROUND(VFMSUBADD, rb, 0)
138   FMA3GROUP_PACKED_AVX512_ROUND(VFNMADD, rb, 0)
139   FMA3GROUP_SCALAR_AVX512_ROUND(VFNMADD, rb, X86InstrFMA3Group::Intrinsic)
140   FMA3GROUP_PACKED_AVX512_ROUND(VFNMSUB, rb, 0)
141   FMA3GROUP_SCALAR_AVX512_ROUND(VFNMSUB, rb, X86InstrFMA3Group::Intrinsic)
142 };
143 
144 static void verifyTables() {
145 #ifndef NDEBUG
146   static std::atomic<bool> TableChecked(false);
147   if (!TableChecked.load(std::memory_order_relaxed)) {
148     assert(llvm::is_sorted(Groups) && llvm::is_sorted(RoundGroups) &&
149            llvm::is_sorted(BroadcastGroups) && "FMA3 tables not sorted!");
150     TableChecked.store(true, std::memory_order_relaxed);
151   }
152 #endif
153 }
154 
155 /// Returns a reference to a group of FMA3 opcodes to where the given
156 /// \p Opcode is included. If the given \p Opcode is not recognized as FMA3
157 /// and not included into any FMA3 group, then nullptr is returned.
158 const X86InstrFMA3Group *llvm::getFMA3Group(unsigned Opcode, uint64_t TSFlags) {
159 
160   // FMA3 instructions have a well defined encoding pattern we can exploit.
161   uint8_t BaseOpcode = X86II::getBaseOpcodeFor(TSFlags);
162   bool IsFMA3Opcode = ((BaseOpcode >= 0x96 && BaseOpcode <= 0x9F) ||
163                        (BaseOpcode >= 0xA6 && BaseOpcode <= 0xAF) ||
164                        (BaseOpcode >= 0xB6 && BaseOpcode <= 0xBF));
165   bool IsFMA3Encoding = ((TSFlags & X86II::EncodingMask) == X86II::VEX &&
166                          (TSFlags & X86II::OpMapMask) == X86II::T8) ||
167                         ((TSFlags & X86II::EncodingMask) == X86II::EVEX &&
168                          ((TSFlags & X86II::OpMapMask) == X86II::T8 ||
169                           (TSFlags & X86II::OpMapMask) == X86II::T_MAP6));
170   bool IsFMA3Prefix = (TSFlags & X86II::OpPrefixMask) == X86II::PD ||
171                       (TSFlags & X86II::OpPrefixMask) == 0; // X86II::PS
172   if (!IsFMA3Opcode || !IsFMA3Encoding || !IsFMA3Prefix)
173     return nullptr;
174 
175   verifyTables();
176 
177   ArrayRef<X86InstrFMA3Group> Table;
178   if (TSFlags & X86II::EVEX_RC)
179     Table = ArrayRef(RoundGroups);
180   else if (TSFlags & X86II::EVEX_B)
181     Table = ArrayRef(BroadcastGroups);
182   else
183     Table = ArrayRef(Groups);
184 
185   // FMA 132 instructions have an opcode of 0x96-0x9F
186   // FMA 213 instructions have an opcode of 0xA6-0xAF
187   // FMA 231 instructions have an opcode of 0xB6-0xBF
188   unsigned FormIndex = ((BaseOpcode - 0x90) >> 4) & 0x3;
189 
190   auto I = partition_point(Table, [=](const X86InstrFMA3Group &Group) {
191     return Group.Opcodes[FormIndex] < Opcode;
192   });
193   assert(I != Table.end() && I->Opcodes[FormIndex] == Opcode &&
194          "Couldn't find FMA3 opcode!");
195   return I;
196 }
197