xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
10b57cec5SDimitry Andric //===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
90b57cec5SDimitry Andric // This pass exports all llvm.bitset's found in the module in the form of a
100b57cec5SDimitry Andric // __cfi_check function, which can be used to verify cross-DSO call targets.
110b57cec5SDimitry Andric //
120b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
130b57cec5SDimitry Andric 
140b57cec5SDimitry Andric #include "llvm/Transforms/IPO/CrossDSOCFI.h"
150b57cec5SDimitry Andric #include "llvm/ADT/SetVector.h"
160b57cec5SDimitry Andric #include "llvm/ADT/Statistic.h"
170b57cec5SDimitry Andric #include "llvm/IR/Constants.h"
180b57cec5SDimitry Andric #include "llvm/IR/Function.h"
190b57cec5SDimitry Andric #include "llvm/IR/GlobalObject.h"
200b57cec5SDimitry Andric #include "llvm/IR/IRBuilder.h"
210b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
220b57cec5SDimitry Andric #include "llvm/IR/Intrinsics.h"
230b57cec5SDimitry Andric #include "llvm/IR/MDBuilder.h"
240b57cec5SDimitry Andric #include "llvm/IR/Module.h"
2506c3fb27SDimitry Andric #include "llvm/TargetParser/Triple.h"
260b57cec5SDimitry Andric #include "llvm/Transforms/IPO.h"
270b57cec5SDimitry Andric 
280b57cec5SDimitry Andric using namespace llvm;
290b57cec5SDimitry Andric 
300b57cec5SDimitry Andric #define DEBUG_TYPE "cross-dso-cfi"
310b57cec5SDimitry Andric 
320b57cec5SDimitry Andric STATISTIC(NumTypeIds, "Number of unique type identifiers");
330b57cec5SDimitry Andric 
340b57cec5SDimitry Andric namespace {
350b57cec5SDimitry Andric 
3606c3fb27SDimitry Andric struct CrossDSOCFI {
370b57cec5SDimitry Andric   MDNode *VeryLikelyWeights;
380b57cec5SDimitry Andric 
390b57cec5SDimitry Andric   ConstantInt *extractNumericTypeId(MDNode *MD);
400b57cec5SDimitry Andric   void buildCFICheck(Module &M);
4106c3fb27SDimitry Andric   bool runOnModule(Module &M);
420b57cec5SDimitry Andric };
430b57cec5SDimitry Andric 
440b57cec5SDimitry Andric } // anonymous namespace
450b57cec5SDimitry Andric 
460b57cec5SDimitry Andric /// Extracts a numeric type identifier from an MDNode containing type metadata.
470b57cec5SDimitry Andric ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) {
480b57cec5SDimitry Andric   // This check excludes vtables for classes inside anonymous namespaces.
490b57cec5SDimitry Andric   auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(1));
500b57cec5SDimitry Andric   if (!TM)
510b57cec5SDimitry Andric     return nullptr;
520b57cec5SDimitry Andric   auto C = dyn_cast_or_null<ConstantInt>(TM->getValue());
530b57cec5SDimitry Andric   if (!C) return nullptr;
540b57cec5SDimitry Andric   // We are looking for i64 constants.
550b57cec5SDimitry Andric   if (C->getBitWidth() != 64) return nullptr;
560b57cec5SDimitry Andric 
570b57cec5SDimitry Andric   return C;
580b57cec5SDimitry Andric }
590b57cec5SDimitry Andric 
600b57cec5SDimitry Andric /// buildCFICheck - emits __cfi_check for the current module.
610b57cec5SDimitry Andric void CrossDSOCFI::buildCFICheck(Module &M) {
620b57cec5SDimitry Andric   // FIXME: verify that __cfi_check ends up near the end of the code section,
630b57cec5SDimitry Andric   // but before the jump slots created in LowerTypeTests.
640b57cec5SDimitry Andric   SetVector<uint64_t> TypeIds;
650b57cec5SDimitry Andric   SmallVector<MDNode *, 2> Types;
660b57cec5SDimitry Andric   for (GlobalObject &GO : M.global_objects()) {
670b57cec5SDimitry Andric     Types.clear();
680b57cec5SDimitry Andric     GO.getMetadata(LLVMContext::MD_type, Types);
698bcb0991SDimitry Andric     for (MDNode *Type : Types)
700b57cec5SDimitry Andric       if (ConstantInt *TypeId = extractNumericTypeId(Type))
710b57cec5SDimitry Andric         TypeIds.insert(TypeId->getZExtValue());
720b57cec5SDimitry Andric   }
730b57cec5SDimitry Andric 
740b57cec5SDimitry Andric   NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions");
750b57cec5SDimitry Andric   if (CfiFunctionsMD) {
76bdd1243dSDimitry Andric     for (auto *Func : CfiFunctionsMD->operands()) {
770b57cec5SDimitry Andric       assert(Func->getNumOperands() >= 2);
780b57cec5SDimitry Andric       for (unsigned I = 2; I < Func->getNumOperands(); ++I)
790b57cec5SDimitry Andric         if (ConstantInt *TypeId =
800b57cec5SDimitry Andric                 extractNumericTypeId(cast<MDNode>(Func->getOperand(I).get())))
810b57cec5SDimitry Andric           TypeIds.insert(TypeId->getZExtValue());
820b57cec5SDimitry Andric     }
830b57cec5SDimitry Andric   }
840b57cec5SDimitry Andric 
850b57cec5SDimitry Andric   LLVMContext &Ctx = M.getContext();
860b57cec5SDimitry Andric   FunctionCallee C = M.getOrInsertFunction(
870b57cec5SDimitry Andric       "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx),
885f757f3fSDimitry Andric       PointerType::getUnqual(Ctx), PointerType::getUnqual(Ctx));
898bcb0991SDimitry Andric   Function *F = cast<Function>(C.getCallee());
900b57cec5SDimitry Andric   // Take over the existing function. The frontend emits a weak stub so that the
910b57cec5SDimitry Andric   // linker knows about the symbol; this pass replaces the function body.
920b57cec5SDimitry Andric   F->deleteBody();
938bcb0991SDimitry Andric   F->setAlignment(Align(4096));
940b57cec5SDimitry Andric 
950b57cec5SDimitry Andric   Triple T(M.getTargetTriple());
960b57cec5SDimitry Andric   if (T.isARM() || T.isThumb())
970b57cec5SDimitry Andric     F->addFnAttr("target-features", "+thumb-mode");
980b57cec5SDimitry Andric 
990b57cec5SDimitry Andric   auto args = F->arg_begin();
1000b57cec5SDimitry Andric   Value &CallSiteTypeId = *(args++);
1010b57cec5SDimitry Andric   CallSiteTypeId.setName("CallSiteTypeId");
1020b57cec5SDimitry Andric   Value &Addr = *(args++);
1030b57cec5SDimitry Andric   Addr.setName("Addr");
1040b57cec5SDimitry Andric   Value &CFICheckFailData = *(args++);
1050b57cec5SDimitry Andric   CFICheckFailData.setName("CFICheckFailData");
1060b57cec5SDimitry Andric   assert(args == F->arg_end());
1070b57cec5SDimitry Andric 
1080b57cec5SDimitry Andric   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
1090b57cec5SDimitry Andric   BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);
1100b57cec5SDimitry Andric 
1110b57cec5SDimitry Andric   BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F);
1120b57cec5SDimitry Andric   IRBuilder<> IRBFail(TrapBB);
1135f757f3fSDimitry Andric   FunctionCallee CFICheckFailFn = M.getOrInsertFunction(
1145f757f3fSDimitry Andric       "__cfi_check_fail", Type::getVoidTy(Ctx), PointerType::getUnqual(Ctx),
1155f757f3fSDimitry Andric       PointerType::getUnqual(Ctx));
1160b57cec5SDimitry Andric   IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr});
1170b57cec5SDimitry Andric   IRBFail.CreateBr(ExitBB);
1180b57cec5SDimitry Andric 
1190b57cec5SDimitry Andric   IRBuilder<> IRBExit(ExitBB);
1200b57cec5SDimitry Andric   IRBExit.CreateRetVoid();
1210b57cec5SDimitry Andric 
1220b57cec5SDimitry Andric   IRBuilder<> IRB(BB);
1230b57cec5SDimitry Andric   SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, TypeIds.size());
1240b57cec5SDimitry Andric   for (uint64_t TypeId : TypeIds) {
1250b57cec5SDimitry Andric     ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId);
1260b57cec5SDimitry Andric     BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F);
1270b57cec5SDimitry Andric     IRBuilder<> IRBTest(TestBB);
1280b57cec5SDimitry Andric     Function *BitsetTestFn = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
1290b57cec5SDimitry Andric 
1300b57cec5SDimitry Andric     Value *Test = IRBTest.CreateCall(
1310b57cec5SDimitry Andric         BitsetTestFn, {&Addr, MetadataAsValue::get(
1320b57cec5SDimitry Andric                                   Ctx, ConstantAsMetadata::get(CaseTypeId))});
1330b57cec5SDimitry Andric     BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB);
1340b57cec5SDimitry Andric     BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights);
1350b57cec5SDimitry Andric 
1360b57cec5SDimitry Andric     SI->addCase(CaseTypeId, TestBB);
1370b57cec5SDimitry Andric     ++NumTypeIds;
1380b57cec5SDimitry Andric   }
1390b57cec5SDimitry Andric }
1400b57cec5SDimitry Andric 
1410b57cec5SDimitry Andric bool CrossDSOCFI::runOnModule(Module &M) {
142*0fca6ea1SDimitry Andric   VeryLikelyWeights = MDBuilder(M.getContext()).createLikelyBranchWeights();
1430b57cec5SDimitry Andric   if (M.getModuleFlag("Cross-DSO CFI") == nullptr)
1440b57cec5SDimitry Andric     return false;
1450b57cec5SDimitry Andric   buildCFICheck(M);
1460b57cec5SDimitry Andric   return true;
1470b57cec5SDimitry Andric }
1480b57cec5SDimitry Andric 
1490b57cec5SDimitry Andric PreservedAnalyses CrossDSOCFIPass::run(Module &M, ModuleAnalysisManager &AM) {
1500b57cec5SDimitry Andric   CrossDSOCFI Impl;
1510b57cec5SDimitry Andric   bool Changed = Impl.runOnModule(M);
1520b57cec5SDimitry Andric   if (!Changed)
1530b57cec5SDimitry Andric     return PreservedAnalyses::all();
1540b57cec5SDimitry Andric   return PreservedAnalyses::none();
1550b57cec5SDimitry Andric }
156