xref: /openbsd-src/gnu/llvm/llvm/lib/Transforms/Utils/MatrixUtils.cpp (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
173471bf0Spatrick //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
273471bf0Spatrick //
373471bf0Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
473471bf0Spatrick // See https://llvm.org/LICENSE.txt for license information.
573471bf0Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
673471bf0Spatrick //
773471bf0Spatrick //===----------------------------------------------------------------------===//
873471bf0Spatrick //
973471bf0Spatrick // Utilities for generating tiled loops for matrix operations.
1073471bf0Spatrick //
1173471bf0Spatrick //===----------------------------------------------------------------------===//
1273471bf0Spatrick 
1373471bf0Spatrick #include "llvm/Transforms/Utils/MatrixUtils.h"
1473471bf0Spatrick #include "llvm/Analysis/DomTreeUpdater.h"
1573471bf0Spatrick #include "llvm/Analysis/LoopInfo.h"
1673471bf0Spatrick #include "llvm/IR/BasicBlock.h"
1773471bf0Spatrick #include "llvm/IR/Dominators.h"
1873471bf0Spatrick #include "llvm/IR/IRBuilder.h"
1973471bf0Spatrick #include "llvm/IR/Type.h"
2073471bf0Spatrick 
2173471bf0Spatrick using namespace llvm;
2273471bf0Spatrick 
CreateLoop(BasicBlock * Preheader,BasicBlock * Exit,Value * Bound,Value * Step,StringRef Name,IRBuilderBase & B,DomTreeUpdater & DTU,Loop * L,LoopInfo & LI)2373471bf0Spatrick BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
2473471bf0Spatrick                                  Value *Bound, Value *Step, StringRef Name,
2573471bf0Spatrick                                  IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
2673471bf0Spatrick                                  LoopInfo &LI) {
2773471bf0Spatrick   LLVMContext &Ctx = Preheader->getContext();
2873471bf0Spatrick   BasicBlock *Header = BasicBlock::Create(
2973471bf0Spatrick       Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
3073471bf0Spatrick   BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
3173471bf0Spatrick                                         Header->getParent(), Exit);
3273471bf0Spatrick   BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
3373471bf0Spatrick                                          Header->getParent(), Exit);
3473471bf0Spatrick 
3573471bf0Spatrick   Type *I32Ty = Type::getInt64Ty(Ctx);
3673471bf0Spatrick   BranchInst::Create(Body, Header);
3773471bf0Spatrick   BranchInst::Create(Latch, Body);
3873471bf0Spatrick   PHINode *IV =
3973471bf0Spatrick       PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator());
4073471bf0Spatrick   IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
4173471bf0Spatrick 
4273471bf0Spatrick   B.SetInsertPoint(Latch);
4373471bf0Spatrick   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
4473471bf0Spatrick   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
4573471bf0Spatrick   BranchInst::Create(Header, Exit, Cond, Latch);
4673471bf0Spatrick   IV->addIncoming(Inc, Latch);
4773471bf0Spatrick 
4873471bf0Spatrick   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
4973471bf0Spatrick   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
5073471bf0Spatrick   PreheaderBr->setSuccessor(0, Header);
5173471bf0Spatrick   DTU.applyUpdatesPermissive({
5273471bf0Spatrick       {DominatorTree::Delete, Preheader, Tmp},
5373471bf0Spatrick       {DominatorTree::Insert, Header, Body},
5473471bf0Spatrick       {DominatorTree::Insert, Body, Latch},
5573471bf0Spatrick       {DominatorTree::Insert, Latch, Header},
5673471bf0Spatrick       {DominatorTree::Insert, Latch, Exit},
5773471bf0Spatrick       {DominatorTree::Insert, Preheader, Header},
5873471bf0Spatrick   });
5973471bf0Spatrick 
6073471bf0Spatrick   L->addBasicBlockToLoop(Header, LI);
6173471bf0Spatrick   L->addBasicBlockToLoop(Body, LI);
6273471bf0Spatrick   L->addBasicBlockToLoop(Latch, LI);
6373471bf0Spatrick   return Body;
6473471bf0Spatrick }
6573471bf0Spatrick 
6673471bf0Spatrick // Creates the following loop nest skeleton:
6773471bf0Spatrick //  for C = 0; C < NumColumns; C += TileSize
6873471bf0Spatrick //    for R = 0; R < NumRows; R += TileSize
6973471bf0Spatrick //      for K = 0; K < Inner ; K += TileSize
CreateTiledLoops(BasicBlock * Start,BasicBlock * End,IRBuilderBase & B,DomTreeUpdater & DTU,LoopInfo & LI)7073471bf0Spatrick BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
7173471bf0Spatrick                                        IRBuilderBase &B, DomTreeUpdater &DTU,
7273471bf0Spatrick                                        LoopInfo &LI) {
73*d415bd75Srobert   Loop *ColumnLoopInfo = LI.AllocateLoop();
74*d415bd75Srobert   Loop *RowLoopInfo = LI.AllocateLoop();
75*d415bd75Srobert   Loop *KLoopInfo = LI.AllocateLoop();
76*d415bd75Srobert   RowLoopInfo->addChildLoop(KLoopInfo);
77*d415bd75Srobert   ColumnLoopInfo->addChildLoop(RowLoopInfo);
7873471bf0Spatrick   if (Loop *ParentL = LI.getLoopFor(Start))
79*d415bd75Srobert     ParentL->addChildLoop(ColumnLoopInfo);
8073471bf0Spatrick   else
81*d415bd75Srobert     LI.addTopLevelLoop(ColumnLoopInfo);
8273471bf0Spatrick 
8373471bf0Spatrick   BasicBlock *ColBody =
8473471bf0Spatrick       CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
85*d415bd75Srobert                  "cols", B, DTU, ColumnLoopInfo, LI);
86*d415bd75Srobert   ColumnLoop.Latch = ColBody->getSingleSuccessor();
8773471bf0Spatrick   BasicBlock *RowBody =
88*d415bd75Srobert       CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
89*d415bd75Srobert                  B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
90*d415bd75Srobert   RowLoop.Latch = RowBody->getSingleSuccessor();
9173471bf0Spatrick 
9273471bf0Spatrick   BasicBlock *InnerBody =
93*d415bd75Srobert       CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
94*d415bd75Srobert                  B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
95*d415bd75Srobert   KLoop.Latch = InnerBody->getSingleSuccessor();
96*d415bd75Srobert   ColumnLoop.Header = ColBody->getSinglePredecessor();
97*d415bd75Srobert   RowLoop.Header = RowBody->getSinglePredecessor();
98*d415bd75Srobert   KLoop.Header = InnerBody->getSinglePredecessor();
99*d415bd75Srobert   RowLoop.Index = &*RowLoop.Header->begin();
100*d415bd75Srobert   ColumnLoop.Index = &*ColumnLoop.Header->begin();
101*d415bd75Srobert   KLoop.Index = &*KLoop.Header->begin();
10273471bf0Spatrick 
10373471bf0Spatrick   return InnerBody;
10473471bf0Spatrick }
105