xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1e8d8bef9SDimitry Andric //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
2e8d8bef9SDimitry Andric //
3e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e8d8bef9SDimitry Andric //
7e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
8e8d8bef9SDimitry Andric //
9e8d8bef9SDimitry Andric // Utilities for generating tiled loops for matrix operations.
10e8d8bef9SDimitry Andric //
11e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
12e8d8bef9SDimitry Andric 
13e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h"
14e8d8bef9SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h"
15e8d8bef9SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
16e8d8bef9SDimitry Andric #include "llvm/IR/BasicBlock.h"
17e8d8bef9SDimitry Andric #include "llvm/IR/Dominators.h"
18e8d8bef9SDimitry Andric #include "llvm/IR/IRBuilder.h"
19e8d8bef9SDimitry Andric #include "llvm/IR/Type.h"
20e8d8bef9SDimitry Andric 
21e8d8bef9SDimitry Andric using namespace llvm;
22e8d8bef9SDimitry Andric 
23e8d8bef9SDimitry Andric BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
24e8d8bef9SDimitry Andric                                  Value *Bound, Value *Step, StringRef Name,
25e8d8bef9SDimitry Andric                                  IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
26e8d8bef9SDimitry Andric                                  LoopInfo &LI) {
27e8d8bef9SDimitry Andric   LLVMContext &Ctx = Preheader->getContext();
28e8d8bef9SDimitry Andric   BasicBlock *Header = BasicBlock::Create(
29e8d8bef9SDimitry Andric       Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
30e8d8bef9SDimitry Andric   BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
31e8d8bef9SDimitry Andric                                         Header->getParent(), Exit);
32e8d8bef9SDimitry Andric   BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
33e8d8bef9SDimitry Andric                                          Header->getParent(), Exit);
34e8d8bef9SDimitry Andric 
35e8d8bef9SDimitry Andric   Type *I32Ty = Type::getInt64Ty(Ctx);
36e8d8bef9SDimitry Andric   BranchInst::Create(Body, Header);
37e8d8bef9SDimitry Andric   BranchInst::Create(Latch, Body);
38e8d8bef9SDimitry Andric   PHINode *IV =
39*0fca6ea1SDimitry Andric       PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
40e8d8bef9SDimitry Andric   IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
41e8d8bef9SDimitry Andric 
42e8d8bef9SDimitry Andric   B.SetInsertPoint(Latch);
43e8d8bef9SDimitry Andric   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
44e8d8bef9SDimitry Andric   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
45e8d8bef9SDimitry Andric   BranchInst::Create(Header, Exit, Cond, Latch);
46e8d8bef9SDimitry Andric   IV->addIncoming(Inc, Latch);
47e8d8bef9SDimitry Andric 
48e8d8bef9SDimitry Andric   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
49e8d8bef9SDimitry Andric   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
50e8d8bef9SDimitry Andric   PreheaderBr->setSuccessor(0, Header);
51e8d8bef9SDimitry Andric   DTU.applyUpdatesPermissive({
52e8d8bef9SDimitry Andric       {DominatorTree::Delete, Preheader, Tmp},
53e8d8bef9SDimitry Andric       {DominatorTree::Insert, Header, Body},
54e8d8bef9SDimitry Andric       {DominatorTree::Insert, Body, Latch},
55e8d8bef9SDimitry Andric       {DominatorTree::Insert, Latch, Header},
56e8d8bef9SDimitry Andric       {DominatorTree::Insert, Latch, Exit},
57e8d8bef9SDimitry Andric       {DominatorTree::Insert, Preheader, Header},
58e8d8bef9SDimitry Andric   });
59e8d8bef9SDimitry Andric 
60e8d8bef9SDimitry Andric   L->addBasicBlockToLoop(Header, LI);
61e8d8bef9SDimitry Andric   L->addBasicBlockToLoop(Body, LI);
62e8d8bef9SDimitry Andric   L->addBasicBlockToLoop(Latch, LI);
63e8d8bef9SDimitry Andric   return Body;
64e8d8bef9SDimitry Andric }
65e8d8bef9SDimitry Andric 
66e8d8bef9SDimitry Andric // Creates the following loop nest skeleton:
67e8d8bef9SDimitry Andric //  for C = 0; C < NumColumns; C += TileSize
68e8d8bef9SDimitry Andric //    for R = 0; R < NumRows; R += TileSize
69e8d8bef9SDimitry Andric //      for K = 0; K < Inner ; K += TileSize
70e8d8bef9SDimitry Andric BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
71e8d8bef9SDimitry Andric                                        IRBuilderBase &B, DomTreeUpdater &DTU,
72e8d8bef9SDimitry Andric                                        LoopInfo &LI) {
73972a253aSDimitry Andric   Loop *ColumnLoopInfo = LI.AllocateLoop();
74972a253aSDimitry Andric   Loop *RowLoopInfo = LI.AllocateLoop();
75972a253aSDimitry Andric   Loop *KLoopInfo = LI.AllocateLoop();
76972a253aSDimitry Andric   RowLoopInfo->addChildLoop(KLoopInfo);
77972a253aSDimitry Andric   ColumnLoopInfo->addChildLoop(RowLoopInfo);
78e8d8bef9SDimitry Andric   if (Loop *ParentL = LI.getLoopFor(Start))
79972a253aSDimitry Andric     ParentL->addChildLoop(ColumnLoopInfo);
80e8d8bef9SDimitry Andric   else
81972a253aSDimitry Andric     LI.addTopLevelLoop(ColumnLoopInfo);
82e8d8bef9SDimitry Andric 
83e8d8bef9SDimitry Andric   BasicBlock *ColBody =
84e8d8bef9SDimitry Andric       CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
85972a253aSDimitry Andric                  "cols", B, DTU, ColumnLoopInfo, LI);
86972a253aSDimitry Andric   ColumnLoop.Latch = ColBody->getSingleSuccessor();
87e8d8bef9SDimitry Andric   BasicBlock *RowBody =
88972a253aSDimitry Andric       CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
89972a253aSDimitry Andric                  B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
90972a253aSDimitry Andric   RowLoop.Latch = RowBody->getSingleSuccessor();
91e8d8bef9SDimitry Andric 
92e8d8bef9SDimitry Andric   BasicBlock *InnerBody =
93972a253aSDimitry Andric       CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
94972a253aSDimitry Andric                  B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
95972a253aSDimitry Andric   KLoop.Latch = InnerBody->getSingleSuccessor();
96972a253aSDimitry Andric   ColumnLoop.Header = ColBody->getSinglePredecessor();
97972a253aSDimitry Andric   RowLoop.Header = RowBody->getSinglePredecessor();
98972a253aSDimitry Andric   KLoop.Header = InnerBody->getSinglePredecessor();
99972a253aSDimitry Andric   RowLoop.Index = &*RowLoop.Header->begin();
100972a253aSDimitry Andric   ColumnLoop.Index = &*ColumnLoop.Header->begin();
101972a253aSDimitry Andric   KLoop.Index = &*KLoop.Header->begin();
102e8d8bef9SDimitry Andric 
103e8d8bef9SDimitry Andric   return InnerBody;
104e8d8bef9SDimitry Andric }
105