xref: /llvm-project/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp (revision de6d0d2de0e2df72bd77f29d27addf13ebfbc997)
1 //===-- RISCVRegisterBankInfo.cpp -------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the RegisterBankInfo class for RISC-V.
10 /// \todo This should be generated by TableGen.
11 //===----------------------------------------------------------------------===//
12 
13 #include "RISCVRegisterBankInfo.h"
14 #include "MCTargetDesc/RISCVMCTargetDesc.h"
15 #include "RISCVSubtarget.h"
16 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterBank.h"
19 #include "llvm/CodeGen/RegisterBankInfo.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 
22 #define GET_TARGET_REGBANK_IMPL
23 #include "RISCVGenRegisterBank.inc"
24 
25 namespace llvm {
26 namespace RISCV {
27 
28 const RegisterBankInfo::PartialMapping PartMappings[] = {
29     // clang-format off
30     {0, 32, GPRBRegBank},
31     {0, 64, GPRBRegBank},
32     {0, 16, FPRBRegBank},
33     {0, 32, FPRBRegBank},
34     {0, 64, FPRBRegBank},
35     {0, 64, VRBRegBank},
36     {0, 128, VRBRegBank},
37     {0, 256, VRBRegBank},
38     {0, 512, VRBRegBank},
39     // clang-format on
40 };
41 
42 enum PartialMappingIdx {
43   PMI_GPRB32 = 0,
44   PMI_GPRB64 = 1,
45   PMI_FPRB16 = 2,
46   PMI_FPRB32 = 3,
47   PMI_FPRB64 = 4,
48   PMI_VRB64 = 5,
49   PMI_VRB128 = 6,
50   PMI_VRB256 = 7,
51   PMI_VRB512 = 8,
52 };
53 
54 const RegisterBankInfo::ValueMapping ValueMappings[] = {
55     // Invalid value mapping.
56     {nullptr, 0},
57     // Maximum 3 GPR operands; 32 bit.
58     {&PartMappings[PMI_GPRB32], 1},
59     {&PartMappings[PMI_GPRB32], 1},
60     {&PartMappings[PMI_GPRB32], 1},
61     // Maximum 3 GPR operands; 64 bit.
62     {&PartMappings[PMI_GPRB64], 1},
63     {&PartMappings[PMI_GPRB64], 1},
64     {&PartMappings[PMI_GPRB64], 1},
65     // Maximum 3 FPR operands; 16 bit.
66     {&PartMappings[PMI_FPRB16], 1},
67     {&PartMappings[PMI_FPRB16], 1},
68     {&PartMappings[PMI_FPRB16], 1},
69     // Maximum 3 FPR operands; 32 bit.
70     {&PartMappings[PMI_FPRB32], 1},
71     {&PartMappings[PMI_FPRB32], 1},
72     {&PartMappings[PMI_FPRB32], 1},
73     // Maximum 3 FPR operands; 64 bit.
74     {&PartMappings[PMI_FPRB64], 1},
75     {&PartMappings[PMI_FPRB64], 1},
76     {&PartMappings[PMI_FPRB64], 1},
77     // Maximum 3 VR LMUL={1, MF2, MF4, MF8} operands.
78     {&PartMappings[PMI_VRB64], 1},
79     {&PartMappings[PMI_VRB64], 1},
80     {&PartMappings[PMI_VRB64], 1},
81     // Maximum 3 VR LMUL=2 operands.
82     {&PartMappings[PMI_VRB128], 1},
83     {&PartMappings[PMI_VRB128], 1},
84     {&PartMappings[PMI_VRB128], 1},
85     // Maximum 3 VR LMUL=4 operands.
86     {&PartMappings[PMI_VRB256], 1},
87     {&PartMappings[PMI_VRB256], 1},
88     {&PartMappings[PMI_VRB256], 1},
89     // Maximum 3 VR LMUL=8 operands.
90     {&PartMappings[PMI_VRB512], 1},
91     {&PartMappings[PMI_VRB512], 1},
92     {&PartMappings[PMI_VRB512], 1},
93 };
94 
95 enum ValueMappingIdx {
96   InvalidIdx = 0,
97   GPRB32Idx = 1,
98   GPRB64Idx = 4,
99   FPRB16Idx = 7,
100   FPRB32Idx = 10,
101   FPRB64Idx = 13,
102   VRB64Idx = 16,
103   VRB128Idx = 19,
104   VRB256Idx = 22,
105   VRB512Idx = 25,
106 };
107 } // namespace RISCV
108 } // namespace llvm
109 
110 using namespace llvm;
111 
112 RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode)
113     : RISCVGenRegisterBankInfo(HwMode) {}
114 
115 static const RegisterBankInfo::ValueMapping *getFPValueMapping(unsigned Size) {
116   unsigned Idx;
117   switch (Size) {
118   default:
119     llvm_unreachable("Unexpected size");
120   case 16:
121     Idx = RISCV::FPRB16Idx;
122     break;
123   case 32:
124     Idx = RISCV::FPRB32Idx;
125     break;
126   case 64:
127     Idx = RISCV::FPRB64Idx;
128     break;
129   }
130   return &RISCV::ValueMappings[Idx];
131 }
132 
133 // TODO: Make this more like AArch64?
134 bool RISCVRegisterBankInfo::hasFPConstraints(
135     const MachineInstr &MI, const MachineRegisterInfo &MRI,
136     const TargetRegisterInfo &TRI) const {
137   if (isPreISelGenericFloatingPointOpcode(MI.getOpcode()))
138     return true;
139 
140   // If we have a copy instruction, we could be feeding floating point
141   // instructions.
142   if (MI.getOpcode() != TargetOpcode::COPY)
143     return false;
144 
145   return getRegBank(MI.getOperand(0).getReg(), MRI, TRI) == &RISCV::FPRBRegBank;
146 }
147 
148 bool RISCVRegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
149                                        const MachineRegisterInfo &MRI,
150                                        const TargetRegisterInfo &TRI) const {
151   switch (MI.getOpcode()) {
152   case RISCV::G_FCVT_W_RV64:
153   case RISCV::G_FCVT_WU_RV64:
154   case RISCV::G_FCLASS:
155   case TargetOpcode::G_FPTOSI:
156   case TargetOpcode::G_FPTOUI:
157   case TargetOpcode::G_FCMP:
158     return true;
159   default:
160     break;
161   }
162 
163   return hasFPConstraints(MI, MRI, TRI);
164 }
165 
166 bool RISCVRegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
167                                           const MachineRegisterInfo &MRI,
168                                           const TargetRegisterInfo &TRI) const {
169   switch (MI.getOpcode()) {
170   case TargetOpcode::G_SITOFP:
171   case TargetOpcode::G_UITOFP:
172     return true;
173   default:
174     break;
175   }
176 
177   return hasFPConstraints(MI, MRI, TRI);
178 }
179 
180 bool RISCVRegisterBankInfo::anyUseOnlyUseFP(
181     Register Def, const MachineRegisterInfo &MRI,
182     const TargetRegisterInfo &TRI) const {
183   return any_of(
184       MRI.use_nodbg_instructions(Def),
185       [&](const MachineInstr &UseMI) { return onlyUsesFP(UseMI, MRI, TRI); });
186 }
187 
188 static const RegisterBankInfo::ValueMapping *getVRBValueMapping(unsigned Size) {
189   unsigned Idx;
190 
191   if (Size <= 64)
192     Idx = RISCV::VRB64Idx;
193   else if (Size == 128)
194     Idx = RISCV::VRB128Idx;
195   else if (Size == 256)
196     Idx = RISCV::VRB256Idx;
197   else if (Size == 512)
198     Idx = RISCV::VRB512Idx;
199   else
200     llvm::report_fatal_error("Invalid Size");
201 
202   return &RISCV::ValueMappings[Idx];
203 }
204 
205 const RegisterBankInfo::InstructionMapping &
206 RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
207   const unsigned Opc = MI.getOpcode();
208 
209   // Try the default logic for non-generic instructions that are either copies
210   // or already have some operands assigned to banks.
211   if (!isPreISelGenericOpcode(Opc) || Opc == TargetOpcode::G_PHI) {
212     const InstructionMapping &Mapping = getInstrMappingImpl(MI);
213     if (Mapping.isValid())
214       return Mapping;
215   }
216 
217   const MachineFunction &MF = *MI.getParent()->getParent();
218   const MachineRegisterInfo &MRI = MF.getRegInfo();
219   const TargetSubtargetInfo &STI = MF.getSubtarget();
220   const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
221 
222   unsigned GPRSize = getMaximumSize(RISCV::GPRBRegBankID);
223   assert((GPRSize == 32 || GPRSize == 64) && "Unexpected GPR size");
224 
225   unsigned NumOperands = MI.getNumOperands();
226   const ValueMapping *GPRValueMapping =
227       &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
228                                           : RISCV::GPRB32Idx];
229 
230   switch (Opc) {
231   case TargetOpcode::G_ADD:
232   case TargetOpcode::G_SUB:
233   case TargetOpcode::G_SHL:
234   case TargetOpcode::G_ASHR:
235   case TargetOpcode::G_LSHR:
236   case TargetOpcode::G_AND:
237   case TargetOpcode::G_OR:
238   case TargetOpcode::G_XOR:
239   case TargetOpcode::G_MUL:
240   case TargetOpcode::G_SDIV:
241   case TargetOpcode::G_SREM:
242   case TargetOpcode::G_SMULH:
243   case TargetOpcode::G_SMAX:
244   case TargetOpcode::G_SMIN:
245   case TargetOpcode::G_UDIV:
246   case TargetOpcode::G_UREM:
247   case TargetOpcode::G_UMULH:
248   case TargetOpcode::G_UMAX:
249   case TargetOpcode::G_UMIN:
250   case TargetOpcode::G_PTR_ADD:
251   case TargetOpcode::G_PTRTOINT:
252   case TargetOpcode::G_INTTOPTR:
253   case TargetOpcode::G_FADD:
254   case TargetOpcode::G_FSUB:
255   case TargetOpcode::G_FMUL:
256   case TargetOpcode::G_FDIV:
257   case TargetOpcode::G_FABS:
258   case TargetOpcode::G_FNEG:
259   case TargetOpcode::G_FSQRT:
260   case TargetOpcode::G_FMAXNUM:
261   case TargetOpcode::G_FMINNUM: {
262     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
263     TypeSize Size = Ty.getSizeInBits();
264 
265     const ValueMapping *Mapping;
266     if (Ty.isVector())
267       Mapping = getVRBValueMapping(Size.getKnownMinValue());
268     else if (isPreISelGenericFloatingPointOpcode(Opc))
269       Mapping = getFPValueMapping(Size.getFixedValue());
270     else
271       Mapping = GPRValueMapping;
272 
273 #ifndef NDEBUG
274     // Make sure all the operands are using similar size and type.
275     for (unsigned Idx = 1; Idx != NumOperands; ++Idx) {
276       LLT OpTy = MRI.getType(MI.getOperand(Idx).getReg());
277       assert(Ty.isVector() == OpTy.isVector() &&
278              "Operand has incompatible type");
279       // Don't check size for GPR.
280       if (OpTy.isVector() || isPreISelGenericFloatingPointOpcode(Opc))
281         assert(Size == OpTy.getSizeInBits() && "Operand has incompatible size");
282     }
283 #endif // End NDEBUG
284 
285     return getInstructionMapping(DefaultMappingID, 1, Mapping, NumOperands);
286   }
287   case TargetOpcode::G_SEXTLOAD:
288   case TargetOpcode::G_ZEXTLOAD:
289     return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
290                                  NumOperands);
291   case TargetOpcode::G_IMPLICIT_DEF: {
292     Register Dst = MI.getOperand(0).getReg();
293     LLT DstTy = MRI.getType(Dst);
294     unsigned DstMinSize = DstTy.getSizeInBits().getKnownMinValue();
295     auto Mapping = GPRValueMapping;
296     // FIXME: May need to do a better job determining when to use FPRB.
297     // For example, the look through COPY case:
298     // %0:_(s32) = G_IMPLICIT_DEF
299     // %1:_(s32) = COPY %0
300     // $f10_d = COPY %1(s32)
301     if (DstTy.isVector())
302       Mapping = getVRBValueMapping(DstMinSize);
303     else if (anyUseOnlyUseFP(Dst, MRI, TRI))
304       Mapping = getFPValueMapping(DstMinSize);
305 
306     return getInstructionMapping(DefaultMappingID, /*Cost=*/1, Mapping,
307                                  NumOperands);
308   }
309   }
310 
311   SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);
312 
313   switch (Opc) {
314   case TargetOpcode::G_LOAD: {
315     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
316     TypeSize Size = Ty.getSizeInBits();
317 
318     OpdsMapping[1] = GPRValueMapping;
319 
320     if (Ty.isVector()) {
321       OpdsMapping[0] = getVRBValueMapping(Size.getKnownMinValue());
322       break;
323     }
324 
325     OpdsMapping[0] = GPRValueMapping;
326 
327     // Use FPR64 for s64 loads on rv32.
328     if (GPRSize == 32 && Size.getFixedValue() == 64) {
329       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
330       OpdsMapping[0] = getFPValueMapping(Size);
331       break;
332     }
333 
334     // Check if that load feeds fp instructions.
335     // In that case, we want the default mapping to be on FPR
336     // instead of blind map every scalar to GPR.
337     if (anyUseOnlyUseFP(MI.getOperand(0).getReg(), MRI, TRI)) {
338       // If we have at least one direct use in a FP instruction,
339       // assume this was a floating point load in the IR. If it was
340       // not, we would have had a bitcast before reaching that
341       // instruction.
342       OpdsMapping[0] = getFPValueMapping(Size);
343       break;
344     }
345 
346     break;
347   }
348   case TargetOpcode::G_STORE: {
349     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
350     TypeSize Size = Ty.getSizeInBits();
351 
352     OpdsMapping[1] = GPRValueMapping;
353 
354     if (Ty.isVector()) {
355       OpdsMapping[0] = getVRBValueMapping(Size.getKnownMinValue());
356       break;
357     }
358 
359     OpdsMapping[0] = GPRValueMapping;
360 
361     // Use FPR64 for s64 stores on rv32.
362     if (GPRSize == 32 && Size.getFixedValue() == 64) {
363       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
364       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
365       break;
366     }
367 
368     MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
369     if (onlyDefinesFP(*DefMI, MRI, TRI))
370       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
371     break;
372   }
373   case TargetOpcode::G_SELECT: {
374     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
375 
376     if (Ty.isVector()) {
377       auto &Sel = cast<GSelect>(MI);
378       LLT TestTy = MRI.getType(Sel.getCondReg());
379       assert(TestTy.isVector() && "Unexpected condition argument type");
380       OpdsMapping[0] = OpdsMapping[2] = OpdsMapping[3] =
381           getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue());
382       OpdsMapping[1] =
383           getVRBValueMapping(TestTy.getSizeInBits().getKnownMinValue());
384       break;
385     }
386 
387     // Try to minimize the number of copies. If we have more floating point
388     // constrained values than not, then we'll put everything on FPR. Otherwise,
389     // everything has to be on GPR.
390     unsigned NumFP = 0;
391 
392     // Use FPR64 for s64 select on rv32.
393     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
394       NumFP = 3;
395     } else {
396       // Check if the uses of the result always produce floating point values.
397       //
398       // For example:
399       //
400       // %z = G_SELECT %cond %x %y
401       // fpr = G_FOO %z ...
402       if (any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()),
403                  [&](const MachineInstr &UseMI) {
404                    return onlyUsesFP(UseMI, MRI, TRI);
405                  }))
406         ++NumFP;
407 
408       // Check if the defs of the source values always produce floating point
409       // values.
410       //
411       // For example:
412       //
413       // %x = G_SOMETHING_ALWAYS_FLOAT %a ...
414       // %z = G_SELECT %cond %x %y
415       //
416       // Also check whether or not the sources have already been decided to be
417       // FPR. Keep track of this.
418       //
419       // This doesn't check the condition, since the condition is always an
420       // integer.
421       for (unsigned Idx = 2; Idx < 4; ++Idx) {
422         Register VReg = MI.getOperand(Idx).getReg();
423         MachineInstr *DefMI = MRI.getVRegDef(VReg);
424         if (getRegBank(VReg, MRI, TRI) == &RISCV::FPRBRegBank ||
425             onlyDefinesFP(*DefMI, MRI, TRI))
426           ++NumFP;
427       }
428     }
429 
430     // Condition operand is always GPR.
431     OpdsMapping[1] = GPRValueMapping;
432 
433     const ValueMapping *Mapping = GPRValueMapping;
434     if (NumFP >= 2)
435       Mapping = getFPValueMapping(Ty.getSizeInBits());
436 
437     OpdsMapping[0] = OpdsMapping[2] = OpdsMapping[3] = Mapping;
438     break;
439   }
440   case RISCV::G_FCVT_W_RV64:
441   case RISCV::G_FCVT_WU_RV64:
442   case TargetOpcode::G_FPTOSI:
443   case TargetOpcode::G_FPTOUI:
444   case RISCV::G_FCLASS: {
445     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
446     OpdsMapping[0] = GPRValueMapping;
447     OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
448     break;
449   }
450   case TargetOpcode::G_SITOFP:
451   case TargetOpcode::G_UITOFP: {
452     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
453     OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
454     OpdsMapping[1] = GPRValueMapping;
455     break;
456   }
457   case TargetOpcode::G_FCMP: {
458     LLT Ty = MRI.getType(MI.getOperand(2).getReg());
459 
460     unsigned Size = Ty.getSizeInBits();
461 
462     OpdsMapping[0] = GPRValueMapping;
463     OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
464     break;
465   }
466   case TargetOpcode::G_MERGE_VALUES: {
467     // Use FPR64 for s64 merge on rv32.
468     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
469     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
470       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
471       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
472       OpdsMapping[1] = GPRValueMapping;
473       OpdsMapping[2] = GPRValueMapping;
474     }
475     break;
476   }
477   case TargetOpcode::G_UNMERGE_VALUES: {
478     // Use FPR64 for s64 unmerge on rv32.
479     LLT Ty = MRI.getType(MI.getOperand(2).getReg());
480     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
481       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
482       OpdsMapping[0] = GPRValueMapping;
483       OpdsMapping[1] = GPRValueMapping;
484       OpdsMapping[2] = getFPValueMapping(Ty.getSizeInBits());
485     }
486     break;
487   }
488   case TargetOpcode::G_SPLAT_VECTOR: {
489     OpdsMapping[0] = getVRBValueMapping(MRI.getType(MI.getOperand(0).getReg())
490                                             .getSizeInBits()
491                                             .getKnownMinValue());
492 
493     LLT ScalarTy = MRI.getType(MI.getOperand(1).getReg());
494     MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(1).getReg());
495     if ((GPRSize == 32 && ScalarTy.getSizeInBits() == 64) ||
496         onlyDefinesFP(*DefMI, MRI, TRI)) {
497       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
498       OpdsMapping[1] = getFPValueMapping(ScalarTy.getSizeInBits());
499     } else
500       OpdsMapping[1] = GPRValueMapping;
501     break;
502   }
503   default:
504     // By default map all scalars to GPR.
505     for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
506        auto &MO = MI.getOperand(Idx);
507        if (!MO.isReg() || !MO.getReg())
508          continue;
509        LLT Ty = MRI.getType(MO.getReg());
510        if (!Ty.isValid())
511          continue;
512 
513        if (Ty.isVector())
514          OpdsMapping[Idx] =
515              getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue());
516        else if (isPreISelGenericFloatingPointOpcode(Opc))
517          OpdsMapping[Idx] = getFPValueMapping(Ty.getSizeInBits());
518        else
519          OpdsMapping[Idx] = GPRValueMapping;
520     }
521     break;
522   }
523 
524   return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
525                                getOperandsMapping(OpdsMapping), NumOperands);
526 }
527