1e8d8bef9SDimitry Andric //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===// 2e8d8bef9SDimitry Andric // 3e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e8d8bef9SDimitry Andric // 7e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===// 8e8d8bef9SDimitry Andric // 9e8d8bef9SDimitry Andric // Utilities for generating tiled loops for matrix operations. 10e8d8bef9SDimitry Andric // 11e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===// 12e8d8bef9SDimitry Andric 13e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h" 14e8d8bef9SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 15e8d8bef9SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 16e8d8bef9SDimitry Andric #include "llvm/IR/BasicBlock.h" 17e8d8bef9SDimitry Andric #include "llvm/IR/Dominators.h" 18e8d8bef9SDimitry Andric #include "llvm/IR/IRBuilder.h" 19e8d8bef9SDimitry Andric #include "llvm/IR/Type.h" 20e8d8bef9SDimitry Andric 21e8d8bef9SDimitry Andric using namespace llvm; 22e8d8bef9SDimitry Andric 23e8d8bef9SDimitry Andric BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, 24e8d8bef9SDimitry Andric Value *Bound, Value *Step, StringRef Name, 25e8d8bef9SDimitry Andric IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L, 26e8d8bef9SDimitry Andric LoopInfo &LI) { 27e8d8bef9SDimitry Andric LLVMContext &Ctx = Preheader->getContext(); 28e8d8bef9SDimitry Andric BasicBlock *Header = BasicBlock::Create( 29e8d8bef9SDimitry Andric Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit); 30e8d8bef9SDimitry Andric BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body", 31e8d8bef9SDimitry Andric Header->getParent(), Exit); 32e8d8bef9SDimitry Andric BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch", 33e8d8bef9SDimitry Andric Header->getParent(), Exit); 34e8d8bef9SDimitry Andric 35e8d8bef9SDimitry Andric Type *I32Ty = Type::getInt64Ty(Ctx); 36e8d8bef9SDimitry Andric BranchInst::Create(Body, Header); 37e8d8bef9SDimitry Andric BranchInst::Create(Latch, Body); 38e8d8bef9SDimitry Andric PHINode *IV = 39*0fca6ea1SDimitry Andric PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator()); 40e8d8bef9SDimitry Andric IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader); 41e8d8bef9SDimitry Andric 42e8d8bef9SDimitry Andric B.SetInsertPoint(Latch); 43e8d8bef9SDimitry Andric Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 44e8d8bef9SDimitry Andric Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 45e8d8bef9SDimitry Andric BranchInst::Create(Header, Exit, Cond, Latch); 46e8d8bef9SDimitry Andric IV->addIncoming(Inc, Latch); 47e8d8bef9SDimitry Andric 48e8d8bef9SDimitry Andric BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 49e8d8bef9SDimitry Andric BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 50e8d8bef9SDimitry Andric PreheaderBr->setSuccessor(0, Header); 51e8d8bef9SDimitry Andric DTU.applyUpdatesPermissive({ 52e8d8bef9SDimitry Andric {DominatorTree::Delete, Preheader, Tmp}, 53e8d8bef9SDimitry Andric {DominatorTree::Insert, Header, Body}, 54e8d8bef9SDimitry Andric {DominatorTree::Insert, Body, Latch}, 55e8d8bef9SDimitry Andric {DominatorTree::Insert, Latch, Header}, 56e8d8bef9SDimitry Andric {DominatorTree::Insert, Latch, Exit}, 57e8d8bef9SDimitry Andric {DominatorTree::Insert, Preheader, Header}, 58e8d8bef9SDimitry Andric }); 59e8d8bef9SDimitry Andric 60e8d8bef9SDimitry Andric L->addBasicBlockToLoop(Header, LI); 61e8d8bef9SDimitry Andric L->addBasicBlockToLoop(Body, LI); 62e8d8bef9SDimitry Andric L->addBasicBlockToLoop(Latch, LI); 63e8d8bef9SDimitry Andric return Body; 64e8d8bef9SDimitry Andric } 65e8d8bef9SDimitry Andric 66e8d8bef9SDimitry Andric // Creates the following loop nest skeleton: 67e8d8bef9SDimitry Andric // for C = 0; C < NumColumns; C += TileSize 68e8d8bef9SDimitry Andric // for R = 0; R < NumRows; R += TileSize 69e8d8bef9SDimitry Andric // for K = 0; K < Inner ; K += TileSize 70e8d8bef9SDimitry Andric BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End, 71e8d8bef9SDimitry Andric IRBuilderBase &B, DomTreeUpdater &DTU, 72e8d8bef9SDimitry Andric LoopInfo &LI) { 73972a253aSDimitry Andric Loop *ColumnLoopInfo = LI.AllocateLoop(); 74972a253aSDimitry Andric Loop *RowLoopInfo = LI.AllocateLoop(); 75972a253aSDimitry Andric Loop *KLoopInfo = LI.AllocateLoop(); 76972a253aSDimitry Andric RowLoopInfo->addChildLoop(KLoopInfo); 77972a253aSDimitry Andric ColumnLoopInfo->addChildLoop(RowLoopInfo); 78e8d8bef9SDimitry Andric if (Loop *ParentL = LI.getLoopFor(Start)) 79972a253aSDimitry Andric ParentL->addChildLoop(ColumnLoopInfo); 80e8d8bef9SDimitry Andric else 81972a253aSDimitry Andric LI.addTopLevelLoop(ColumnLoopInfo); 82e8d8bef9SDimitry Andric 83e8d8bef9SDimitry Andric BasicBlock *ColBody = 84e8d8bef9SDimitry Andric CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize), 85972a253aSDimitry Andric "cols", B, DTU, ColumnLoopInfo, LI); 86972a253aSDimitry Andric ColumnLoop.Latch = ColBody->getSingleSuccessor(); 87e8d8bef9SDimitry Andric BasicBlock *RowBody = 88972a253aSDimitry Andric CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows), 89972a253aSDimitry Andric B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI); 90972a253aSDimitry Andric RowLoop.Latch = RowBody->getSingleSuccessor(); 91e8d8bef9SDimitry Andric 92e8d8bef9SDimitry Andric BasicBlock *InnerBody = 93972a253aSDimitry Andric CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner), 94972a253aSDimitry Andric B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI); 95972a253aSDimitry Andric KLoop.Latch = InnerBody->getSingleSuccessor(); 96972a253aSDimitry Andric ColumnLoop.Header = ColBody->getSinglePredecessor(); 97972a253aSDimitry Andric RowLoop.Header = RowBody->getSinglePredecessor(); 98972a253aSDimitry Andric KLoop.Header = InnerBody->getSinglePredecessor(); 99972a253aSDimitry Andric RowLoop.Index = &*RowLoop.Header->begin(); 100972a253aSDimitry Andric ColumnLoop.Index = &*ColumnLoop.Header->begin(); 101972a253aSDimitry Andric KLoop.Index = &*KLoop.Header->begin(); 102e8d8bef9SDimitry Andric 103e8d8bef9SDimitry Andric return InnerBody; 104e8d8bef9SDimitry Andric } 105