xref: /llvm-project/llvm/lib/Target/BPF/BPFASpaceCastSimplifyPass.cpp (revision 8e702735090388a3231a863e343f880d0f96fecb)
12aacb56eS4ast //===-- BPFASpaceCastSimplifyPass.cpp - BPF addrspacecast simplications --===//
22aacb56eS4ast //
32aacb56eS4ast // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42aacb56eS4ast // See https://llvm.org/LICENSE.txt for license information.
52aacb56eS4ast // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62aacb56eS4ast //
72aacb56eS4ast //===----------------------------------------------------------------------===//
82aacb56eS4ast 
92aacb56eS4ast #include "BPF.h"
102aacb56eS4ast #include <optional>
112aacb56eS4ast 
122aacb56eS4ast #define DEBUG_TYPE "bpf-aspace-simplify"
132aacb56eS4ast 
142aacb56eS4ast using namespace llvm;
152aacb56eS4ast 
162aacb56eS4ast namespace {
172aacb56eS4ast 
182aacb56eS4ast struct CastGEPCast {
192aacb56eS4ast   AddrSpaceCastInst *OuterCast;
202aacb56eS4ast 
212aacb56eS4ast   // Match chain of instructions:
222aacb56eS4ast   //   %inner = addrspacecast N->M
232aacb56eS4ast   //   %gep   = getelementptr %inner, ...
242aacb56eS4ast   //   %outer = addrspacecast M->N %gep
252aacb56eS4ast   // Where I is %outer.
262aacb56eS4ast   static std::optional<CastGEPCast> match(Value *I) {
272aacb56eS4ast     auto *OuterCast = dyn_cast<AddrSpaceCastInst>(I);
282aacb56eS4ast     if (!OuterCast)
292aacb56eS4ast       return std::nullopt;
302aacb56eS4ast     auto *GEP = dyn_cast<GetElementPtrInst>(OuterCast->getPointerOperand());
312aacb56eS4ast     if (!GEP)
322aacb56eS4ast       return std::nullopt;
332aacb56eS4ast     auto *InnerCast = dyn_cast<AddrSpaceCastInst>(GEP->getPointerOperand());
342aacb56eS4ast     if (!InnerCast)
352aacb56eS4ast       return std::nullopt;
362aacb56eS4ast     if (InnerCast->getSrcAddressSpace() != OuterCast->getDestAddressSpace())
372aacb56eS4ast       return std::nullopt;
382aacb56eS4ast     if (InnerCast->getDestAddressSpace() != OuterCast->getSrcAddressSpace())
392aacb56eS4ast       return std::nullopt;
402aacb56eS4ast     return CastGEPCast{OuterCast};
412aacb56eS4ast   }
422aacb56eS4ast 
432aacb56eS4ast   static PointerType *changeAddressSpace(PointerType *Ty, unsigned AS) {
442aacb56eS4ast     return Ty->get(Ty->getContext(), AS);
452aacb56eS4ast   }
462aacb56eS4ast 
472aacb56eS4ast   // Assuming match(this->OuterCast) is true, convert:
482aacb56eS4ast   //   (addrspacecast M->N (getelementptr (addrspacecast N->M ptr) ...))
492aacb56eS4ast   // To:
502aacb56eS4ast   //   (getelementptr ptr ...)
512aacb56eS4ast   GetElementPtrInst *rewrite() {
522aacb56eS4ast     auto *GEP = cast<GetElementPtrInst>(OuterCast->getPointerOperand());
532aacb56eS4ast     auto *InnerCast = cast<AddrSpaceCastInst>(GEP->getPointerOperand());
542aacb56eS4ast     unsigned AS = OuterCast->getDestAddressSpace();
552aacb56eS4ast     auto *NewGEP = cast<GetElementPtrInst>(GEP->clone());
562aacb56eS4ast     NewGEP->setName(GEP->getName());
57*8e702735SJeremy Morse     NewGEP->insertAfter(OuterCast->getIterator());
582aacb56eS4ast     NewGEP->setOperand(0, InnerCast->getPointerOperand());
592aacb56eS4ast     auto *GEPTy = cast<PointerType>(GEP->getType());
602aacb56eS4ast     NewGEP->mutateType(changeAddressSpace(GEPTy, AS));
612aacb56eS4ast     OuterCast->replaceAllUsesWith(NewGEP);
622aacb56eS4ast     OuterCast->eraseFromParent();
632aacb56eS4ast     if (GEP->use_empty())
642aacb56eS4ast       GEP->eraseFromParent();
652aacb56eS4ast     if (InnerCast->use_empty())
662aacb56eS4ast       InnerCast->eraseFromParent();
672aacb56eS4ast     return NewGEP;
682aacb56eS4ast   }
692aacb56eS4ast };
702aacb56eS4ast 
712aacb56eS4ast } // anonymous namespace
722aacb56eS4ast 
732aacb56eS4ast PreservedAnalyses BPFASpaceCastSimplifyPass::run(Function &F,
742aacb56eS4ast                                                  FunctionAnalysisManager &AM) {
752aacb56eS4ast   SmallVector<CastGEPCast, 16> WorkList;
762aacb56eS4ast   bool Changed = false;
772aacb56eS4ast   for (BasicBlock &BB : F) {
782aacb56eS4ast     for (Instruction &I : BB)
792aacb56eS4ast       if (auto It = CastGEPCast::match(&I))
802aacb56eS4ast         WorkList.push_back(It.value());
812aacb56eS4ast     Changed |= !WorkList.empty();
822aacb56eS4ast 
832aacb56eS4ast     while (!WorkList.empty()) {
842aacb56eS4ast       CastGEPCast InsnChain = WorkList.pop_back_val();
852aacb56eS4ast       GetElementPtrInst *NewGEP = InsnChain.rewrite();
862aacb56eS4ast       for (User *U : NewGEP->users())
872aacb56eS4ast         if (auto It = CastGEPCast::match(U))
882aacb56eS4ast           WorkList.push_back(It.value());
892aacb56eS4ast     }
902aacb56eS4ast   }
912aacb56eS4ast   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
922aacb56eS4ast }
93