xref: /llvm-project/polly/lib/Support/SCEVValidator.cpp (revision 60b54f19e6a88c43abc6148c38d0a0ecc69a55f6)
1 
2 #include "polly/Support/SCEVValidator.h"
3 
4 #include "llvm/Analysis/ScalarEvolution.h"
5 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
6 #include "llvm/Analysis/RegionInfo.h"
7 
8 #include <vector>
9 
10 using namespace llvm;
11 
12 namespace SCEVType {
13   enum TYPE {INT, PARAM, IV, INVALID};
14 }
15 
16 struct ValidatorResult {
17   SCEVType::TYPE type;
18   std::vector<const SCEV*> Parameters;
19 
20   ValidatorResult() : type(SCEVType::INVALID) {};
21 
22   ValidatorResult(const ValidatorResult &vres) {
23     type = vres.type;
24     Parameters = vres.Parameters;
25   };
26 
27   ValidatorResult(SCEVType::TYPE type) : type(type) {};
28   ValidatorResult(SCEVType::TYPE type, const SCEV* Expr) : type(type) {
29     Parameters.push_back(Expr);
30   };
31 
32   bool isConstant() {
33     return type == SCEVType::INT || type == SCEVType::PARAM;
34   }
35 
36   bool isValid() {
37     return type != SCEVType::INVALID;
38   }
39 
40   bool isIV() {
41     return type == SCEVType::IV;
42   }
43 
44   bool isINT() {
45     return type == SCEVType::INT;
46   }
47 
48   void addParamsFrom(struct ValidatorResult &Source) {
49     Parameters.insert(Parameters.end(), Source.Parameters.begin(),
50                       Source.Parameters.end());
51   }
52 };
53 
54 /// Check if a SCEV is valid in a SCoP.
55 struct SCEVValidator
56   : public SCEVVisitor<SCEVValidator, struct ValidatorResult> {
57 private:
58   const Region *R;
59   ScalarEvolution &SE;
60   Value **BaseAddress;
61 
62 public:
63   SCEVValidator(const Region *R, ScalarEvolution &SE,
64                 Value **BaseAddress) : R(R), SE(SE),
65     BaseAddress(BaseAddress) {};
66 
67   struct ValidatorResult visitConstant(const SCEVConstant *Constant) {
68     return ValidatorResult(SCEVType::INT);
69   }
70 
71   struct ValidatorResult visitTruncateExpr(const SCEVTruncateExpr* Expr) {
72     ValidatorResult Op = visit(Expr->getOperand());
73 
74     // We currently do not represent a truncate expression as an affine
75     // expression. If it is constant during Scop execution, we treat it as a
76     // parameter, otherwise we bail out.
77     if (Op.isConstant())
78       return ValidatorResult(SCEVType::PARAM, Expr);
79 
80     return ValidatorResult (SCEVType::INVALID);
81   }
82 
83   struct ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr) {
84     ValidatorResult Op = visit(Expr->getOperand());
85 
86     // We currently do not represent a zero extend expression as an affine
87     // expression. If it is constant during Scop execution, we treat it as a
88     // parameter, otherwise we bail out.
89     if (Op.isConstant())
90       return ValidatorResult (SCEVType::PARAM, Expr);
91 
92     return ValidatorResult(SCEVType::INVALID);
93   }
94 
95   struct ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr* Expr) {
96     // We currently allow only signed SCEV expressions. In the case of a
97     // signed value, a sign extend is a noop.
98     //
99     // TODO: Reconsider this when we add support for unsigned values.
100     return visit(Expr->getOperand());
101   }
102 
103   struct ValidatorResult visitAddExpr(const SCEVAddExpr* Expr) {
104     ValidatorResult Return(SCEVType::INT);
105 
106     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
107       ValidatorResult Op = visit(Expr->getOperand(i));
108 
109       if (!Op.isValid())
110         return ValidatorResult(SCEVType::INVALID);
111 
112       Return.type = std::max(Return.type, Op.type);
113       Return.addParamsFrom(Op);
114     }
115 
116     // TODO: Check for NSW and NUW.
117     return Return;
118   }
119 
120   struct ValidatorResult visitMulExpr(const SCEVMulExpr* Expr) {
121     ValidatorResult Return(SCEVType::INT);
122 
123     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
124       ValidatorResult Op = visit(Expr->getOperand(i));
125 
126       if (Op.type == SCEVType::INT)
127         continue;
128 
129       if (Op.type == SCEVType::INVALID || Return.type != SCEVType::INT)
130         return ValidatorResult(SCEVType::INVALID);
131 
132       Return.type = Op.type;
133       Return.addParamsFrom(Op);
134     }
135 
136     // TODO: Check for NSW and NUW.
137     return Return;
138   }
139 
140   struct ValidatorResult visitUDivExpr(const SCEVUDivExpr* Expr) {
141     ValidatorResult LHS = visit(Expr->getLHS());
142     ValidatorResult RHS = visit(Expr->getRHS());
143 
144     // We currently do not represent a unsigned devision as an affine
145     // expression. If the division is constant during Scop execution we treat it
146     // as a parameter, otherwise we bail out.
147     if (LHS.isConstant() && RHS.isConstant())
148       return ValidatorResult(SCEVType::PARAM, Expr);
149 
150     return ValidatorResult(SCEVType::INVALID);
151   }
152 
153   struct ValidatorResult visitAddRecExpr(const SCEVAddRecExpr* Expr) {
154     if (!Expr->isAffine())
155       return ValidatorResult(SCEVType::INVALID);
156 
157     ValidatorResult Start = visit(Expr->getStart());
158     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
159 
160     if (!Start.isValid() || !Recurrence.isValid() || Recurrence.isIV())
161       return ValidatorResult(SCEVType::INVALID);
162 
163 
164     if (!R->contains(Expr->getLoop())) {
165       if (Start.isIV())
166         return ValidatorResult(SCEVType::INVALID);
167       else
168         return ValidatorResult(SCEVType::PARAM, Expr);
169     }
170 
171     if (!Recurrence.isINT())
172       return ValidatorResult(SCEVType::INVALID);
173 
174     ValidatorResult Result(SCEVType::IV);
175     Result.addParamsFrom(Start);
176     return Result;
177   }
178 
179   struct ValidatorResult visitSMaxExpr(const SCEVSMaxExpr* Expr) {
180     ValidatorResult Return(SCEVType::INT);
181 
182     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
183       ValidatorResult Op = visit(Expr->getOperand(i));
184 
185       if (!Op.isValid())
186         return ValidatorResult(SCEVType::INVALID);
187 
188       Return.type = std::max(Return.type, Op.type);
189       Return.addParamsFrom(Op);
190     }
191 
192     return Return;
193   }
194 
195   struct ValidatorResult visitUMaxExpr(const SCEVUMaxExpr* Expr) {
196     ValidatorResult Return(SCEVType::PARAM);
197 
198     // We do not support unsigned operations. If 'Expr' is constant during Scop
199     // execution we treat this as a parameter, otherwise we bail out.
200     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
201       ValidatorResult Op = visit(Expr->getOperand(i));
202 
203       if (!Op.isConstant())
204         return ValidatorResult(SCEVType::INVALID);
205 
206       Return.addParamsFrom(Op);
207     }
208 
209     return Return;
210   }
211 
212   ValidatorResult visitUnknown(const SCEVUnknown* Expr) {
213     Value *V = Expr->getValue();
214 
215     if (isa<UndefValue>(V))
216       return ValidatorResult(SCEVType::INVALID);
217 
218     if (BaseAddress) {
219       if (*BaseAddress)
220         return ValidatorResult(SCEVType::INVALID);
221       else
222         *BaseAddress = V;
223     }
224 
225     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue()))
226       if (R->contains(I))
227         return ValidatorResult(SCEVType::INVALID);
228 
229     if (BaseAddress)
230       return ValidatorResult(SCEVType::PARAM);
231     else
232       return ValidatorResult(SCEVType::PARAM, Expr);
233   }
234 };
235 
236 namespace polly {
237   bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE,
238                     Value **BaseAddress) {
239     if (isa<SCEVCouldNotCompute>(Expr))
240       return false;
241 
242     if (BaseAddress)
243       *BaseAddress = NULL;
244 
245     SCEVValidator Validator(R, SE, BaseAddress);
246     ValidatorResult Result = Validator.visit(Expr);
247 
248     return Result.isValid();
249   }
250 
251   std::vector<const SCEV*> getParamsInAffineExpr(const Region *R,
252                                                  const SCEV *Expr,
253                                                  ScalarEvolution &SE,
254                                                  Value **BaseAddress) {
255     if (isa<SCEVCouldNotCompute>(Expr))
256       return std::vector<const SCEV*>();
257 
258     if (BaseAddress)
259       *BaseAddress = NULL;
260 
261     SCEVValidator Validator(R, SE, BaseAddress);
262     ValidatorResult Result = Validator.visit(Expr);
263 
264     return Result.Parameters;
265   }
266 }
267 
268 
269