xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (revision 3ed2a81358e11a582eb5cc3edf711447767036e6)
1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26 
27 // clang-format off
28 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
29     TargetOpcode::G_ADD,
30     TargetOpcode::G_FADD,
31     TargetOpcode::G_STRICT_FADD,
32     TargetOpcode::G_SUB,
33     TargetOpcode::G_FSUB,
34     TargetOpcode::G_STRICT_FSUB,
35     TargetOpcode::G_MUL,
36     TargetOpcode::G_FMUL,
37     TargetOpcode::G_STRICT_FMUL,
38     TargetOpcode::G_SDIV,
39     TargetOpcode::G_UDIV,
40     TargetOpcode::G_FDIV,
41     TargetOpcode::G_STRICT_FDIV,
42     TargetOpcode::G_SREM,
43     TargetOpcode::G_UREM,
44     TargetOpcode::G_FREM,
45     TargetOpcode::G_STRICT_FREM,
46     TargetOpcode::G_FNEG,
47     TargetOpcode::G_CONSTANT,
48     TargetOpcode::G_FCONSTANT,
49     TargetOpcode::G_AND,
50     TargetOpcode::G_OR,
51     TargetOpcode::G_XOR,
52     TargetOpcode::G_SHL,
53     TargetOpcode::G_ASHR,
54     TargetOpcode::G_LSHR,
55     TargetOpcode::G_SELECT,
56     TargetOpcode::G_EXTRACT_VECTOR_ELT,
57 };
58 // clang-format on
59 
60 bool isTypeFoldingSupported(unsigned Opcode) {
61   return TypeFoldingSupportingOpcs.count(Opcode) > 0;
62 }
63 
64 LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
65   return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
66     const LLT Ty = Query.Types[TypeIdx];
67     return IsExtendedInts && Ty.isValid() && Ty.isScalar();
68   };
69 }
70 
71 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
72   using namespace TargetOpcode;
73 
74   this->ST = &ST;
75   GR = ST.getSPIRVGlobalRegistry();
76 
77   const LLT s1 = LLT::scalar(1);
78   const LLT s8 = LLT::scalar(8);
79   const LLT s16 = LLT::scalar(16);
80   const LLT s32 = LLT::scalar(32);
81   const LLT s64 = LLT::scalar(64);
82 
83   const LLT v16s64 = LLT::fixed_vector(16, 64);
84   const LLT v16s32 = LLT::fixed_vector(16, 32);
85   const LLT v16s16 = LLT::fixed_vector(16, 16);
86   const LLT v16s8 = LLT::fixed_vector(16, 8);
87   const LLT v16s1 = LLT::fixed_vector(16, 1);
88 
89   const LLT v8s64 = LLT::fixed_vector(8, 64);
90   const LLT v8s32 = LLT::fixed_vector(8, 32);
91   const LLT v8s16 = LLT::fixed_vector(8, 16);
92   const LLT v8s8 = LLT::fixed_vector(8, 8);
93   const LLT v8s1 = LLT::fixed_vector(8, 1);
94 
95   const LLT v4s64 = LLT::fixed_vector(4, 64);
96   const LLT v4s32 = LLT::fixed_vector(4, 32);
97   const LLT v4s16 = LLT::fixed_vector(4, 16);
98   const LLT v4s8 = LLT::fixed_vector(4, 8);
99   const LLT v4s1 = LLT::fixed_vector(4, 1);
100 
101   const LLT v3s64 = LLT::fixed_vector(3, 64);
102   const LLT v3s32 = LLT::fixed_vector(3, 32);
103   const LLT v3s16 = LLT::fixed_vector(3, 16);
104   const LLT v3s8 = LLT::fixed_vector(3, 8);
105   const LLT v3s1 = LLT::fixed_vector(3, 1);
106 
107   const LLT v2s64 = LLT::fixed_vector(2, 64);
108   const LLT v2s32 = LLT::fixed_vector(2, 32);
109   const LLT v2s16 = LLT::fixed_vector(2, 16);
110   const LLT v2s8 = LLT::fixed_vector(2, 8);
111   const LLT v2s1 = LLT::fixed_vector(2, 1);
112 
113   const unsigned PSize = ST.getPointerSize();
114   const LLT p0 = LLT::pointer(0, PSize); // Function
115   const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
116   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
117   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
118   const LLT p4 = LLT::pointer(4, PSize); // Generic
119   const LLT p5 =
120       LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
121   const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
122   const LLT p7 = LLT::pointer(7, PSize); // Input
123   const LLT p8 = LLT::pointer(8, PSize); // Output
124   const LLT p10 = LLT::pointer(10, PSize); // Private
125 
126   // TODO: remove copy-pasting here by using concatenation in some way.
127   auto allPtrsScalarsAndVectors = {
128       p0,   p1,   p2,    p3,    p4,    p5,    p6,    p7,     p8,     p10,
129       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
130       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
131       v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
132 
133   auto allVectors = {v2s1,  v2s8,   v2s16,  v2s32, v2s64, v3s1,  v3s8,
134                      v3s16, v3s32,  v3s64,  v4s1,  v4s8,  v4s16, v4s32,
135                      v4s64, v8s1,   v8s8,   v8s16, v8s32, v8s64, v16s1,
136                      v16s8, v16s16, v16s32, v16s64};
137 
138   auto allScalarsAndVectors = {
139       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
140       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
141       v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
142 
143   auto allIntScalarsAndVectors = {s8,    s16,   s32,   s64,    v2s8,   v2s16,
144                                   v2s32, v2s64, v3s8,  v3s16,  v3s32,  v3s64,
145                                   v4s8,  v4s16, v4s32, v4s64,  v8s8,   v8s16,
146                                   v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
147 
148   auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
149 
150   auto allIntScalars = {s8, s16, s32, s64};
151 
152   auto allFloatScalars = {s16, s32, s64};
153 
154   auto allFloatScalarsAndVectors = {
155       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
156       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
157 
158   auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2,
159                                        p3, p4,  p5,  p6,  p7, p8, p10};
160 
161   auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10};
162 
163   bool IsExtendedInts =
164       ST.canUseExtension(
165           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
166       ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
167   auto extendedScalarsAndVectors =
168       [IsExtendedInts](const LegalityQuery &Query) {
169         const LLT Ty = Query.Types[0];
170         return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
171       };
172   auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
173                                               const LegalityQuery &Query) {
174     const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
175     return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
176            !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
177   };
178   auto extendedPtrsScalarsAndVectors =
179       [IsExtendedInts](const LegalityQuery &Query) {
180         const LLT Ty = Query.Types[0];
181         return IsExtendedInts && Ty.isValid();
182       };
183 
184   for (auto Opc : TypeFoldingSupportingOpcs)
185     getActionDefinitionsBuilder(Opc).custom();
186 
187   getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
188 
189   // TODO: add proper rules for vectors legalization.
190   getActionDefinitionsBuilder(
191       {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
192       .alwaysLegal();
193 
194   // Vector Reduction Operations
195   getActionDefinitionsBuilder(
196       {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
197        G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
198        G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
199        G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
200       .legalFor(allVectors)
201       .scalarize(1)
202       .lower();
203 
204   getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
205       .scalarize(2)
206       .lower();
207 
208   // Merge/Unmerge
209   // TODO: add proper legalization rules.
210   getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
211 
212   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
213       .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
214 
215   getActionDefinitionsBuilder(G_MEMSET).legalIf(
216       all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
217 
218   getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
219       .legalForCartesianProduct(allPtrs, allPtrs);
220 
221   getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
222 
223   getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
224                                G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
225                                G_USUBSAT, G_SCMP, G_UCMP})
226       .legalFor(allIntScalarsAndVectors)
227       .legalIf(extendedScalarsAndVectors);
228 
229   getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
230       .legalFor(allFloatScalarsAndVectors);
231 
232   getActionDefinitionsBuilder(G_STRICT_FLDEXP)
233       .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);
234 
235   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
236       .legalForCartesianProduct(allIntScalarsAndVectors,
237                                 allFloatScalarsAndVectors);
238 
239   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
240       .legalForCartesianProduct(allFloatScalarsAndVectors,
241                                 allScalarsAndVectors);
242 
243   getActionDefinitionsBuilder(G_CTPOP)
244       .legalForCartesianProduct(allIntScalarsAndVectors)
245       .legalIf(extendedScalarsAndVectorsProduct);
246 
247   // Extensions.
248   getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
249       .legalForCartesianProduct(allScalarsAndVectors)
250       .legalIf(extendedScalarsAndVectorsProduct);
251 
252   getActionDefinitionsBuilder(G_PHI)
253       .legalFor(allPtrsScalarsAndVectors)
254       .legalIf(extendedPtrsScalarsAndVectors);
255 
256   getActionDefinitionsBuilder(G_BITCAST).legalIf(
257       all(typeInSet(0, allPtrsScalarsAndVectors),
258           typeInSet(1, allPtrsScalarsAndVectors)));
259 
260   getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
261 
262   getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
263 
264   getActionDefinitionsBuilder(G_INTTOPTR)
265       .legalForCartesianProduct(allPtrs, allIntScalars)
266       .legalIf(
267           all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
268   getActionDefinitionsBuilder(G_PTRTOINT)
269       .legalForCartesianProduct(allIntScalars, allPtrs)
270       .legalIf(
271           all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
272   getActionDefinitionsBuilder(G_PTR_ADD)
273       .legalForCartesianProduct(allPtrs, allIntScalars)
274       .legalIf(
275           all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
276 
277   // ST.canDirectlyComparePointers() for pointer args is supported in
278   // legalizeCustom().
279   getActionDefinitionsBuilder(G_ICMP).customIf(
280       all(typeInSet(0, allBoolScalarsAndVectors),
281           typeInSet(1, allPtrsScalarsAndVectors)));
282 
283   getActionDefinitionsBuilder(G_FCMP).legalIf(
284       all(typeInSet(0, allBoolScalarsAndVectors),
285           typeInSet(1, allFloatScalarsAndVectors)));
286 
287   getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
288                                G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
289                                G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
290                                G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
291       .legalForCartesianProduct(allIntScalars, allPtrs);
292 
293   getActionDefinitionsBuilder(
294       {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
295       .legalForCartesianProduct(allFloatScalars, allPtrs);
296 
297   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
298       .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
299 
300   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
301   // TODO: add proper legalization rules.
302   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
303 
304   getActionDefinitionsBuilder(
305       {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
306       .alwaysLegal();
307 
308   // FP conversions.
309   getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
310       .legalForCartesianProduct(allFloatScalarsAndVectors);
311 
312   // Pointer-handling.
313   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
314 
315   // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
316   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
317 
318   // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
319   // tighten these requirements. Many of these math functions are only legal on
320   // specific bitwidths, so they are not selectable for
321   // allFloatScalarsAndVectors.
322   getActionDefinitionsBuilder({G_STRICT_FSQRT,
323                                G_FPOW,
324                                G_FEXP,
325                                G_FEXP2,
326                                G_FLOG,
327                                G_FLOG2,
328                                G_FLOG10,
329                                G_FABS,
330                                G_FMINNUM,
331                                G_FMAXNUM,
332                                G_FCEIL,
333                                G_FCOS,
334                                G_FSIN,
335                                G_FTAN,
336                                G_FACOS,
337                                G_FASIN,
338                                G_FATAN,
339                                G_FATAN2,
340                                G_FCOSH,
341                                G_FSINH,
342                                G_FTANH,
343                                G_FSQRT,
344                                G_FFLOOR,
345                                G_FRINT,
346                                G_FNEARBYINT,
347                                G_INTRINSIC_ROUND,
348                                G_INTRINSIC_TRUNC,
349                                G_FMINIMUM,
350                                G_FMAXIMUM,
351                                G_INTRINSIC_ROUNDEVEN})
352       .legalFor(allFloatScalarsAndVectors);
353 
354   getActionDefinitionsBuilder(G_FCOPYSIGN)
355       .legalForCartesianProduct(allFloatScalarsAndVectors,
356                                 allFloatScalarsAndVectors);
357 
358   getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
359       allFloatScalarsAndVectors, allIntScalarsAndVectors);
360 
361   if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
362     getActionDefinitionsBuilder(
363         {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
364         .legalForCartesianProduct(allIntScalarsAndVectors,
365                                   allIntScalarsAndVectors);
366 
367     // Struct return types become a single scalar, so cannot easily legalize.
368     getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
369   }
370 
371   getLegacyLegalizerInfo().computeTables();
372   verify(*ST.getInstrInfo());
373 }
374 
375 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
376                                 LegalizerHelper &Helper,
377                                 MachineRegisterInfo &MRI,
378                                 SPIRVGlobalRegistry *GR) {
379   Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
380   MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
381   GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
382   Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
383       .addDef(ConvReg)
384       .addUse(Reg);
385   return ConvReg;
386 }
387 
388 bool SPIRVLegalizerInfo::legalizeCustom(
389     LegalizerHelper &Helper, MachineInstr &MI,
390     LostDebugLocObserver &LocObserver) const {
391   auto Opc = MI.getOpcode();
392   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
393   if (!isTypeFoldingSupported(Opc)) {
394     assert(Opc == TargetOpcode::G_ICMP);
395     assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
396     auto &Op0 = MI.getOperand(2);
397     auto &Op1 = MI.getOperand(3);
398     Register Reg0 = Op0.getReg();
399     Register Reg1 = Op1.getReg();
400     CmpInst::Predicate Cond =
401         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
402     if ((!ST->canDirectlyComparePointers() ||
403          (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
404         MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
405       LLT ConvT = LLT::scalar(ST->getPointerSize());
406       Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
407                                       ST->getPointerSize());
408       SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
409       Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
410       Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
411     }
412     return true;
413   }
414   // TODO: implement legalization for other opcodes.
415   return true;
416 }
417