xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp (revision e8d8bef961a50d4dc22501cde4fb9fb0be1b2532)
1*e8d8bef9SDimitry Andric //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
2*e8d8bef9SDimitry Andric //
3*e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*e8d8bef9SDimitry Andric //
7*e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
8*e8d8bef9SDimitry Andric //
9*e8d8bef9SDimitry Andric // Utilities for generating tiled loops for matrix operations.
10*e8d8bef9SDimitry Andric //
11*e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
12*e8d8bef9SDimitry Andric 
13*e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h"
14*e8d8bef9SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h"
15*e8d8bef9SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
16*e8d8bef9SDimitry Andric #include "llvm/IR/BasicBlock.h"
17*e8d8bef9SDimitry Andric #include "llvm/IR/Dominators.h"
18*e8d8bef9SDimitry Andric #include "llvm/IR/IRBuilder.h"
19*e8d8bef9SDimitry Andric #include "llvm/IR/Type.h"
20*e8d8bef9SDimitry Andric 
21*e8d8bef9SDimitry Andric using namespace llvm;
22*e8d8bef9SDimitry Andric 
23*e8d8bef9SDimitry Andric BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
24*e8d8bef9SDimitry Andric                                  Value *Bound, Value *Step, StringRef Name,
25*e8d8bef9SDimitry Andric                                  IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
26*e8d8bef9SDimitry Andric                                  LoopInfo &LI) {
27*e8d8bef9SDimitry Andric   LLVMContext &Ctx = Preheader->getContext();
28*e8d8bef9SDimitry Andric   BasicBlock *Header = BasicBlock::Create(
29*e8d8bef9SDimitry Andric       Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
30*e8d8bef9SDimitry Andric   BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
31*e8d8bef9SDimitry Andric                                         Header->getParent(), Exit);
32*e8d8bef9SDimitry Andric   BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
33*e8d8bef9SDimitry Andric                                          Header->getParent(), Exit);
34*e8d8bef9SDimitry Andric 
35*e8d8bef9SDimitry Andric   Type *I32Ty = Type::getInt64Ty(Ctx);
36*e8d8bef9SDimitry Andric   BranchInst::Create(Body, Header);
37*e8d8bef9SDimitry Andric   BranchInst::Create(Latch, Body);
38*e8d8bef9SDimitry Andric   PHINode *IV =
39*e8d8bef9SDimitry Andric       PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator());
40*e8d8bef9SDimitry Andric   IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
41*e8d8bef9SDimitry Andric 
42*e8d8bef9SDimitry Andric   B.SetInsertPoint(Latch);
43*e8d8bef9SDimitry Andric   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
44*e8d8bef9SDimitry Andric   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
45*e8d8bef9SDimitry Andric   BranchInst::Create(Header, Exit, Cond, Latch);
46*e8d8bef9SDimitry Andric   IV->addIncoming(Inc, Latch);
47*e8d8bef9SDimitry Andric 
48*e8d8bef9SDimitry Andric   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
49*e8d8bef9SDimitry Andric   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
50*e8d8bef9SDimitry Andric   PreheaderBr->setSuccessor(0, Header);
51*e8d8bef9SDimitry Andric   DTU.applyUpdatesPermissive({
52*e8d8bef9SDimitry Andric       {DominatorTree::Delete, Preheader, Tmp},
53*e8d8bef9SDimitry Andric       {DominatorTree::Insert, Header, Body},
54*e8d8bef9SDimitry Andric       {DominatorTree::Insert, Body, Latch},
55*e8d8bef9SDimitry Andric       {DominatorTree::Insert, Latch, Header},
56*e8d8bef9SDimitry Andric       {DominatorTree::Insert, Latch, Exit},
57*e8d8bef9SDimitry Andric       {DominatorTree::Insert, Preheader, Header},
58*e8d8bef9SDimitry Andric   });
59*e8d8bef9SDimitry Andric 
60*e8d8bef9SDimitry Andric   L->addBasicBlockToLoop(Header, LI);
61*e8d8bef9SDimitry Andric   L->addBasicBlockToLoop(Body, LI);
62*e8d8bef9SDimitry Andric   L->addBasicBlockToLoop(Latch, LI);
63*e8d8bef9SDimitry Andric   return Body;
64*e8d8bef9SDimitry Andric }
65*e8d8bef9SDimitry Andric 
66*e8d8bef9SDimitry Andric // Creates the following loop nest skeleton:
67*e8d8bef9SDimitry Andric //  for C = 0; C < NumColumns; C += TileSize
68*e8d8bef9SDimitry Andric //    for R = 0; R < NumRows; R += TileSize
69*e8d8bef9SDimitry Andric //      for K = 0; K < Inner ; K += TileSize
70*e8d8bef9SDimitry Andric BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
71*e8d8bef9SDimitry Andric                                        IRBuilderBase &B, DomTreeUpdater &DTU,
72*e8d8bef9SDimitry Andric                                        LoopInfo &LI) {
73*e8d8bef9SDimitry Andric   Loop *ColLoop = LI.AllocateLoop();
74*e8d8bef9SDimitry Andric   Loop *RowLoop = LI.AllocateLoop();
75*e8d8bef9SDimitry Andric   Loop *InnerLoop = LI.AllocateLoop();
76*e8d8bef9SDimitry Andric   RowLoop->addChildLoop(InnerLoop);
77*e8d8bef9SDimitry Andric   ColLoop->addChildLoop(RowLoop);
78*e8d8bef9SDimitry Andric   if (Loop *ParentL = LI.getLoopFor(Start))
79*e8d8bef9SDimitry Andric     ParentL->addChildLoop(ColLoop);
80*e8d8bef9SDimitry Andric   else
81*e8d8bef9SDimitry Andric     LI.addTopLevelLoop(ColLoop);
82*e8d8bef9SDimitry Andric 
83*e8d8bef9SDimitry Andric   BasicBlock *ColBody =
84*e8d8bef9SDimitry Andric       CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
85*e8d8bef9SDimitry Andric                  "cols", B, DTU, ColLoop, LI);
86*e8d8bef9SDimitry Andric   BasicBlock *ColLatch = ColBody->getSingleSuccessor();
87*e8d8bef9SDimitry Andric   BasicBlock *RowBody =
88*e8d8bef9SDimitry Andric       CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize),
89*e8d8bef9SDimitry Andric                  "rows", B, DTU, RowLoop, LI);
90*e8d8bef9SDimitry Andric   RowLoopLatch = RowBody->getSingleSuccessor();
91*e8d8bef9SDimitry Andric 
92*e8d8bef9SDimitry Andric   BasicBlock *InnerBody =
93*e8d8bef9SDimitry Andric       CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner),
94*e8d8bef9SDimitry Andric                  B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI);
95*e8d8bef9SDimitry Andric   InnerLoopLatch = InnerBody->getSingleSuccessor();
96*e8d8bef9SDimitry Andric   ColumnLoopHeader = ColBody->getSinglePredecessor();
97*e8d8bef9SDimitry Andric   RowLoopHeader = RowBody->getSinglePredecessor();
98*e8d8bef9SDimitry Andric   InnerLoopHeader = InnerBody->getSinglePredecessor();
99*e8d8bef9SDimitry Andric   CurrentRow = &*RowLoopHeader->begin();
100*e8d8bef9SDimitry Andric   CurrentCol = &*ColumnLoopHeader->begin();
101*e8d8bef9SDimitry Andric   CurrentK = &*InnerLoopHeader->begin();
102*e8d8bef9SDimitry Andric 
103*e8d8bef9SDimitry Andric   return InnerBody;
104*e8d8bef9SDimitry Andric }
105