xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision bcc0a0d56059a366c0c0d964fd760ee4bc83809e)
1 
2 #include "polly/Support/SCEVValidator.h"
3 
4 #include "llvm/Analysis/ScalarEvolution.h"
5 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
6 #include "llvm/Analysis/RegionInfo.h"
7 
8 #include <vector>
9 
10 using namespace llvm;
11 
12 namespace SCEVType {
13   enum TYPE {INT, PARAM, IV, INVALID};
14 }
15 
16 struct ValidatorResult {
17   SCEVType::TYPE type;
18   std::vector<const SCEV*> Parameters;
19 
20   ValidatorResult() : type(SCEVType::INVALID) {};
21 
22   ValidatorResult(const ValidatorResult &vres) {
23     type = vres.type;
24     Parameters = vres.Parameters;
25   };
26 
27   ValidatorResult(SCEVType::TYPE type) : type(type) {};
28   ValidatorResult(SCEVType::TYPE type, const SCEV *Expr) : type(type) {
29     Parameters.push_back(Expr);
30   };
31 
32   bool isConstant() {
33     return type == SCEVType::INT || type == SCEVType::PARAM;
34   }
35 
36   bool isValid() {
37     return type != SCEVType::INVALID;
38   }
39 
40   bool isIV() {
41     return type == SCEVType::IV;
42   }
43 
44   bool isINT() {
45     return type == SCEVType::INT;
46   }
47 
48   void addParamsFrom(struct ValidatorResult &Source) {
49     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
50                       Source.Parameters.end());
51   }
52 };
53 
54 /// Check if a SCEV is valid in a SCoP.
55 struct SCEVValidator
56   : public SCEVVisitor<SCEVValidator, struct ValidatorResult> {
57 private:
58   const Region *R;
59   ScalarEvolution &SE;
60   const Value *BaseAddress;
61 
62 public:
63   SCEVValidator(const Region *R, ScalarEvolution &SE,
64                 const Value *BaseAddress) : R(R), SE(SE),
65     BaseAddress(BaseAddress) {};
66 
67   struct ValidatorResult visitConstant(const SCEVConstant *Constant) {
68     return ValidatorResult(SCEVType::INT);
69   }
70 
71   struct ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
72     ValidatorResult Op = visit(Expr->getOperand());
73 
74     // We currently do not represent a truncate expression as an affine
75     // expression. If it is constant during Scop execution, we treat it as a
76     // parameter, otherwise we bail out.
77     if (Op.isConstant())
78       return ValidatorResult(SCEVType::PARAM, Expr);
79 
80     return ValidatorResult (SCEVType::INVALID);
81   }
82 
83   struct ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
84     ValidatorResult Op = visit(Expr->getOperand());
85 
86     // We currently do not represent a zero extend expression as an affine
87     // expression. If it is constant during Scop execution, we treat it as a
88     // parameter, otherwise we bail out.
89     if (Op.isConstant())
90       return ValidatorResult (SCEVType::PARAM, Expr);
91 
92     return ValidatorResult(SCEVType::INVALID);
93   }
94 
95   struct ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
96     // We currently allow only signed SCEV expressions. In the case of a
97     // signed value, a sign extend is a noop.
98     //
99     // TODO: Reconsider this when we add support for unsigned values.
100     return visit(Expr->getOperand());
101   }
102 
103   struct ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
104     ValidatorResult Return(SCEVType::INT);
105 
106     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
107       ValidatorResult Op = visit(Expr->getOperand(i));
108 
109       if (!Op.isValid())
110         return ValidatorResult(SCEVType::INVALID);
111 
112       Return.type = std::max(Return.type, Op.type);
113       Return.addParamsFrom(Op);
114     }
115 
116     // TODO: Check for NSW and NUW.
117     return Return;
118   }
119 
120   struct ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
121     ValidatorResult Return(SCEVType::INT);
122 
123     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
124       ValidatorResult Op = visit(Expr->getOperand(i));
125 
126       if (Op.type == SCEVType::INT)
127         continue;
128 
129       if (Op.type == SCEVType::INVALID || Return.type != SCEVType::INT)
130         return ValidatorResult(SCEVType::INVALID);
131 
132       Return.type = Op.type;
133       Return.addParamsFrom(Op);
134     }
135 
136     // TODO: Check for NSW and NUW.
137     return Return;
138   }
139 
140   struct ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
141     ValidatorResult LHS = visit(Expr->getLHS());
142     ValidatorResult RHS = visit(Expr->getRHS());
143 
144     // We currently do not represent an unsigned devision as an affine
145     // expression. If the division is constant during Scop execution we treat it
146     // as a parameter, otherwise we bail out.
147     if (LHS.isConstant() && RHS.isConstant())
148       return ValidatorResult(SCEVType::PARAM, Expr);
149 
150     return ValidatorResult(SCEVType::INVALID);
151   }
152 
153   struct ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
154     if (!Expr->isAffine())
155       return ValidatorResult(SCEVType::INVALID);
156 
157     ValidatorResult Start = visit(Expr->getStart());
158     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
159 
160     if (!Start.isValid() || !Recurrence.isConstant())
161       return ValidatorResult(SCEVType::INVALID);
162 
163     if (R->contains(Expr->getLoop())) {
164       if (Recurrence.isINT()) {
165         ValidatorResult Result(SCEVType::IV);
166         Result.addParamsFrom(Start);
167         return Result;
168       }
169 
170       return ValidatorResult(SCEVType::INVALID);
171     }
172 
173     if (Start.isConstant())
174       return ValidatorResult(SCEVType::PARAM, Expr);
175 
176     return ValidatorResult(SCEVType::INVALID);
177   }
178 
179   struct ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
180     ValidatorResult Return(SCEVType::INT);
181 
182     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
183       ValidatorResult Op = visit(Expr->getOperand(i));
184 
185       if (!Op.isValid())
186         return ValidatorResult(SCEVType::INVALID);
187 
188       Return.type = std::max(Return.type, Op.type);
189       Return.addParamsFrom(Op);
190     }
191 
192     return Return;
193   }
194 
195   struct ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
196     ValidatorResult Return(SCEVType::PARAM);
197 
198     // We do not support unsigned operations. If 'Expr' is constant during Scop
199     // execution we treat this as a parameter, otherwise we bail out.
200     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
201       ValidatorResult Op = visit(Expr->getOperand(i));
202 
203       if (!Op.isConstant())
204         return ValidatorResult(SCEVType::INVALID);
205 
206       Return.addParamsFrom(Op);
207     }
208 
209     return Return;
210   }
211 
212   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
213     Value *V = Expr->getValue();
214 
215     if (isa<UndefValue>(V))
216       return ValidatorResult(SCEVType::INVALID);
217 
218     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue()))
219       if (R->contains(I))
220         return ValidatorResult(SCEVType::INVALID);
221 
222     if (BaseAddress == V)
223       return ValidatorResult(SCEVType::INVALID);
224 
225     return ValidatorResult(SCEVType::PARAM, Expr);
226   }
227 };
228 
229 namespace polly {
230   bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
231                     const Value *BaseAddress) {
232     if (isa<SCEVCouldNotCompute>(Expr))
233       return false;
234 
235     SCEVValidator Validator(R, SE, BaseAddress);
236     ValidatorResult Result = Validator.visit(Expr);
237 
238     return Result.isValid();
239   }
240 
241   std::vector<const SCEV*> getParamsInAffineExpr(const Region *R,
242                                                  const SCEV *Expr,
243                                                  ScalarEvolution &SE,
244                                                  const Value *BaseAddress) {
245     if (isa<SCEVCouldNotCompute>(Expr))
246       return std::vector<const SCEV*>();
247 
248     SCEVValidator Validator(R, SE, BaseAddress);
249     ValidatorResult Result = Validator.visit(Expr);
250 
251     return Result.Parameters;
252   }
253 }
254 
255 
256