xref: /llvm-project/llvm/include/llvm/Transforms/Scalar/GVNExpression.h (revision 984bca9d1faaa1fa5c694f8f2a5524b2374d204a)
1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 ///
11 /// The header file for the GVN pass that contains expression handling
12 /// classes
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
18 
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <cassert>
32 #include <iterator>
33 #include <utility>
34 
35 namespace llvm {
36 
37 class BasicBlock;
38 class Type;
39 
40 namespace GVNExpression {
41 
42 enum ExpressionType {
43   ET_Base,
44   ET_Constant,
45   ET_Variable,
46   ET_Dead,
47   ET_Unknown,
48   ET_BasicStart,
49   ET_Basic,
50   ET_AggregateValue,
51   ET_Phi,
52   ET_MemoryStart,
53   ET_Call,
54   ET_Load,
55   ET_Store,
56   ET_MemoryEnd,
57   ET_BasicEnd
58 };
59 
60 class Expression {
61 private:
62   ExpressionType EType;
63   unsigned Opcode;
64   mutable hash_code HashVal = 0;
65 
66 public:
67   Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
68       : EType(ET), Opcode(O) {}
69   Expression(const Expression &) = delete;
70   Expression &operator=(const Expression &) = delete;
71   virtual ~Expression();
72 
73   static unsigned getEmptyKey() { return ~0U; }
74   static unsigned getTombstoneKey() { return ~1U; }
75 
76   bool operator!=(const Expression &Other) const { return !(*this == Other); }
77   bool operator==(const Expression &Other) const {
78     if (getOpcode() != Other.getOpcode())
79       return false;
80     if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
81       return true;
82     // Compare the expression type for anything but load and store.
83     // For load and store we set the opcode to zero to make them equal.
84     if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
85         getExpressionType() != Other.getExpressionType())
86       return false;
87 
88     return equals(Other);
89   }
90 
91   hash_code getComputedHash() const {
92     // It's theoretically possible for a thing to hash to zero.  In that case,
93     // we will just compute the hash a few extra times, which is no worse that
94     // we did before, which was to compute it always.
95     if (static_cast<unsigned>(HashVal) == 0)
96       HashVal = getHashValue();
97     return HashVal;
98   }
99 
100   virtual bool equals(const Expression &Other) const { return true; }
101 
102   // Return true if the two expressions are exactly the same, including the
103   // normally ignored fields.
104   virtual bool exactlyEquals(const Expression &Other) const {
105     return getExpressionType() == Other.getExpressionType() && equals(Other);
106   }
107 
108   unsigned getOpcode() const { return Opcode; }
109   void setOpcode(unsigned opcode) { Opcode = opcode; }
110   ExpressionType getExpressionType() const { return EType; }
111 
112   // We deliberately leave the expression type out of the hash value.
113   virtual hash_code getHashValue() const { return getOpcode(); }
114 
115   // Debugging support
116   virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
117     if (PrintEType)
118       OS << "etype = " << getExpressionType() << ",";
119     OS << "opcode = " << getOpcode() << ", ";
120   }
121 
122   void print(raw_ostream &OS) const {
123     OS << "{ ";
124     printInternal(OS, true);
125     OS << "}";
126   }
127 
128   LLVM_DUMP_METHOD void dump() const;
129 };
130 
131 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
132   E.print(OS);
133   return OS;
134 }
135 
136 class BasicExpression : public Expression {
137 private:
138   using RecyclerType = ArrayRecycler<Value *>;
139   using RecyclerCapacity = RecyclerType::Capacity;
140 
141   Value **Operands = nullptr;
142   unsigned MaxOperands;
143   unsigned NumOperands = 0;
144   Type *ValueType = nullptr;
145 
146 public:
147   BasicExpression(unsigned NumOperands)
148       : BasicExpression(NumOperands, ET_Basic) {}
149   BasicExpression(unsigned NumOperands, ExpressionType ET)
150       : Expression(ET), MaxOperands(NumOperands) {}
151   BasicExpression() = delete;
152   BasicExpression(const BasicExpression &) = delete;
153   BasicExpression &operator=(const BasicExpression &) = delete;
154   ~BasicExpression() override;
155 
156   static bool classof(const Expression *EB) {
157     ExpressionType ET = EB->getExpressionType();
158     return ET > ET_BasicStart && ET < ET_BasicEnd;
159   }
160 
161   /// Swap two operands. Used during GVN to put commutative operands in
162   /// order.
163   void swapOperands(unsigned First, unsigned Second) {
164     std::swap(Operands[First], Operands[Second]);
165   }
166 
167   Value *getOperand(unsigned N) const {
168     assert(Operands && "Operands not allocated");
169     assert(N < NumOperands && "Operand out of range");
170     return Operands[N];
171   }
172 
173   void setOperand(unsigned N, Value *V) {
174     assert(Operands && "Operands not allocated before setting");
175     assert(N < NumOperands && "Operand out of range");
176     Operands[N] = V;
177   }
178 
179   unsigned getNumOperands() const { return NumOperands; }
180 
181   using op_iterator = Value **;
182   using const_op_iterator = Value *const *;
183 
184   op_iterator op_begin() { return Operands; }
185   op_iterator op_end() { return Operands + NumOperands; }
186   const_op_iterator op_begin() const { return Operands; }
187   const_op_iterator op_end() const { return Operands + NumOperands; }
188   iterator_range<op_iterator> operands() {
189     return iterator_range<op_iterator>(op_begin(), op_end());
190   }
191   iterator_range<const_op_iterator> operands() const {
192     return iterator_range<const_op_iterator>(op_begin(), op_end());
193   }
194 
195   void op_push_back(Value *Arg) {
196     assert(NumOperands < MaxOperands && "Tried to add too many operands");
197     assert(Operands && "Operandss not allocated before pushing");
198     Operands[NumOperands++] = Arg;
199   }
200   bool op_empty() const { return getNumOperands() == 0; }
201 
202   void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
203     assert(!Operands && "Operands already allocated");
204     Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
205   }
206   void deallocateOperands(RecyclerType &Recycler) {
207     Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
208   }
209 
210   void setType(Type *T) { ValueType = T; }
211   Type *getType() const { return ValueType; }
212 
213   bool equals(const Expression &Other) const override {
214     if (getOpcode() != Other.getOpcode())
215       return false;
216 
217     const auto &OE = cast<BasicExpression>(Other);
218     return getType() == OE.getType() && NumOperands == OE.NumOperands &&
219            std::equal(op_begin(), op_end(), OE.op_begin());
220   }
221 
222   hash_code getHashValue() const override {
223     return hash_combine(this->Expression::getHashValue(), ValueType,
224                         hash_combine_range(op_begin(), op_end()));
225   }
226 
227   // Debugging support
228   void printInternal(raw_ostream &OS, bool PrintEType) const override {
229     if (PrintEType)
230       OS << "ExpressionTypeBasic, ";
231 
232     this->Expression::printInternal(OS, false);
233     OS << "operands = {";
234     for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
235       OS << "[" << i << "] = ";
236       Operands[i]->printAsOperand(OS);
237       OS << "  ";
238     }
239     OS << "} ";
240   }
241 };
242 
243 class op_inserter {
244 private:
245   using Container = BasicExpression;
246 
247   Container *BE;
248 
249 public:
250   using iterator_category = std::output_iterator_tag;
251   using value_type = void;
252   using difference_type = void;
253   using pointer = void;
254   using reference = void;
255 
256   explicit op_inserter(BasicExpression &E) : BE(&E) {}
257   explicit op_inserter(BasicExpression *E) : BE(E) {}
258 
259   op_inserter &operator=(Value *val) {
260     BE->op_push_back(val);
261     return *this;
262   }
263   op_inserter &operator*() { return *this; }
264   op_inserter &operator++() { return *this; }
265   op_inserter &operator++(int) { return *this; }
266 };
267 
268 class MemoryExpression : public BasicExpression {
269 private:
270   const MemoryAccess *MemoryLeader;
271 
272 public:
273   MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
274                    const MemoryAccess *MemoryLeader)
275       : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
276   MemoryExpression() = delete;
277   MemoryExpression(const MemoryExpression &) = delete;
278   MemoryExpression &operator=(const MemoryExpression &) = delete;
279 
280   static bool classof(const Expression *EB) {
281     return EB->getExpressionType() > ET_MemoryStart &&
282            EB->getExpressionType() < ET_MemoryEnd;
283   }
284 
285   hash_code getHashValue() const override {
286     return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
287   }
288 
289   bool equals(const Expression &Other) const override {
290     if (!this->BasicExpression::equals(Other))
291       return false;
292     const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
293 
294     return MemoryLeader == OtherMCE.MemoryLeader;
295   }
296 
297   const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
298   void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
299 };
300 
301 class CallExpression final : public MemoryExpression {
302 private:
303   CallInst *Call;
304 
305 public:
306   CallExpression(unsigned NumOperands, CallInst *C,
307                  const MemoryAccess *MemoryLeader)
308       : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
309   CallExpression() = delete;
310   CallExpression(const CallExpression &) = delete;
311   CallExpression &operator=(const CallExpression &) = delete;
312   ~CallExpression() override;
313 
314   static bool classof(const Expression *EB) {
315     return EB->getExpressionType() == ET_Call;
316   }
317 
318   bool equals(const Expression &Other) const override;
319   bool exactlyEquals(const Expression &Other) const override {
320     return Expression::exactlyEquals(Other) &&
321            cast<CallExpression>(Other).Call == Call;
322   }
323 
324   // Debugging support
325   void printInternal(raw_ostream &OS, bool PrintEType) const override {
326     if (PrintEType)
327       OS << "ExpressionTypeCall, ";
328     this->BasicExpression::printInternal(OS, false);
329     OS << " represents call at ";
330     Call->printAsOperand(OS);
331   }
332 };
333 
334 class LoadExpression final : public MemoryExpression {
335 private:
336   LoadInst *Load;
337 
338 public:
339   LoadExpression(unsigned NumOperands, LoadInst *L,
340                  const MemoryAccess *MemoryLeader)
341       : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
342 
343   LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
344                  const MemoryAccess *MemoryLeader)
345       : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {}
346 
347   LoadExpression() = delete;
348   LoadExpression(const LoadExpression &) = delete;
349   LoadExpression &operator=(const LoadExpression &) = delete;
350   ~LoadExpression() override;
351 
352   static bool classof(const Expression *EB) {
353     return EB->getExpressionType() == ET_Load;
354   }
355 
356   LoadInst *getLoadInst() const { return Load; }
357   void setLoadInst(LoadInst *L) { Load = L; }
358 
359   bool equals(const Expression &Other) const override;
360   bool exactlyEquals(const Expression &Other) const override {
361     return Expression::exactlyEquals(Other) &&
362            cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
363   }
364 
365   // Debugging support
366   void printInternal(raw_ostream &OS, bool PrintEType) const override {
367     if (PrintEType)
368       OS << "ExpressionTypeLoad, ";
369     this->BasicExpression::printInternal(OS, false);
370     OS << " represents Load at ";
371     Load->printAsOperand(OS);
372     OS << " with MemoryLeader " << *getMemoryLeader();
373   }
374 };
375 
376 class StoreExpression final : public MemoryExpression {
377 private:
378   StoreInst *Store;
379   Value *StoredValue;
380 
381 public:
382   StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
383                   const MemoryAccess *MemoryLeader)
384       : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
385         StoredValue(StoredValue) {}
386   StoreExpression() = delete;
387   StoreExpression(const StoreExpression &) = delete;
388   StoreExpression &operator=(const StoreExpression &) = delete;
389   ~StoreExpression() override;
390 
391   static bool classof(const Expression *EB) {
392     return EB->getExpressionType() == ET_Store;
393   }
394 
395   StoreInst *getStoreInst() const { return Store; }
396   Value *getStoredValue() const { return StoredValue; }
397 
398   bool equals(const Expression &Other) const override;
399 
400   bool exactlyEquals(const Expression &Other) const override {
401     return Expression::exactlyEquals(Other) &&
402            cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
403   }
404 
405   // Debugging support
406   void printInternal(raw_ostream &OS, bool PrintEType) const override {
407     if (PrintEType)
408       OS << "ExpressionTypeStore, ";
409     this->BasicExpression::printInternal(OS, false);
410     OS << " represents Store  " << *Store;
411     OS << " with StoredValue ";
412     StoredValue->printAsOperand(OS);
413     OS << " and MemoryLeader " << *getMemoryLeader();
414   }
415 };
416 
417 class AggregateValueExpression final : public BasicExpression {
418 private:
419   unsigned MaxIntOperands;
420   unsigned NumIntOperands = 0;
421   unsigned *IntOperands = nullptr;
422 
423 public:
424   AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
425       : BasicExpression(NumOperands, ET_AggregateValue),
426         MaxIntOperands(NumIntOperands) {}
427   AggregateValueExpression() = delete;
428   AggregateValueExpression(const AggregateValueExpression &) = delete;
429   AggregateValueExpression &
430   operator=(const AggregateValueExpression &) = delete;
431   ~AggregateValueExpression() override;
432 
433   static bool classof(const Expression *EB) {
434     return EB->getExpressionType() == ET_AggregateValue;
435   }
436 
437   using int_arg_iterator = unsigned *;
438   using const_int_arg_iterator = const unsigned *;
439 
440   int_arg_iterator int_op_begin() { return IntOperands; }
441   int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
442   const_int_arg_iterator int_op_begin() const { return IntOperands; }
443   const_int_arg_iterator int_op_end() const {
444     return IntOperands + NumIntOperands;
445   }
446   unsigned int_op_size() const { return NumIntOperands; }
447   bool int_op_empty() const { return NumIntOperands == 0; }
448   void int_op_push_back(unsigned IntOperand) {
449     assert(NumIntOperands < MaxIntOperands &&
450            "Tried to add too many int operands");
451     assert(IntOperands && "Operands not allocated before pushing");
452     IntOperands[NumIntOperands++] = IntOperand;
453   }
454 
455   virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
456     assert(!IntOperands && "Operands already allocated");
457     IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
458   }
459 
460   bool equals(const Expression &Other) const override {
461     if (!this->BasicExpression::equals(Other))
462       return false;
463     const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
464     return NumIntOperands == OE.NumIntOperands &&
465            std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
466   }
467 
468   hash_code getHashValue() const override {
469     return hash_combine(this->BasicExpression::getHashValue(),
470                         hash_combine_range(int_op_begin(), int_op_end()));
471   }
472 
473   // Debugging support
474   void printInternal(raw_ostream &OS, bool PrintEType) const override {
475     if (PrintEType)
476       OS << "ExpressionTypeAggregateValue, ";
477     this->BasicExpression::printInternal(OS, false);
478     OS << ", intoperands = {";
479     for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
480       OS << "[" << i << "] = " << IntOperands[i] << "  ";
481     }
482     OS << "}";
483   }
484 };
485 
486 class int_op_inserter {
487 private:
488   using Container = AggregateValueExpression;
489 
490   Container *AVE;
491 
492 public:
493   using iterator_category = std::output_iterator_tag;
494   using value_type = void;
495   using difference_type = void;
496   using pointer = void;
497   using reference = void;
498 
499   explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
500   explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
501 
502   int_op_inserter &operator=(unsigned int val) {
503     AVE->int_op_push_back(val);
504     return *this;
505   }
506   int_op_inserter &operator*() { return *this; }
507   int_op_inserter &operator++() { return *this; }
508   int_op_inserter &operator++(int) { return *this; }
509 };
510 
511 class PHIExpression final : public BasicExpression {
512 private:
513   BasicBlock *BB;
514 
515 public:
516   PHIExpression(unsigned NumOperands, BasicBlock *B)
517       : BasicExpression(NumOperands, ET_Phi), BB(B) {}
518   PHIExpression() = delete;
519   PHIExpression(const PHIExpression &) = delete;
520   PHIExpression &operator=(const PHIExpression &) = delete;
521   ~PHIExpression() override;
522 
523   static bool classof(const Expression *EB) {
524     return EB->getExpressionType() == ET_Phi;
525   }
526 
527   bool equals(const Expression &Other) const override {
528     if (!this->BasicExpression::equals(Other))
529       return false;
530     const PHIExpression &OE = cast<PHIExpression>(Other);
531     return BB == OE.BB;
532   }
533 
534   hash_code getHashValue() const override {
535     return hash_combine(this->BasicExpression::getHashValue(), BB);
536   }
537 
538   // Debugging support
539   void printInternal(raw_ostream &OS, bool PrintEType) const override {
540     if (PrintEType)
541       OS << "ExpressionTypePhi, ";
542     this->BasicExpression::printInternal(OS, false);
543     OS << "bb = " << BB;
544   }
545 };
546 
547 class DeadExpression final : public Expression {
548 public:
549   DeadExpression() : Expression(ET_Dead) {}
550   DeadExpression(const DeadExpression &) = delete;
551   DeadExpression &operator=(const DeadExpression &) = delete;
552 
553   static bool classof(const Expression *E) {
554     return E->getExpressionType() == ET_Dead;
555   }
556 };
557 
558 class VariableExpression final : public Expression {
559 private:
560   Value *VariableValue;
561 
562 public:
563   VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
564   VariableExpression() = delete;
565   VariableExpression(const VariableExpression &) = delete;
566   VariableExpression &operator=(const VariableExpression &) = delete;
567 
568   static bool classof(const Expression *EB) {
569     return EB->getExpressionType() == ET_Variable;
570   }
571 
572   Value *getVariableValue() const { return VariableValue; }
573   void setVariableValue(Value *V) { VariableValue = V; }
574 
575   bool equals(const Expression &Other) const override {
576     const VariableExpression &OC = cast<VariableExpression>(Other);
577     return VariableValue == OC.VariableValue;
578   }
579 
580   hash_code getHashValue() const override {
581     return hash_combine(this->Expression::getHashValue(),
582                         VariableValue->getType(), VariableValue);
583   }
584 
585   // Debugging support
586   void printInternal(raw_ostream &OS, bool PrintEType) const override {
587     if (PrintEType)
588       OS << "ExpressionTypeVariable, ";
589     this->Expression::printInternal(OS, false);
590     OS << " variable = " << *VariableValue;
591   }
592 };
593 
594 class ConstantExpression final : public Expression {
595 private:
596   Constant *ConstantValue = nullptr;
597 
598 public:
599   ConstantExpression() : Expression(ET_Constant) {}
600   ConstantExpression(Constant *constantValue)
601       : Expression(ET_Constant), ConstantValue(constantValue) {}
602   ConstantExpression(const ConstantExpression &) = delete;
603   ConstantExpression &operator=(const ConstantExpression &) = delete;
604 
605   static bool classof(const Expression *EB) {
606     return EB->getExpressionType() == ET_Constant;
607   }
608 
609   Constant *getConstantValue() const { return ConstantValue; }
610   void setConstantValue(Constant *V) { ConstantValue = V; }
611 
612   bool equals(const Expression &Other) const override {
613     const ConstantExpression &OC = cast<ConstantExpression>(Other);
614     return ConstantValue == OC.ConstantValue;
615   }
616 
617   hash_code getHashValue() const override {
618     return hash_combine(this->Expression::getHashValue(),
619                         ConstantValue->getType(), ConstantValue);
620   }
621 
622   // Debugging support
623   void printInternal(raw_ostream &OS, bool PrintEType) const override {
624     if (PrintEType)
625       OS << "ExpressionTypeConstant, ";
626     this->Expression::printInternal(OS, false);
627     OS << " constant = " << *ConstantValue;
628   }
629 };
630 
631 class UnknownExpression final : public Expression {
632 private:
633   Instruction *Inst;
634 
635 public:
636   UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
637   UnknownExpression() = delete;
638   UnknownExpression(const UnknownExpression &) = delete;
639   UnknownExpression &operator=(const UnknownExpression &) = delete;
640 
641   static bool classof(const Expression *EB) {
642     return EB->getExpressionType() == ET_Unknown;
643   }
644 
645   Instruction *getInstruction() const { return Inst; }
646   void setInstruction(Instruction *I) { Inst = I; }
647 
648   bool equals(const Expression &Other) const override {
649     const auto &OU = cast<UnknownExpression>(Other);
650     return Inst == OU.Inst;
651   }
652 
653   hash_code getHashValue() const override {
654     return hash_combine(this->Expression::getHashValue(), Inst);
655   }
656 
657   // Debugging support
658   void printInternal(raw_ostream &OS, bool PrintEType) const override {
659     if (PrintEType)
660       OS << "ExpressionTypeUnknown, ";
661     this->Expression::printInternal(OS, false);
662     OS << " inst = " << *Inst;
663   }
664 };
665 
666 } // end namespace GVNExpression
667 
668 } // end namespace llvm
669 
670 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
671