xref: /llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp (revision 6b62a9135a28bd001263e5a9db08d4cff1123126)
1e1270b16SFlorian Hahn //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
2e1270b16SFlorian Hahn //
3e1270b16SFlorian Hahn // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e1270b16SFlorian Hahn // See https://llvm.org/LICENSE.txt for license information.
5e1270b16SFlorian Hahn // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e1270b16SFlorian Hahn //
7e1270b16SFlorian Hahn //===----------------------------------------------------------------------===//
8e1270b16SFlorian Hahn //
9e1270b16SFlorian Hahn // Utilities for generating tiled loops for matrix operations.
10e1270b16SFlorian Hahn //
11e1270b16SFlorian Hahn //===----------------------------------------------------------------------===//
12e1270b16SFlorian Hahn 
13e1270b16SFlorian Hahn #include "llvm/Transforms/Utils/MatrixUtils.h"
14e1270b16SFlorian Hahn #include "llvm/Analysis/DomTreeUpdater.h"
15e1270b16SFlorian Hahn #include "llvm/Analysis/LoopInfo.h"
16e1270b16SFlorian Hahn #include "llvm/IR/BasicBlock.h"
17e1270b16SFlorian Hahn #include "llvm/IR/Dominators.h"
18e1270b16SFlorian Hahn #include "llvm/IR/IRBuilder.h"
19e1270b16SFlorian Hahn #include "llvm/IR/Type.h"
20e1270b16SFlorian Hahn 
21e1270b16SFlorian Hahn using namespace llvm;
22e1270b16SFlorian Hahn 
CreateLoop(BasicBlock * Preheader,BasicBlock * Exit,Value * Bound,Value * Step,StringRef Name,IRBuilderBase & B,DomTreeUpdater & DTU,Loop * L,LoopInfo & LI)23e1270b16SFlorian Hahn BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
24e1270b16SFlorian Hahn                                  Value *Bound, Value *Step, StringRef Name,
25e1270b16SFlorian Hahn                                  IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
26e1270b16SFlorian Hahn                                  LoopInfo &LI) {
27e1270b16SFlorian Hahn   LLVMContext &Ctx = Preheader->getContext();
28e1270b16SFlorian Hahn   BasicBlock *Header = BasicBlock::Create(
29e1270b16SFlorian Hahn       Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
30e1270b16SFlorian Hahn   BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
31e1270b16SFlorian Hahn                                         Header->getParent(), Exit);
32e1270b16SFlorian Hahn   BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
33e1270b16SFlorian Hahn                                          Header->getParent(), Exit);
34e1270b16SFlorian Hahn 
35e1270b16SFlorian Hahn   Type *I32Ty = Type::getInt64Ty(Ctx);
36e1270b16SFlorian Hahn   BranchInst::Create(Body, Header);
37e1270b16SFlorian Hahn   BranchInst::Create(Latch, Body);
38e1270b16SFlorian Hahn   PHINode *IV =
39*6b62a913SJeremy Morse       PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
40e1270b16SFlorian Hahn   IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
41e1270b16SFlorian Hahn 
42e1270b16SFlorian Hahn   B.SetInsertPoint(Latch);
43e1270b16SFlorian Hahn   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
44e1270b16SFlorian Hahn   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
45e1270b16SFlorian Hahn   BranchInst::Create(Header, Exit, Cond, Latch);
46e1270b16SFlorian Hahn   IV->addIncoming(Inc, Latch);
47e1270b16SFlorian Hahn 
48e1270b16SFlorian Hahn   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
49e1270b16SFlorian Hahn   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
50e1270b16SFlorian Hahn   PreheaderBr->setSuccessor(0, Header);
51e1270b16SFlorian Hahn   DTU.applyUpdatesPermissive({
52e1270b16SFlorian Hahn       {DominatorTree::Delete, Preheader, Tmp},
53e1270b16SFlorian Hahn       {DominatorTree::Insert, Header, Body},
54e1270b16SFlorian Hahn       {DominatorTree::Insert, Body, Latch},
55e1270b16SFlorian Hahn       {DominatorTree::Insert, Latch, Header},
56e1270b16SFlorian Hahn       {DominatorTree::Insert, Latch, Exit},
57e1270b16SFlorian Hahn       {DominatorTree::Insert, Preheader, Header},
58e1270b16SFlorian Hahn   });
59e1270b16SFlorian Hahn 
60e1270b16SFlorian Hahn   L->addBasicBlockToLoop(Header, LI);
61e1270b16SFlorian Hahn   L->addBasicBlockToLoop(Body, LI);
62e1270b16SFlorian Hahn   L->addBasicBlockToLoop(Latch, LI);
63e1270b16SFlorian Hahn   return Body;
64e1270b16SFlorian Hahn }
65e1270b16SFlorian Hahn 
66e1270b16SFlorian Hahn // Creates the following loop nest skeleton:
67e1270b16SFlorian Hahn //  for C = 0; C < NumColumns; C += TileSize
68e1270b16SFlorian Hahn //    for R = 0; R < NumRows; R += TileSize
69e1270b16SFlorian Hahn //      for K = 0; K < Inner ; K += TileSize
CreateTiledLoops(BasicBlock * Start,BasicBlock * End,IRBuilderBase & B,DomTreeUpdater & DTU,LoopInfo & LI)70e1270b16SFlorian Hahn BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
71e1270b16SFlorian Hahn                                        IRBuilderBase &B, DomTreeUpdater &DTU,
72e1270b16SFlorian Hahn                                        LoopInfo &LI) {
732c6e8b46SFrancis Visoiu Mistrih   Loop *ColumnLoopInfo = LI.AllocateLoop();
742c6e8b46SFrancis Visoiu Mistrih   Loop *RowLoopInfo = LI.AllocateLoop();
752c6e8b46SFrancis Visoiu Mistrih   Loop *KLoopInfo = LI.AllocateLoop();
762c6e8b46SFrancis Visoiu Mistrih   RowLoopInfo->addChildLoop(KLoopInfo);
772c6e8b46SFrancis Visoiu Mistrih   ColumnLoopInfo->addChildLoop(RowLoopInfo);
78e1270b16SFlorian Hahn   if (Loop *ParentL = LI.getLoopFor(Start))
792c6e8b46SFrancis Visoiu Mistrih     ParentL->addChildLoop(ColumnLoopInfo);
80e1270b16SFlorian Hahn   else
812c6e8b46SFrancis Visoiu Mistrih     LI.addTopLevelLoop(ColumnLoopInfo);
82e1270b16SFlorian Hahn 
83e1270b16SFlorian Hahn   BasicBlock *ColBody =
84e1270b16SFlorian Hahn       CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
852c6e8b46SFrancis Visoiu Mistrih                  "cols", B, DTU, ColumnLoopInfo, LI);
862c6e8b46SFrancis Visoiu Mistrih   ColumnLoop.Latch = ColBody->getSingleSuccessor();
87e1270b16SFlorian Hahn   BasicBlock *RowBody =
882c6e8b46SFrancis Visoiu Mistrih       CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
892c6e8b46SFrancis Visoiu Mistrih                  B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
902c6e8b46SFrancis Visoiu Mistrih   RowLoop.Latch = RowBody->getSingleSuccessor();
91e1270b16SFlorian Hahn 
92e1270b16SFlorian Hahn   BasicBlock *InnerBody =
932c6e8b46SFrancis Visoiu Mistrih       CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
942c6e8b46SFrancis Visoiu Mistrih                  B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
952c6e8b46SFrancis Visoiu Mistrih   KLoop.Latch = InnerBody->getSingleSuccessor();
962c6e8b46SFrancis Visoiu Mistrih   ColumnLoop.Header = ColBody->getSinglePredecessor();
972c6e8b46SFrancis Visoiu Mistrih   RowLoop.Header = RowBody->getSinglePredecessor();
982c6e8b46SFrancis Visoiu Mistrih   KLoop.Header = InnerBody->getSinglePredecessor();
992c6e8b46SFrancis Visoiu Mistrih   RowLoop.Index = &*RowLoop.Header->begin();
1002c6e8b46SFrancis Visoiu Mistrih   ColumnLoop.Index = &*ColumnLoop.Header->begin();
1012c6e8b46SFrancis Visoiu Mistrih   KLoop.Index = &*KLoop.Header->begin();
102e1270b16SFlorian Hahn 
103e1270b16SFlorian Hahn   return InnerBody;
104e1270b16SFlorian Hahn }
105