//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This transform allocates SME tiles at the 'func.func' op level for ArmSME // operations. It roughly implements a linear scan register allocator, similar // to the one outlined in [1], but with simplifications and assumptions made for // our use case. Note that this is a greedy allocator (so it may not always find // the most optimal allocation of tiles). // // The allocator operates at the CF dialect level. It is the responsibility of // users to ensure the IR has been lowered to CF before invoking the tile // allocator. // // The 128-bit tiles overlap with other element tiles as follows (see section // B2.3.2 of SME spec [2]): // // Tile Overlaps // --------------------------------------------------------------------------- // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q // ZA0.D ZA0.Q, ZA8.Q // ZA1.D ZA1.Q, ZA9.Q // ZA2.D ZA2.Q, ZA10.Q // ZA3.D ZA3.Q, ZA11.Q // ZA4.D ZA4.Q, ZA12.Q // ZA5.D ZA5.Q, ZA13.Q // ZA6.D ZA6.Q, ZA14.Q // ZA7.D ZA7.Q, ZA15.Q // // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer) // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf // [2] https://developer.arm.com/documentation/ddi0616/aa // //===----------------------------------------------------------------------===// #include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/TypeSwitch.h" #include namespace mlir::arm_sme { #define GEN_PASS_DEF_TESTTILEALLOCATION #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" } // namespace mlir::arm_sme using namespace mlir; using namespace mlir::arm_sme; namespace { enum class TileMask : unsigned { // clang-format off kZA0B = 0xffff, // 1111 1111 1111 1111 kZA0H = 0xaaaa, // 1010 1010 1010 1010 kZA1H = 0x5555, // 0101 0101 0101 0101 kZA0S = 0x8888, // 1000 1000 1000 1000 kZA1S = 0x4444, // 0100 0100 0100 0100 kZA2S = 0x2222, // 0010 0010 0010 0010 kZA3S = 0x1111, // 0001 0001 0001 0001 kZA0D = 0x8080, // 1000 0000 1000 0000 kZA1D = 0x4040, // 0100 0000 0100 0000 kZA2D = 0x2020, // 0010 0000 0010 0000 kZA3D = 0x1010, // 0001 0000 0001 0000 kZA4D = 0x808, // 0000 1000 0000 1000 kZA5D = 0x404, // 0000 0100 0000 0100 kZA6D = 0x202, // 0000 0010 0000 0010 kZA7D = 0x101, // 0000 0001 0000 0001 kZA0Q = 0x8000, // 1000 0000 0000 0000 kZA1Q = 0x4000, // 0100 0000 0000 0000 kZA2Q = 0x2000, // 0010 0000 0000 0000 kZA3Q = 0x1000, // 0001 0000 0000 0000 kZA4Q = 0x800, // 0000 1000 0000 0000 kZA5Q = 0x400, // 0000 0100 0000 0000 kZA6Q = 0x200, // 0000 0010 0000 0000 kZA7Q = 0x100, // 0000 0001 0000 0000 kZA8Q = 0x80, // 0000 0000 1000 0000 kZA9Q = 0x40, // 0000 0000 0100 0000 kZA10Q = 0x20, // 0000 0000 0010 0000 kZA11Q = 0x10, // 0000 0000 0001 0000 kZA12Q = 0x8, // 0000 0000 0000 1000 kZA13Q = 0x4, // 0000 0000 0000 0100 kZA14Q = 0x2, // 0000 0000 0000 0010 kZA15Q = 0x1, // 0000 0000 0000 0001 kNone = 0x0, // 0000 0000 0000 0000 // clang-format on LLVM_MARK_AS_BITMASK_ENUM(kZA0B) }; /// Returns the set of masks relevant for the given type. static ArrayRef getMasks(ArmSMETileType type) { static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B}; static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H}; static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S}; static constexpr std::array ZA_D_MASKS = { TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D, TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D}; static constexpr std::array ZA_Q_MASKS = { TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q, TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q, TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q, TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q}; switch (type) { case ArmSMETileType::ZAB: return ZA_B_MASKS; case ArmSMETileType::ZAH: return ZA_H_MASKS; case ArmSMETileType::ZAS: return ZA_S_MASKS; case ArmSMETileType::ZAD: return ZA_D_MASKS; case ArmSMETileType::ZAQ: return ZA_Q_MASKS; } llvm_unreachable("unknown type in getMasks"); } class TileAllocator { public: /// Allocates and returns a tile ID. Fails if there are no tiles left. FailureOr allocateTileId(ArmSMETileType tileType) { auto masks = getMasks(tileType); for (auto [tileId, tileMask] : llvm::enumerate(masks)) { if ((tilesInUse & tileMask) == TileMask::kNone) { tilesInUse |= tileMask; return tileId; } } return failure(); } /// Acquires a specific tile ID. Asserts the tile is initially free. void acquireTileId(ArmSMETileType tileType, unsigned tileId) { TileMask tileMask = getMasks(tileType)[tileId]; assert((tilesInUse & tileMask) == TileMask::kNone && "cannot acquire allocated tile!"); tilesInUse |= tileMask; } /// Releases a previously allocated tile ID. void releaseTileId(ArmSMETileType tileType, unsigned tileId) { TileMask tileMask = getMasks(tileType)[tileId]; assert((tilesInUse & tileMask) == tileMask && "cannot release unallocated tile!"); tilesInUse ^= tileMask; } /// Allocates an in-memory tile ID. unsigned allocateInMemoryTileId() { // Note: We never release in-memory tile IDs. We could, which may allow // reusing an allocation, but as we _never_ want to spill an SME tile this // is not optimized. return nextInMemoryTileId++; } private: TileMask tilesInUse = TileMask::kNone; unsigned nextInMemoryTileId = kInMemoryTileIdBase; }; /// Add new intermediate blocks for the true and false destinations of /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness /// overlaps due to copies at branches. /// /// BEFORE: /// ```mlir /// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2 /// ``` /// /// AFTER: /// ```mlir /// cf.cond_br %cond, ^bb1_copy, ^bb2_copy /// ^bb1_copy: /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) /// ^bb2_copy: /// cf.br ^bb2 /// ``` void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { SmallVector worklist; function.walk([&](cf::CondBranchOp condBranch) { if (llvm::any_of(condBranch->getOperands(), [&](Value value) { return isValidSMETileVectorType(value.getType()); })) { worklist.push_back(condBranch); } }); auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) { rewriter.setInsertionPointToEnd(source); rewriter.create(loc, dest, args); }; for (auto condBranch : worklist) { auto loc = condBranch.getLoc(); Block *block = condBranch->getBlock(); auto newTrueBranch = rewriter.splitBlock(block, block->end()); auto newFalseBranch = rewriter.splitBlock(block, block->end()); insertJump(loc, newTrueBranch, condBranch.getTrueDest(), condBranch.getTrueDestOperands()); insertJump(loc, newFalseBranch, condBranch.getFalseDest(), condBranch.getFalseDestOperands()); rewriter.modifyOpInPlace(condBranch, [&] { condBranch.getFalseDestOperandsMutable().clear(); condBranch.getTrueDestOperandsMutable().clear(); condBranch.setSuccessor(newTrueBranch, 0); condBranch.setSuccessor(newFalseBranch, 1); }); } } /// Inserts tile copies at `cf.br` operations. /// /// BEFORE: /// ```mlir /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) /// ``` /// /// AFTER: /// ```mlir /// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> /// cf.br ^bb1(%copy: vector<[4]x[4]xf32>) /// ``` void insertCopiesAtBranches(IRRewriter &rewriter, FunctionOpInterface function) { for (Block &block : function.getBlocks()) { Operation *terminator = block.getTerminator(); if (!isa(terminator)) continue; rewriter.setInsertionPoint(terminator); for (OpOperand &operand : terminator->getOpOperands()) { if (isValidSMETileVectorType(operand.get().getType())) { auto copy = rewriter.create(terminator->getLoc(), operand.get()); rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); }); } } } } /// Prepares the IR for tile allocation. It does this by first 'splitting' /// conditional branches (see `splitCondBranches`), then inserting tile copies /// at branch operations. The conditional branches are split to prevent the /// copies needed for them overlapping between the true and false paths of the /// branch (see `tile-allocation-copies.mlir` and /// `tile-allocation-liveness.mlir` for examples). The copies break up live /// ranges and ensure when moving out of SSA the semantics of the program are /// preserved. void preprocessForTileAllocation(IRRewriter &rewriter, FunctionOpInterface function) { splitCondBranches(rewriter, function); insertCopiesAtBranches(rewriter, function); } /// A live range for a (collection of) tile values. A live range is built up of /// non-overlapping intervals [start, end) which represent parts of the program /// where a value in the range needs to be live (i.e. in an SME virtual tile). /// Note that as the intervals are non-overlapping all values within a live /// range can be allocated to the same SME virtual tile. struct LiveRange { using RangeSet = llvm::IntervalMap>; using Allocator = RangeSet::Allocator; // Dummy value for the IntervalMap. Only the keys matter (the intervals). static constexpr uint8_t kValidLiveRange = 0xff; LiveRange(Allocator &allocator) : ranges(std::make_unique(allocator)) {} /// Returns true if this range overlaps with `otherRange`. bool overlaps(LiveRange const &otherRange) const { return llvm::IntervalMapOverlaps(*ranges, *otherRange.ranges) .valid(); } /// Returns true if this range is active at `point` in the program. bool overlaps(uint64_t point) const { return ranges->lookup(point) == kValidLiveRange; } /// Unions this live range with `otherRange`, aborts if the ranges overlap. void unionWith(LiveRange const &otherRange) { for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end(); ++it) ranges->insert(it.start(), it.stop(), kValidLiveRange); values.set_union(otherRange.values); } /// Inserts an interval [start, end) for `value` into this range. void insert(Value value, unsigned start, unsigned end) { values.insert(value); if (start != end) ranges->insert(start, end, kValidLiveRange); } bool empty() const { return ranges->empty(); } unsigned start() const { return ranges->start(); } unsigned end() const { return ranges->stop(); } bool operator<(LiveRange const &other) const { return start() < other.start(); } ArmSMETileType getTileType() const { return *getSMETileType(cast(values[0].getType())); } /// The values contained in this live range. SetVector values; /// A set of (non-overlapping) intervals that mark where any value in `values` /// is live. std::unique_ptr ranges; /// The tile ID (or none) assigned to this live range. std::optional tileId; }; /// Number operations within a function to allow computing live ranges. /// Operations are numbered consecutively wihin blocks, and the blocks are /// topologically sorted (using forward edges). This function is only correct if /// all ArmSME have been converted to CF (which is asserted). DenseMap generateOperationNumbering(FunctionOpInterface function) { unsigned index = 0; SetVector blocks = getBlocksSortedByDominance(function.getFunctionBody()); DenseMap operationToIndexMap; for (Block *block : blocks) { index++; // We want block args to have their own number. for (Operation &op : block->getOperations()) { #ifndef NDEBUG op.walk([&](ArmSMETileOpInterface nestedOp) { assert(&op == nestedOp.getOperation() && "ArmSME tile allocation does not support nested regions"); }); #endif operationToIndexMap.try_emplace(&op, index++); } } return operationToIndexMap; } /// Gather live ranges for SME tiles from the MLIR liveness analysis. DenseMap gatherTileLiveRanges(DenseMap const &operationToIndexMap, LiveRange::Allocator &liveRangeAllocator, Liveness &liveness, FunctionOpInterface function) { assert(!operationToIndexMap.empty() && "expected operation numbering"); DenseMap liveRanges; /// Defines or updates a live range for an SME tile value. Live-ins may update /// an existing live range (rather than define a new one). Note: If /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in /// the block. auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef, LivenessBlockInfo const &livenessInfo, bool liveAtBlockEntry = false) { if (!isValidSMETileVectorType(value.getType())) return; // Find or create a live range for `value`. auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator); LiveRange &valueLiveRange = it->second; auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. unsigned startOpIdx = operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock); valueLiveRange.insert(value, startOpIdx, endOpIdx); }; for (Block &block : function.getBlocks()) { LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block); // Handle block arguments: for (Value argument : block.getArguments()) defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo, /*liveAtBlockEntry=*/true); // Handle live-ins: for (Value liveIn : livenessInfo->in()) defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo, /*liveAtBlockEntry=*/true); // Handle new definitions: for (Operation &op : block) { for (Value result : op.getResults()) defineOrUpdateValueLiveRange(result, &op, *livenessInfo); } } return liveRanges; } /// Iterate over all predecessor tile values to a (tile) block argument. static void forEachPredecessorTileValue(BlockArgument blockArg, function_ref callback) { Block *block = blockArg.getOwner(); unsigned argNumber = blockArg.getArgNumber(); for (Block *pred : block->getPredecessors()) { TypeSwitch(pred->getTerminator()) .Case([&](auto branch) { Value predecessorOperand = branch.getDestOperands()[argNumber]; callback(predecessorOperand); }) .Case([&](auto condBranch) { if (condBranch.getFalseDest() == block) { Value predecessorOperand = condBranch.getFalseDestOperands()[argNumber]; callback(predecessorOperand); } if (condBranch.getTrueDest() == block) { Value predecessorOperand = condBranch.getTrueDestOperands()[argNumber]; callback(predecessorOperand); } }); } } /// Coalesce live ranges where it would prevent unnecessary tile moves. SmallVector coalesceTileLiveRanges(DenseMap &initialLiveRanges) { DenseMap liveRanges; for (auto &[value, liveRange] : initialLiveRanges) { liveRanges.insert({value, &liveRange}); } // Merge the live ranges of values `a` and `b` into one (if they do not // overlap). After this, the values `a` and `b` will both point to the same // live range (which will contain multiple values). auto mergeValuesIfNonOverlapping = [&](Value a, Value b) { LiveRange *aLiveRange = liveRanges.at(a); LiveRange *bLiveRange = liveRanges.at(b); if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) { aLiveRange->unionWith(*bLiveRange); for (Value value : bLiveRange->values) liveRanges[value] = aLiveRange; } }; // Merge the live ranges of new definitions with their tile operands. auto unifyDefinitionsWithOperands = [&](Value value) { auto armSMEOp = value.getDefiningOp(); if (!armSMEOp) return; for (auto operand : armSMEOp->getOperands()) { if (isValidSMETileVectorType(operand.getType())) mergeValuesIfNonOverlapping(value, operand); } }; // Merge the live ranges of block arguments with their predecessors. auto unifyBlockArgumentsWithPredecessors = [&](Value value) { auto blockArg = dyn_cast(value); if (!blockArg) return; forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) { mergeValuesIfNonOverlapping(blockArg, predecessorTile); }); }; auto applyRule = [&](auto rule) { llvm::for_each(llvm::make_first_range(initialLiveRanges), rule); }; // Unify as many live ranges as we can. This prevents unnecessary moves. applyRule(unifyBlockArgumentsWithPredecessors); applyRule(unifyDefinitionsWithOperands); // Remove duplicate live range entries. SetVector uniqueLiveRanges; for (auto [_, liveRange] : liveRanges) { if (!liveRange->empty()) uniqueLiveRanges.insert(liveRange); } // Sort the new live ranges by starting point (ready for tile allocation). auto coalescedLiveRanges = uniqueLiveRanges.takeVector(); std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(), [](LiveRange *a, LiveRange *b) { return *a < *b; }); return std::move(coalescedLiveRanges); } /// Choose a live range to spill (via some heuristics). This picks either a live /// range from `overlappingRanges`, or the new live range `newRange`. template LiveRange * chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges, LiveRange *newRange) { // Heuristic: Spill trivially copyable operations (usually free). auto isTrivialSpill = [&](LiveRange &allocatedRange) { return isTileTypeGreaterOrEqual(allocatedRange.getTileType(), newRange->getTileType()) && allocatedRange.values.size() == 1 && isTriviallyCloneableTileOp( allocatedRange.values[0].getDefiningOp()); }; if (isTrivialSpill(*newRange)) return newRange; auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill); if (trivialSpill != overlappingRanges.end()) return &*trivialSpill; // Heuristic: Spill the range that ends last (with a compatible tile type). auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) { return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) || a.end() < b.end(); }; LiveRange &latestEndingLiveRange = *std::max_element(overlappingRanges.begin(), overlappingRanges.end(), isSmallerTileTypeOrEndsEarlier); if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange)) return &latestEndingLiveRange; return newRange; } /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics. void allocateTilesToLiveRanges( ArrayRef liveRangesSortedByStartPoint) { TileAllocator tileAllocator; // `activeRanges` = Live ranges that need to be in a tile at the // `currentPoint` in the program. SetVector activeRanges; // `inactiveRanges` = Live ranges that _do not_ need to be in a tile // at the `currentPoint` in the program but could become active again later. // An inactive section of a live range can be seen as a 'hole' in the live // range, where it is possible to reuse the live range's tile ID _before_ it // has ended. By identifying 'holes', the allocator can reuse tiles more // often, which helps avoid costly tile spills. SetVector inactiveRanges; for (LiveRange *nextRange : liveRangesSortedByStartPoint) { auto currentPoint = nextRange->start(); // 1. Update the `activeRanges` at `currentPoint`. activeRanges.remove_if([&](LiveRange *activeRange) { // Check for live ranges that have expired. if (activeRange->end() <= currentPoint) { tileAllocator.releaseTileId(activeRange->getTileType(), *activeRange->tileId); return true; } // Check for live ranges that have become inactive. if (!activeRange->overlaps(currentPoint)) { tileAllocator.releaseTileId(activeRange->getTileType(), *activeRange->tileId); inactiveRanges.insert(activeRange); return true; } return false; }); // 2. Update the `inactiveRanges` at `currentPoint`. inactiveRanges.remove_if([&](LiveRange *inactiveRange) { // Check for live ranges that have expired. if (inactiveRange->end() <= currentPoint) { return true; } // Check for live ranges that have become active. if (inactiveRange->overlaps(currentPoint)) { tileAllocator.acquireTileId(inactiveRange->getTileType(), *inactiveRange->tileId); activeRanges.insert(inactiveRange); return true; } return false; }); // 3. Collect inactive live ranges that overlap with the new live range. // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint` // whereas this checks if there is an overlap at any future point too. SmallVector overlappingInactiveRanges; for (LiveRange *inactiveRange : inactiveRanges) { if (inactiveRange->overlaps(*nextRange)) { // We need to reserve the tile IDs of overlapping inactive ranges to // prevent two (overlapping) live ranges from getting the same tile ID. tileAllocator.acquireTileId(inactiveRange->getTileType(), *inactiveRange->tileId); overlappingInactiveRanges.push_back(inactiveRange); } } // 4. Allocate a tile ID to `nextRange`. auto rangeTileType = nextRange->getTileType(); auto tileId = tileAllocator.allocateTileId(rangeTileType); if (succeeded(tileId)) { nextRange->tileId = *tileId; } else { // Create an iterator over all overlapping live ranges. auto allOverlappingRanges = llvm::concat( llvm::make_pointee_range(activeRanges.getArrayRef()), llvm::make_pointee_range(overlappingInactiveRanges)); // Choose an overlapping live range to spill. LiveRange *rangeToSpill = chooseSpillUsingHeuristics(allOverlappingRanges, nextRange); if (rangeToSpill != nextRange) { // Spill an (in)active live range (so release its tile ID first). tileAllocator.releaseTileId(rangeToSpill->getTileType(), *rangeToSpill->tileId); // This will always succeed after a spill (of an active live range). nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType); // Remove the live range from the active/inactive sets. if (!activeRanges.remove(rangeToSpill)) { bool removed = inactiveRanges.remove(rangeToSpill); assert(removed && "expected a range to be removed!"); (void)removed; } } rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId(); } // 5. Insert the live range into the active ranges. if (nextRange->tileId < kInMemoryTileIdBase) activeRanges.insert(nextRange); // 6. Release tiles reserved for inactive live ranges (in step 3). for (LiveRange *range : overlappingInactiveRanges) { if (*range->tileId < kInMemoryTileIdBase) tileAllocator.releaseTileId(range->getTileType(), *range->tileId); } } } /// Assigns a tile ID to an MLIR value. void assignTileIdToValue(IRRewriter &rewriter, Value value, IntegerAttr tileIdAttr) { if (auto tileOp = value.getDefiningOp()) rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); for (Operation *user : value.getUsers()) { if (auto tileOp = dyn_cast(user)) { // Ensure ArmSME ops that don't produce a value still get a tile ID. if (!hasTileResult(tileOp)) rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); } } } /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts. LogicalResult assignTileIdsAndResolveTrivialConflicts( IRRewriter &rewriter, FunctionOpInterface function, ArrayRef allocatedLiveRanges) { for (LiveRange const *liveRange : allocatedLiveRanges) { auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId); auto isAllocatedToSameTile = [&](Value value) { if (auto tileOp = value.getDefiningOp(); tileOp && tileOp.getTileId() == tileIdAttr) return true; return liveRange->values.contains(value); }; /// Eliminates copies where the operand has the same tile ID. auto foldRedundantCopies = [&](Value value) -> LogicalResult { auto copyOp = value.getDefiningOp(); if (!copyOp || !isAllocatedToSameTile(copyOp.getTile())) return failure(); rewriter.replaceAllUsesWith(copyOp, copyOp.getTile()); return success(); }; /// Validates each predecessor to a tile block argument has been assigned /// the same tile ID. auto validateBlockArguments = [&](Value value) { auto blockArg = dyn_cast(value); if (!blockArg) { // Not a block argument (nothing to validate). return success(); } bool tileMismatch = false; forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) { if (tileMismatch) return; if (!isAllocatedToSameTile(predecessorTile)) { blockArg.getOwner()->getParentOp()->emitOpError( "block argument not allocated to the same SME virtial tile as " "predecessors"); tileMismatch = true; } }); return success(/*isSuccess=*/!tileMismatch); }; /// Attempts to resolve (trivial) tile ID conflicts. auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult { auto tileOp = value.getDefiningOp(); OpOperand *tileOperand = getTileOpOperand(tileOp); if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) { // Operand already allocated to the correct tile. // No conflict to resolve. return success(); } auto operandTileOp = tileOperand->get().getDefiningOp(); if (!isTriviallyCloneableTileOp(operandTileOp)) { auto error = tileOp.emitOpError("tile operand allocated to different SME " "virtial tile (move required)"); error.attachNote(tileOperand->get().getLoc()) << "tile operand is: " << tileOperand->get(); return error; } // Cloning prevents a move/spill (though may require recomputation). rewriter.setInsertionPoint(tileOp); auto clonedOp = operandTileOp.clone(); rewriter.modifyOpInPlace(clonedOp, [&] { clonedOp.setTileId(tileOp.getTileId()); }); rewriter.insert(clonedOp); if (isa(tileOp)) { rewriter.replaceAllUsesWith(tileOp->getResult(0), clonedOp->getResult(0)); } else { rewriter.modifyOpInPlace( tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); }); } return success(); }; for (Value value : liveRange->values) { // 1. Assign the tile ID to the value. assignTileIdToValue(rewriter, value, tileIdAttr); // 2. Attempt to eliminate redundant tile copies. if (succeeded(foldRedundantCopies(value))) continue; // 3. Validate tile block arguments. if (failed(validateBlockArguments(value))) return failure(); // 4. Attempt to resolve (trivial) tile ID conflicts. if (failed(resolveTrivialTileConflicts(value))) return failure(); } } return success(); } /// Prints live ranges alongside operation names for debugging. void dumpLiveRanges(DenseMap const &operationToIndexMap, ArrayRef liveRanges, FunctionOpInterface function) { llvm::errs() << "SME Tile Liveness: @" << function.getName() << "\nKey:\nS - Start\nE - End\n| - Live\n"; for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) { llvm::errs() << "^bb" << blockIdx << ":\n"; for (Operation &op : block.getOperations()) { unsigned operationIndex = operationToIndexMap.at(&op); for (LiveRange const *range : liveRanges) { char liveness = ' '; for (auto it = range->ranges->begin(); it != range->ranges->end(); ++it) { if (it.start() == operationIndex) liveness = (liveness == 'E' ? '|' : 'S'); else if (it.stop() == operationIndex) liveness = (liveness == 'S' ? '|' : 'E'); else if (operationIndex >= it.start() && operationIndex < it.stop()) liveness = '|'; } llvm::errs() << liveness; } llvm::errs() << ' ' << op.getName() << '\n'; } } llvm::errs() << "==========\n"; } struct TestTileAllocationPass : public arm_sme::impl::TestTileAllocationBase { using TestTileAllocationBase::TestTileAllocationBase; void runOnOperation() override { FunctionOpInterface function = getOperation(); if (preprocessOnly) { IRRewriter rewriter(function); return preprocessForTileAllocation(rewriter, function); } if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) signalPassFailure(); } }; } // namespace LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function, bool dumpRanges) { if (function.empty()) { // TODO: Also return early if the function contains no ArmSME ops? return success(); } LiveRange::Allocator liveRangeAllocator; IRRewriter rewriter(function.getContext()); // 1. Preprocess the IR for tile allocation. preprocessForTileAllocation(rewriter, function); // 2. Gather live ranges for each ArmSME tile within the function. Liveness liveness(function); auto operationToIndexMap = generateOperationNumbering(function); auto initialLiveRanges = gatherTileLiveRanges( operationToIndexMap, liveRangeAllocator, liveness, function); if (initialLiveRanges.empty()) return success(); if (dumpRanges) { // Wrangle initial live ranges into a form suitable for printing. auto nonEmpty = llvm::make_filter_range( llvm::make_second_range(initialLiveRanges), [&](LiveRange const &liveRange) { return !liveRange.empty(); }); auto initialRanges = llvm::to_vector(llvm::map_range( nonEmpty, [](LiveRange const &liveRange) { return &liveRange; })); std::sort(initialRanges.begin(), initialRanges.end(), [](LiveRange const *a, LiveRange const *b) { return *a < *b; }); llvm::errs() << "\n========== Initial Live Ranges:\n"; dumpLiveRanges(operationToIndexMap, initialRanges, function); } // 3. Coalesce (non-overlapping) live ranges where it would be beneficial // for tile allocation. E.g. Unify the result of an operation with its // operands. auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges); if (dumpRanges) { llvm::errs() << "\n========== Coalesced Live Ranges:\n"; dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function); } // 4. Allocate tile IDs to live ranges. allocateTilesToLiveRanges(coalescedLiveRanges); // 5. Assign the tile IDs back to the ArmSME operations. if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function, coalescedLiveRanges))) { return failure(); } // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no // users). This prevents the LLVM conversion needlessly inserting spills. eraseTriviallyDeadTileOps(rewriter, function); return success(); }