1*e8d8bef9SDimitry Andric //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===// 2*e8d8bef9SDimitry Andric // 3*e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*e8d8bef9SDimitry Andric // 7*e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===// 8*e8d8bef9SDimitry Andric // 9*e8d8bef9SDimitry Andric // Utilities for generating tiled loops for matrix operations. 10*e8d8bef9SDimitry Andric // 11*e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===// 12*e8d8bef9SDimitry Andric 13*e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h" 14*e8d8bef9SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 15*e8d8bef9SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 16*e8d8bef9SDimitry Andric #include "llvm/IR/BasicBlock.h" 17*e8d8bef9SDimitry Andric #include "llvm/IR/Dominators.h" 18*e8d8bef9SDimitry Andric #include "llvm/IR/IRBuilder.h" 19*e8d8bef9SDimitry Andric #include "llvm/IR/Type.h" 20*e8d8bef9SDimitry Andric 21*e8d8bef9SDimitry Andric using namespace llvm; 22*e8d8bef9SDimitry Andric 23*e8d8bef9SDimitry Andric BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, 24*e8d8bef9SDimitry Andric Value *Bound, Value *Step, StringRef Name, 25*e8d8bef9SDimitry Andric IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L, 26*e8d8bef9SDimitry Andric LoopInfo &LI) { 27*e8d8bef9SDimitry Andric LLVMContext &Ctx = Preheader->getContext(); 28*e8d8bef9SDimitry Andric BasicBlock *Header = BasicBlock::Create( 29*e8d8bef9SDimitry Andric Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit); 30*e8d8bef9SDimitry Andric BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body", 31*e8d8bef9SDimitry Andric Header->getParent(), Exit); 32*e8d8bef9SDimitry Andric BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch", 33*e8d8bef9SDimitry Andric Header->getParent(), Exit); 34*e8d8bef9SDimitry Andric 35*e8d8bef9SDimitry Andric Type *I32Ty = Type::getInt64Ty(Ctx); 36*e8d8bef9SDimitry Andric BranchInst::Create(Body, Header); 37*e8d8bef9SDimitry Andric BranchInst::Create(Latch, Body); 38*e8d8bef9SDimitry Andric PHINode *IV = 39*e8d8bef9SDimitry Andric PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()); 40*e8d8bef9SDimitry Andric IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader); 41*e8d8bef9SDimitry Andric 42*e8d8bef9SDimitry Andric B.SetInsertPoint(Latch); 43*e8d8bef9SDimitry Andric Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 44*e8d8bef9SDimitry Andric Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 45*e8d8bef9SDimitry Andric BranchInst::Create(Header, Exit, Cond, Latch); 46*e8d8bef9SDimitry Andric IV->addIncoming(Inc, Latch); 47*e8d8bef9SDimitry Andric 48*e8d8bef9SDimitry Andric BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 49*e8d8bef9SDimitry Andric BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 50*e8d8bef9SDimitry Andric PreheaderBr->setSuccessor(0, Header); 51*e8d8bef9SDimitry Andric DTU.applyUpdatesPermissive({ 52*e8d8bef9SDimitry Andric {DominatorTree::Delete, Preheader, Tmp}, 53*e8d8bef9SDimitry Andric {DominatorTree::Insert, Header, Body}, 54*e8d8bef9SDimitry Andric {DominatorTree::Insert, Body, Latch}, 55*e8d8bef9SDimitry Andric {DominatorTree::Insert, Latch, Header}, 56*e8d8bef9SDimitry Andric {DominatorTree::Insert, Latch, Exit}, 57*e8d8bef9SDimitry Andric {DominatorTree::Insert, Preheader, Header}, 58*e8d8bef9SDimitry Andric }); 59*e8d8bef9SDimitry Andric 60*e8d8bef9SDimitry Andric L->addBasicBlockToLoop(Header, LI); 61*e8d8bef9SDimitry Andric L->addBasicBlockToLoop(Body, LI); 62*e8d8bef9SDimitry Andric L->addBasicBlockToLoop(Latch, LI); 63*e8d8bef9SDimitry Andric return Body; 64*e8d8bef9SDimitry Andric } 65*e8d8bef9SDimitry Andric 66*e8d8bef9SDimitry Andric // Creates the following loop nest skeleton: 67*e8d8bef9SDimitry Andric // for C = 0; C < NumColumns; C += TileSize 68*e8d8bef9SDimitry Andric // for R = 0; R < NumRows; R += TileSize 69*e8d8bef9SDimitry Andric // for K = 0; K < Inner ; K += TileSize 70*e8d8bef9SDimitry Andric BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End, 71*e8d8bef9SDimitry Andric IRBuilderBase &B, DomTreeUpdater &DTU, 72*e8d8bef9SDimitry Andric LoopInfo &LI) { 73*e8d8bef9SDimitry Andric Loop *ColLoop = LI.AllocateLoop(); 74*e8d8bef9SDimitry Andric Loop *RowLoop = LI.AllocateLoop(); 75*e8d8bef9SDimitry Andric Loop *InnerLoop = LI.AllocateLoop(); 76*e8d8bef9SDimitry Andric RowLoop->addChildLoop(InnerLoop); 77*e8d8bef9SDimitry Andric ColLoop->addChildLoop(RowLoop); 78*e8d8bef9SDimitry Andric if (Loop *ParentL = LI.getLoopFor(Start)) 79*e8d8bef9SDimitry Andric ParentL->addChildLoop(ColLoop); 80*e8d8bef9SDimitry Andric else 81*e8d8bef9SDimitry Andric LI.addTopLevelLoop(ColLoop); 82*e8d8bef9SDimitry Andric 83*e8d8bef9SDimitry Andric BasicBlock *ColBody = 84*e8d8bef9SDimitry Andric CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize), 85*e8d8bef9SDimitry Andric "cols", B, DTU, ColLoop, LI); 86*e8d8bef9SDimitry Andric BasicBlock *ColLatch = ColBody->getSingleSuccessor(); 87*e8d8bef9SDimitry Andric BasicBlock *RowBody = 88*e8d8bef9SDimitry Andric CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize), 89*e8d8bef9SDimitry Andric "rows", B, DTU, RowLoop, LI); 90*e8d8bef9SDimitry Andric RowLoopLatch = RowBody->getSingleSuccessor(); 91*e8d8bef9SDimitry Andric 92*e8d8bef9SDimitry Andric BasicBlock *InnerBody = 93*e8d8bef9SDimitry Andric CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner), 94*e8d8bef9SDimitry Andric B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI); 95*e8d8bef9SDimitry Andric InnerLoopLatch = InnerBody->getSingleSuccessor(); 96*e8d8bef9SDimitry Andric ColumnLoopHeader = ColBody->getSinglePredecessor(); 97*e8d8bef9SDimitry Andric RowLoopHeader = RowBody->getSinglePredecessor(); 98*e8d8bef9SDimitry Andric InnerLoopHeader = InnerBody->getSinglePredecessor(); 99*e8d8bef9SDimitry Andric CurrentRow = &*RowLoopHeader->begin(); 100*e8d8bef9SDimitry Andric CurrentCol = &*ColumnLoopHeader->begin(); 101*e8d8bef9SDimitry Andric CurrentK = &*InnerLoopHeader->begin(); 102*e8d8bef9SDimitry Andric 103*e8d8bef9SDimitry Andric return InnerBody; 104*e8d8bef9SDimitry Andric } 105