xref: /netbsd-src/external/apache2/llvm/dist/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp (revision 7330f729ccf0bd976a06f95fad452fe774fc7fd1)
1*7330f729Sjoerg //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2*7330f729Sjoerg //
3*7330f729Sjoerg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*7330f729Sjoerg // See https://llvm.org/LICENSE.txt for license information.
5*7330f729Sjoerg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*7330f729Sjoerg //
7*7330f729Sjoerg //===----------------------------------------------------------------------===//
8*7330f729Sjoerg //
9*7330f729Sjoerg // This file implements utilities for working with "normalized" expressions.
10*7330f729Sjoerg // See the comments at the top of ScalarEvolutionNormalization.h for details.
11*7330f729Sjoerg //
12*7330f729Sjoerg //===----------------------------------------------------------------------===//
13*7330f729Sjoerg 
14*7330f729Sjoerg #include "llvm/Analysis/ScalarEvolutionNormalization.h"
15*7330f729Sjoerg #include "llvm/Analysis/LoopInfo.h"
16*7330f729Sjoerg #include "llvm/Analysis/ScalarEvolutionExpressions.h"
17*7330f729Sjoerg using namespace llvm;
18*7330f729Sjoerg 
19*7330f729Sjoerg /// TransformKind - Different types of transformations that
20*7330f729Sjoerg /// TransformForPostIncUse can do.
21*7330f729Sjoerg enum TransformKind {
22*7330f729Sjoerg   /// Normalize - Normalize according to the given loops.
23*7330f729Sjoerg   Normalize,
24*7330f729Sjoerg   /// Denormalize - Perform the inverse transform on the expression with the
25*7330f729Sjoerg   /// given loop set.
26*7330f729Sjoerg   Denormalize
27*7330f729Sjoerg };
28*7330f729Sjoerg 
29*7330f729Sjoerg namespace {
30*7330f729Sjoerg struct NormalizeDenormalizeRewriter
31*7330f729Sjoerg     : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
32*7330f729Sjoerg   const TransformKind Kind;
33*7330f729Sjoerg 
34*7330f729Sjoerg   // NB! Pred is a function_ref.  Storing it here is okay only because
35*7330f729Sjoerg   // we're careful about the lifetime of NormalizeDenormalizeRewriter.
36*7330f729Sjoerg   const NormalizePredTy Pred;
37*7330f729Sjoerg 
NormalizeDenormalizeRewriter__anon70131a300111::NormalizeDenormalizeRewriter38*7330f729Sjoerg   NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
39*7330f729Sjoerg                                ScalarEvolution &SE)
40*7330f729Sjoerg       : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
41*7330f729Sjoerg         Pred(Pred) {}
42*7330f729Sjoerg   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
43*7330f729Sjoerg };
44*7330f729Sjoerg } // namespace
45*7330f729Sjoerg 
46*7330f729Sjoerg const SCEV *
visitAddRecExpr(const SCEVAddRecExpr * AR)47*7330f729Sjoerg NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
48*7330f729Sjoerg   SmallVector<const SCEV *, 8> Operands;
49*7330f729Sjoerg 
50*7330f729Sjoerg   transform(AR->operands(), std::back_inserter(Operands),
51*7330f729Sjoerg             [&](const SCEV *Op) { return visit(Op); });
52*7330f729Sjoerg 
53*7330f729Sjoerg   if (!Pred(AR))
54*7330f729Sjoerg     return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
55*7330f729Sjoerg 
56*7330f729Sjoerg   // Normalization and denormalization are fancy names for decrementing and
57*7330f729Sjoerg   // incrementing a SCEV expression with respect to a set of loops.  Since
58*7330f729Sjoerg   // Pred(AR) has returned true, we know we need to normalize or denormalize AR
59*7330f729Sjoerg   // with respect to its loop.
60*7330f729Sjoerg 
61*7330f729Sjoerg   if (Kind == Denormalize) {
62*7330f729Sjoerg     // Denormalization / "partial increment" is essentially the same as \c
63*7330f729Sjoerg     // SCEVAddRecExpr::getPostIncExpr.  Here we use an explicit loop to make the
64*7330f729Sjoerg     // symmetry with Normalization clear.
65*7330f729Sjoerg     for (int i = 0, e = Operands.size() - 1; i < e; i++)
66*7330f729Sjoerg       Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
67*7330f729Sjoerg   } else {
68*7330f729Sjoerg     assert(Kind == Normalize && "Only two possibilities!");
69*7330f729Sjoerg 
70*7330f729Sjoerg     // Normalization / "partial decrement" is a bit more subtle.  Since
71*7330f729Sjoerg     // incrementing a SCEV expression (in general) changes the step of the SCEV
72*7330f729Sjoerg     // expression as well, we cannot use the step of the current expression.
73*7330f729Sjoerg     // Instead, we have to use the step of the very expression we're trying to
74*7330f729Sjoerg     // compute!
75*7330f729Sjoerg     //
76*7330f729Sjoerg     // We solve the issue by recursively building up the result, starting from
77*7330f729Sjoerg     // the "least significant" operand in the add recurrence:
78*7330f729Sjoerg     //
79*7330f729Sjoerg     // Base case:
80*7330f729Sjoerg     //   Single operand add recurrence.  It's its own normalization.
81*7330f729Sjoerg     //
82*7330f729Sjoerg     // N-operand case:
83*7330f729Sjoerg     //   {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
84*7330f729Sjoerg     //
85*7330f729Sjoerg     //   Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
86*7330f729Sjoerg     //   normalization by induction.  We subtract the normalized step
87*7330f729Sjoerg     //   recurrence from S_{N-1} to get the normalization of S.
88*7330f729Sjoerg 
89*7330f729Sjoerg     for (int i = Operands.size() - 2; i >= 0; i--)
90*7330f729Sjoerg       Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
91*7330f729Sjoerg   }
92*7330f729Sjoerg 
93*7330f729Sjoerg   return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
94*7330f729Sjoerg }
95*7330f729Sjoerg 
normalizeForPostIncUse(const SCEV * S,const PostIncLoopSet & Loops,ScalarEvolution & SE)96*7330f729Sjoerg const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
97*7330f729Sjoerg                                          const PostIncLoopSet &Loops,
98*7330f729Sjoerg                                          ScalarEvolution &SE) {
99*7330f729Sjoerg   auto Pred = [&](const SCEVAddRecExpr *AR) {
100*7330f729Sjoerg     return Loops.count(AR->getLoop());
101*7330f729Sjoerg   };
102*7330f729Sjoerg   return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
103*7330f729Sjoerg }
104*7330f729Sjoerg 
normalizeForPostIncUseIf(const SCEV * S,NormalizePredTy Pred,ScalarEvolution & SE)105*7330f729Sjoerg const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
106*7330f729Sjoerg                                            ScalarEvolution &SE) {
107*7330f729Sjoerg   return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
108*7330f729Sjoerg }
109*7330f729Sjoerg 
denormalizeForPostIncUse(const SCEV * S,const PostIncLoopSet & Loops,ScalarEvolution & SE)110*7330f729Sjoerg const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
111*7330f729Sjoerg                                            const PostIncLoopSet &Loops,
112*7330f729Sjoerg                                            ScalarEvolution &SE) {
113*7330f729Sjoerg   auto Pred = [&](const SCEVAddRecExpr *AR) {
114*7330f729Sjoerg     return Loops.count(AR->getLoop());
115*7330f729Sjoerg   };
116*7330f729Sjoerg   return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
117*7330f729Sjoerg }
118