xref: /llvm-project/flang/include/flang/Optimizer/Builder/PPCIntrinsicCall.h (revision e6a4346b5a105c2f28349270c3a82935c9a84d16)
1 //==-- Builder/PPCIntrinsicCall.h - lowering of PowerPC intrinsics -*-C++-*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_LOWER_PPCINTRINSICCALL_H
10 #define FORTRAN_LOWER_PPCINTRINSICCALL_H
11 
12 #include "flang/Common/static-multimap-view.h"
13 #include "flang/Optimizer/Builder/IntrinsicCall.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 
16 namespace fir {
17 
18 /// Enums used to templatize vector intrinsic function generators. Enum does
19 /// not contain every vector intrinsic, only intrinsics that share generators.
20 enum class VecOp {
21   Abs,
22   Add,
23   And,
24   Anyge,
25   Cmpge,
26   Cmpgt,
27   Cmple,
28   Cmplt,
29   Convert,
30   Ctf,
31   Cvf,
32   Ld,
33   Lde,
34   Ldl,
35   Lvsl,
36   Lvsr,
37   Lxv,
38   Lxvp,
39   Mergeh,
40   Mergel,
41   Msub,
42   Mul,
43   Nmadd,
44   Perm,
45   Permi,
46   Sel,
47   Sl,
48   Sld,
49   Sldw,
50   Sll,
51   Slo,
52   Splat,
53   Splat_s32,
54   Splats,
55   Sr,
56   Srl,
57   Sro,
58   St,
59   Ste,
60   Stxv,
61   Stxvp,
62   Sub,
63   Xl,
64   Xlbe,
65   Xld2,
66   Xlw4,
67   Xor,
68   Xst,
69   Xst_be,
70   Xstd2,
71   Xstw4
72 };
73 
74 /// Enums used to templatize and share lowering of PowerPC MMA intrinsics.
75 enum class MMAOp {
76   AssembleAcc,
77   AssemblePair,
78   DisassembleAcc,
79   DisassemblePair,
80   Xxmfacc,
81   Xxmtacc,
82   Xxsetaccz,
83   Pmxvbf16ger2,
84   Pmxvbf16ger2nn,
85   Pmxvbf16ger2np,
86   Pmxvbf16ger2pn,
87   Pmxvbf16ger2pp,
88   Pmxvf16ger2,
89   Pmxvf16ger2nn,
90   Pmxvf16ger2np,
91   Pmxvf16ger2pn,
92   Pmxvf16ger2pp,
93 
94   Pmxvf32ger,
95   Pmxvf32gernn,
96   Pmxvf32gernp,
97   Pmxvf32gerpn,
98   Pmxvf32gerpp,
99   Pmxvf64ger,
100   Pmxvf64gernn,
101   Pmxvf64gernp,
102   Pmxvf64gerpn,
103   Pmxvf64gerpp,
104 
105   Pmxvi16ger2,
106   Pmxvi16ger2pp,
107   Pmxvi16ger2s,
108   Pmxvi16ger2spp,
109   Pmxvi4ger8,
110   Pmxvi4ger8pp,
111   Pmxvi8ger4,
112   Pmxvi8ger4pp,
113   Pmxvi8ger4spp,
114 
115   Xvbf16ger2,
116   Xvbf16ger2nn,
117   Xvbf16ger2np,
118   Xvbf16ger2pn,
119   Xvbf16ger2pp,
120   Xvf16ger2,
121   Xvf16ger2nn,
122   Xvf16ger2np,
123   Xvf16ger2pn,
124   Xvf16ger2pp,
125   Xvf32ger,
126   Xvf32gernn,
127   Xvf32gernp,
128   Xvf32gerpn,
129   Xvf32gerpp,
130   Xvf64ger,
131   Xvf64gernn,
132   Xvf64gernp,
133   Xvf64gerpn,
134   Xvf64gerpp,
135   Xvi16ger2,
136   Xvi16ger2pp,
137   Xvi16ger2s,
138   Xvi16ger2spp,
139   Xvi4ger8,
140   Xvi4ger8pp,
141   Xvi8ger4,
142   Xvi8ger4pp,
143   Xvi8ger4spp,
144 };
145 
146 enum class MMAHandlerOp {
147   NoOp,
148   SubToFunc,
149   SubToFuncReverseArgOnLE,
150   FirstArgIsResult,
151 };
152 
153 // Wrapper struct to encapsulate information for a vector type. Preserves
154 // sign of eleTy if eleTy is signed/unsigned integer. Helps with vector type
155 // conversions.
156 struct VecTypeInfo {
157   mlir::Type eleTy;
158   uint64_t len;
159 
160   mlir::Type toFirVectorType() { return fir::VectorType::get(len, eleTy); }
161 
162   // We need a builder to do the signless element conversion.
163   mlir::Type toMlirVectorType(mlir::MLIRContext *context) {
164     // Will convert to eleTy to signless int if eleTy is signed/unsigned int.
165     auto convEleTy{getConvertedElementType(context, eleTy)};
166     return mlir::VectorType::get(len, convEleTy);
167   }
168 
169   bool isFloat32() { return mlir::isa<mlir::Float32Type>(eleTy); }
170 
171   bool isFloat64() { return mlir::isa<mlir::Float64Type>(eleTy); }
172 
173   bool isFloat() { return isFloat32() || isFloat64(); }
174 };
175 
176 //===----------------------------------------------------------------------===//
177 // Helper functions for argument handling in vector intrinsics.
178 //===----------------------------------------------------------------------===//
179 
180 // Returns a VecTypeInfo with element type and length of given fir vector type.
181 // Preserves signness of fir vector type if element type of integer.
182 static inline VecTypeInfo getVecTypeFromFirType(mlir::Type firTy) {
183   assert(mlir::isa<fir::VectorType>(firTy));
184   VecTypeInfo vecTyInfo;
185   vecTyInfo.eleTy = mlir::dyn_cast<fir::VectorType>(firTy).getElementType();
186   vecTyInfo.len = mlir::dyn_cast<fir::VectorType>(firTy).getLen();
187   return vecTyInfo;
188 }
189 
190 static inline VecTypeInfo getVecTypeFromFir(mlir::Value firVec) {
191   return getVecTypeFromFirType(firVec.getType());
192 }
193 
194 // Calculates the vector length and returns a VecTypeInfo with element type and
195 // length.
196 static inline VecTypeInfo getVecTypeFromEle(mlir::Value ele) {
197   VecTypeInfo vecTyInfo;
198   vecTyInfo.eleTy = ele.getType();
199   vecTyInfo.len = 16 / (vecTyInfo.eleTy.getIntOrFloatBitWidth() / 8);
200   return vecTyInfo;
201 }
202 
203 // Converts array of fir vectors to mlir vectors.
204 static inline llvm::SmallVector<mlir::Value, 4>
205 convertVecArgs(fir::FirOpBuilder &builder, mlir::Location loc,
206                VecTypeInfo vecTyInfo, llvm::SmallVector<mlir::Value, 4> args) {
207   llvm::SmallVector<mlir::Value, 4> newArgs;
208   auto ty{vecTyInfo.toMlirVectorType(builder.getContext())};
209   assert(ty && "unknown mlir vector type");
210   for (size_t i = 0; i < args.size(); i++)
211     newArgs.push_back(builder.createConvert(loc, ty, args[i]));
212   return newArgs;
213 }
214 
215 // This overload method is used only if arguments are of different types.
216 static inline llvm::SmallVector<mlir::Value, 4>
217 convertVecArgs(fir::FirOpBuilder &builder, mlir::Location loc,
218                llvm::SmallVectorImpl<VecTypeInfo> &vecTyInfo,
219                llvm::SmallVector<mlir::Value, 4> args) {
220   llvm::SmallVector<mlir::Value, 4> newArgs;
221   for (size_t i = 0; i < args.size(); i++) {
222     mlir::Type ty{vecTyInfo[i].toMlirVectorType(builder.getContext())};
223     assert(ty && "unknown mlir vector type");
224     newArgs.push_back(builder.createConvert(loc, ty, args[i]));
225   }
226   return newArgs;
227 }
228 
229 struct PPCIntrinsicLibrary : IntrinsicLibrary {
230 
231   // Constructors.
232   explicit PPCIntrinsicLibrary(fir::FirOpBuilder &builder, mlir::Location loc)
233       : IntrinsicLibrary(builder, loc) {}
234   PPCIntrinsicLibrary() = delete;
235   PPCIntrinsicLibrary(const PPCIntrinsicLibrary &) = delete;
236 
237   // Helper functions for vector element ordering.
238   bool isBEVecElemOrderOnLE();
239   bool isNativeVecElemOrderOnLE();
240   bool changeVecElemOrder();
241 
242   // PPC MMA intrinsic generic handler
243   template <MMAOp IntrId, MMAHandlerOp HandlerOp>
244   void genMmaIntr(llvm::ArrayRef<fir::ExtendedValue>);
245 
246   // PPC intrinsic handlers.
247   template <bool isImm>
248   void genMtfsf(llvm::ArrayRef<fir::ExtendedValue>);
249 
250   fir::ExtendedValue genVecAbs(mlir::Type resultType,
251                                llvm::ArrayRef<fir::ExtendedValue> args);
252   template <VecOp>
253   fir::ExtendedValue
254   genVecAddAndMulSubXor(mlir::Type resultType,
255                         llvm::ArrayRef<fir::ExtendedValue> args);
256 
257   template <VecOp>
258   fir::ExtendedValue genVecCmp(mlir::Type resultType,
259                                llvm::ArrayRef<fir::ExtendedValue> args);
260 
261   template <VecOp>
262   fir::ExtendedValue genVecConvert(mlir::Type resultType,
263                                    llvm::ArrayRef<fir::ExtendedValue> args);
264 
265   template <VecOp>
266   fir::ExtendedValue genVecAnyCompare(mlir::Type resultType,
267                                       llvm::ArrayRef<fir::ExtendedValue> args);
268 
269   fir::ExtendedValue genVecExtract(mlir::Type resultType,
270                                    llvm::ArrayRef<fir::ExtendedValue> args);
271 
272   fir::ExtendedValue genVecInsert(mlir::Type resultType,
273                                   llvm::ArrayRef<fir::ExtendedValue> args);
274 
275   template <VecOp>
276   fir::ExtendedValue genVecMerge(mlir::Type resultType,
277                                  llvm::ArrayRef<fir::ExtendedValue> args);
278 
279   template <VecOp>
280   fir::ExtendedValue genVecPerm(mlir::Type resultType,
281                                 llvm::ArrayRef<fir::ExtendedValue> args);
282 
283   fir::ExtendedValue genVecXlGrp(mlir::Type resultType,
284                                  llvm::ArrayRef<fir::ExtendedValue> args);
285 
286   template <VecOp>
287   fir::ExtendedValue genVecLdCallGrp(mlir::Type resultType,
288                                      llvm::ArrayRef<fir::ExtendedValue> args);
289 
290   template <VecOp>
291   fir::ExtendedValue genVecLdNoCallGrp(mlir::Type resultType,
292                                        llvm::ArrayRef<fir::ExtendedValue> args);
293 
294   template <VecOp>
295   fir::ExtendedValue genVecLvsGrp(mlir::Type resultType,
296                                   llvm::ArrayRef<fir::ExtendedValue> args);
297 
298   template <VecOp>
299   fir::ExtendedValue genVecNmaddMsub(mlir::Type resultType,
300                                      llvm::ArrayRef<fir::ExtendedValue> args);
301 
302   template <VecOp>
303   fir::ExtendedValue genVecShift(mlir::Type,
304                                  llvm::ArrayRef<fir::ExtendedValue>);
305 
306   fir::ExtendedValue genVecSel(mlir::Type resultType,
307                                llvm::ArrayRef<fir::ExtendedValue> args);
308 
309   template <VecOp>
310   void genVecStore(llvm::ArrayRef<fir::ExtendedValue>);
311 
312   template <VecOp>
313   void genVecXStore(llvm::ArrayRef<fir::ExtendedValue>);
314 
315   template <VecOp vop>
316   fir::ExtendedValue genVecSplat(mlir::Type resultType,
317                                  llvm::ArrayRef<fir::ExtendedValue> args);
318 
319   fir::ExtendedValue genVecXlds(mlir::Type resultType,
320                                 llvm::ArrayRef<fir::ExtendedValue> args);
321 };
322 
323 const IntrinsicHandler *findPPCIntrinsicHandler(llvm::StringRef name);
324 
325 std::pair<const MathOperation *, const MathOperation *>
326 checkPPCMathOperationsRange(llvm::StringRef name);
327 
328 } // namespace fir
329 
330 #endif // FORTRAN_LOWER_PPCINTRINSICCALL_H
331