17c836512SeopXD //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
27c836512SeopXD //
37c836512SeopXD // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47c836512SeopXD // See https://llvm.org/LICENSE.txt for license information.
57c836512SeopXD // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67c836512SeopXD //
77c836512SeopXD //===----------------------------------------------------------------------===//
87c836512SeopXD // This file implements the machine function pass to insert read/write of CSR-s
97c836512SeopXD // of the RISC-V instructions.
107c836512SeopXD //
1180ce7ce9SCraig Topper // Currently the pass implements:
1280ce7ce9SCraig Topper // -Writing and saving frm before an RVV floating-point instruction with a
1380ce7ce9SCraig Topper // static rounding mode and restores the value after.
147c836512SeopXD //
157c836512SeopXD //===----------------------------------------------------------------------===//
167c836512SeopXD
1776482078SeopXD #include "MCTargetDesc/RISCVBaseInfo.h"
187c836512SeopXD #include "RISCV.h"
197c836512SeopXD #include "RISCVSubtarget.h"
207c836512SeopXD #include "llvm/CodeGen/MachineFunctionPass.h"
217c836512SeopXD using namespace llvm;
227c836512SeopXD
237c836512SeopXD #define DEBUG_TYPE "riscv-insert-read-write-csr"
247c836512SeopXD #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
257c836512SeopXD
26df08350dSYeting Kuo static cl::opt<bool>
27df08350dSYeting Kuo DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false),
28df08350dSYeting Kuo cl::Hidden,
29df08350dSYeting Kuo cl::desc("Disable optimized frm insertion."));
30df08350dSYeting Kuo
317c836512SeopXD namespace {
327c836512SeopXD
337c836512SeopXD class RISCVInsertReadWriteCSR : public MachineFunctionPass {
347c836512SeopXD const TargetInstrInfo *TII;
357c836512SeopXD
367c836512SeopXD public:
377c836512SeopXD static char ID;
387c836512SeopXD
RISCVInsertReadWriteCSR()394162a9bcSCraig Topper RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {}
407c836512SeopXD
417c836512SeopXD bool runOnMachineFunction(MachineFunction &MF) override;
427c836512SeopXD
getAnalysisUsage(AnalysisUsage & AU) const437c836512SeopXD void getAnalysisUsage(AnalysisUsage &AU) const override {
447c836512SeopXD AU.setPreservesCFG();
457c836512SeopXD MachineFunctionPass::getAnalysisUsage(AU);
467c836512SeopXD }
477c836512SeopXD
getPassName() const487c836512SeopXD StringRef getPassName() const override {
497c836512SeopXD return RISCV_INSERT_READ_WRITE_CSR_NAME;
507c836512SeopXD }
517c836512SeopXD
527c836512SeopXD private:
5376482078SeopXD bool emitWriteRoundingMode(MachineBasicBlock &MBB);
54df08350dSYeting Kuo bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB);
557c836512SeopXD };
567c836512SeopXD
577c836512SeopXD } // end anonymous namespace
587c836512SeopXD
597c836512SeopXD char RISCVInsertReadWriteCSR::ID = 0;
607c836512SeopXD
INITIALIZE_PASS(RISCVInsertReadWriteCSR,DEBUG_TYPE,RISCV_INSERT_READ_WRITE_CSR_NAME,false,false)617c836512SeopXD INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
627c836512SeopXD RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
637c836512SeopXD
64df08350dSYeting Kuo // TODO: Use more accurate rounding mode at the start of MBB.
65df08350dSYeting Kuo bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) {
66df08350dSYeting Kuo bool Changed = false;
67df08350dSYeting Kuo MachineInstr *LastFRMChanger = nullptr;
68df08350dSYeting Kuo unsigned CurrentRM = RISCVFPRndMode::DYN;
69df08350dSYeting Kuo Register SavedFRM;
70df08350dSYeting Kuo
71df08350dSYeting Kuo for (MachineInstr &MI : MBB) {
72df08350dSYeting Kuo if (MI.getOpcode() == RISCV::SwapFRMImm ||
73df08350dSYeting Kuo MI.getOpcode() == RISCV::WriteFRMImm) {
74df08350dSYeting Kuo CurrentRM = MI.getOperand(0).getImm();
75df08350dSYeting Kuo SavedFRM = Register();
76df08350dSYeting Kuo continue;
77df08350dSYeting Kuo }
78df08350dSYeting Kuo
79df08350dSYeting Kuo if (MI.getOpcode() == RISCV::WriteFRM) {
80df08350dSYeting Kuo CurrentRM = RISCVFPRndMode::DYN;
81df08350dSYeting Kuo SavedFRM = Register();
82df08350dSYeting Kuo continue;
83df08350dSYeting Kuo }
84df08350dSYeting Kuo
85*f6d431f2SXu Zhang if (MI.isCall() || MI.isInlineAsm() ||
86*f6d431f2SXu Zhang MI.readsRegister(RISCV::FRM, /*TRI=*/nullptr)) {
87df08350dSYeting Kuo // Restore FRM before unknown operations.
88df08350dSYeting Kuo if (SavedFRM.isValid())
89df08350dSYeting Kuo BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM))
90df08350dSYeting Kuo .addReg(SavedFRM);
91df08350dSYeting Kuo CurrentRM = RISCVFPRndMode::DYN;
92df08350dSYeting Kuo SavedFRM = Register();
93df08350dSYeting Kuo continue;
94df08350dSYeting Kuo }
95df08350dSYeting Kuo
96*f6d431f2SXu Zhang assert(!MI.modifiesRegister(RISCV::FRM, /*TRI=*/nullptr) &&
97df08350dSYeting Kuo "Expected that MI could not modify FRM.");
98df08350dSYeting Kuo
99df08350dSYeting Kuo int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
100df08350dSYeting Kuo if (FRMIdx < 0)
101df08350dSYeting Kuo continue;
102df08350dSYeting Kuo unsigned InstrRM = MI.getOperand(FRMIdx).getImm();
103df08350dSYeting Kuo
104df08350dSYeting Kuo LastFRMChanger = &MI;
105df08350dSYeting Kuo
106df08350dSYeting Kuo // Make MI implicit use FRM.
107df08350dSYeting Kuo MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
108df08350dSYeting Kuo /*IsImp*/ true));
109df08350dSYeting Kuo Changed = true;
110df08350dSYeting Kuo
111df08350dSYeting Kuo // Skip if MI uses same rounding mode as FRM.
112df08350dSYeting Kuo if (InstrRM == CurrentRM)
113df08350dSYeting Kuo continue;
114df08350dSYeting Kuo
115df08350dSYeting Kuo if (!SavedFRM.isValid()) {
116df08350dSYeting Kuo // Save current FRM value to SavedFRM.
117df08350dSYeting Kuo MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
118df08350dSYeting Kuo SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
119df08350dSYeting Kuo BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM)
120df08350dSYeting Kuo .addImm(InstrRM);
121df08350dSYeting Kuo } else {
122df08350dSYeting Kuo // Don't need to save current FRM when SavedFRM having value.
123df08350dSYeting Kuo BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
124df08350dSYeting Kuo .addImm(InstrRM);
125df08350dSYeting Kuo }
126df08350dSYeting Kuo CurrentRM = InstrRM;
127df08350dSYeting Kuo }
128df08350dSYeting Kuo
129df08350dSYeting Kuo // Restore FRM if needed.
130df08350dSYeting Kuo if (SavedFRM.isValid()) {
131df08350dSYeting Kuo assert(LastFRMChanger && "Expected valid pointer.");
132df08350dSYeting Kuo MachineInstrBuilder MIB =
133df08350dSYeting Kuo BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
134df08350dSYeting Kuo .addReg(SavedFRM);
135df08350dSYeting Kuo MBB.insertAfter(LastFRMChanger, MIB);
136df08350dSYeting Kuo }
137df08350dSYeting Kuo
138df08350dSYeting Kuo return Changed;
139df08350dSYeting Kuo }
140df08350dSYeting Kuo
141014390d9SCraig Topper // This function also swaps frm and restores it when encountering an RVV
142014390d9SCraig Topper // floating point instruction with a static rounding mode.
emitWriteRoundingMode(MachineBasicBlock & MBB)14376482078SeopXD bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
1447c836512SeopXD bool Changed = false;
1457c836512SeopXD for (MachineInstr &MI : MBB) {
146b441fd60SCraig Topper int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
147b441fd60SCraig Topper if (FRMIdx < 0)
148b441fd60SCraig Topper continue;
149b441fd60SCraig Topper
150b441fd60SCraig Topper unsigned FRMImm = MI.getOperand(FRMIdx).getImm();
15176482078SeopXD
15276482078SeopXD // The value is a hint to this pass to not alter the frm value.
15376482078SeopXD if (FRMImm == RISCVFPRndMode::DYN)
15476482078SeopXD continue;
15576482078SeopXD
15676482078SeopXD Changed = true;
15776482078SeopXD
15876482078SeopXD // Save
15976482078SeopXD MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
16076482078SeopXD Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
16176482078SeopXD BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
16276482078SeopXD SavedFRM)
16376482078SeopXD .addImm(FRMImm);
16476482078SeopXD MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
16576482078SeopXD /*IsImp*/ true));
16676482078SeopXD // Restore
16776482078SeopXD MachineInstrBuilder MIB =
16876482078SeopXD BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
16976482078SeopXD .addReg(SavedFRM);
17076482078SeopXD MBB.insertAfter(MI, MIB);
17176482078SeopXD }
1727c836512SeopXD return Changed;
1737c836512SeopXD }
1747c836512SeopXD
runOnMachineFunction(MachineFunction & MF)1757c836512SeopXD bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
1767c836512SeopXD // Skip if the vector extension is not enabled.
1777c836512SeopXD const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
1787c836512SeopXD if (!ST.hasVInstructions())
1797c836512SeopXD return false;
1807c836512SeopXD
1817c836512SeopXD TII = ST.getInstrInfo();
1827c836512SeopXD
1837c836512SeopXD bool Changed = false;
1847c836512SeopXD
185df08350dSYeting Kuo for (MachineBasicBlock &MBB : MF) {
186df08350dSYeting Kuo if (DisableFRMInsertOpt)
18776482078SeopXD Changed |= emitWriteRoundingMode(MBB);
188df08350dSYeting Kuo else
189df08350dSYeting Kuo Changed |= emitWriteRoundingModeOpt(MBB);
190df08350dSYeting Kuo }
1917c836512SeopXD
1927c836512SeopXD return Changed;
1937c836512SeopXD }
1947c836512SeopXD
createRISCVInsertReadWriteCSRPass()1957c836512SeopXD FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
1967c836512SeopXD return new RISCVInsertReadWriteCSR();
1977c836512SeopXD }
198