xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- RISCVRegisterBankInfo.cpp -------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the RegisterBankInfo class for RISC-V.
10 /// \todo This should be generated by TableGen.
11 //===----------------------------------------------------------------------===//
12 
13 #include "RISCVRegisterBankInfo.h"
14 #include "MCTargetDesc/RISCVMCTargetDesc.h"
15 #include "RISCVSubtarget.h"
16 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterBank.h"
19 #include "llvm/CodeGen/RegisterBankInfo.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 
22 #define GET_TARGET_REGBANK_IMPL
23 #include "RISCVGenRegisterBank.inc"
24 
25 namespace llvm {
26 namespace RISCV {
27 
28 const RegisterBankInfo::PartialMapping PartMappings[] = {
29     // clang-format off
30     {0, 32, GPRBRegBank},
31     {0, 64, GPRBRegBank},
32     {0, 16, FPRBRegBank},
33     {0, 32, FPRBRegBank},
34     {0, 64, FPRBRegBank},
35     {0, 64, VRBRegBank},
36     {0, 128, VRBRegBank},
37     {0, 256, VRBRegBank},
38     {0, 512, VRBRegBank},
39     // clang-format on
40 };
41 
42 enum PartialMappingIdx {
43   PMI_GPRB32 = 0,
44   PMI_GPRB64 = 1,
45   PMI_FPRB16 = 2,
46   PMI_FPRB32 = 3,
47   PMI_FPRB64 = 4,
48   PMI_VRB64 = 5,
49   PMI_VRB128 = 6,
50   PMI_VRB256 = 7,
51   PMI_VRB512 = 8,
52 };
53 
54 const RegisterBankInfo::ValueMapping ValueMappings[] = {
55     // Invalid value mapping.
56     {nullptr, 0},
57     // Maximum 3 GPR operands; 32 bit.
58     {&PartMappings[PMI_GPRB32], 1},
59     {&PartMappings[PMI_GPRB32], 1},
60     {&PartMappings[PMI_GPRB32], 1},
61     // Maximum 3 GPR operands; 64 bit.
62     {&PartMappings[PMI_GPRB64], 1},
63     {&PartMappings[PMI_GPRB64], 1},
64     {&PartMappings[PMI_GPRB64], 1},
65     // Maximum 3 FPR operands; 16 bit.
66     {&PartMappings[PMI_FPRB16], 1},
67     {&PartMappings[PMI_FPRB16], 1},
68     {&PartMappings[PMI_FPRB16], 1},
69     // Maximum 3 FPR operands; 32 bit.
70     {&PartMappings[PMI_FPRB32], 1},
71     {&PartMappings[PMI_FPRB32], 1},
72     {&PartMappings[PMI_FPRB32], 1},
73     // Maximum 3 FPR operands; 64 bit.
74     {&PartMappings[PMI_FPRB64], 1},
75     {&PartMappings[PMI_FPRB64], 1},
76     {&PartMappings[PMI_FPRB64], 1},
77     // Maximum 3 VR LMUL={1, MF2, MF4, MF8} operands.
78     {&PartMappings[PMI_VRB64], 1},
79     {&PartMappings[PMI_VRB64], 1},
80     {&PartMappings[PMI_VRB64], 1},
81     // Maximum 3 VR LMUL=2 operands.
82     {&PartMappings[PMI_VRB128], 1},
83     {&PartMappings[PMI_VRB128], 1},
84     {&PartMappings[PMI_VRB128], 1},
85     // Maximum 3 VR LMUL=4 operands.
86     {&PartMappings[PMI_VRB256], 1},
87     {&PartMappings[PMI_VRB256], 1},
88     {&PartMappings[PMI_VRB256], 1},
89     // Maximum 3 VR LMUL=8 operands.
90     {&PartMappings[PMI_VRB512], 1},
91     {&PartMappings[PMI_VRB512], 1},
92     {&PartMappings[PMI_VRB512], 1},
93 };
94 
95 enum ValueMappingIdx {
96   InvalidIdx = 0,
97   GPRB32Idx = 1,
98   GPRB64Idx = 4,
99   FPRB16Idx = 7,
100   FPRB32Idx = 10,
101   FPRB64Idx = 13,
102   VRB64Idx = 16,
103   VRB128Idx = 19,
104   VRB256Idx = 22,
105   VRB512Idx = 25,
106 };
107 } // namespace RISCV
108 } // namespace llvm
109 
110 using namespace llvm;
111 
112 RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode)
113     : RISCVGenRegisterBankInfo(HwMode) {}
114 
115 const RegisterBank &
116 RISCVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
117                                               LLT Ty) const {
118   switch (RC.getID()) {
119   default:
120     llvm_unreachable("Register class not supported");
121   case RISCV::GPRRegClassID:
122   case RISCV::GPRF16RegClassID:
123   case RISCV::GPRF32RegClassID:
124   case RISCV::GPRNoX0RegClassID:
125   case RISCV::GPRNoX0X2RegClassID:
126   case RISCV::GPRJALRRegClassID:
127   case RISCV::GPRJALRNonX7RegClassID:
128   case RISCV::GPRTCRegClassID:
129   case RISCV::GPRTCNonX7RegClassID:
130   case RISCV::GPRC_and_GPRTCRegClassID:
131   case RISCV::GPRCRegClassID:
132   case RISCV::GPRC_and_SR07RegClassID:
133   case RISCV::SR07RegClassID:
134   case RISCV::SPRegClassID:
135   case RISCV::GPRX0RegClassID:
136     return getRegBank(RISCV::GPRBRegBankID);
137   case RISCV::FPR64RegClassID:
138   case RISCV::FPR16RegClassID:
139   case RISCV::FPR32RegClassID:
140   case RISCV::FPR64CRegClassID:
141   case RISCV::FPR32CRegClassID:
142     return getRegBank(RISCV::FPRBRegBankID);
143   case RISCV::VMRegClassID:
144   case RISCV::VRRegClassID:
145   case RISCV::VRNoV0RegClassID:
146   case RISCV::VRM2RegClassID:
147   case RISCV::VRM2NoV0RegClassID:
148   case RISCV::VRM4RegClassID:
149   case RISCV::VRM4NoV0RegClassID:
150   case RISCV::VMV0RegClassID:
151   case RISCV::VRM2_with_sub_vrm1_0_in_VMV0RegClassID:
152   case RISCV::VRM4_with_sub_vrm1_0_in_VMV0RegClassID:
153   case RISCV::VRM8RegClassID:
154   case RISCV::VRM8NoV0RegClassID:
155   case RISCV::VRM8_with_sub_vrm1_0_in_VMV0RegClassID:
156     return getRegBank(RISCV::VRBRegBankID);
157   }
158 }
159 
160 static const RegisterBankInfo::ValueMapping *getFPValueMapping(unsigned Size) {
161   unsigned Idx;
162   switch (Size) {
163   default:
164     llvm_unreachable("Unexpected size");
165   case 16:
166     Idx = RISCV::FPRB16Idx;
167     break;
168   case 32:
169     Idx = RISCV::FPRB32Idx;
170     break;
171   case 64:
172     Idx = RISCV::FPRB64Idx;
173     break;
174   }
175   return &RISCV::ValueMappings[Idx];
176 }
177 
178 // TODO: Make this more like AArch64?
179 bool RISCVRegisterBankInfo::hasFPConstraints(
180     const MachineInstr &MI, const MachineRegisterInfo &MRI,
181     const TargetRegisterInfo &TRI) const {
182   if (isPreISelGenericFloatingPointOpcode(MI.getOpcode()))
183     return true;
184 
185   // If we have a copy instruction, we could be feeding floating point
186   // instructions.
187   if (MI.getOpcode() != TargetOpcode::COPY)
188     return false;
189 
190   return getRegBank(MI.getOperand(0).getReg(), MRI, TRI) == &RISCV::FPRBRegBank;
191 }
192 
193 bool RISCVRegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
194                                        const MachineRegisterInfo &MRI,
195                                        const TargetRegisterInfo &TRI) const {
196   switch (MI.getOpcode()) {
197   case TargetOpcode::G_FPTOSI:
198   case TargetOpcode::G_FPTOUI:
199   case TargetOpcode::G_FCMP:
200     return true;
201   default:
202     break;
203   }
204 
205   return hasFPConstraints(MI, MRI, TRI);
206 }
207 
208 bool RISCVRegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
209                                           const MachineRegisterInfo &MRI,
210                                           const TargetRegisterInfo &TRI) const {
211   switch (MI.getOpcode()) {
212   case TargetOpcode::G_SITOFP:
213   case TargetOpcode::G_UITOFP:
214     return true;
215   default:
216     break;
217   }
218 
219   return hasFPConstraints(MI, MRI, TRI);
220 }
221 
222 bool RISCVRegisterBankInfo::anyUseOnlyUseFP(
223     Register Def, const MachineRegisterInfo &MRI,
224     const TargetRegisterInfo &TRI) const {
225   return any_of(
226       MRI.use_nodbg_instructions(Def),
227       [&](const MachineInstr &UseMI) { return onlyUsesFP(UseMI, MRI, TRI); });
228 }
229 
230 static const RegisterBankInfo::ValueMapping *getVRBValueMapping(unsigned Size) {
231   unsigned Idx;
232 
233   if (Size <= 64)
234     Idx = RISCV::VRB64Idx;
235   else if (Size == 128)
236     Idx = RISCV::VRB128Idx;
237   else if (Size == 256)
238     Idx = RISCV::VRB256Idx;
239   else if (Size == 512)
240     Idx = RISCV::VRB512Idx;
241   else
242     llvm::report_fatal_error("Invalid Size");
243 
244   return &RISCV::ValueMappings[Idx];
245 }
246 
247 const RegisterBankInfo::InstructionMapping &
248 RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
249   const unsigned Opc = MI.getOpcode();
250 
251   // Try the default logic for non-generic instructions that are either copies
252   // or already have some operands assigned to banks.
253   if (!isPreISelGenericOpcode(Opc) || Opc == TargetOpcode::G_PHI) {
254     const InstructionMapping &Mapping = getInstrMappingImpl(MI);
255     if (Mapping.isValid())
256       return Mapping;
257   }
258 
259   const MachineFunction &MF = *MI.getParent()->getParent();
260   const MachineRegisterInfo &MRI = MF.getRegInfo();
261   const TargetSubtargetInfo &STI = MF.getSubtarget();
262   const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
263 
264   unsigned GPRSize = getMaximumSize(RISCV::GPRBRegBankID);
265   assert((GPRSize == 32 || GPRSize == 64) && "Unexpected GPR size");
266 
267   unsigned NumOperands = MI.getNumOperands();
268   const ValueMapping *GPRValueMapping =
269       &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
270                                           : RISCV::GPRB32Idx];
271 
272   switch (Opc) {
273   case TargetOpcode::G_ADD:
274   case TargetOpcode::G_SUB:
275   case TargetOpcode::G_SHL:
276   case TargetOpcode::G_ASHR:
277   case TargetOpcode::G_LSHR:
278   case TargetOpcode::G_AND:
279   case TargetOpcode::G_OR:
280   case TargetOpcode::G_XOR:
281   case TargetOpcode::G_MUL:
282   case TargetOpcode::G_SDIV:
283   case TargetOpcode::G_SREM:
284   case TargetOpcode::G_SMULH:
285   case TargetOpcode::G_SMAX:
286   case TargetOpcode::G_SMIN:
287   case TargetOpcode::G_UDIV:
288   case TargetOpcode::G_UREM:
289   case TargetOpcode::G_UMULH:
290   case TargetOpcode::G_UMAX:
291   case TargetOpcode::G_UMIN:
292   case TargetOpcode::G_PTR_ADD:
293   case TargetOpcode::G_PTRTOINT:
294   case TargetOpcode::G_INTTOPTR:
295   case TargetOpcode::G_FADD:
296   case TargetOpcode::G_FSUB:
297   case TargetOpcode::G_FMUL:
298   case TargetOpcode::G_FDIV:
299   case TargetOpcode::G_FABS:
300   case TargetOpcode::G_FNEG:
301   case TargetOpcode::G_FSQRT:
302   case TargetOpcode::G_FMAXNUM:
303   case TargetOpcode::G_FMINNUM: {
304     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
305     TypeSize Size = Ty.getSizeInBits();
306 
307     const ValueMapping *Mapping;
308     if (Ty.isVector())
309       Mapping = getVRBValueMapping(Size.getKnownMinValue());
310     else if (isPreISelGenericFloatingPointOpcode(Opc))
311       Mapping = getFPValueMapping(Size.getFixedValue());
312     else
313       Mapping = GPRValueMapping;
314 
315 #ifndef NDEBUG
316     // Make sure all the operands are using similar size and type.
317     for (unsigned Idx = 1; Idx != NumOperands; ++Idx) {
318       LLT OpTy = MRI.getType(MI.getOperand(Idx).getReg());
319       assert(Ty.isVector() == OpTy.isVector() &&
320              "Operand has incompatible type");
321       // Don't check size for GPR.
322       if (OpTy.isVector() || isPreISelGenericFloatingPointOpcode(Opc))
323         assert(Size == OpTy.getSizeInBits() && "Operand has incompatible size");
324     }
325 #endif // End NDEBUG
326 
327     return getInstructionMapping(DefaultMappingID, 1, Mapping, NumOperands);
328   }
329   case TargetOpcode::G_SEXTLOAD:
330   case TargetOpcode::G_ZEXTLOAD:
331     return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
332                                  NumOperands);
333   case TargetOpcode::G_IMPLICIT_DEF: {
334     Register Dst = MI.getOperand(0).getReg();
335     LLT DstTy = MRI.getType(Dst);
336     unsigned DstMinSize = DstTy.getSizeInBits().getKnownMinValue();
337     auto Mapping = GPRValueMapping;
338     // FIXME: May need to do a better job determining when to use FPRB.
339     // For example, the look through COPY case:
340     // %0:_(s32) = G_IMPLICIT_DEF
341     // %1:_(s32) = COPY %0
342     // $f10_d = COPY %1(s32)
343     if (DstTy.isVector())
344       Mapping = getVRBValueMapping(DstMinSize);
345     else if (anyUseOnlyUseFP(Dst, MRI, TRI))
346       Mapping = getFPValueMapping(DstMinSize);
347 
348     return getInstructionMapping(DefaultMappingID, /*Cost=*/1, Mapping,
349                                  NumOperands);
350   }
351   }
352 
353   SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);
354 
355   switch (Opc) {
356   case TargetOpcode::G_LOAD: {
357     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
358     OpdsMapping[0] = GPRValueMapping;
359     OpdsMapping[1] = GPRValueMapping;
360     // Use FPR64 for s64 loads on rv32.
361     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
362       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
363       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
364       break;
365     }
366 
367     // Check if that load feeds fp instructions.
368     // In that case, we want the default mapping to be on FPR
369     // instead of blind map every scalar to GPR.
370     if (anyUseOnlyUseFP(MI.getOperand(0).getReg(), MRI, TRI))
371       // If we have at least one direct use in a FP instruction,
372       // assume this was a floating point load in the IR. If it was
373       // not, we would have had a bitcast before reaching that
374       // instruction.
375       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
376 
377     break;
378   }
379   case TargetOpcode::G_STORE: {
380     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
381     OpdsMapping[0] = GPRValueMapping;
382     OpdsMapping[1] = GPRValueMapping;
383     // Use FPR64 for s64 stores on rv32.
384     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
385       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
386       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
387       break;
388     }
389 
390     MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
391     if (onlyDefinesFP(*DefMI, MRI, TRI))
392       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
393     break;
394   }
395   case TargetOpcode::G_SELECT: {
396     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
397 
398     if (Ty.isVector()) {
399       auto &Sel = cast<GSelect>(MI);
400       LLT TestTy = MRI.getType(Sel.getCondReg());
401       assert(TestTy.isVector() && "Unexpected condition argument type");
402       OpdsMapping[0] = OpdsMapping[2] = OpdsMapping[3] =
403           getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue());
404       OpdsMapping[1] =
405           getVRBValueMapping(TestTy.getSizeInBits().getKnownMinValue());
406       break;
407     }
408 
409     // Try to minimize the number of copies. If we have more floating point
410     // constrained values than not, then we'll put everything on FPR. Otherwise,
411     // everything has to be on GPR.
412     unsigned NumFP = 0;
413 
414     // Use FPR64 for s64 select on rv32.
415     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
416       NumFP = 3;
417     } else {
418       // Check if the uses of the result always produce floating point values.
419       //
420       // For example:
421       //
422       // %z = G_SELECT %cond %x %y
423       // fpr = G_FOO %z ...
424       if (any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()),
425                  [&](const MachineInstr &UseMI) {
426                    return onlyUsesFP(UseMI, MRI, TRI);
427                  }))
428         ++NumFP;
429 
430       // Check if the defs of the source values always produce floating point
431       // values.
432       //
433       // For example:
434       //
435       // %x = G_SOMETHING_ALWAYS_FLOAT %a ...
436       // %z = G_SELECT %cond %x %y
437       //
438       // Also check whether or not the sources have already been decided to be
439       // FPR. Keep track of this.
440       //
441       // This doesn't check the condition, since the condition is always an
442       // integer.
443       for (unsigned Idx = 2; Idx < 4; ++Idx) {
444         Register VReg = MI.getOperand(Idx).getReg();
445         MachineInstr *DefMI = MRI.getVRegDef(VReg);
446         if (getRegBank(VReg, MRI, TRI) == &RISCV::FPRBRegBank ||
447             onlyDefinesFP(*DefMI, MRI, TRI))
448           ++NumFP;
449       }
450     }
451 
452     // Condition operand is always GPR.
453     OpdsMapping[1] = GPRValueMapping;
454 
455     const ValueMapping *Mapping = GPRValueMapping;
456     if (NumFP >= 2)
457       Mapping = getFPValueMapping(Ty.getSizeInBits());
458 
459     OpdsMapping[0] = OpdsMapping[2] = OpdsMapping[3] = Mapping;
460     break;
461   }
462   case TargetOpcode::G_FPTOSI:
463   case TargetOpcode::G_FPTOUI:
464   case RISCV::G_FCLASS: {
465     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
466     OpdsMapping[0] = GPRValueMapping;
467     OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
468     break;
469   }
470   case TargetOpcode::G_SITOFP:
471   case TargetOpcode::G_UITOFP: {
472     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
473     OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
474     OpdsMapping[1] = GPRValueMapping;
475     break;
476   }
477   case TargetOpcode::G_FCMP: {
478     LLT Ty = MRI.getType(MI.getOperand(2).getReg());
479 
480     unsigned Size = Ty.getSizeInBits();
481 
482     OpdsMapping[0] = GPRValueMapping;
483     OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
484     break;
485   }
486   case TargetOpcode::G_MERGE_VALUES: {
487     // Use FPR64 for s64 merge on rv32.
488     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
489     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
490       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
491       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
492       OpdsMapping[1] = GPRValueMapping;
493       OpdsMapping[2] = GPRValueMapping;
494     }
495     break;
496   }
497   case TargetOpcode::G_UNMERGE_VALUES: {
498     // Use FPR64 for s64 unmerge on rv32.
499     LLT Ty = MRI.getType(MI.getOperand(2).getReg());
500     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
501       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
502       OpdsMapping[0] = GPRValueMapping;
503       OpdsMapping[1] = GPRValueMapping;
504       OpdsMapping[2] = getFPValueMapping(Ty.getSizeInBits());
505     }
506     break;
507   }
508   default:
509     // By default map all scalars to GPR.
510     for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
511        auto &MO = MI.getOperand(Idx);
512        if (!MO.isReg() || !MO.getReg())
513          continue;
514        LLT Ty = MRI.getType(MO.getReg());
515        if (!Ty.isValid())
516          continue;
517 
518        if (Ty.isVector())
519          OpdsMapping[Idx] =
520              getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue());
521        else if (isPreISelGenericFloatingPointOpcode(Opc))
522          OpdsMapping[Idx] = getFPValueMapping(Ty.getSizeInBits());
523        else
524          OpdsMapping[Idx] = GPRValueMapping;
525     }
526     break;
527   }
528 
529   return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
530                                getOperandsMapping(OpdsMapping), NumOperands);
531 }
532