xref: /freebsd-src/contrib/llvm-project/llvm/include/llvm/Transforms/Scalar/GVNExpression.h (revision 0b57cec536236d46e3dba9bd041533462f33dbb7)
1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 ///
11 /// The header file for the GVN pass that contains expression handling
12 /// classes
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
18 
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <cassert>
32 #include <iterator>
33 #include <utility>
34 
35 namespace llvm {
36 
37 class BasicBlock;
38 class Type;
39 
40 namespace GVNExpression {
41 
42 enum ExpressionType {
43   ET_Base,
44   ET_Constant,
45   ET_Variable,
46   ET_Dead,
47   ET_Unknown,
48   ET_BasicStart,
49   ET_Basic,
50   ET_AggregateValue,
51   ET_Phi,
52   ET_MemoryStart,
53   ET_Call,
54   ET_Load,
55   ET_Store,
56   ET_MemoryEnd,
57   ET_BasicEnd
58 };
59 
60 class Expression {
61 private:
62   ExpressionType EType;
63   unsigned Opcode;
64   mutable hash_code HashVal = 0;
65 
66 public:
67   Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
68       : EType(ET), Opcode(O) {}
69   Expression(const Expression &) = delete;
70   Expression &operator=(const Expression &) = delete;
71   virtual ~Expression();
72 
73   static unsigned getEmptyKey() { return ~0U; }
74   static unsigned getTombstoneKey() { return ~1U; }
75 
76   bool operator!=(const Expression &Other) const { return !(*this == Other); }
77   bool operator==(const Expression &Other) const {
78     if (getOpcode() != Other.getOpcode())
79       return false;
80     if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
81       return true;
82     // Compare the expression type for anything but load and store.
83     // For load and store we set the opcode to zero to make them equal.
84     if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
85         getExpressionType() != Other.getExpressionType())
86       return false;
87 
88     return equals(Other);
89   }
90 
91   hash_code getComputedHash() const {
92     // It's theoretically possible for a thing to hash to zero.  In that case,
93     // we will just compute the hash a few extra times, which is no worse that
94     // we did before, which was to compute it always.
95     if (static_cast<unsigned>(HashVal) == 0)
96       HashVal = getHashValue();
97     return HashVal;
98   }
99 
100   virtual bool equals(const Expression &Other) const { return true; }
101 
102   // Return true if the two expressions are exactly the same, including the
103   // normally ignored fields.
104   virtual bool exactlyEquals(const Expression &Other) const {
105     return getExpressionType() == Other.getExpressionType() && equals(Other);
106   }
107 
108   unsigned getOpcode() const { return Opcode; }
109   void setOpcode(unsigned opcode) { Opcode = opcode; }
110   ExpressionType getExpressionType() const { return EType; }
111 
112   // We deliberately leave the expression type out of the hash value.
113   virtual hash_code getHashValue() const { return getOpcode(); }
114 
115   // Debugging support
116   virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
117     if (PrintEType)
118       OS << "etype = " << getExpressionType() << ",";
119     OS << "opcode = " << getOpcode() << ", ";
120   }
121 
122   void print(raw_ostream &OS) const {
123     OS << "{ ";
124     printInternal(OS, true);
125     OS << "}";
126   }
127 
128   LLVM_DUMP_METHOD void dump() const;
129 };
130 
131 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
132   E.print(OS);
133   return OS;
134 }
135 
136 class BasicExpression : public Expression {
137 private:
138   using RecyclerType = ArrayRecycler<Value *>;
139   using RecyclerCapacity = RecyclerType::Capacity;
140 
141   Value **Operands = nullptr;
142   unsigned MaxOperands;
143   unsigned NumOperands = 0;
144   Type *ValueType = nullptr;
145 
146 public:
147   BasicExpression(unsigned NumOperands)
148       : BasicExpression(NumOperands, ET_Basic) {}
149   BasicExpression(unsigned NumOperands, ExpressionType ET)
150       : Expression(ET), MaxOperands(NumOperands) {}
151   BasicExpression() = delete;
152   BasicExpression(const BasicExpression &) = delete;
153   BasicExpression &operator=(const BasicExpression &) = delete;
154   ~BasicExpression() override;
155 
156   static bool classof(const Expression *EB) {
157     ExpressionType ET = EB->getExpressionType();
158     return ET > ET_BasicStart && ET < ET_BasicEnd;
159   }
160 
161   /// Swap two operands. Used during GVN to put commutative operands in
162   /// order.
163   void swapOperands(unsigned First, unsigned Second) {
164     std::swap(Operands[First], Operands[Second]);
165   }
166 
167   Value *getOperand(unsigned N) const {
168     assert(Operands && "Operands not allocated");
169     assert(N < NumOperands && "Operand out of range");
170     return Operands[N];
171   }
172 
173   void setOperand(unsigned N, Value *V) {
174     assert(Operands && "Operands not allocated before setting");
175     assert(N < NumOperands && "Operand out of range");
176     Operands[N] = V;
177   }
178 
179   unsigned getNumOperands() const { return NumOperands; }
180 
181   using op_iterator = Value **;
182   using const_op_iterator = Value *const *;
183 
184   op_iterator op_begin() { return Operands; }
185   op_iterator op_end() { return Operands + NumOperands; }
186   const_op_iterator op_begin() const { return Operands; }
187   const_op_iterator op_end() const { return Operands + NumOperands; }
188   iterator_range<op_iterator> operands() {
189     return iterator_range<op_iterator>(op_begin(), op_end());
190   }
191   iterator_range<const_op_iterator> operands() const {
192     return iterator_range<const_op_iterator>(op_begin(), op_end());
193   }
194 
195   void op_push_back(Value *Arg) {
196     assert(NumOperands < MaxOperands && "Tried to add too many operands");
197     assert(Operands && "Operandss not allocated before pushing");
198     Operands[NumOperands++] = Arg;
199   }
200   bool op_empty() const { return getNumOperands() == 0; }
201 
202   void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
203     assert(!Operands && "Operands already allocated");
204     Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
205   }
206   void deallocateOperands(RecyclerType &Recycler) {
207     Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
208   }
209 
210   void setType(Type *T) { ValueType = T; }
211   Type *getType() const { return ValueType; }
212 
213   bool equals(const Expression &Other) const override {
214     if (getOpcode() != Other.getOpcode())
215       return false;
216 
217     const auto &OE = cast<BasicExpression>(Other);
218     return getType() == OE.getType() && NumOperands == OE.NumOperands &&
219            std::equal(op_begin(), op_end(), OE.op_begin());
220   }
221 
222   hash_code getHashValue() const override {
223     return hash_combine(this->Expression::getHashValue(), ValueType,
224                         hash_combine_range(op_begin(), op_end()));
225   }
226 
227   // Debugging support
228   void printInternal(raw_ostream &OS, bool PrintEType) const override {
229     if (PrintEType)
230       OS << "ExpressionTypeBasic, ";
231 
232     this->Expression::printInternal(OS, false);
233     OS << "operands = {";
234     for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
235       OS << "[" << i << "] = ";
236       Operands[i]->printAsOperand(OS);
237       OS << "  ";
238     }
239     OS << "} ";
240   }
241 };
242 
243 class op_inserter
244     : public std::iterator<std::output_iterator_tag, void, void, void, void> {
245 private:
246   using Container = BasicExpression;
247 
248   Container *BE;
249 
250 public:
251   explicit op_inserter(BasicExpression &E) : BE(&E) {}
252   explicit op_inserter(BasicExpression *E) : BE(E) {}
253 
254   op_inserter &operator=(Value *val) {
255     BE->op_push_back(val);
256     return *this;
257   }
258   op_inserter &operator*() { return *this; }
259   op_inserter &operator++() { return *this; }
260   op_inserter &operator++(int) { return *this; }
261 };
262 
263 class MemoryExpression : public BasicExpression {
264 private:
265   const MemoryAccess *MemoryLeader;
266 
267 public:
268   MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
269                    const MemoryAccess *MemoryLeader)
270       : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
271   MemoryExpression() = delete;
272   MemoryExpression(const MemoryExpression &) = delete;
273   MemoryExpression &operator=(const MemoryExpression &) = delete;
274 
275   static bool classof(const Expression *EB) {
276     return EB->getExpressionType() > ET_MemoryStart &&
277            EB->getExpressionType() < ET_MemoryEnd;
278   }
279 
280   hash_code getHashValue() const override {
281     return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
282   }
283 
284   bool equals(const Expression &Other) const override {
285     if (!this->BasicExpression::equals(Other))
286       return false;
287     const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
288 
289     return MemoryLeader == OtherMCE.MemoryLeader;
290   }
291 
292   const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
293   void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
294 };
295 
296 class CallExpression final : public MemoryExpression {
297 private:
298   CallInst *Call;
299 
300 public:
301   CallExpression(unsigned NumOperands, CallInst *C,
302                  const MemoryAccess *MemoryLeader)
303       : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
304   CallExpression() = delete;
305   CallExpression(const CallExpression &) = delete;
306   CallExpression &operator=(const CallExpression &) = delete;
307   ~CallExpression() override;
308 
309   static bool classof(const Expression *EB) {
310     return EB->getExpressionType() == ET_Call;
311   }
312 
313   // Debugging support
314   void printInternal(raw_ostream &OS, bool PrintEType) const override {
315     if (PrintEType)
316       OS << "ExpressionTypeCall, ";
317     this->BasicExpression::printInternal(OS, false);
318     OS << " represents call at ";
319     Call->printAsOperand(OS);
320   }
321 };
322 
323 class LoadExpression final : public MemoryExpression {
324 private:
325   LoadInst *Load;
326   unsigned Alignment;
327 
328 public:
329   LoadExpression(unsigned NumOperands, LoadInst *L,
330                  const MemoryAccess *MemoryLeader)
331       : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
332 
333   LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
334                  const MemoryAccess *MemoryLeader)
335       : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {
336     Alignment = L ? L->getAlignment() : 0;
337   }
338 
339   LoadExpression() = delete;
340   LoadExpression(const LoadExpression &) = delete;
341   LoadExpression &operator=(const LoadExpression &) = delete;
342   ~LoadExpression() override;
343 
344   static bool classof(const Expression *EB) {
345     return EB->getExpressionType() == ET_Load;
346   }
347 
348   LoadInst *getLoadInst() const { return Load; }
349   void setLoadInst(LoadInst *L) { Load = L; }
350 
351   unsigned getAlignment() const { return Alignment; }
352   void setAlignment(unsigned Align) { Alignment = Align; }
353 
354   bool equals(const Expression &Other) const override;
355   bool exactlyEquals(const Expression &Other) const override {
356     return Expression::exactlyEquals(Other) &&
357            cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
358   }
359 
360   // Debugging support
361   void printInternal(raw_ostream &OS, bool PrintEType) const override {
362     if (PrintEType)
363       OS << "ExpressionTypeLoad, ";
364     this->BasicExpression::printInternal(OS, false);
365     OS << " represents Load at ";
366     Load->printAsOperand(OS);
367     OS << " with MemoryLeader " << *getMemoryLeader();
368   }
369 };
370 
371 class StoreExpression final : public MemoryExpression {
372 private:
373   StoreInst *Store;
374   Value *StoredValue;
375 
376 public:
377   StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
378                   const MemoryAccess *MemoryLeader)
379       : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
380         StoredValue(StoredValue) {}
381   StoreExpression() = delete;
382   StoreExpression(const StoreExpression &) = delete;
383   StoreExpression &operator=(const StoreExpression &) = delete;
384   ~StoreExpression() override;
385 
386   static bool classof(const Expression *EB) {
387     return EB->getExpressionType() == ET_Store;
388   }
389 
390   StoreInst *getStoreInst() const { return Store; }
391   Value *getStoredValue() const { return StoredValue; }
392 
393   bool equals(const Expression &Other) const override;
394 
395   bool exactlyEquals(const Expression &Other) const override {
396     return Expression::exactlyEquals(Other) &&
397            cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
398   }
399 
400   // Debugging support
401   void printInternal(raw_ostream &OS, bool PrintEType) const override {
402     if (PrintEType)
403       OS << "ExpressionTypeStore, ";
404     this->BasicExpression::printInternal(OS, false);
405     OS << " represents Store  " << *Store;
406     OS << " with StoredValue ";
407     StoredValue->printAsOperand(OS);
408     OS << " and MemoryLeader " << *getMemoryLeader();
409   }
410 };
411 
412 class AggregateValueExpression final : public BasicExpression {
413 private:
414   unsigned MaxIntOperands;
415   unsigned NumIntOperands = 0;
416   unsigned *IntOperands = nullptr;
417 
418 public:
419   AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
420       : BasicExpression(NumOperands, ET_AggregateValue),
421         MaxIntOperands(NumIntOperands) {}
422   AggregateValueExpression() = delete;
423   AggregateValueExpression(const AggregateValueExpression &) = delete;
424   AggregateValueExpression &
425   operator=(const AggregateValueExpression &) = delete;
426   ~AggregateValueExpression() override;
427 
428   static bool classof(const Expression *EB) {
429     return EB->getExpressionType() == ET_AggregateValue;
430   }
431 
432   using int_arg_iterator = unsigned *;
433   using const_int_arg_iterator = const unsigned *;
434 
435   int_arg_iterator int_op_begin() { return IntOperands; }
436   int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
437   const_int_arg_iterator int_op_begin() const { return IntOperands; }
438   const_int_arg_iterator int_op_end() const {
439     return IntOperands + NumIntOperands;
440   }
441   unsigned int_op_size() const { return NumIntOperands; }
442   bool int_op_empty() const { return NumIntOperands == 0; }
443   void int_op_push_back(unsigned IntOperand) {
444     assert(NumIntOperands < MaxIntOperands &&
445            "Tried to add too many int operands");
446     assert(IntOperands && "Operands not allocated before pushing");
447     IntOperands[NumIntOperands++] = IntOperand;
448   }
449 
450   virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
451     assert(!IntOperands && "Operands already allocated");
452     IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
453   }
454 
455   bool equals(const Expression &Other) const override {
456     if (!this->BasicExpression::equals(Other))
457       return false;
458     const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
459     return NumIntOperands == OE.NumIntOperands &&
460            std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
461   }
462 
463   hash_code getHashValue() const override {
464     return hash_combine(this->BasicExpression::getHashValue(),
465                         hash_combine_range(int_op_begin(), int_op_end()));
466   }
467 
468   // Debugging support
469   void printInternal(raw_ostream &OS, bool PrintEType) const override {
470     if (PrintEType)
471       OS << "ExpressionTypeAggregateValue, ";
472     this->BasicExpression::printInternal(OS, false);
473     OS << ", intoperands = {";
474     for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
475       OS << "[" << i << "] = " << IntOperands[i] << "  ";
476     }
477     OS << "}";
478   }
479 };
480 
481 class int_op_inserter
482     : public std::iterator<std::output_iterator_tag, void, void, void, void> {
483 private:
484   using Container = AggregateValueExpression;
485 
486   Container *AVE;
487 
488 public:
489   explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
490   explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
491 
492   int_op_inserter &operator=(unsigned int val) {
493     AVE->int_op_push_back(val);
494     return *this;
495   }
496   int_op_inserter &operator*() { return *this; }
497   int_op_inserter &operator++() { return *this; }
498   int_op_inserter &operator++(int) { return *this; }
499 };
500 
501 class PHIExpression final : public BasicExpression {
502 private:
503   BasicBlock *BB;
504 
505 public:
506   PHIExpression(unsigned NumOperands, BasicBlock *B)
507       : BasicExpression(NumOperands, ET_Phi), BB(B) {}
508   PHIExpression() = delete;
509   PHIExpression(const PHIExpression &) = delete;
510   PHIExpression &operator=(const PHIExpression &) = delete;
511   ~PHIExpression() override;
512 
513   static bool classof(const Expression *EB) {
514     return EB->getExpressionType() == ET_Phi;
515   }
516 
517   bool equals(const Expression &Other) const override {
518     if (!this->BasicExpression::equals(Other))
519       return false;
520     const PHIExpression &OE = cast<PHIExpression>(Other);
521     return BB == OE.BB;
522   }
523 
524   hash_code getHashValue() const override {
525     return hash_combine(this->BasicExpression::getHashValue(), BB);
526   }
527 
528   // Debugging support
529   void printInternal(raw_ostream &OS, bool PrintEType) const override {
530     if (PrintEType)
531       OS << "ExpressionTypePhi, ";
532     this->BasicExpression::printInternal(OS, false);
533     OS << "bb = " << BB;
534   }
535 };
536 
537 class DeadExpression final : public Expression {
538 public:
539   DeadExpression() : Expression(ET_Dead) {}
540   DeadExpression(const DeadExpression &) = delete;
541   DeadExpression &operator=(const DeadExpression &) = delete;
542 
543   static bool classof(const Expression *E) {
544     return E->getExpressionType() == ET_Dead;
545   }
546 };
547 
548 class VariableExpression final : public Expression {
549 private:
550   Value *VariableValue;
551 
552 public:
553   VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
554   VariableExpression() = delete;
555   VariableExpression(const VariableExpression &) = delete;
556   VariableExpression &operator=(const VariableExpression &) = delete;
557 
558   static bool classof(const Expression *EB) {
559     return EB->getExpressionType() == ET_Variable;
560   }
561 
562   Value *getVariableValue() const { return VariableValue; }
563   void setVariableValue(Value *V) { VariableValue = V; }
564 
565   bool equals(const Expression &Other) const override {
566     const VariableExpression &OC = cast<VariableExpression>(Other);
567     return VariableValue == OC.VariableValue;
568   }
569 
570   hash_code getHashValue() const override {
571     return hash_combine(this->Expression::getHashValue(),
572                         VariableValue->getType(), VariableValue);
573   }
574 
575   // Debugging support
576   void printInternal(raw_ostream &OS, bool PrintEType) const override {
577     if (PrintEType)
578       OS << "ExpressionTypeVariable, ";
579     this->Expression::printInternal(OS, false);
580     OS << " variable = " << *VariableValue;
581   }
582 };
583 
584 class ConstantExpression final : public Expression {
585 private:
586   Constant *ConstantValue = nullptr;
587 
588 public:
589   ConstantExpression() : Expression(ET_Constant) {}
590   ConstantExpression(Constant *constantValue)
591       : Expression(ET_Constant), ConstantValue(constantValue) {}
592   ConstantExpression(const ConstantExpression &) = delete;
593   ConstantExpression &operator=(const ConstantExpression &) = delete;
594 
595   static bool classof(const Expression *EB) {
596     return EB->getExpressionType() == ET_Constant;
597   }
598 
599   Constant *getConstantValue() const { return ConstantValue; }
600   void setConstantValue(Constant *V) { ConstantValue = V; }
601 
602   bool equals(const Expression &Other) const override {
603     const ConstantExpression &OC = cast<ConstantExpression>(Other);
604     return ConstantValue == OC.ConstantValue;
605   }
606 
607   hash_code getHashValue() const override {
608     return hash_combine(this->Expression::getHashValue(),
609                         ConstantValue->getType(), ConstantValue);
610   }
611 
612   // Debugging support
613   void printInternal(raw_ostream &OS, bool PrintEType) const override {
614     if (PrintEType)
615       OS << "ExpressionTypeConstant, ";
616     this->Expression::printInternal(OS, false);
617     OS << " constant = " << *ConstantValue;
618   }
619 };
620 
621 class UnknownExpression final : public Expression {
622 private:
623   Instruction *Inst;
624 
625 public:
626   UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
627   UnknownExpression() = delete;
628   UnknownExpression(const UnknownExpression &) = delete;
629   UnknownExpression &operator=(const UnknownExpression &) = delete;
630 
631   static bool classof(const Expression *EB) {
632     return EB->getExpressionType() == ET_Unknown;
633   }
634 
635   Instruction *getInstruction() const { return Inst; }
636   void setInstruction(Instruction *I) { Inst = I; }
637 
638   bool equals(const Expression &Other) const override {
639     const auto &OU = cast<UnknownExpression>(Other);
640     return Inst == OU.Inst;
641   }
642 
643   hash_code getHashValue() const override {
644     return hash_combine(this->Expression::getHashValue(), Inst);
645   }
646 
647   // Debugging support
648   void printInternal(raw_ostream &OS, bool PrintEType) const override {
649     if (PrintEType)
650       OS << "ExpressionTypeUnknown, ";
651     this->Expression::printInternal(OS, false);
652     OS << " inst = " << *Inst;
653   }
654 };
655 
656 } // end namespace GVNExpression
657 
658 } // end namespace llvm
659 
660 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
661