1fb54fec7SCullen Rhodes //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// 2fb54fec7SCullen Rhodes // 3fb54fec7SCullen Rhodes // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fb54fec7SCullen Rhodes // See https://llvm.org/LICENSE.txt for license information. 5fb54fec7SCullen Rhodes // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fb54fec7SCullen Rhodes // 7fb54fec7SCullen Rhodes //===----------------------------------------------------------------------===// 8fb54fec7SCullen Rhodes // 9041baf2fSBenjamin Maxwell // This transform allocates SME tiles at the 'func.func' op level for ArmSME 10041baf2fSBenjamin Maxwell // operations. It roughly implements a linear scan register allocator, similar 11041baf2fSBenjamin Maxwell // to the one outlined in [1], but with simplifications and assumptions made for 12041baf2fSBenjamin Maxwell // our use case. Note that this is a greedy allocator (so it may not always find 13041baf2fSBenjamin Maxwell // the most optimal allocation of tiles). 14041baf2fSBenjamin Maxwell // 15041baf2fSBenjamin Maxwell // The allocator operates at the CF dialect level. It is the responsibility of 16041baf2fSBenjamin Maxwell // users to ensure the IR has been lowered to CF before invoking the tile 17041baf2fSBenjamin Maxwell // allocator. 18fb54fec7SCullen Rhodes // 19fb54fec7SCullen Rhodes // The 128-bit tiles overlap with other element tiles as follows (see section 20041baf2fSBenjamin Maxwell // B2.3.2 of SME spec [2]): 21fb54fec7SCullen Rhodes // 22fb54fec7SCullen Rhodes // Tile Overlaps 23fb54fec7SCullen Rhodes // --------------------------------------------------------------------------- 24fb54fec7SCullen Rhodes // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, 25fb54fec7SCullen Rhodes // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q 26fb54fec7SCullen Rhodes // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q 27fb54fec7SCullen Rhodes // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q 28fb54fec7SCullen Rhodes // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q 29fb54fec7SCullen Rhodes // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q 30fb54fec7SCullen Rhodes // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q 31fb54fec7SCullen Rhodes // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q 32fb54fec7SCullen Rhodes // ZA0.D ZA0.Q, ZA8.Q 33fb54fec7SCullen Rhodes // ZA1.D ZA1.Q, ZA9.Q 34fb54fec7SCullen Rhodes // ZA2.D ZA2.Q, ZA10.Q 35fb54fec7SCullen Rhodes // ZA3.D ZA3.Q, ZA11.Q 36fb54fec7SCullen Rhodes // ZA4.D ZA4.Q, ZA12.Q 37fb54fec7SCullen Rhodes // ZA5.D ZA5.Q, ZA13.Q 38fb54fec7SCullen Rhodes // ZA6.D ZA6.Q, ZA14.Q 39fb54fec7SCullen Rhodes // ZA7.D ZA7.Q, ZA15.Q 40fb54fec7SCullen Rhodes // 41041baf2fSBenjamin Maxwell // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register 42041baf2fSBenjamin Maxwell // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer) 43041baf2fSBenjamin Maxwell // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf 44041baf2fSBenjamin Maxwell // [2] https://developer.arm.com/documentation/ddi0616/aa 45fb54fec7SCullen Rhodes // 46fb54fec7SCullen Rhodes //===----------------------------------------------------------------------===// 47fb54fec7SCullen Rhodes 48041baf2fSBenjamin Maxwell #include "mlir/Analysis/Liveness.h" 49b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h" 50fb54fec7SCullen Rhodes #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 51fb54fec7SCullen Rhodes #include "mlir/Dialect/ArmSME/Transforms/Passes.h" 52041baf2fSBenjamin Maxwell #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" 53eaff02f2SBenjamin Maxwell #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 54fb54fec7SCullen Rhodes #include "mlir/Dialect/Func/IR/FuncOps.h" 55041baf2fSBenjamin Maxwell #include "mlir/Transforms/RegionUtils.h" 56041baf2fSBenjamin Maxwell #include "llvm/ADT/IntervalMap.h" 57eaff02f2SBenjamin Maxwell #include "llvm/ADT/TypeSwitch.h" 58041baf2fSBenjamin Maxwell #include <algorithm> 59fb54fec7SCullen Rhodes 60041baf2fSBenjamin Maxwell namespace mlir::arm_sme { 61041baf2fSBenjamin Maxwell #define GEN_PASS_DEF_TESTTILEALLOCATION 62fb54fec7SCullen Rhodes #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" 63041baf2fSBenjamin Maxwell } // namespace mlir::arm_sme 64fb54fec7SCullen Rhodes 65fb54fec7SCullen Rhodes using namespace mlir; 66fb54fec7SCullen Rhodes using namespace mlir::arm_sme; 67fb54fec7SCullen Rhodes 68fb54fec7SCullen Rhodes namespace { 69fb54fec7SCullen Rhodes 70fb54fec7SCullen Rhodes enum class TileMask : unsigned { 71fb54fec7SCullen Rhodes // clang-format off 72fb54fec7SCullen Rhodes kZA0B = 0xffff, // 1111 1111 1111 1111 73fb54fec7SCullen Rhodes 74fb54fec7SCullen Rhodes kZA0H = 0xaaaa, // 1010 1010 1010 1010 75fb54fec7SCullen Rhodes kZA1H = 0x5555, // 0101 0101 0101 0101 76fb54fec7SCullen Rhodes 77fb54fec7SCullen Rhodes kZA0S = 0x8888, // 1000 1000 1000 1000 78fb54fec7SCullen Rhodes kZA1S = 0x4444, // 0100 0100 0100 0100 79fb54fec7SCullen Rhodes kZA2S = 0x2222, // 0010 0010 0010 0010 80fb54fec7SCullen Rhodes kZA3S = 0x1111, // 0001 0001 0001 0001 81fb54fec7SCullen Rhodes 82fb54fec7SCullen Rhodes kZA0D = 0x8080, // 1000 0000 1000 0000 83fb54fec7SCullen Rhodes kZA1D = 0x4040, // 0100 0000 0100 0000 84fb54fec7SCullen Rhodes kZA2D = 0x2020, // 0010 0000 0010 0000 85fb54fec7SCullen Rhodes kZA3D = 0x1010, // 0001 0000 0001 0000 86fb54fec7SCullen Rhodes kZA4D = 0x808, // 0000 1000 0000 1000 87fb54fec7SCullen Rhodes kZA5D = 0x404, // 0000 0100 0000 0100 88fb54fec7SCullen Rhodes kZA6D = 0x202, // 0000 0010 0000 0010 89fb54fec7SCullen Rhodes kZA7D = 0x101, // 0000 0001 0000 0001 90fb54fec7SCullen Rhodes 91fb54fec7SCullen Rhodes kZA0Q = 0x8000, // 1000 0000 0000 0000 92fb54fec7SCullen Rhodes kZA1Q = 0x4000, // 0100 0000 0000 0000 93fb54fec7SCullen Rhodes kZA2Q = 0x2000, // 0010 0000 0000 0000 94fb54fec7SCullen Rhodes kZA3Q = 0x1000, // 0001 0000 0000 0000 95fb54fec7SCullen Rhodes kZA4Q = 0x800, // 0000 1000 0000 0000 96fb54fec7SCullen Rhodes kZA5Q = 0x400, // 0000 0100 0000 0000 97fb54fec7SCullen Rhodes kZA6Q = 0x200, // 0000 0010 0000 0000 98fb54fec7SCullen Rhodes kZA7Q = 0x100, // 0000 0001 0000 0000 99fb54fec7SCullen Rhodes kZA8Q = 0x80, // 0000 0000 1000 0000 100fb54fec7SCullen Rhodes kZA9Q = 0x40, // 0000 0000 0100 0000 101fb54fec7SCullen Rhodes kZA10Q = 0x20, // 0000 0000 0010 0000 102fb54fec7SCullen Rhodes kZA11Q = 0x10, // 0000 0000 0001 0000 103fb54fec7SCullen Rhodes kZA12Q = 0x8, // 0000 0000 0000 1000 104fb54fec7SCullen Rhodes kZA13Q = 0x4, // 0000 0000 0000 0100 105fb54fec7SCullen Rhodes kZA14Q = 0x2, // 0000 0000 0000 0010 106fb54fec7SCullen Rhodes kZA15Q = 0x1, // 0000 0000 0000 0001 107fb54fec7SCullen Rhodes 108fb54fec7SCullen Rhodes kNone = 0x0, // 0000 0000 0000 0000 109fb54fec7SCullen Rhodes // clang-format on 110fb54fec7SCullen Rhodes 111fb54fec7SCullen Rhodes LLVM_MARK_AS_BITMASK_ENUM(kZA0B) 112fb54fec7SCullen Rhodes }; 113fb54fec7SCullen Rhodes 114fb54fec7SCullen Rhodes /// Returns the set of masks relevant for the given type. 115eaff02f2SBenjamin Maxwell static ArrayRef<TileMask> getMasks(ArmSMETileType type) { 116eaff02f2SBenjamin Maxwell static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B}; 117eaff02f2SBenjamin Maxwell static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H}; 118eaff02f2SBenjamin Maxwell static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S, 119eaff02f2SBenjamin Maxwell TileMask::kZA2S, TileMask::kZA3S}; 120eaff02f2SBenjamin Maxwell static constexpr std::array ZA_D_MASKS = { 121fb54fec7SCullen Rhodes TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D, 122fb54fec7SCullen Rhodes TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D}; 123eaff02f2SBenjamin Maxwell static constexpr std::array ZA_Q_MASKS = { 124fb54fec7SCullen Rhodes TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q, 125fb54fec7SCullen Rhodes TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q, 126fb54fec7SCullen Rhodes TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q, 127fb54fec7SCullen Rhodes TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q}; 128eaff02f2SBenjamin Maxwell switch (type) { 129eaff02f2SBenjamin Maxwell case ArmSMETileType::ZAB: 130fb54fec7SCullen Rhodes return ZA_B_MASKS; 131eaff02f2SBenjamin Maxwell case ArmSMETileType::ZAH: 132fb54fec7SCullen Rhodes return ZA_H_MASKS; 133eaff02f2SBenjamin Maxwell case ArmSMETileType::ZAS: 134fb54fec7SCullen Rhodes return ZA_S_MASKS; 135eaff02f2SBenjamin Maxwell case ArmSMETileType::ZAD: 136fb54fec7SCullen Rhodes return ZA_D_MASKS; 137eaff02f2SBenjamin Maxwell case ArmSMETileType::ZAQ: 138fb54fec7SCullen Rhodes return ZA_Q_MASKS; 139fb54fec7SCullen Rhodes } 140*d5746d73SFrank Schlimbach llvm_unreachable("unknown type in getMasks"); 141fb54fec7SCullen Rhodes } 142fb54fec7SCullen Rhodes 143041baf2fSBenjamin Maxwell class TileAllocator { 144041baf2fSBenjamin Maxwell public: 145041baf2fSBenjamin Maxwell /// Allocates and returns a tile ID. Fails if there are no tiles left. 146041baf2fSBenjamin Maxwell FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) { 147eaff02f2SBenjamin Maxwell auto masks = getMasks(tileType); 148eaff02f2SBenjamin Maxwell for (auto [tileId, tileMask] : llvm::enumerate(masks)) { 149fb54fec7SCullen Rhodes if ((tilesInUse & tileMask) == TileMask::kNone) { 150fb54fec7SCullen Rhodes tilesInUse |= tileMask; 151eaff02f2SBenjamin Maxwell return tileId; 152fb54fec7SCullen Rhodes } 153fb54fec7SCullen Rhodes } 154eaff02f2SBenjamin Maxwell return failure(); 155fb54fec7SCullen Rhodes } 156fb54fec7SCullen Rhodes 157eed72d43SBenjamin Maxwell /// Acquires a specific tile ID. Asserts the tile is initially free. 158eed72d43SBenjamin Maxwell void acquireTileId(ArmSMETileType tileType, unsigned tileId) { 159eed72d43SBenjamin Maxwell TileMask tileMask = getMasks(tileType)[tileId]; 160eed72d43SBenjamin Maxwell assert((tilesInUse & tileMask) == TileMask::kNone && 161eed72d43SBenjamin Maxwell "cannot acquire allocated tile!"); 162eed72d43SBenjamin Maxwell tilesInUse |= tileMask; 163eed72d43SBenjamin Maxwell } 164eed72d43SBenjamin Maxwell 165041baf2fSBenjamin Maxwell /// Releases a previously allocated tile ID. 166041baf2fSBenjamin Maxwell void releaseTileId(ArmSMETileType tileType, unsigned tileId) { 167041baf2fSBenjamin Maxwell TileMask tileMask = getMasks(tileType)[tileId]; 168eed72d43SBenjamin Maxwell assert((tilesInUse & tileMask) == tileMask && 169041baf2fSBenjamin Maxwell "cannot release unallocated tile!"); 170041baf2fSBenjamin Maxwell tilesInUse ^= tileMask; 171eaff02f2SBenjamin Maxwell } 172041baf2fSBenjamin Maxwell 173041baf2fSBenjamin Maxwell /// Allocates an in-memory tile ID. 174041baf2fSBenjamin Maxwell unsigned allocateInMemoryTileId() { 175041baf2fSBenjamin Maxwell // Note: We never release in-memory tile IDs. We could, which may allow 176041baf2fSBenjamin Maxwell // reusing an allocation, but as we _never_ want to spill an SME tile this 177041baf2fSBenjamin Maxwell // is not optimized. 178041baf2fSBenjamin Maxwell return nextInMemoryTileId++; 179041baf2fSBenjamin Maxwell } 180041baf2fSBenjamin Maxwell 181041baf2fSBenjamin Maxwell private: 182041baf2fSBenjamin Maxwell TileMask tilesInUse = TileMask::kNone; 183041baf2fSBenjamin Maxwell unsigned nextInMemoryTileId = kInMemoryTileIdBase; 184eaff02f2SBenjamin Maxwell }; 185041baf2fSBenjamin Maxwell 186041baf2fSBenjamin Maxwell /// Add new intermediate blocks for the true and false destinations of 187041baf2fSBenjamin Maxwell /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness 188041baf2fSBenjamin Maxwell /// overlaps due to copies at branches. 189041baf2fSBenjamin Maxwell /// 190041baf2fSBenjamin Maxwell /// BEFORE: 191041baf2fSBenjamin Maxwell /// ```mlir 192041baf2fSBenjamin Maxwell /// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2 193041baf2fSBenjamin Maxwell /// ``` 194041baf2fSBenjamin Maxwell /// 195041baf2fSBenjamin Maxwell /// AFTER: 196041baf2fSBenjamin Maxwell /// ```mlir 197041baf2fSBenjamin Maxwell /// cf.cond_br %cond, ^bb1_copy, ^bb2_copy 198041baf2fSBenjamin Maxwell /// ^bb1_copy: 199041baf2fSBenjamin Maxwell /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) 200041baf2fSBenjamin Maxwell /// ^bb2_copy: 201041baf2fSBenjamin Maxwell /// cf.br ^bb2 202041baf2fSBenjamin Maxwell /// ``` 203041baf2fSBenjamin Maxwell void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { 204041baf2fSBenjamin Maxwell SmallVector<cf::CondBranchOp> worklist; 205041baf2fSBenjamin Maxwell function.walk([&](cf::CondBranchOp condBranch) { 206041baf2fSBenjamin Maxwell if (llvm::any_of(condBranch->getOperands(), [&](Value value) { 207041baf2fSBenjamin Maxwell return isValidSMETileVectorType(value.getType()); 208041baf2fSBenjamin Maxwell })) { 209041baf2fSBenjamin Maxwell worklist.push_back(condBranch); 210041baf2fSBenjamin Maxwell } 211041baf2fSBenjamin Maxwell }); 212041baf2fSBenjamin Maxwell 213041baf2fSBenjamin Maxwell auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) { 214041baf2fSBenjamin Maxwell rewriter.setInsertionPointToEnd(source); 215041baf2fSBenjamin Maxwell rewriter.create<cf::BranchOp>(loc, dest, args); 216041baf2fSBenjamin Maxwell }; 217041baf2fSBenjamin Maxwell 218041baf2fSBenjamin Maxwell for (auto condBranch : worklist) { 219041baf2fSBenjamin Maxwell auto loc = condBranch.getLoc(); 220041baf2fSBenjamin Maxwell Block *block = condBranch->getBlock(); 221041baf2fSBenjamin Maxwell auto newTrueBranch = rewriter.splitBlock(block, block->end()); 222041baf2fSBenjamin Maxwell auto newFalseBranch = rewriter.splitBlock(block, block->end()); 223041baf2fSBenjamin Maxwell insertJump(loc, newTrueBranch, condBranch.getTrueDest(), 224041baf2fSBenjamin Maxwell condBranch.getTrueDestOperands()); 225041baf2fSBenjamin Maxwell insertJump(loc, newFalseBranch, condBranch.getFalseDest(), 226041baf2fSBenjamin Maxwell condBranch.getFalseDestOperands()); 227041baf2fSBenjamin Maxwell rewriter.modifyOpInPlace(condBranch, [&] { 228041baf2fSBenjamin Maxwell condBranch.getFalseDestOperandsMutable().clear(); 229041baf2fSBenjamin Maxwell condBranch.getTrueDestOperandsMutable().clear(); 230041baf2fSBenjamin Maxwell condBranch.setSuccessor(newTrueBranch, 0); 231041baf2fSBenjamin Maxwell condBranch.setSuccessor(newFalseBranch, 1); 232041baf2fSBenjamin Maxwell }); 233041baf2fSBenjamin Maxwell } 234041baf2fSBenjamin Maxwell } 235041baf2fSBenjamin Maxwell 236041baf2fSBenjamin Maxwell /// Inserts tile copies at `cf.br` operations. 237041baf2fSBenjamin Maxwell /// 238041baf2fSBenjamin Maxwell /// BEFORE: 239041baf2fSBenjamin Maxwell /// ```mlir 240041baf2fSBenjamin Maxwell /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) 241041baf2fSBenjamin Maxwell /// ``` 242041baf2fSBenjamin Maxwell /// 243041baf2fSBenjamin Maxwell /// AFTER: 244041baf2fSBenjamin Maxwell /// ```mlir 245041baf2fSBenjamin Maxwell /// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> 246041baf2fSBenjamin Maxwell /// cf.br ^bb1(%copy: vector<[4]x[4]xf32>) 247041baf2fSBenjamin Maxwell /// ``` 248041baf2fSBenjamin Maxwell void insertCopiesAtBranches(IRRewriter &rewriter, 249041baf2fSBenjamin Maxwell FunctionOpInterface function) { 250041baf2fSBenjamin Maxwell for (Block &block : function.getBlocks()) { 251041baf2fSBenjamin Maxwell Operation *terminator = block.getTerminator(); 252041baf2fSBenjamin Maxwell if (!isa<cf::BranchOp>(terminator)) 253eaff02f2SBenjamin Maxwell continue; 254041baf2fSBenjamin Maxwell rewriter.setInsertionPoint(terminator); 255041baf2fSBenjamin Maxwell for (OpOperand &operand : terminator->getOpOperands()) { 256041baf2fSBenjamin Maxwell if (isValidSMETileVectorType(operand.get().getType())) { 257041baf2fSBenjamin Maxwell auto copy = 258041baf2fSBenjamin Maxwell rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get()); 259041baf2fSBenjamin Maxwell rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); }); 260041baf2fSBenjamin Maxwell } 261041baf2fSBenjamin Maxwell } 262041baf2fSBenjamin Maxwell } 263041baf2fSBenjamin Maxwell } 264041baf2fSBenjamin Maxwell 265041baf2fSBenjamin Maxwell /// Prepares the IR for tile allocation. It does this by first 'splitting' 266041baf2fSBenjamin Maxwell /// conditional branches (see `splitCondBranches`), then inserting tile copies 267041baf2fSBenjamin Maxwell /// at branch operations. The conditional branches are split to prevent the 268041baf2fSBenjamin Maxwell /// copies needed for them overlapping between the true and false paths of the 269041baf2fSBenjamin Maxwell /// branch (see `tile-allocation-copies.mlir` and 270041baf2fSBenjamin Maxwell /// `tile-allocation-liveness.mlir` for examples). The copies break up live 271041baf2fSBenjamin Maxwell /// ranges and ensure when moving out of SSA the semantics of the program are 272041baf2fSBenjamin Maxwell /// preserved. 273041baf2fSBenjamin Maxwell void preprocessForTileAllocation(IRRewriter &rewriter, 274041baf2fSBenjamin Maxwell FunctionOpInterface function) { 275041baf2fSBenjamin Maxwell splitCondBranches(rewriter, function); 276041baf2fSBenjamin Maxwell insertCopiesAtBranches(rewriter, function); 277041baf2fSBenjamin Maxwell } 278041baf2fSBenjamin Maxwell 279041baf2fSBenjamin Maxwell /// A live range for a (collection of) tile values. A live range is built up of 280041baf2fSBenjamin Maxwell /// non-overlapping intervals [start, end) which represent parts of the program 281041baf2fSBenjamin Maxwell /// where a value in the range needs to be live (i.e. in an SME virtual tile). 282041baf2fSBenjamin Maxwell /// Note that as the intervals are non-overlapping all values within a live 283041baf2fSBenjamin Maxwell /// range can be allocated to the same SME virtual tile. 284041baf2fSBenjamin Maxwell struct LiveRange { 285041baf2fSBenjamin Maxwell using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16, 286041baf2fSBenjamin Maxwell llvm::IntervalMapHalfOpenInfo<unsigned>>; 287041baf2fSBenjamin Maxwell using Allocator = RangeSet::Allocator; 288041baf2fSBenjamin Maxwell // Dummy value for the IntervalMap. Only the keys matter (the intervals). 289041baf2fSBenjamin Maxwell static constexpr uint8_t kValidLiveRange = 0xff; 290041baf2fSBenjamin Maxwell 291041baf2fSBenjamin Maxwell LiveRange(Allocator &allocator) 292041baf2fSBenjamin Maxwell : ranges(std::make_unique<RangeSet>(allocator)) {} 293041baf2fSBenjamin Maxwell 294041baf2fSBenjamin Maxwell /// Returns true if this range overlaps with `otherRange`. 295041baf2fSBenjamin Maxwell bool overlaps(LiveRange const &otherRange) const { 296041baf2fSBenjamin Maxwell return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges, 297041baf2fSBenjamin Maxwell *otherRange.ranges) 298041baf2fSBenjamin Maxwell .valid(); 299041baf2fSBenjamin Maxwell } 300041baf2fSBenjamin Maxwell 301eed72d43SBenjamin Maxwell /// Returns true if this range is active at `point` in the program. 302eed72d43SBenjamin Maxwell bool overlaps(uint64_t point) const { 303eed72d43SBenjamin Maxwell return ranges->lookup(point) == kValidLiveRange; 304eed72d43SBenjamin Maxwell } 305eed72d43SBenjamin Maxwell 306041baf2fSBenjamin Maxwell /// Unions this live range with `otherRange`, aborts if the ranges overlap. 307041baf2fSBenjamin Maxwell void unionWith(LiveRange const &otherRange) { 308041baf2fSBenjamin Maxwell for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end(); 309041baf2fSBenjamin Maxwell ++it) 310041baf2fSBenjamin Maxwell ranges->insert(it.start(), it.stop(), kValidLiveRange); 311041baf2fSBenjamin Maxwell values.set_union(otherRange.values); 312041baf2fSBenjamin Maxwell } 313041baf2fSBenjamin Maxwell 314041baf2fSBenjamin Maxwell /// Inserts an interval [start, end) for `value` into this range. 315041baf2fSBenjamin Maxwell void insert(Value value, unsigned start, unsigned end) { 316041baf2fSBenjamin Maxwell values.insert(value); 317041baf2fSBenjamin Maxwell if (start != end) 318041baf2fSBenjamin Maxwell ranges->insert(start, end, kValidLiveRange); 319041baf2fSBenjamin Maxwell } 320041baf2fSBenjamin Maxwell 321041baf2fSBenjamin Maxwell bool empty() const { return ranges->empty(); } 322041baf2fSBenjamin Maxwell unsigned start() const { return ranges->start(); } 323041baf2fSBenjamin Maxwell unsigned end() const { return ranges->stop(); } 324041baf2fSBenjamin Maxwell bool operator<(LiveRange const &other) const { 325041baf2fSBenjamin Maxwell return start() < other.start(); 326041baf2fSBenjamin Maxwell } 327041baf2fSBenjamin Maxwell 328041baf2fSBenjamin Maxwell ArmSMETileType getTileType() const { 329041baf2fSBenjamin Maxwell return *getSMETileType(cast<VectorType>(values[0].getType())); 330041baf2fSBenjamin Maxwell } 331041baf2fSBenjamin Maxwell 332041baf2fSBenjamin Maxwell /// The values contained in this live range. 333041baf2fSBenjamin Maxwell SetVector<Value> values; 334041baf2fSBenjamin Maxwell 335041baf2fSBenjamin Maxwell /// A set of (non-overlapping) intervals that mark where any value in `values` 336041baf2fSBenjamin Maxwell /// is live. 337041baf2fSBenjamin Maxwell std::unique_ptr<RangeSet> ranges; 338041baf2fSBenjamin Maxwell 339041baf2fSBenjamin Maxwell /// The tile ID (or none) assigned to this live range. 340041baf2fSBenjamin Maxwell std::optional<unsigned> tileId; 341041baf2fSBenjamin Maxwell }; 342041baf2fSBenjamin Maxwell 343041baf2fSBenjamin Maxwell /// Number operations within a function to allow computing live ranges. 344041baf2fSBenjamin Maxwell /// Operations are numbered consecutively wihin blocks, and the blocks are 345041baf2fSBenjamin Maxwell /// topologically sorted (using forward edges). This function is only correct if 346041baf2fSBenjamin Maxwell /// all ArmSME have been converted to CF (which is asserted). 347041baf2fSBenjamin Maxwell DenseMap<Operation *, unsigned> 348041baf2fSBenjamin Maxwell generateOperationNumbering(FunctionOpInterface function) { 349041baf2fSBenjamin Maxwell unsigned index = 0; 350041baf2fSBenjamin Maxwell SetVector<Block *> blocks = 3519e39a0c7SChristian Ulmann getBlocksSortedByDominance(function.getFunctionBody()); 352041baf2fSBenjamin Maxwell DenseMap<Operation *, unsigned> operationToIndexMap; 353041baf2fSBenjamin Maxwell for (Block *block : blocks) { 354041baf2fSBenjamin Maxwell index++; // We want block args to have their own number. 355041baf2fSBenjamin Maxwell for (Operation &op : block->getOperations()) { 356041baf2fSBenjamin Maxwell #ifndef NDEBUG 357041baf2fSBenjamin Maxwell op.walk([&](ArmSMETileOpInterface nestedOp) { 358041baf2fSBenjamin Maxwell assert(&op == nestedOp.getOperation() && 359041baf2fSBenjamin Maxwell "ArmSME tile allocation does not support nested regions"); 360041baf2fSBenjamin Maxwell }); 361041baf2fSBenjamin Maxwell #endif 362041baf2fSBenjamin Maxwell operationToIndexMap.try_emplace(&op, index++); 363041baf2fSBenjamin Maxwell } 364041baf2fSBenjamin Maxwell } 365041baf2fSBenjamin Maxwell return operationToIndexMap; 366041baf2fSBenjamin Maxwell } 367041baf2fSBenjamin Maxwell 368041baf2fSBenjamin Maxwell /// Gather live ranges for SME tiles from the MLIR liveness analysis. 369041baf2fSBenjamin Maxwell DenseMap<Value, LiveRange> 370041baf2fSBenjamin Maxwell gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, 371041baf2fSBenjamin Maxwell LiveRange::Allocator &liveRangeAllocator, 372041baf2fSBenjamin Maxwell Liveness &liveness, FunctionOpInterface function) { 373041baf2fSBenjamin Maxwell assert(!operationToIndexMap.empty() && "expected operation numbering"); 374041baf2fSBenjamin Maxwell DenseMap<Value, LiveRange> liveRanges; 375041baf2fSBenjamin Maxwell /// Defines or updates a live range for an SME tile value. Live-ins may update 376041baf2fSBenjamin Maxwell /// an existing live range (rather than define a new one). Note: If 377041baf2fSBenjamin Maxwell /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in 378041baf2fSBenjamin Maxwell /// the block. 379041baf2fSBenjamin Maxwell auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef, 380041baf2fSBenjamin Maxwell LivenessBlockInfo const &livenessInfo, 381041baf2fSBenjamin Maxwell bool liveAtBlockEntry = false) { 382041baf2fSBenjamin Maxwell if (!isValidSMETileVectorType(value.getType())) 383041baf2fSBenjamin Maxwell return; 384041baf2fSBenjamin Maxwell // Find or create a live range for `value`. 385041baf2fSBenjamin Maxwell auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator); 386041baf2fSBenjamin Maxwell LiveRange &valueLiveRange = it->second; 387041baf2fSBenjamin Maxwell auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); 388041baf2fSBenjamin Maxwell // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. 389041baf2fSBenjamin Maxwell unsigned startOpIdx = 390041baf2fSBenjamin Maxwell operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); 391041baf2fSBenjamin Maxwell unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock); 392041baf2fSBenjamin Maxwell valueLiveRange.insert(value, startOpIdx, endOpIdx); 393041baf2fSBenjamin Maxwell }; 394041baf2fSBenjamin Maxwell 395041baf2fSBenjamin Maxwell for (Block &block : function.getBlocks()) { 396041baf2fSBenjamin Maxwell LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block); 397041baf2fSBenjamin Maxwell // Handle block arguments: 398041baf2fSBenjamin Maxwell for (Value argument : block.getArguments()) 399041baf2fSBenjamin Maxwell defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo, 400041baf2fSBenjamin Maxwell /*liveAtBlockEntry=*/true); 401041baf2fSBenjamin Maxwell // Handle live-ins: 402041baf2fSBenjamin Maxwell for (Value liveIn : livenessInfo->in()) 403041baf2fSBenjamin Maxwell defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo, 404041baf2fSBenjamin Maxwell /*liveAtBlockEntry=*/true); 405041baf2fSBenjamin Maxwell // Handle new definitions: 406041baf2fSBenjamin Maxwell for (Operation &op : block) { 407041baf2fSBenjamin Maxwell for (Value result : op.getResults()) 408041baf2fSBenjamin Maxwell defineOrUpdateValueLiveRange(result, &op, *livenessInfo); 409041baf2fSBenjamin Maxwell } 410041baf2fSBenjamin Maxwell } 411041baf2fSBenjamin Maxwell 412041baf2fSBenjamin Maxwell return liveRanges; 413041baf2fSBenjamin Maxwell } 414041baf2fSBenjamin Maxwell 415041baf2fSBenjamin Maxwell /// Iterate over all predecessor tile values to a (tile) block argument. 416041baf2fSBenjamin Maxwell static void forEachPredecessorTileValue(BlockArgument blockArg, 417041baf2fSBenjamin Maxwell function_ref<void(Value)> callback) { 418041baf2fSBenjamin Maxwell Block *block = blockArg.getOwner(); 419041baf2fSBenjamin Maxwell unsigned argNumber = blockArg.getArgNumber(); 420041baf2fSBenjamin Maxwell for (Block *pred : block->getPredecessors()) { 421041baf2fSBenjamin Maxwell TypeSwitch<Operation *>(pred->getTerminator()) 422041baf2fSBenjamin Maxwell .Case<cf::BranchOp>([&](auto branch) { 423041baf2fSBenjamin Maxwell Value predecessorOperand = branch.getDestOperands()[argNumber]; 424041baf2fSBenjamin Maxwell callback(predecessorOperand); 425eaff02f2SBenjamin Maxwell }) 426041baf2fSBenjamin Maxwell .Case<cf::CondBranchOp>([&](auto condBranch) { 427041baf2fSBenjamin Maxwell if (condBranch.getFalseDest() == block) { 428041baf2fSBenjamin Maxwell Value predecessorOperand = 429041baf2fSBenjamin Maxwell condBranch.getFalseDestOperands()[argNumber]; 430041baf2fSBenjamin Maxwell callback(predecessorOperand); 431041baf2fSBenjamin Maxwell } 432041baf2fSBenjamin Maxwell if (condBranch.getTrueDest() == block) { 433041baf2fSBenjamin Maxwell Value predecessorOperand = 434041baf2fSBenjamin Maxwell condBranch.getTrueDestOperands()[argNumber]; 435041baf2fSBenjamin Maxwell callback(predecessorOperand); 436041baf2fSBenjamin Maxwell } 437eaff02f2SBenjamin Maxwell }); 438eaff02f2SBenjamin Maxwell } 439eaff02f2SBenjamin Maxwell } 440041baf2fSBenjamin Maxwell 441041baf2fSBenjamin Maxwell /// Coalesce live ranges where it would prevent unnecessary tile moves. 442041baf2fSBenjamin Maxwell SmallVector<LiveRange *> 443041baf2fSBenjamin Maxwell coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) { 444041baf2fSBenjamin Maxwell DenseMap<Value, LiveRange *> liveRanges; 445041baf2fSBenjamin Maxwell for (auto &[value, liveRange] : initialLiveRanges) { 446041baf2fSBenjamin Maxwell liveRanges.insert({value, &liveRange}); 447041baf2fSBenjamin Maxwell } 448041baf2fSBenjamin Maxwell 449041baf2fSBenjamin Maxwell // Merge the live ranges of values `a` and `b` into one (if they do not 450041baf2fSBenjamin Maxwell // overlap). After this, the values `a` and `b` will both point to the same 451041baf2fSBenjamin Maxwell // live range (which will contain multiple values). 452041baf2fSBenjamin Maxwell auto mergeValuesIfNonOverlapping = [&](Value a, Value b) { 453041baf2fSBenjamin Maxwell LiveRange *aLiveRange = liveRanges.at(a); 454041baf2fSBenjamin Maxwell LiveRange *bLiveRange = liveRanges.at(b); 455041baf2fSBenjamin Maxwell if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) { 456041baf2fSBenjamin Maxwell aLiveRange->unionWith(*bLiveRange); 457041baf2fSBenjamin Maxwell for (Value value : bLiveRange->values) 458041baf2fSBenjamin Maxwell liveRanges[value] = aLiveRange; 459041baf2fSBenjamin Maxwell } 460041baf2fSBenjamin Maxwell }; 461041baf2fSBenjamin Maxwell 462041baf2fSBenjamin Maxwell // Merge the live ranges of new definitions with their tile operands. 463041baf2fSBenjamin Maxwell auto unifyDefinitionsWithOperands = [&](Value value) { 464041baf2fSBenjamin Maxwell auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>(); 465041baf2fSBenjamin Maxwell if (!armSMEOp) 466041baf2fSBenjamin Maxwell return; 467041baf2fSBenjamin Maxwell for (auto operand : armSMEOp->getOperands()) { 468041baf2fSBenjamin Maxwell if (isValidSMETileVectorType(operand.getType())) 469041baf2fSBenjamin Maxwell mergeValuesIfNonOverlapping(value, operand); 470041baf2fSBenjamin Maxwell } 471041baf2fSBenjamin Maxwell }; 472041baf2fSBenjamin Maxwell 473041baf2fSBenjamin Maxwell // Merge the live ranges of block arguments with their predecessors. 474041baf2fSBenjamin Maxwell auto unifyBlockArgumentsWithPredecessors = [&](Value value) { 475041baf2fSBenjamin Maxwell auto blockArg = dyn_cast<BlockArgument>(value); 476041baf2fSBenjamin Maxwell if (!blockArg) 477041baf2fSBenjamin Maxwell return; 478041baf2fSBenjamin Maxwell forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) { 479041baf2fSBenjamin Maxwell mergeValuesIfNonOverlapping(blockArg, predecessorTile); 480041baf2fSBenjamin Maxwell }); 481041baf2fSBenjamin Maxwell }; 482041baf2fSBenjamin Maxwell 483041baf2fSBenjamin Maxwell auto applyRule = [&](auto rule) { 484041baf2fSBenjamin Maxwell llvm::for_each(llvm::make_first_range(initialLiveRanges), rule); 485041baf2fSBenjamin Maxwell }; 486041baf2fSBenjamin Maxwell 487041baf2fSBenjamin Maxwell // Unify as many live ranges as we can. This prevents unnecessary moves. 488041baf2fSBenjamin Maxwell applyRule(unifyBlockArgumentsWithPredecessors); 489041baf2fSBenjamin Maxwell applyRule(unifyDefinitionsWithOperands); 490041baf2fSBenjamin Maxwell 491041baf2fSBenjamin Maxwell // Remove duplicate live range entries. 492041baf2fSBenjamin Maxwell SetVector<LiveRange *> uniqueLiveRanges; 493041baf2fSBenjamin Maxwell for (auto [_, liveRange] : liveRanges) { 494041baf2fSBenjamin Maxwell if (!liveRange->empty()) 495041baf2fSBenjamin Maxwell uniqueLiveRanges.insert(liveRange); 496041baf2fSBenjamin Maxwell } 497041baf2fSBenjamin Maxwell 498041baf2fSBenjamin Maxwell // Sort the new live ranges by starting point (ready for tile allocation). 499041baf2fSBenjamin Maxwell auto coalescedLiveRanges = uniqueLiveRanges.takeVector(); 500041baf2fSBenjamin Maxwell std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(), 501041baf2fSBenjamin Maxwell [](LiveRange *a, LiveRange *b) { return *a < *b; }); 502041baf2fSBenjamin Maxwell return std::move(coalescedLiveRanges); 503041baf2fSBenjamin Maxwell } 504041baf2fSBenjamin Maxwell 505eed72d43SBenjamin Maxwell /// Choose a live range to spill (via some heuristics). This picks either a live 506eed72d43SBenjamin Maxwell /// range from `overlappingRanges`, or the new live range `newRange`. 507eed72d43SBenjamin Maxwell template <typename OverlappingRangesIterator> 508eed72d43SBenjamin Maxwell LiveRange * 509eed72d43SBenjamin Maxwell chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges, 510041baf2fSBenjamin Maxwell LiveRange *newRange) { 511041baf2fSBenjamin Maxwell // Heuristic: Spill trivially copyable operations (usually free). 512eed72d43SBenjamin Maxwell auto isTrivialSpill = [&](LiveRange &allocatedRange) { 513eed72d43SBenjamin Maxwell return isTileTypeGreaterOrEqual(allocatedRange.getTileType(), 514041baf2fSBenjamin Maxwell newRange->getTileType()) && 515eed72d43SBenjamin Maxwell allocatedRange.values.size() == 1 && 516041baf2fSBenjamin Maxwell isTriviallyCloneableTileOp( 517eed72d43SBenjamin Maxwell allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>()); 518041baf2fSBenjamin Maxwell }; 519eed72d43SBenjamin Maxwell if (isTrivialSpill(*newRange)) 520041baf2fSBenjamin Maxwell return newRange; 521eed72d43SBenjamin Maxwell auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill); 522eed72d43SBenjamin Maxwell if (trivialSpill != overlappingRanges.end()) 523eed72d43SBenjamin Maxwell return &*trivialSpill; 524041baf2fSBenjamin Maxwell 525041baf2fSBenjamin Maxwell // Heuristic: Spill the range that ends last (with a compatible tile type). 526eed72d43SBenjamin Maxwell auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) { 527eed72d43SBenjamin Maxwell return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) || 528eed72d43SBenjamin Maxwell a.end() < b.end(); 529041baf2fSBenjamin Maxwell }; 530eed72d43SBenjamin Maxwell LiveRange &latestEndingLiveRange = 531eed72d43SBenjamin Maxwell *std::max_element(overlappingRanges.begin(), overlappingRanges.end(), 532eed72d43SBenjamin Maxwell isSmallerTileTypeOrEndsEarlier); 533eed72d43SBenjamin Maxwell if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange)) 534eed72d43SBenjamin Maxwell return &latestEndingLiveRange; 535041baf2fSBenjamin Maxwell return newRange; 536041baf2fSBenjamin Maxwell } 537041baf2fSBenjamin Maxwell 538041baf2fSBenjamin Maxwell /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics. 539041baf2fSBenjamin Maxwell void allocateTilesToLiveRanges( 540041baf2fSBenjamin Maxwell ArrayRef<LiveRange *> liveRangesSortedByStartPoint) { 541041baf2fSBenjamin Maxwell TileAllocator tileAllocator; 542eed72d43SBenjamin Maxwell // `activeRanges` = Live ranges that need to be in a tile at the 543eed72d43SBenjamin Maxwell // `currentPoint` in the program. 544041baf2fSBenjamin Maxwell SetVector<LiveRange *> activeRanges; 545eed72d43SBenjamin Maxwell // `inactiveRanges` = Live ranges that _do not_ need to be in a tile 546eed72d43SBenjamin Maxwell // at the `currentPoint` in the program but could become active again later. 547eed72d43SBenjamin Maxwell // An inactive section of a live range can be seen as a 'hole' in the live 548eed72d43SBenjamin Maxwell // range, where it is possible to reuse the live range's tile ID _before_ it 549eed72d43SBenjamin Maxwell // has ended. By identifying 'holes', the allocator can reuse tiles more 550eed72d43SBenjamin Maxwell // often, which helps avoid costly tile spills. 551eed72d43SBenjamin Maxwell SetVector<LiveRange *> inactiveRanges; 552041baf2fSBenjamin Maxwell for (LiveRange *nextRange : liveRangesSortedByStartPoint) { 553eed72d43SBenjamin Maxwell auto currentPoint = nextRange->start(); 554eed72d43SBenjamin Maxwell // 1. Update the `activeRanges` at `currentPoint`. 555041baf2fSBenjamin Maxwell activeRanges.remove_if([&](LiveRange *activeRange) { 556eed72d43SBenjamin Maxwell // Check for live ranges that have expired. 557eed72d43SBenjamin Maxwell if (activeRange->end() <= currentPoint) { 558041baf2fSBenjamin Maxwell tileAllocator.releaseTileId(activeRange->getTileType(), 559041baf2fSBenjamin Maxwell *activeRange->tileId); 560041baf2fSBenjamin Maxwell return true; 561041baf2fSBenjamin Maxwell } 562eed72d43SBenjamin Maxwell // Check for live ranges that have become inactive. 563eed72d43SBenjamin Maxwell if (!activeRange->overlaps(currentPoint)) { 564eed72d43SBenjamin Maxwell tileAllocator.releaseTileId(activeRange->getTileType(), 565eed72d43SBenjamin Maxwell *activeRange->tileId); 566eed72d43SBenjamin Maxwell inactiveRanges.insert(activeRange); 567eed72d43SBenjamin Maxwell return true; 568eed72d43SBenjamin Maxwell } 569eed72d43SBenjamin Maxwell return false; 570eed72d43SBenjamin Maxwell }); 571eed72d43SBenjamin Maxwell // 2. Update the `inactiveRanges` at `currentPoint`. 572eed72d43SBenjamin Maxwell inactiveRanges.remove_if([&](LiveRange *inactiveRange) { 573eed72d43SBenjamin Maxwell // Check for live ranges that have expired. 574eed72d43SBenjamin Maxwell if (inactiveRange->end() <= currentPoint) { 575eed72d43SBenjamin Maxwell return true; 576eed72d43SBenjamin Maxwell } 577eed72d43SBenjamin Maxwell // Check for live ranges that have become active. 578eed72d43SBenjamin Maxwell if (inactiveRange->overlaps(currentPoint)) { 579eed72d43SBenjamin Maxwell tileAllocator.acquireTileId(inactiveRange->getTileType(), 580eed72d43SBenjamin Maxwell *inactiveRange->tileId); 581eed72d43SBenjamin Maxwell activeRanges.insert(inactiveRange); 582eed72d43SBenjamin Maxwell return true; 583eed72d43SBenjamin Maxwell } 584041baf2fSBenjamin Maxwell return false; 585041baf2fSBenjamin Maxwell }); 586041baf2fSBenjamin Maxwell 587eed72d43SBenjamin Maxwell // 3. Collect inactive live ranges that overlap with the new live range. 588eed72d43SBenjamin Maxwell // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint` 589eed72d43SBenjamin Maxwell // whereas this checks if there is an overlap at any future point too. 590eed72d43SBenjamin Maxwell SmallVector<LiveRange *> overlappingInactiveRanges; 591eed72d43SBenjamin Maxwell for (LiveRange *inactiveRange : inactiveRanges) { 592eed72d43SBenjamin Maxwell if (inactiveRange->overlaps(*nextRange)) { 593eed72d43SBenjamin Maxwell // We need to reserve the tile IDs of overlapping inactive ranges to 594eed72d43SBenjamin Maxwell // prevent two (overlapping) live ranges from getting the same tile ID. 595eed72d43SBenjamin Maxwell tileAllocator.acquireTileId(inactiveRange->getTileType(), 596eed72d43SBenjamin Maxwell *inactiveRange->tileId); 597eed72d43SBenjamin Maxwell overlappingInactiveRanges.push_back(inactiveRange); 598eed72d43SBenjamin Maxwell } 599eed72d43SBenjamin Maxwell } 600eed72d43SBenjamin Maxwell 601eed72d43SBenjamin Maxwell // 4. Allocate a tile ID to `nextRange`. 602041baf2fSBenjamin Maxwell auto rangeTileType = nextRange->getTileType(); 603041baf2fSBenjamin Maxwell auto tileId = tileAllocator.allocateTileId(rangeTileType); 604041baf2fSBenjamin Maxwell if (succeeded(tileId)) { 605041baf2fSBenjamin Maxwell nextRange->tileId = *tileId; 606041baf2fSBenjamin Maxwell } else { 607eed72d43SBenjamin Maxwell // Create an iterator over all overlapping live ranges. 608eed72d43SBenjamin Maxwell auto allOverlappingRanges = llvm::concat<LiveRange>( 609eed72d43SBenjamin Maxwell llvm::make_pointee_range(activeRanges.getArrayRef()), 610eed72d43SBenjamin Maxwell llvm::make_pointee_range(overlappingInactiveRanges)); 611eed72d43SBenjamin Maxwell // Choose an overlapping live range to spill. 612041baf2fSBenjamin Maxwell LiveRange *rangeToSpill = 613eed72d43SBenjamin Maxwell chooseSpillUsingHeuristics(allOverlappingRanges, nextRange); 614041baf2fSBenjamin Maxwell if (rangeToSpill != nextRange) { 615eed72d43SBenjamin Maxwell // Spill an (in)active live range (so release its tile ID first). 616041baf2fSBenjamin Maxwell tileAllocator.releaseTileId(rangeToSpill->getTileType(), 617041baf2fSBenjamin Maxwell *rangeToSpill->tileId); 618041baf2fSBenjamin Maxwell // This will always succeed after a spill (of an active live range). 619041baf2fSBenjamin Maxwell nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType); 620eed72d43SBenjamin Maxwell // Remove the live range from the active/inactive sets. 621eed72d43SBenjamin Maxwell if (!activeRanges.remove(rangeToSpill)) { 622eed72d43SBenjamin Maxwell bool removed = inactiveRanges.remove(rangeToSpill); 623eed72d43SBenjamin Maxwell assert(removed && "expected a range to be removed!"); 62499faa038SKazu Hirata (void)removed; 625eed72d43SBenjamin Maxwell } 626041baf2fSBenjamin Maxwell } 627041baf2fSBenjamin Maxwell rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId(); 628041baf2fSBenjamin Maxwell } 629041baf2fSBenjamin Maxwell 630eed72d43SBenjamin Maxwell // 5. Insert the live range into the active ranges. 631041baf2fSBenjamin Maxwell if (nextRange->tileId < kInMemoryTileIdBase) 632041baf2fSBenjamin Maxwell activeRanges.insert(nextRange); 633eed72d43SBenjamin Maxwell 634eed72d43SBenjamin Maxwell // 6. Release tiles reserved for inactive live ranges (in step 3). 635eed72d43SBenjamin Maxwell for (LiveRange *range : overlappingInactiveRanges) { 636eed72d43SBenjamin Maxwell if (*range->tileId < kInMemoryTileIdBase) 637eed72d43SBenjamin Maxwell tileAllocator.releaseTileId(range->getTileType(), *range->tileId); 638eed72d43SBenjamin Maxwell } 639041baf2fSBenjamin Maxwell } 640041baf2fSBenjamin Maxwell } 641041baf2fSBenjamin Maxwell 642041baf2fSBenjamin Maxwell /// Assigns a tile ID to an MLIR value. 643041baf2fSBenjamin Maxwell void assignTileIdToValue(IRRewriter &rewriter, Value value, 644041baf2fSBenjamin Maxwell IntegerAttr tileIdAttr) { 645041baf2fSBenjamin Maxwell if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) 646041baf2fSBenjamin Maxwell rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); 647041baf2fSBenjamin Maxwell for (Operation *user : value.getUsers()) { 648041baf2fSBenjamin Maxwell if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) { 649041baf2fSBenjamin Maxwell // Ensure ArmSME ops that don't produce a value still get a tile ID. 650041baf2fSBenjamin Maxwell if (!hasTileResult(tileOp)) 651041baf2fSBenjamin Maxwell rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); 652041baf2fSBenjamin Maxwell } 653041baf2fSBenjamin Maxwell } 654041baf2fSBenjamin Maxwell } 655041baf2fSBenjamin Maxwell 656041baf2fSBenjamin Maxwell /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts. 657041baf2fSBenjamin Maxwell LogicalResult assignTileIdsAndResolveTrivialConflicts( 658041baf2fSBenjamin Maxwell IRRewriter &rewriter, FunctionOpInterface function, 659041baf2fSBenjamin Maxwell ArrayRef<LiveRange *> allocatedLiveRanges) { 660041baf2fSBenjamin Maxwell for (LiveRange const *liveRange : allocatedLiveRanges) { 661041baf2fSBenjamin Maxwell auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId); 662041baf2fSBenjamin Maxwell auto isAllocatedToSameTile = [&](Value value) { 663041baf2fSBenjamin Maxwell if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>(); 664041baf2fSBenjamin Maxwell tileOp && tileOp.getTileId() == tileIdAttr) 665041baf2fSBenjamin Maxwell return true; 666041baf2fSBenjamin Maxwell return liveRange->values.contains(value); 667041baf2fSBenjamin Maxwell }; 668041baf2fSBenjamin Maxwell 669041baf2fSBenjamin Maxwell /// Eliminates copies where the operand has the same tile ID. 670041baf2fSBenjamin Maxwell auto foldRedundantCopies = [&](Value value) -> LogicalResult { 671041baf2fSBenjamin Maxwell auto copyOp = value.getDefiningOp<CopyTileOp>(); 672041baf2fSBenjamin Maxwell if (!copyOp || !isAllocatedToSameTile(copyOp.getTile())) 673fb54fec7SCullen Rhodes return failure(); 674041baf2fSBenjamin Maxwell rewriter.replaceAllUsesWith(copyOp, copyOp.getTile()); 675041baf2fSBenjamin Maxwell return success(); 6765417a5feSBenjamin Maxwell }; 6775417a5feSBenjamin Maxwell 678041baf2fSBenjamin Maxwell /// Validates each predecessor to a tile block argument has been assigned 679041baf2fSBenjamin Maxwell /// the same tile ID. 680041baf2fSBenjamin Maxwell auto validateBlockArguments = [&](Value value) { 681041baf2fSBenjamin Maxwell auto blockArg = dyn_cast<BlockArgument>(value); 682041baf2fSBenjamin Maxwell if (!blockArg) { 683041baf2fSBenjamin Maxwell // Not a block argument (nothing to validate). 684fb54fec7SCullen Rhodes return success(); 685fb54fec7SCullen Rhodes } 686041baf2fSBenjamin Maxwell bool tileMismatch = false; 687041baf2fSBenjamin Maxwell forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) { 688041baf2fSBenjamin Maxwell if (tileMismatch) 689041baf2fSBenjamin Maxwell return; 690041baf2fSBenjamin Maxwell if (!isAllocatedToSameTile(predecessorTile)) { 691041baf2fSBenjamin Maxwell blockArg.getOwner()->getParentOp()->emitOpError( 692041baf2fSBenjamin Maxwell "block argument not allocated to the same SME virtial tile as " 693041baf2fSBenjamin Maxwell "predecessors"); 694041baf2fSBenjamin Maxwell tileMismatch = true; 695041baf2fSBenjamin Maxwell } 696041baf2fSBenjamin Maxwell }); 697041baf2fSBenjamin Maxwell return success(/*isSuccess=*/!tileMismatch); 698fb54fec7SCullen Rhodes }; 699fb54fec7SCullen Rhodes 700041baf2fSBenjamin Maxwell /// Attempts to resolve (trivial) tile ID conflicts. 701041baf2fSBenjamin Maxwell auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult { 702041baf2fSBenjamin Maxwell auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>(); 703041baf2fSBenjamin Maxwell OpOperand *tileOperand = getTileOpOperand(tileOp); 704041baf2fSBenjamin Maxwell if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) { 705041baf2fSBenjamin Maxwell // Operand already allocated to the correct tile. 706041baf2fSBenjamin Maxwell // No conflict to resolve. 707041baf2fSBenjamin Maxwell return success(); 708fb54fec7SCullen Rhodes } 709041baf2fSBenjamin Maxwell auto operandTileOp = 710041baf2fSBenjamin Maxwell tileOperand->get().getDefiningOp<ArmSMETileOpInterface>(); 711041baf2fSBenjamin Maxwell if (!isTriviallyCloneableTileOp(operandTileOp)) { 712041baf2fSBenjamin Maxwell auto error = 713041baf2fSBenjamin Maxwell tileOp.emitOpError("tile operand allocated to different SME " 714041baf2fSBenjamin Maxwell "virtial tile (move required)"); 715041baf2fSBenjamin Maxwell error.attachNote(tileOperand->get().getLoc()) 716041baf2fSBenjamin Maxwell << "tile operand is: " << tileOperand->get(); 717041baf2fSBenjamin Maxwell return error; 718041baf2fSBenjamin Maxwell } 719041baf2fSBenjamin Maxwell // Cloning prevents a move/spill (though may require recomputation). 720041baf2fSBenjamin Maxwell rewriter.setInsertionPoint(tileOp); 721041baf2fSBenjamin Maxwell auto clonedOp = operandTileOp.clone(); 722041baf2fSBenjamin Maxwell rewriter.modifyOpInPlace(clonedOp, 723041baf2fSBenjamin Maxwell [&] { clonedOp.setTileId(tileOp.getTileId()); }); 724041baf2fSBenjamin Maxwell rewriter.insert(clonedOp); 725041baf2fSBenjamin Maxwell if (isa<CopyTileOp>(tileOp)) { 726041baf2fSBenjamin Maxwell rewriter.replaceAllUsesWith(tileOp->getResult(0), 727041baf2fSBenjamin Maxwell clonedOp->getResult(0)); 728041baf2fSBenjamin Maxwell } else { 729041baf2fSBenjamin Maxwell rewriter.modifyOpInPlace( 730041baf2fSBenjamin Maxwell tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); }); 731041baf2fSBenjamin Maxwell } 732041baf2fSBenjamin Maxwell return success(); 733041baf2fSBenjamin Maxwell }; 734041baf2fSBenjamin Maxwell 735041baf2fSBenjamin Maxwell for (Value value : liveRange->values) { 736041baf2fSBenjamin Maxwell // 1. Assign the tile ID to the value. 737041baf2fSBenjamin Maxwell assignTileIdToValue(rewriter, value, tileIdAttr); 738041baf2fSBenjamin Maxwell 739041baf2fSBenjamin Maxwell // 2. Attempt to eliminate redundant tile copies. 740041baf2fSBenjamin Maxwell if (succeeded(foldRedundantCopies(value))) 741041baf2fSBenjamin Maxwell continue; 742041baf2fSBenjamin Maxwell 743041baf2fSBenjamin Maxwell // 3. Validate tile block arguments. 744041baf2fSBenjamin Maxwell if (failed(validateBlockArguments(value))) 745041baf2fSBenjamin Maxwell return failure(); 746041baf2fSBenjamin Maxwell 747041baf2fSBenjamin Maxwell // 4. Attempt to resolve (trivial) tile ID conflicts. 748041baf2fSBenjamin Maxwell if (failed(resolveTrivialTileConflicts(value))) 749041baf2fSBenjamin Maxwell return failure(); 750041baf2fSBenjamin Maxwell } 751041baf2fSBenjamin Maxwell } 752041baf2fSBenjamin Maxwell return success(); 753041baf2fSBenjamin Maxwell } 754041baf2fSBenjamin Maxwell 755041baf2fSBenjamin Maxwell /// Prints live ranges alongside operation names for debugging. 756041baf2fSBenjamin Maxwell void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, 757041baf2fSBenjamin Maxwell ArrayRef<LiveRange const *> liveRanges, 758041baf2fSBenjamin Maxwell FunctionOpInterface function) { 759041baf2fSBenjamin Maxwell llvm::errs() << "SME Tile Liveness: @" << function.getName() 760041baf2fSBenjamin Maxwell << "\nKey:\nS - Start\nE - End\n| - Live\n"; 761041baf2fSBenjamin Maxwell for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) { 762041baf2fSBenjamin Maxwell llvm::errs() << "^bb" << blockIdx << ":\n"; 763041baf2fSBenjamin Maxwell for (Operation &op : block.getOperations()) { 764041baf2fSBenjamin Maxwell unsigned operationIndex = operationToIndexMap.at(&op); 765041baf2fSBenjamin Maxwell for (LiveRange const *range : liveRanges) { 766041baf2fSBenjamin Maxwell char liveness = ' '; 767041baf2fSBenjamin Maxwell for (auto it = range->ranges->begin(); it != range->ranges->end(); 768041baf2fSBenjamin Maxwell ++it) { 769041baf2fSBenjamin Maxwell if (it.start() == operationIndex) 770041baf2fSBenjamin Maxwell liveness = (liveness == 'E' ? '|' : 'S'); 771041baf2fSBenjamin Maxwell else if (it.stop() == operationIndex) 772041baf2fSBenjamin Maxwell liveness = (liveness == 'S' ? '|' : 'E'); 773041baf2fSBenjamin Maxwell else if (operationIndex >= it.start() && operationIndex < it.stop()) 774041baf2fSBenjamin Maxwell liveness = '|'; 775041baf2fSBenjamin Maxwell } 776041baf2fSBenjamin Maxwell llvm::errs() << liveness; 777041baf2fSBenjamin Maxwell } 778041baf2fSBenjamin Maxwell llvm::errs() << ' ' << op.getName() << '\n'; 779041baf2fSBenjamin Maxwell } 780041baf2fSBenjamin Maxwell } 781041baf2fSBenjamin Maxwell llvm::errs() << "==========\n"; 782041baf2fSBenjamin Maxwell } 783041baf2fSBenjamin Maxwell 784041baf2fSBenjamin Maxwell struct TestTileAllocationPass 785041baf2fSBenjamin Maxwell : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> { 786041baf2fSBenjamin Maxwell using TestTileAllocationBase::TestTileAllocationBase; 787041baf2fSBenjamin Maxwell void runOnOperation() override { 788041baf2fSBenjamin Maxwell FunctionOpInterface function = getOperation(); 789041baf2fSBenjamin Maxwell if (preprocessOnly) { 790041baf2fSBenjamin Maxwell IRRewriter rewriter(function); 791041baf2fSBenjamin Maxwell return preprocessForTileAllocation(rewriter, function); 792041baf2fSBenjamin Maxwell } 793041baf2fSBenjamin Maxwell if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) 794041baf2fSBenjamin Maxwell signalPassFailure(); 795eaff02f2SBenjamin Maxwell } 796fb54fec7SCullen Rhodes }; 797fb54fec7SCullen Rhodes } // namespace 798fb54fec7SCullen Rhodes 799041baf2fSBenjamin Maxwell LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function, 800041baf2fSBenjamin Maxwell bool dumpRanges) { 801041baf2fSBenjamin Maxwell if (function.empty()) { 802041baf2fSBenjamin Maxwell // TODO: Also return early if the function contains no ArmSME ops? 803041baf2fSBenjamin Maxwell return success(); 804041baf2fSBenjamin Maxwell } 805041baf2fSBenjamin Maxwell 806041baf2fSBenjamin Maxwell LiveRange::Allocator liveRangeAllocator; 807041baf2fSBenjamin Maxwell IRRewriter rewriter(function.getContext()); 808041baf2fSBenjamin Maxwell 809041baf2fSBenjamin Maxwell // 1. Preprocess the IR for tile allocation. 810041baf2fSBenjamin Maxwell preprocessForTileAllocation(rewriter, function); 811041baf2fSBenjamin Maxwell 812041baf2fSBenjamin Maxwell // 2. Gather live ranges for each ArmSME tile within the function. 813041baf2fSBenjamin Maxwell Liveness liveness(function); 814041baf2fSBenjamin Maxwell auto operationToIndexMap = generateOperationNumbering(function); 815041baf2fSBenjamin Maxwell auto initialLiveRanges = gatherTileLiveRanges( 816041baf2fSBenjamin Maxwell operationToIndexMap, liveRangeAllocator, liveness, function); 817041baf2fSBenjamin Maxwell if (initialLiveRanges.empty()) 818041baf2fSBenjamin Maxwell return success(); 819041baf2fSBenjamin Maxwell 820041baf2fSBenjamin Maxwell if (dumpRanges) { 821041baf2fSBenjamin Maxwell // Wrangle initial live ranges into a form suitable for printing. 822041baf2fSBenjamin Maxwell auto nonEmpty = llvm::make_filter_range( 823041baf2fSBenjamin Maxwell llvm::make_second_range(initialLiveRanges), 824041baf2fSBenjamin Maxwell [&](LiveRange const &liveRange) { return !liveRange.empty(); }); 825041baf2fSBenjamin Maxwell auto initialRanges = llvm::to_vector(llvm::map_range( 826041baf2fSBenjamin Maxwell nonEmpty, [](LiveRange const &liveRange) { return &liveRange; })); 827041baf2fSBenjamin Maxwell std::sort(initialRanges.begin(), initialRanges.end(), 828041baf2fSBenjamin Maxwell [](LiveRange const *a, LiveRange const *b) { return *a < *b; }); 829041baf2fSBenjamin Maxwell llvm::errs() << "\n========== Initial Live Ranges:\n"; 830041baf2fSBenjamin Maxwell dumpLiveRanges(operationToIndexMap, initialRanges, function); 831041baf2fSBenjamin Maxwell } 832041baf2fSBenjamin Maxwell 833041baf2fSBenjamin Maxwell // 3. Coalesce (non-overlapping) live ranges where it would be beneficial 834041baf2fSBenjamin Maxwell // for tile allocation. E.g. Unify the result of an operation with its 835041baf2fSBenjamin Maxwell // operands. 836041baf2fSBenjamin Maxwell auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges); 837041baf2fSBenjamin Maxwell 838041baf2fSBenjamin Maxwell if (dumpRanges) { 839041baf2fSBenjamin Maxwell llvm::errs() << "\n========== Coalesced Live Ranges:\n"; 840041baf2fSBenjamin Maxwell dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function); 841041baf2fSBenjamin Maxwell } 842041baf2fSBenjamin Maxwell 843041baf2fSBenjamin Maxwell // 4. Allocate tile IDs to live ranges. 844041baf2fSBenjamin Maxwell allocateTilesToLiveRanges(coalescedLiveRanges); 845041baf2fSBenjamin Maxwell 846041baf2fSBenjamin Maxwell // 5. Assign the tile IDs back to the ArmSME operations. 847041baf2fSBenjamin Maxwell if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function, 848041baf2fSBenjamin Maxwell coalescedLiveRanges))) { 849041baf2fSBenjamin Maxwell return failure(); 850041baf2fSBenjamin Maxwell } 851041baf2fSBenjamin Maxwell 852041baf2fSBenjamin Maxwell // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no 853041baf2fSBenjamin Maxwell // users). This prevents the LLVM conversion needlessly inserting spills. 854041baf2fSBenjamin Maxwell eraseTriviallyDeadTileOps(rewriter, function); 855041baf2fSBenjamin Maxwell return success(); 856fb54fec7SCullen Rhodes } 857