1 //==-- X86LoadValueInjectionLoadHardening.cpp - LVI load hardening for x86 --=// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// Description: This pass finds Load Value Injection (LVI) gadgets consisting 10 /// of a load from memory (i.e., SOURCE), and any operation that may transmit 11 /// the value loaded from memory over a covert channel, or use the value loaded 12 /// from memory to determine a branch/call target (i.e., SINK). After finding 13 /// all such gadgets in a given function, the pass minimally inserts LFENCE 14 /// instructions in such a manner that the following property is satisfied: for 15 /// all SOURCE+SINK pairs, all paths in the CFG from SOURCE to SINK contain at 16 /// least one LFENCE instruction. The algorithm that implements this minimal 17 /// insertion is influenced by an academic paper that minimally inserts memory 18 /// fences for high-performance concurrent programs: 19 /// http://www.cs.ucr.edu/~lesani/companion/oopsla15/OOPSLA15.pdf 20 /// The algorithm implemented in this pass is as follows: 21 /// 1. Build a condensed CFG (i.e., a GadgetGraph) consisting only of the 22 /// following components: 23 /// - SOURCE instructions (also includes function arguments) 24 /// - SINK instructions 25 /// - Basic block entry points 26 /// - Basic block terminators 27 /// - LFENCE instructions 28 /// 2. Analyze the GadgetGraph to determine which SOURCE+SINK pairs (i.e., 29 /// gadgets) are already mitigated by existing LFENCEs. If all gadgets have been 30 /// mitigated, go to step 6. 31 /// 3. Use a heuristic or plugin to approximate minimal LFENCE insertion. 32 /// 4. Insert one LFENCE along each CFG edge that was cut in step 3. 33 /// 5. Go to step 2. 34 /// 6. If any LFENCEs were inserted, return `true` from runOnMachineFunction() 35 /// to tell LLVM that the function was modified. 36 /// 37 //===----------------------------------------------------------------------===// 38 39 #include "ImmutableGraph.h" 40 #include "X86.h" 41 #include "X86Subtarget.h" 42 #include "X86TargetMachine.h" 43 #include "llvm/ADT/DenseMap.h" 44 #include "llvm/ADT/STLExtras.h" 45 #include "llvm/ADT/SmallSet.h" 46 #include "llvm/ADT/Statistic.h" 47 #include "llvm/ADT/StringRef.h" 48 #include "llvm/CodeGen/MachineBasicBlock.h" 49 #include "llvm/CodeGen/MachineDominanceFrontier.h" 50 #include "llvm/CodeGen/MachineDominators.h" 51 #include "llvm/CodeGen/MachineFunction.h" 52 #include "llvm/CodeGen/MachineFunctionPass.h" 53 #include "llvm/CodeGen/MachineInstr.h" 54 #include "llvm/CodeGen/MachineInstrBuilder.h" 55 #include "llvm/CodeGen/MachineLoopInfo.h" 56 #include "llvm/CodeGen/RDFGraph.h" 57 #include "llvm/CodeGen/RDFLiveness.h" 58 #include "llvm/InitializePasses.h" 59 #include "llvm/Support/CommandLine.h" 60 #include "llvm/Support/DOTGraphTraits.h" 61 #include "llvm/Support/Debug.h" 62 #include "llvm/Support/DynamicLibrary.h" 63 #include "llvm/Support/GraphWriter.h" 64 #include "llvm/Support/raw_ostream.h" 65 66 using namespace llvm; 67 68 #define PASS_KEY "x86-lvi-load" 69 #define DEBUG_TYPE PASS_KEY 70 71 STATISTIC(NumFences, "Number of LFENCEs inserted for LVI mitigation"); 72 STATISTIC(NumFunctionsConsidered, "Number of functions analyzed"); 73 STATISTIC(NumFunctionsMitigated, "Number of functions for which mitigations " 74 "were deployed"); 75 STATISTIC(NumGadgets, "Number of LVI gadgets detected during analysis"); 76 77 static cl::opt<std::string> OptimizePluginPath( 78 PASS_KEY "-opt-plugin", 79 cl::desc("Specify a plugin to optimize LFENCE insertion"), cl::Hidden); 80 81 static cl::opt<bool> NoConditionalBranches( 82 PASS_KEY "-no-cbranch", 83 cl::desc("Don't treat conditional branches as disclosure gadgets. This " 84 "may improve performance, at the cost of security."), 85 cl::init(false), cl::Hidden); 86 87 static cl::opt<bool> EmitDot( 88 PASS_KEY "-dot", 89 cl::desc( 90 "For each function, emit a dot graph depicting potential LVI gadgets"), 91 cl::init(false), cl::Hidden); 92 93 static cl::opt<bool> EmitDotOnly( 94 PASS_KEY "-dot-only", 95 cl::desc("For each function, emit a dot graph depicting potential LVI " 96 "gadgets, and do not insert any fences"), 97 cl::init(false), cl::Hidden); 98 99 static cl::opt<bool> EmitDotVerify( 100 PASS_KEY "-dot-verify", 101 cl::desc("For each function, emit a dot graph to stdout depicting " 102 "potential LVI gadgets, used for testing purposes only"), 103 cl::init(false), cl::Hidden); 104 105 static llvm::sys::DynamicLibrary OptimizeDL; 106 typedef int (*OptimizeCutT)(unsigned int *Nodes, unsigned int NodesSize, 107 unsigned int *Edges, int *EdgeValues, 108 int *CutEdges /* out */, unsigned int EdgesSize); 109 static OptimizeCutT OptimizeCut = nullptr; 110 111 namespace { 112 113 struct MachineGadgetGraph : ImmutableGraph<MachineInstr *, int> { 114 static constexpr int GadgetEdgeSentinel = -1; 115 static constexpr MachineInstr *const ArgNodeSentinel = nullptr; 116 117 using GraphT = ImmutableGraph<MachineInstr *, int>; 118 using Node = typename GraphT::Node; 119 using Edge = typename GraphT::Edge; 120 using size_type = typename GraphT::size_type; 121 MachineGadgetGraph(std::unique_ptr<Node[]> Nodes, 122 std::unique_ptr<Edge[]> Edges, size_type NodesSize, 123 size_type EdgesSize, int NumFences = 0, int NumGadgets = 0) 124 : GraphT(std::move(Nodes), std::move(Edges), NodesSize, EdgesSize), 125 NumFences(NumFences), NumGadgets(NumGadgets) {} 126 static inline bool isCFGEdge(const Edge &E) { 127 return E.getValue() != GadgetEdgeSentinel; 128 } 129 static inline bool isGadgetEdge(const Edge &E) { 130 return E.getValue() == GadgetEdgeSentinel; 131 } 132 int NumFences; 133 int NumGadgets; 134 }; 135 136 class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass { 137 public: 138 X86LoadValueInjectionLoadHardeningPass() : MachineFunctionPass(ID) {} 139 140 StringRef getPassName() const override { 141 return "X86 Load Value Injection (LVI) Load Hardening"; 142 } 143 void getAnalysisUsage(AnalysisUsage &AU) const override; 144 bool runOnMachineFunction(MachineFunction &MF) override; 145 146 static char ID; 147 148 private: 149 using GraphBuilder = ImmutableGraphBuilder<MachineGadgetGraph>; 150 using Edge = MachineGadgetGraph::Edge; 151 using Node = MachineGadgetGraph::Node; 152 using EdgeSet = MachineGadgetGraph::EdgeSet; 153 using NodeSet = MachineGadgetGraph::NodeSet; 154 155 const X86Subtarget *STI = nullptr; 156 const TargetInstrInfo *TII = nullptr; 157 const TargetRegisterInfo *TRI = nullptr; 158 159 std::unique_ptr<MachineGadgetGraph> 160 getGadgetGraph(MachineFunction &MF, const MachineLoopInfo &MLI, 161 const MachineDominatorTree &MDT, 162 const MachineDominanceFrontier &MDF) const; 163 int hardenLoadsWithPlugin(MachineFunction &MF, 164 std::unique_ptr<MachineGadgetGraph> Graph) const; 165 int hardenLoadsWithHeuristic(MachineFunction &MF, 166 std::unique_ptr<MachineGadgetGraph> Graph) const; 167 int elimMitigatedEdgesAndNodes(MachineGadgetGraph &G, 168 EdgeSet &ElimEdges /* in, out */, 169 NodeSet &ElimNodes /* in, out */) const; 170 std::unique_ptr<MachineGadgetGraph> 171 trimMitigatedEdges(std::unique_ptr<MachineGadgetGraph> Graph) const; 172 int insertFences(MachineFunction &MF, MachineGadgetGraph &G, 173 EdgeSet &CutEdges /* in, out */) const; 174 bool instrUsesRegToAccessMemory(const MachineInstr &I, unsigned Reg) const; 175 bool instrUsesRegToBranch(const MachineInstr &I, unsigned Reg) const; 176 inline bool isFence(const MachineInstr *MI) const { 177 return MI && (MI->getOpcode() == X86::LFENCE || 178 (STI->useLVIControlFlowIntegrity() && MI->isCall())); 179 } 180 }; 181 182 } // end anonymous namespace 183 184 namespace llvm { 185 186 template <> 187 struct GraphTraits<MachineGadgetGraph *> 188 : GraphTraits<ImmutableGraph<MachineInstr *, int> *> {}; 189 190 template <> 191 struct DOTGraphTraits<MachineGadgetGraph *> : DefaultDOTGraphTraits { 192 using GraphType = MachineGadgetGraph; 193 using Traits = llvm::GraphTraits<GraphType *>; 194 using NodeRef = typename Traits::NodeRef; 195 using EdgeRef = typename Traits::EdgeRef; 196 using ChildIteratorType = typename Traits::ChildIteratorType; 197 using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType; 198 199 DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} 200 201 std::string getNodeLabel(NodeRef Node, GraphType *) { 202 if (Node->getValue() == MachineGadgetGraph::ArgNodeSentinel) 203 return "ARGS"; 204 205 std::string Str; 206 raw_string_ostream OS(Str); 207 OS << *Node->getValue(); 208 return OS.str(); 209 } 210 211 static std::string getNodeAttributes(NodeRef Node, GraphType *) { 212 MachineInstr *MI = Node->getValue(); 213 if (MI == MachineGadgetGraph::ArgNodeSentinel) 214 return "color = blue"; 215 if (MI->getOpcode() == X86::LFENCE) 216 return "color = green"; 217 return ""; 218 } 219 220 static std::string getEdgeAttributes(NodeRef, ChildIteratorType E, 221 GraphType *) { 222 int EdgeVal = (*E.getCurrent()).getValue(); 223 return EdgeVal >= 0 ? "label = " + std::to_string(EdgeVal) 224 : "color = red, style = \"dashed\""; 225 } 226 }; 227 228 } // end namespace llvm 229 230 constexpr MachineInstr *MachineGadgetGraph::ArgNodeSentinel; 231 constexpr int MachineGadgetGraph::GadgetEdgeSentinel; 232 233 char X86LoadValueInjectionLoadHardeningPass::ID = 0; 234 235 void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage( 236 AnalysisUsage &AU) const { 237 MachineFunctionPass::getAnalysisUsage(AU); 238 AU.addRequired<MachineLoopInfoWrapperPass>(); 239 AU.addRequired<MachineDominatorTreeWrapperPass>(); 240 AU.addRequired<MachineDominanceFrontier>(); 241 AU.setPreservesCFG(); 242 } 243 244 static void writeGadgetGraph(raw_ostream &OS, MachineFunction &MF, 245 MachineGadgetGraph *G) { 246 WriteGraph(OS, G, /*ShortNames*/ false, 247 "Speculative gadgets for \"" + MF.getName() + "\" function"); 248 } 249 250 bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction( 251 MachineFunction &MF) { 252 LLVM_DEBUG(dbgs() << "***** " << getPassName() << " : " << MF.getName() 253 << " *****\n"); 254 STI = &MF.getSubtarget<X86Subtarget>(); 255 if (!STI->useLVILoadHardening()) 256 return false; 257 258 // FIXME: support 32-bit 259 if (!STI->is64Bit()) 260 report_fatal_error("LVI load hardening is only supported on 64-bit", false); 261 262 // Don't skip functions with the "optnone" attr but participate in opt-bisect. 263 const Function &F = MF.getFunction(); 264 if (!F.hasOptNone() && skipFunction(F)) 265 return false; 266 267 ++NumFunctionsConsidered; 268 TII = STI->getInstrInfo(); 269 TRI = STI->getRegisterInfo(); 270 LLVM_DEBUG(dbgs() << "Building gadget graph...\n"); 271 const auto &MLI = getAnalysis<MachineLoopInfoWrapperPass>().getLI(); 272 const auto &MDT = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); 273 const auto &MDF = getAnalysis<MachineDominanceFrontier>(); 274 std::unique_ptr<MachineGadgetGraph> Graph = getGadgetGraph(MF, MLI, MDT, MDF); 275 LLVM_DEBUG(dbgs() << "Building gadget graph... Done\n"); 276 if (Graph == nullptr) 277 return false; // didn't find any gadgets 278 279 if (EmitDotVerify) { 280 writeGadgetGraph(outs(), MF, Graph.get()); 281 return false; 282 } 283 284 if (EmitDot || EmitDotOnly) { 285 LLVM_DEBUG(dbgs() << "Emitting gadget graph...\n"); 286 std::error_code FileError; 287 std::string FileName = "lvi."; 288 FileName += MF.getName(); 289 FileName += ".dot"; 290 raw_fd_ostream FileOut(FileName, FileError); 291 if (FileError) 292 errs() << FileError.message(); 293 writeGadgetGraph(FileOut, MF, Graph.get()); 294 FileOut.close(); 295 LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n"); 296 if (EmitDotOnly) 297 return false; 298 } 299 300 int FencesInserted; 301 if (!OptimizePluginPath.empty()) { 302 if (!OptimizeDL.isValid()) { 303 std::string ErrorMsg; 304 OptimizeDL = llvm::sys::DynamicLibrary::getPermanentLibrary( 305 OptimizePluginPath.c_str(), &ErrorMsg); 306 if (!ErrorMsg.empty()) 307 report_fatal_error(Twine("Failed to load opt plugin: \"") + ErrorMsg + 308 "\""); 309 OptimizeCut = (OptimizeCutT)OptimizeDL.getAddressOfSymbol("optimize_cut"); 310 if (!OptimizeCut) 311 report_fatal_error("Invalid optimization plugin"); 312 } 313 FencesInserted = hardenLoadsWithPlugin(MF, std::move(Graph)); 314 } else { // Use the default greedy heuristic 315 FencesInserted = hardenLoadsWithHeuristic(MF, std::move(Graph)); 316 } 317 318 if (FencesInserted > 0) 319 ++NumFunctionsMitigated; 320 NumFences += FencesInserted; 321 return (FencesInserted > 0); 322 } 323 324 std::unique_ptr<MachineGadgetGraph> 325 X86LoadValueInjectionLoadHardeningPass::getGadgetGraph( 326 MachineFunction &MF, const MachineLoopInfo &MLI, 327 const MachineDominatorTree &MDT, 328 const MachineDominanceFrontier &MDF) const { 329 using namespace rdf; 330 331 // Build the Register Dataflow Graph using the RDF framework 332 DataFlowGraph DFG{MF, *TII, *TRI, MDT, MDF}; 333 DFG.build(); 334 Liveness L{MF.getRegInfo(), DFG}; 335 L.computePhiInfo(); 336 337 GraphBuilder Builder; 338 using GraphIter = typename GraphBuilder::BuilderNodeRef; 339 DenseMap<MachineInstr *, GraphIter> NodeMap; 340 int FenceCount = 0, GadgetCount = 0; 341 auto MaybeAddNode = [&NodeMap, &Builder](MachineInstr *MI) { 342 auto Ref = NodeMap.find(MI); 343 if (Ref == NodeMap.end()) { 344 auto I = Builder.addVertex(MI); 345 NodeMap[MI] = I; 346 return std::pair<GraphIter, bool>{I, true}; 347 } 348 return std::pair<GraphIter, bool>{Ref->getSecond(), false}; 349 }; 350 351 // The `Transmitters` map memoizes transmitters found for each def. If a def 352 // has not yet been analyzed, then it will not appear in the map. If a def 353 // has been analyzed and was determined not to have any transmitters, then 354 // its list of transmitters will be empty. 355 DenseMap<NodeId, std::vector<NodeId>> Transmitters; 356 357 // Analyze all machine instructions to find gadgets and LFENCEs, adding 358 // each interesting value to `Nodes` 359 auto AnalyzeDef = [&](NodeAddr<DefNode *> SourceDef) { 360 SmallSet<NodeId, 8> UsesVisited, DefsVisited; 361 std::function<void(NodeAddr<DefNode *>)> AnalyzeDefUseChain = 362 [&](NodeAddr<DefNode *> Def) { 363 if (Transmitters.contains(Def.Id)) 364 return; // Already analyzed `Def` 365 366 // Use RDF to find all the uses of `Def` 367 rdf::NodeSet Uses; 368 RegisterRef DefReg = Def.Addr->getRegRef(DFG); 369 for (auto UseID : L.getAllReachedUses(DefReg, Def)) { 370 auto Use = DFG.addr<UseNode *>(UseID); 371 if (Use.Addr->getFlags() & NodeAttrs::PhiRef) { // phi node 372 NodeAddr<PhiNode *> Phi = Use.Addr->getOwner(DFG); 373 for (const auto& I : L.getRealUses(Phi.Id)) { 374 if (DFG.getPRI().alias(RegisterRef(I.first), DefReg)) { 375 for (const auto &UA : I.second) 376 Uses.emplace(UA.first); 377 } 378 } 379 } else { // not a phi node 380 Uses.emplace(UseID); 381 } 382 } 383 384 // For each use of `Def`, we want to know whether: 385 // (1) The use can leak the Def'ed value, 386 // (2) The use can further propagate the Def'ed value to more defs 387 for (auto UseID : Uses) { 388 if (!UsesVisited.insert(UseID).second) 389 continue; // Already visited this use of `Def` 390 391 auto Use = DFG.addr<UseNode *>(UseID); 392 assert(!(Use.Addr->getFlags() & NodeAttrs::PhiRef)); 393 MachineOperand &UseMO = Use.Addr->getOp(); 394 MachineInstr &UseMI = *UseMO.getParent(); 395 assert(UseMO.isReg()); 396 397 // We naively assume that an instruction propagates any loaded 398 // uses to all defs unless the instruction is a call, in which 399 // case all arguments will be treated as gadget sources during 400 // analysis of the callee function. 401 if (UseMI.isCall()) 402 continue; 403 404 // Check whether this use can transmit (leak) its value. 405 if (instrUsesRegToAccessMemory(UseMI, UseMO.getReg()) || 406 (!NoConditionalBranches && 407 instrUsesRegToBranch(UseMI, UseMO.getReg()))) { 408 Transmitters[Def.Id].push_back(Use.Addr->getOwner(DFG).Id); 409 if (UseMI.mayLoad()) 410 continue; // Found a transmitting load -- no need to continue 411 // traversing its defs (i.e., this load will become 412 // a new gadget source anyways). 413 } 414 415 // Check whether the use propagates to more defs. 416 NodeAddr<InstrNode *> Owner{Use.Addr->getOwner(DFG)}; 417 rdf::NodeList AnalyzedChildDefs; 418 for (const auto &ChildDef : 419 Owner.Addr->members_if(DataFlowGraph::IsDef, DFG)) { 420 if (!DefsVisited.insert(ChildDef.Id).second) 421 continue; // Already visited this def 422 if (Def.Addr->getAttrs() & NodeAttrs::Dead) 423 continue; 424 if (Def.Id == ChildDef.Id) 425 continue; // `Def` uses itself (e.g., increment loop counter) 426 427 AnalyzeDefUseChain(ChildDef); 428 429 // `Def` inherits all of its child defs' transmitters. 430 for (auto TransmitterId : Transmitters[ChildDef.Id]) 431 Transmitters[Def.Id].push_back(TransmitterId); 432 } 433 } 434 435 // Note that this statement adds `Def.Id` to the map if no 436 // transmitters were found for `Def`. 437 auto &DefTransmitters = Transmitters[Def.Id]; 438 439 // Remove duplicate transmitters 440 llvm::sort(DefTransmitters); 441 DefTransmitters.erase(llvm::unique(DefTransmitters), 442 DefTransmitters.end()); 443 }; 444 445 // Find all of the transmitters 446 AnalyzeDefUseChain(SourceDef); 447 auto &SourceDefTransmitters = Transmitters[SourceDef.Id]; 448 if (SourceDefTransmitters.empty()) 449 return; // No transmitters for `SourceDef` 450 451 MachineInstr *Source = SourceDef.Addr->getFlags() & NodeAttrs::PhiRef 452 ? MachineGadgetGraph::ArgNodeSentinel 453 : SourceDef.Addr->getOp().getParent(); 454 auto GadgetSource = MaybeAddNode(Source); 455 // Each transmitter is a sink for `SourceDef`. 456 for (auto TransmitterId : SourceDefTransmitters) { 457 MachineInstr *Sink = DFG.addr<StmtNode *>(TransmitterId).Addr->getCode(); 458 auto GadgetSink = MaybeAddNode(Sink); 459 // Add the gadget edge to the graph. 460 Builder.addEdge(MachineGadgetGraph::GadgetEdgeSentinel, 461 GadgetSource.first, GadgetSink.first); 462 ++GadgetCount; 463 } 464 }; 465 466 LLVM_DEBUG(dbgs() << "Analyzing def-use chains to find gadgets\n"); 467 // Analyze function arguments 468 NodeAddr<BlockNode *> EntryBlock = DFG.getFunc().Addr->getEntryBlock(DFG); 469 for (NodeAddr<PhiNode *> ArgPhi : 470 EntryBlock.Addr->members_if(DataFlowGraph::IsPhi, DFG)) { 471 NodeList Defs = ArgPhi.Addr->members_if(DataFlowGraph::IsDef, DFG); 472 llvm::for_each(Defs, AnalyzeDef); 473 } 474 // Analyze every instruction in MF 475 for (NodeAddr<BlockNode *> BA : DFG.getFunc().Addr->members(DFG)) { 476 for (NodeAddr<StmtNode *> SA : 477 BA.Addr->members_if(DataFlowGraph::IsCode<NodeAttrs::Stmt>, DFG)) { 478 MachineInstr *MI = SA.Addr->getCode(); 479 if (isFence(MI)) { 480 MaybeAddNode(MI); 481 ++FenceCount; 482 } else if (MI->mayLoad()) { 483 NodeList Defs = SA.Addr->members_if(DataFlowGraph::IsDef, DFG); 484 llvm::for_each(Defs, AnalyzeDef); 485 } 486 } 487 } 488 LLVM_DEBUG(dbgs() << "Found " << FenceCount << " fences\n"); 489 LLVM_DEBUG(dbgs() << "Found " << GadgetCount << " gadgets\n"); 490 if (GadgetCount == 0) 491 return nullptr; 492 NumGadgets += GadgetCount; 493 494 // Traverse CFG to build the rest of the graph 495 SmallSet<MachineBasicBlock *, 8> BlocksVisited; 496 std::function<void(MachineBasicBlock *, GraphIter, unsigned)> TraverseCFG = 497 [&](MachineBasicBlock *MBB, GraphIter GI, unsigned ParentDepth) { 498 unsigned LoopDepth = MLI.getLoopDepth(MBB); 499 if (!MBB->empty()) { 500 // Always add the first instruction in each block 501 auto NI = MBB->begin(); 502 auto BeginBB = MaybeAddNode(&*NI); 503 Builder.addEdge(ParentDepth, GI, BeginBB.first); 504 if (!BlocksVisited.insert(MBB).second) 505 return; 506 507 // Add any instructions within the block that are gadget components 508 GI = BeginBB.first; 509 while (++NI != MBB->end()) { 510 auto Ref = NodeMap.find(&*NI); 511 if (Ref != NodeMap.end()) { 512 Builder.addEdge(LoopDepth, GI, Ref->getSecond()); 513 GI = Ref->getSecond(); 514 } 515 } 516 517 // Always add the terminator instruction, if one exists 518 auto T = MBB->getFirstTerminator(); 519 if (T != MBB->end()) { 520 auto EndBB = MaybeAddNode(&*T); 521 if (EndBB.second) 522 Builder.addEdge(LoopDepth, GI, EndBB.first); 523 GI = EndBB.first; 524 } 525 } 526 for (MachineBasicBlock *Succ : MBB->successors()) 527 TraverseCFG(Succ, GI, LoopDepth); 528 }; 529 // ArgNodeSentinel is a pseudo-instruction that represents MF args in the 530 // GadgetGraph 531 GraphIter ArgNode = MaybeAddNode(MachineGadgetGraph::ArgNodeSentinel).first; 532 TraverseCFG(&MF.front(), ArgNode, 0); 533 std::unique_ptr<MachineGadgetGraph> G{Builder.get(FenceCount, GadgetCount)}; 534 LLVM_DEBUG(dbgs() << "Found " << G->nodes_size() << " nodes\n"); 535 return G; 536 } 537 538 // Returns the number of remaining gadget edges that could not be eliminated 539 int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes( 540 MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */, 541 NodeSet &ElimNodes /* in, out */) const { 542 if (G.NumFences > 0) { 543 // Eliminate fences and CFG edges that ingress and egress the fence, as 544 // they are trivially mitigated. 545 for (const Edge &E : G.edges()) { 546 const Node *Dest = E.getDest(); 547 if (isFence(Dest->getValue())) { 548 ElimNodes.insert(*Dest); 549 ElimEdges.insert(E); 550 for (const Edge &DE : Dest->edges()) 551 ElimEdges.insert(DE); 552 } 553 } 554 } 555 556 // Find and eliminate gadget edges that have been mitigated. 557 int RemainingGadgets = 0; 558 NodeSet ReachableNodes{G}; 559 for (const Node &RootN : G.nodes()) { 560 if (llvm::none_of(RootN.edges(), MachineGadgetGraph::isGadgetEdge)) 561 continue; // skip this node if it isn't a gadget source 562 563 // Find all of the nodes that are CFG-reachable from RootN using DFS 564 ReachableNodes.clear(); 565 std::function<void(const Node *, bool)> FindReachableNodes = 566 [&](const Node *N, bool FirstNode) { 567 if (!FirstNode) 568 ReachableNodes.insert(*N); 569 for (const Edge &E : N->edges()) { 570 const Node *Dest = E.getDest(); 571 if (MachineGadgetGraph::isCFGEdge(E) && !ElimEdges.contains(E) && 572 !ReachableNodes.contains(*Dest)) 573 FindReachableNodes(Dest, false); 574 } 575 }; 576 FindReachableNodes(&RootN, true); 577 578 // Any gadget whose sink is unreachable has been mitigated 579 for (const Edge &E : RootN.edges()) { 580 if (MachineGadgetGraph::isGadgetEdge(E)) { 581 if (ReachableNodes.contains(*E.getDest())) { 582 // This gadget's sink is reachable 583 ++RemainingGadgets; 584 } else { // This gadget's sink is unreachable, and therefore mitigated 585 ElimEdges.insert(E); 586 } 587 } 588 } 589 } 590 return RemainingGadgets; 591 } 592 593 std::unique_ptr<MachineGadgetGraph> 594 X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges( 595 std::unique_ptr<MachineGadgetGraph> Graph) const { 596 NodeSet ElimNodes{*Graph}; 597 EdgeSet ElimEdges{*Graph}; 598 int RemainingGadgets = 599 elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes); 600 if (ElimEdges.empty() && ElimNodes.empty()) { 601 Graph->NumFences = 0; 602 Graph->NumGadgets = RemainingGadgets; 603 } else { 604 Graph = GraphBuilder::trim(*Graph, ElimNodes, ElimEdges, 0 /* NumFences */, 605 RemainingGadgets); 606 } 607 return Graph; 608 } 609 610 int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin( 611 MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const { 612 int FencesInserted = 0; 613 614 do { 615 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); 616 Graph = trimMitigatedEdges(std::move(Graph)); 617 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); 618 if (Graph->NumGadgets == 0) 619 break; 620 621 LLVM_DEBUG(dbgs() << "Cutting edges...\n"); 622 EdgeSet CutEdges{*Graph}; 623 auto Nodes = std::make_unique<unsigned int[]>(Graph->nodes_size() + 624 1 /* terminator node */); 625 auto Edges = std::make_unique<unsigned int[]>(Graph->edges_size()); 626 auto EdgeCuts = std::make_unique<int[]>(Graph->edges_size()); 627 auto EdgeValues = std::make_unique<int[]>(Graph->edges_size()); 628 for (const Node &N : Graph->nodes()) { 629 Nodes[Graph->getNodeIndex(N)] = Graph->getEdgeIndex(*N.edges_begin()); 630 } 631 Nodes[Graph->nodes_size()] = Graph->edges_size(); // terminator node 632 for (const Edge &E : Graph->edges()) { 633 Edges[Graph->getEdgeIndex(E)] = Graph->getNodeIndex(*E.getDest()); 634 EdgeValues[Graph->getEdgeIndex(E)] = E.getValue(); 635 } 636 OptimizeCut(Nodes.get(), Graph->nodes_size(), Edges.get(), EdgeValues.get(), 637 EdgeCuts.get(), Graph->edges_size()); 638 for (int I = 0; I < Graph->edges_size(); ++I) 639 if (EdgeCuts[I]) 640 CutEdges.set(I); 641 LLVM_DEBUG(dbgs() << "Cutting edges... Done\n"); 642 LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n"); 643 644 LLVM_DEBUG(dbgs() << "Inserting LFENCEs...\n"); 645 FencesInserted += insertFences(MF, *Graph, CutEdges); 646 LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n"); 647 LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n"); 648 649 Graph = GraphBuilder::trim(*Graph, NodeSet{*Graph}, CutEdges); 650 } while (true); 651 652 return FencesInserted; 653 } 654 655 int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic( 656 MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const { 657 // If `MF` does not have any fences, then no gadgets would have been 658 // mitigated at this point. 659 if (Graph->NumFences > 0) { 660 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); 661 Graph = trimMitigatedEdges(std::move(Graph)); 662 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); 663 } 664 665 if (Graph->NumGadgets == 0) 666 return 0; 667 668 LLVM_DEBUG(dbgs() << "Cutting edges...\n"); 669 EdgeSet CutEdges{*Graph}; 670 671 // Begin by collecting all ingress CFG edges for each node 672 DenseMap<const Node *, SmallVector<const Edge *, 2>> IngressEdgeMap; 673 for (const Edge &E : Graph->edges()) 674 if (MachineGadgetGraph::isCFGEdge(E)) 675 IngressEdgeMap[E.getDest()].push_back(&E); 676 677 // For each gadget edge, make cuts that guarantee the gadget will be 678 // mitigated. A computationally efficient way to achieve this is to either: 679 // (a) cut all egress CFG edges from the gadget source, or 680 // (b) cut all ingress CFG edges to the gadget sink. 681 // 682 // Moreover, the algorithm tries not to make a cut into a loop by preferring 683 // to make a (b)-type cut if the gadget source resides at a greater loop depth 684 // than the gadget sink, or an (a)-type cut otherwise. 685 for (const Node &N : Graph->nodes()) { 686 for (const Edge &E : N.edges()) { 687 if (!MachineGadgetGraph::isGadgetEdge(E)) 688 continue; 689 690 SmallVector<const Edge *, 2> EgressEdges; 691 SmallVector<const Edge *, 2> &IngressEdges = IngressEdgeMap[E.getDest()]; 692 for (const Edge &EgressEdge : N.edges()) 693 if (MachineGadgetGraph::isCFGEdge(EgressEdge)) 694 EgressEdges.push_back(&EgressEdge); 695 696 int EgressCutCost = 0, IngressCutCost = 0; 697 for (const Edge *EgressEdge : EgressEdges) 698 if (!CutEdges.contains(*EgressEdge)) 699 EgressCutCost += EgressEdge->getValue(); 700 for (const Edge *IngressEdge : IngressEdges) 701 if (!CutEdges.contains(*IngressEdge)) 702 IngressCutCost += IngressEdge->getValue(); 703 704 auto &EdgesToCut = 705 IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges; 706 for (const Edge *E : EdgesToCut) 707 CutEdges.insert(*E); 708 } 709 } 710 LLVM_DEBUG(dbgs() << "Cutting edges... Done\n"); 711 LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n"); 712 713 LLVM_DEBUG(dbgs() << "Inserting LFENCEs...\n"); 714 int FencesInserted = insertFences(MF, *Graph, CutEdges); 715 LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n"); 716 LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n"); 717 718 return FencesInserted; 719 } 720 721 int X86LoadValueInjectionLoadHardeningPass::insertFences( 722 MachineFunction &MF, MachineGadgetGraph &G, 723 EdgeSet &CutEdges /* in, out */) const { 724 int FencesInserted = 0; 725 for (const Node &N : G.nodes()) { 726 for (const Edge &E : N.edges()) { 727 if (CutEdges.contains(E)) { 728 MachineInstr *MI = N.getValue(), *Prev; 729 MachineBasicBlock *MBB; // Insert an LFENCE in this MBB 730 MachineBasicBlock::iterator InsertionPt; // ...at this point 731 if (MI == MachineGadgetGraph::ArgNodeSentinel) { 732 // insert LFENCE at beginning of entry block 733 MBB = &MF.front(); 734 InsertionPt = MBB->begin(); 735 Prev = nullptr; 736 } else if (MI->isBranch()) { // insert the LFENCE before the branch 737 MBB = MI->getParent(); 738 InsertionPt = MI; 739 Prev = MI->getPrevNode(); 740 // Remove all egress CFG edges from this branch because the inserted 741 // LFENCE prevents gadgets from crossing the branch. 742 for (const Edge &E : N.edges()) { 743 if (MachineGadgetGraph::isCFGEdge(E)) 744 CutEdges.insert(E); 745 } 746 } else { // insert the LFENCE after the instruction 747 MBB = MI->getParent(); 748 InsertionPt = MI->getNextNode() ? MI->getNextNode() : MBB->end(); 749 Prev = InsertionPt == MBB->end() 750 ? (MBB->empty() ? nullptr : &MBB->back()) 751 : InsertionPt->getPrevNode(); 752 } 753 // Ensure this insertion is not redundant (two LFENCEs in sequence). 754 if ((InsertionPt == MBB->end() || !isFence(&*InsertionPt)) && 755 (!Prev || !isFence(Prev))) { 756 BuildMI(*MBB, InsertionPt, DebugLoc(), TII->get(X86::LFENCE)); 757 ++FencesInserted; 758 } 759 } 760 } 761 } 762 return FencesInserted; 763 } 764 765 bool X86LoadValueInjectionLoadHardeningPass::instrUsesRegToAccessMemory( 766 const MachineInstr &MI, unsigned Reg) const { 767 if (!MI.mayLoadOrStore() || MI.getOpcode() == X86::MFENCE || 768 MI.getOpcode() == X86::SFENCE || MI.getOpcode() == X86::LFENCE) 769 return false; 770 771 const int MemRefBeginIdx = X86::getFirstAddrOperandIdx(MI); 772 if (MemRefBeginIdx < 0) { 773 LLVM_DEBUG(dbgs() << "Warning: unable to obtain memory operand for loading " 774 "instruction:\n"; 775 MI.print(dbgs()); dbgs() << '\n';); 776 return false; 777 } 778 779 const MachineOperand &BaseMO = 780 MI.getOperand(MemRefBeginIdx + X86::AddrBaseReg); 781 const MachineOperand &IndexMO = 782 MI.getOperand(MemRefBeginIdx + X86::AddrIndexReg); 783 return (BaseMO.isReg() && BaseMO.getReg() != X86::NoRegister && 784 TRI->regsOverlap(BaseMO.getReg(), Reg)) || 785 (IndexMO.isReg() && IndexMO.getReg() != X86::NoRegister && 786 TRI->regsOverlap(IndexMO.getReg(), Reg)); 787 } 788 789 bool X86LoadValueInjectionLoadHardeningPass::instrUsesRegToBranch( 790 const MachineInstr &MI, unsigned Reg) const { 791 if (!MI.isConditionalBranch()) 792 return false; 793 for (const MachineOperand &Use : MI.uses()) 794 if (Use.isReg() && Use.getReg() == Reg) 795 return true; 796 return false; 797 } 798 799 INITIALIZE_PASS_BEGIN(X86LoadValueInjectionLoadHardeningPass, PASS_KEY, 800 "X86 LVI load hardening", false, false) 801 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass) 802 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) 803 INITIALIZE_PASS_DEPENDENCY(MachineDominanceFrontier) 804 INITIALIZE_PASS_END(X86LoadValueInjectionLoadHardeningPass, PASS_KEY, 805 "X86 LVI load hardening", false, false) 806 807 FunctionPass *llvm::createX86LoadValueInjectionLoadHardeningPass() { 808 return new X86LoadValueInjectionLoadHardeningPass(); 809 } 810