xref: /llvm-project/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp (revision f6d431f208c0fa48827eac40e7acf788346a9967)
17c836512SeopXD //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
27c836512SeopXD //
37c836512SeopXD // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47c836512SeopXD // See https://llvm.org/LICENSE.txt for license information.
57c836512SeopXD // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67c836512SeopXD //
77c836512SeopXD //===----------------------------------------------------------------------===//
87c836512SeopXD // This file implements the machine function pass to insert read/write of CSR-s
97c836512SeopXD // of the RISC-V instructions.
107c836512SeopXD //
1180ce7ce9SCraig Topper // Currently the pass implements:
1280ce7ce9SCraig Topper // -Writing and saving frm before an RVV floating-point instruction with a
1380ce7ce9SCraig Topper //  static rounding mode and restores the value after.
147c836512SeopXD //
157c836512SeopXD //===----------------------------------------------------------------------===//
167c836512SeopXD 
1776482078SeopXD #include "MCTargetDesc/RISCVBaseInfo.h"
187c836512SeopXD #include "RISCV.h"
197c836512SeopXD #include "RISCVSubtarget.h"
207c836512SeopXD #include "llvm/CodeGen/MachineFunctionPass.h"
217c836512SeopXD using namespace llvm;
227c836512SeopXD 
237c836512SeopXD #define DEBUG_TYPE "riscv-insert-read-write-csr"
247c836512SeopXD #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
257c836512SeopXD 
26df08350dSYeting Kuo static cl::opt<bool>
27df08350dSYeting Kuo     DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false),
28df08350dSYeting Kuo                         cl::Hidden,
29df08350dSYeting Kuo                         cl::desc("Disable optimized frm insertion."));
30df08350dSYeting Kuo 
317c836512SeopXD namespace {
327c836512SeopXD 
337c836512SeopXD class RISCVInsertReadWriteCSR : public MachineFunctionPass {
347c836512SeopXD   const TargetInstrInfo *TII;
357c836512SeopXD 
367c836512SeopXD public:
377c836512SeopXD   static char ID;
387c836512SeopXD 
RISCVInsertReadWriteCSR()394162a9bcSCraig Topper   RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {}
407c836512SeopXD 
417c836512SeopXD   bool runOnMachineFunction(MachineFunction &MF) override;
427c836512SeopXD 
getAnalysisUsage(AnalysisUsage & AU) const437c836512SeopXD   void getAnalysisUsage(AnalysisUsage &AU) const override {
447c836512SeopXD     AU.setPreservesCFG();
457c836512SeopXD     MachineFunctionPass::getAnalysisUsage(AU);
467c836512SeopXD   }
477c836512SeopXD 
getPassName() const487c836512SeopXD   StringRef getPassName() const override {
497c836512SeopXD     return RISCV_INSERT_READ_WRITE_CSR_NAME;
507c836512SeopXD   }
517c836512SeopXD 
527c836512SeopXD private:
5376482078SeopXD   bool emitWriteRoundingMode(MachineBasicBlock &MBB);
54df08350dSYeting Kuo   bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB);
557c836512SeopXD };
567c836512SeopXD 
577c836512SeopXD } // end anonymous namespace
587c836512SeopXD 
597c836512SeopXD char RISCVInsertReadWriteCSR::ID = 0;
607c836512SeopXD 
INITIALIZE_PASS(RISCVInsertReadWriteCSR,DEBUG_TYPE,RISCV_INSERT_READ_WRITE_CSR_NAME,false,false)617c836512SeopXD INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
627c836512SeopXD                 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
637c836512SeopXD 
64df08350dSYeting Kuo // TODO: Use more accurate rounding mode at the start of MBB.
65df08350dSYeting Kuo bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) {
66df08350dSYeting Kuo   bool Changed = false;
67df08350dSYeting Kuo   MachineInstr *LastFRMChanger = nullptr;
68df08350dSYeting Kuo   unsigned CurrentRM = RISCVFPRndMode::DYN;
69df08350dSYeting Kuo   Register SavedFRM;
70df08350dSYeting Kuo 
71df08350dSYeting Kuo   for (MachineInstr &MI : MBB) {
72df08350dSYeting Kuo     if (MI.getOpcode() == RISCV::SwapFRMImm ||
73df08350dSYeting Kuo         MI.getOpcode() == RISCV::WriteFRMImm) {
74df08350dSYeting Kuo       CurrentRM = MI.getOperand(0).getImm();
75df08350dSYeting Kuo       SavedFRM = Register();
76df08350dSYeting Kuo       continue;
77df08350dSYeting Kuo     }
78df08350dSYeting Kuo 
79df08350dSYeting Kuo     if (MI.getOpcode() == RISCV::WriteFRM) {
80df08350dSYeting Kuo       CurrentRM = RISCVFPRndMode::DYN;
81df08350dSYeting Kuo       SavedFRM = Register();
82df08350dSYeting Kuo       continue;
83df08350dSYeting Kuo     }
84df08350dSYeting Kuo 
85*f6d431f2SXu Zhang     if (MI.isCall() || MI.isInlineAsm() ||
86*f6d431f2SXu Zhang         MI.readsRegister(RISCV::FRM, /*TRI=*/nullptr)) {
87df08350dSYeting Kuo       // Restore FRM before unknown operations.
88df08350dSYeting Kuo       if (SavedFRM.isValid())
89df08350dSYeting Kuo         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM))
90df08350dSYeting Kuo             .addReg(SavedFRM);
91df08350dSYeting Kuo       CurrentRM = RISCVFPRndMode::DYN;
92df08350dSYeting Kuo       SavedFRM = Register();
93df08350dSYeting Kuo       continue;
94df08350dSYeting Kuo     }
95df08350dSYeting Kuo 
96*f6d431f2SXu Zhang     assert(!MI.modifiesRegister(RISCV::FRM, /*TRI=*/nullptr) &&
97df08350dSYeting Kuo            "Expected that MI could not modify FRM.");
98df08350dSYeting Kuo 
99df08350dSYeting Kuo     int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
100df08350dSYeting Kuo     if (FRMIdx < 0)
101df08350dSYeting Kuo       continue;
102df08350dSYeting Kuo     unsigned InstrRM = MI.getOperand(FRMIdx).getImm();
103df08350dSYeting Kuo 
104df08350dSYeting Kuo     LastFRMChanger = &MI;
105df08350dSYeting Kuo 
106df08350dSYeting Kuo     // Make MI implicit use FRM.
107df08350dSYeting Kuo     MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
108df08350dSYeting Kuo                                             /*IsImp*/ true));
109df08350dSYeting Kuo     Changed = true;
110df08350dSYeting Kuo 
111df08350dSYeting Kuo     // Skip if MI uses same rounding mode as FRM.
112df08350dSYeting Kuo     if (InstrRM == CurrentRM)
113df08350dSYeting Kuo       continue;
114df08350dSYeting Kuo 
115df08350dSYeting Kuo     if (!SavedFRM.isValid()) {
116df08350dSYeting Kuo       // Save current FRM value to SavedFRM.
117df08350dSYeting Kuo       MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
118df08350dSYeting Kuo       SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
119df08350dSYeting Kuo       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM)
120df08350dSYeting Kuo           .addImm(InstrRM);
121df08350dSYeting Kuo     } else {
122df08350dSYeting Kuo       // Don't need to save current FRM when SavedFRM having value.
123df08350dSYeting Kuo       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
124df08350dSYeting Kuo           .addImm(InstrRM);
125df08350dSYeting Kuo     }
126df08350dSYeting Kuo     CurrentRM = InstrRM;
127df08350dSYeting Kuo   }
128df08350dSYeting Kuo 
129df08350dSYeting Kuo   // Restore FRM if needed.
130df08350dSYeting Kuo   if (SavedFRM.isValid()) {
131df08350dSYeting Kuo     assert(LastFRMChanger && "Expected valid pointer.");
132df08350dSYeting Kuo     MachineInstrBuilder MIB =
133df08350dSYeting Kuo         BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
134df08350dSYeting Kuo             .addReg(SavedFRM);
135df08350dSYeting Kuo     MBB.insertAfter(LastFRMChanger, MIB);
136df08350dSYeting Kuo   }
137df08350dSYeting Kuo 
138df08350dSYeting Kuo   return Changed;
139df08350dSYeting Kuo }
140df08350dSYeting Kuo 
141014390d9SCraig Topper // This function also swaps frm and restores it when encountering an RVV
142014390d9SCraig Topper // floating point instruction with a static rounding mode.
emitWriteRoundingMode(MachineBasicBlock & MBB)14376482078SeopXD bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
1447c836512SeopXD   bool Changed = false;
1457c836512SeopXD   for (MachineInstr &MI : MBB) {
146b441fd60SCraig Topper     int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
147b441fd60SCraig Topper     if (FRMIdx < 0)
148b441fd60SCraig Topper       continue;
149b441fd60SCraig Topper 
150b441fd60SCraig Topper     unsigned FRMImm = MI.getOperand(FRMIdx).getImm();
15176482078SeopXD 
15276482078SeopXD     // The value is a hint to this pass to not alter the frm value.
15376482078SeopXD     if (FRMImm == RISCVFPRndMode::DYN)
15476482078SeopXD       continue;
15576482078SeopXD 
15676482078SeopXD     Changed = true;
15776482078SeopXD 
15876482078SeopXD     // Save
15976482078SeopXD     MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
16076482078SeopXD     Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
16176482078SeopXD     BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
16276482078SeopXD             SavedFRM)
16376482078SeopXD         .addImm(FRMImm);
16476482078SeopXD     MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
16576482078SeopXD                                             /*IsImp*/ true));
16676482078SeopXD     // Restore
16776482078SeopXD     MachineInstrBuilder MIB =
16876482078SeopXD         BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
16976482078SeopXD             .addReg(SavedFRM);
17076482078SeopXD     MBB.insertAfter(MI, MIB);
17176482078SeopXD   }
1727c836512SeopXD   return Changed;
1737c836512SeopXD }
1747c836512SeopXD 
runOnMachineFunction(MachineFunction & MF)1757c836512SeopXD bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
1767c836512SeopXD   // Skip if the vector extension is not enabled.
1777c836512SeopXD   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
1787c836512SeopXD   if (!ST.hasVInstructions())
1797c836512SeopXD     return false;
1807c836512SeopXD 
1817c836512SeopXD   TII = ST.getInstrInfo();
1827c836512SeopXD 
1837c836512SeopXD   bool Changed = false;
1847c836512SeopXD 
185df08350dSYeting Kuo   for (MachineBasicBlock &MBB : MF) {
186df08350dSYeting Kuo     if (DisableFRMInsertOpt)
18776482078SeopXD       Changed |= emitWriteRoundingMode(MBB);
188df08350dSYeting Kuo     else
189df08350dSYeting Kuo       Changed |= emitWriteRoundingModeOpt(MBB);
190df08350dSYeting Kuo   }
1917c836512SeopXD 
1927c836512SeopXD   return Changed;
1937c836512SeopXD }
1947c836512SeopXD 
createRISCVInsertReadWriteCSRPass()1957c836512SeopXD FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
1967c836512SeopXD   return new RISCVInsertReadWriteCSR();
1977c836512SeopXD }
198