10b57cec5SDimitry Andric //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
90b57cec5SDimitry Andric // This file implements utilities for working with "normalized" expressions.
100b57cec5SDimitry Andric // See the comments at the top of ScalarEvolutionNormalization.h for details.
110b57cec5SDimitry Andric //
120b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
130b57cec5SDimitry Andric
140b57cec5SDimitry Andric #include "llvm/Analysis/ScalarEvolutionNormalization.h"
150b57cec5SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
1681ad6265SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
170b57cec5SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h"
180b57cec5SDimitry Andric using namespace llvm;
190b57cec5SDimitry Andric
200b57cec5SDimitry Andric /// TransformKind - Different types of transformations that
210b57cec5SDimitry Andric /// TransformForPostIncUse can do.
220b57cec5SDimitry Andric enum TransformKind {
230b57cec5SDimitry Andric /// Normalize - Normalize according to the given loops.
240b57cec5SDimitry Andric Normalize,
250b57cec5SDimitry Andric /// Denormalize - Perform the inverse transform on the expression with the
260b57cec5SDimitry Andric /// given loop set.
270b57cec5SDimitry Andric Denormalize
280b57cec5SDimitry Andric };
290b57cec5SDimitry Andric
300b57cec5SDimitry Andric namespace {
310b57cec5SDimitry Andric struct NormalizeDenormalizeRewriter
320b57cec5SDimitry Andric : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
330b57cec5SDimitry Andric const TransformKind Kind;
340b57cec5SDimitry Andric
350b57cec5SDimitry Andric // NB! Pred is a function_ref. Storing it here is okay only because
360b57cec5SDimitry Andric // we're careful about the lifetime of NormalizeDenormalizeRewriter.
370b57cec5SDimitry Andric const NormalizePredTy Pred;
380b57cec5SDimitry Andric
NormalizeDenormalizeRewriter__anon7777de170111::NormalizeDenormalizeRewriter390b57cec5SDimitry Andric NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
400b57cec5SDimitry Andric ScalarEvolution &SE)
410b57cec5SDimitry Andric : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
420b57cec5SDimitry Andric Pred(Pred) {}
430b57cec5SDimitry Andric const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
440b57cec5SDimitry Andric };
450b57cec5SDimitry Andric } // namespace
460b57cec5SDimitry Andric
470b57cec5SDimitry Andric const SCEV *
visitAddRecExpr(const SCEVAddRecExpr * AR)480b57cec5SDimitry Andric NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
490b57cec5SDimitry Andric SmallVector<const SCEV *, 8> Operands;
500b57cec5SDimitry Andric
510b57cec5SDimitry Andric transform(AR->operands(), std::back_inserter(Operands),
520b57cec5SDimitry Andric [&](const SCEV *Op) { return visit(Op); });
530b57cec5SDimitry Andric
540b57cec5SDimitry Andric if (!Pred(AR))
550b57cec5SDimitry Andric return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
560b57cec5SDimitry Andric
570b57cec5SDimitry Andric // Normalization and denormalization are fancy names for decrementing and
580b57cec5SDimitry Andric // incrementing a SCEV expression with respect to a set of loops. Since
590b57cec5SDimitry Andric // Pred(AR) has returned true, we know we need to normalize or denormalize AR
600b57cec5SDimitry Andric // with respect to its loop.
610b57cec5SDimitry Andric
620b57cec5SDimitry Andric if (Kind == Denormalize) {
630b57cec5SDimitry Andric // Denormalization / "partial increment" is essentially the same as \c
640b57cec5SDimitry Andric // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the
650b57cec5SDimitry Andric // symmetry with Normalization clear.
660b57cec5SDimitry Andric for (int i = 0, e = Operands.size() - 1; i < e; i++)
670b57cec5SDimitry Andric Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
680b57cec5SDimitry Andric } else {
690b57cec5SDimitry Andric assert(Kind == Normalize && "Only two possibilities!");
700b57cec5SDimitry Andric
710b57cec5SDimitry Andric // Normalization / "partial decrement" is a bit more subtle. Since
720b57cec5SDimitry Andric // incrementing a SCEV expression (in general) changes the step of the SCEV
730b57cec5SDimitry Andric // expression as well, we cannot use the step of the current expression.
740b57cec5SDimitry Andric // Instead, we have to use the step of the very expression we're trying to
750b57cec5SDimitry Andric // compute!
760b57cec5SDimitry Andric //
770b57cec5SDimitry Andric // We solve the issue by recursively building up the result, starting from
780b57cec5SDimitry Andric // the "least significant" operand in the add recurrence:
790b57cec5SDimitry Andric //
800b57cec5SDimitry Andric // Base case:
810b57cec5SDimitry Andric // Single operand add recurrence. It's its own normalization.
820b57cec5SDimitry Andric //
830b57cec5SDimitry Andric // N-operand case:
840b57cec5SDimitry Andric // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
850b57cec5SDimitry Andric //
860b57cec5SDimitry Andric // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
870b57cec5SDimitry Andric // normalization by induction. We subtract the normalized step
880b57cec5SDimitry Andric // recurrence from S_{N-1} to get the normalization of S.
890b57cec5SDimitry Andric
900b57cec5SDimitry Andric for (int i = Operands.size() - 2; i >= 0; i--)
910b57cec5SDimitry Andric Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
920b57cec5SDimitry Andric }
930b57cec5SDimitry Andric
940b57cec5SDimitry Andric return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
950b57cec5SDimitry Andric }
960b57cec5SDimitry Andric
normalizeForPostIncUse(const SCEV * S,const PostIncLoopSet & Loops,ScalarEvolution & SE,bool CheckInvertible)970b57cec5SDimitry Andric const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
980b57cec5SDimitry Andric const PostIncLoopSet &Loops,
99*06c3fb27SDimitry Andric ScalarEvolution &SE,
100*06c3fb27SDimitry Andric bool CheckInvertible) {
101*06c3fb27SDimitry Andric if (Loops.empty())
102*06c3fb27SDimitry Andric return S;
1030b57cec5SDimitry Andric auto Pred = [&](const SCEVAddRecExpr *AR) {
1040b57cec5SDimitry Andric return Loops.count(AR->getLoop());
1050b57cec5SDimitry Andric };
106*06c3fb27SDimitry Andric const SCEV *Normalized =
107*06c3fb27SDimitry Andric NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
108*06c3fb27SDimitry Andric const SCEV *Denormalized = denormalizeForPostIncUse(Normalized, Loops, SE);
109*06c3fb27SDimitry Andric // If the normalized expression isn't invertible.
110*06c3fb27SDimitry Andric if (CheckInvertible && Denormalized != S)
111*06c3fb27SDimitry Andric return nullptr;
112*06c3fb27SDimitry Andric return Normalized;
1130b57cec5SDimitry Andric }
1140b57cec5SDimitry Andric
normalizeForPostIncUseIf(const SCEV * S,NormalizePredTy Pred,ScalarEvolution & SE)1150b57cec5SDimitry Andric const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
1160b57cec5SDimitry Andric ScalarEvolution &SE) {
1170b57cec5SDimitry Andric return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
1180b57cec5SDimitry Andric }
1190b57cec5SDimitry Andric
denormalizeForPostIncUse(const SCEV * S,const PostIncLoopSet & Loops,ScalarEvolution & SE)1200b57cec5SDimitry Andric const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
1210b57cec5SDimitry Andric const PostIncLoopSet &Loops,
1220b57cec5SDimitry Andric ScalarEvolution &SE) {
123*06c3fb27SDimitry Andric if (Loops.empty())
124*06c3fb27SDimitry Andric return S;
1250b57cec5SDimitry Andric auto Pred = [&](const SCEVAddRecExpr *AR) {
1260b57cec5SDimitry Andric return Loops.count(AR->getLoop());
1270b57cec5SDimitry Andric };
1280b57cec5SDimitry Andric return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
1290b57cec5SDimitry Andric }
130