1 //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transform allocates SME tiles at the 'func.func' op level for ArmSME 10 // operations. It roughly implements a linear scan register allocator, similar 11 // to the one outlined in [1], but with simplifications and assumptions made for 12 // our use case. Note that this is a greedy allocator (so it may not always find 13 // the most optimal allocation of tiles). 14 // 15 // The allocator operates at the CF dialect level. It is the responsibility of 16 // users to ensure the IR has been lowered to CF before invoking the tile 17 // allocator. 18 // 19 // The 128-bit tiles overlap with other element tiles as follows (see section 20 // B2.3.2 of SME spec [2]): 21 // 22 // Tile Overlaps 23 // --------------------------------------------------------------------------- 24 // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, 25 // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q 26 // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q 27 // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q 28 // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q 29 // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q 30 // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q 31 // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q 32 // ZA0.D ZA0.Q, ZA8.Q 33 // ZA1.D ZA1.Q, ZA9.Q 34 // ZA2.D ZA2.Q, ZA10.Q 35 // ZA3.D ZA3.Q, ZA11.Q 36 // ZA4.D ZA4.Q, ZA12.Q 37 // ZA5.D ZA5.Q, ZA13.Q 38 // ZA6.D ZA6.Q, ZA14.Q 39 // ZA7.D ZA7.Q, ZA15.Q 40 // 41 // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register 42 // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer) 43 // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf 44 // [2] https://developer.arm.com/documentation/ddi0616/aa 45 // 46 //===----------------------------------------------------------------------===// 47 48 #include "mlir/Analysis/Liveness.h" 49 #include "mlir/Analysis/TopologicalSortUtils.h" 50 #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 51 #include "mlir/Dialect/ArmSME/Transforms/Passes.h" 52 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" 53 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 54 #include "mlir/Dialect/Func/IR/FuncOps.h" 55 #include "mlir/Transforms/RegionUtils.h" 56 #include "llvm/ADT/IntervalMap.h" 57 #include "llvm/ADT/TypeSwitch.h" 58 #include <algorithm> 59 60 namespace mlir::arm_sme { 61 #define GEN_PASS_DEF_TESTTILEALLOCATION 62 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" 63 } // namespace mlir::arm_sme 64 65 using namespace mlir; 66 using namespace mlir::arm_sme; 67 68 namespace { 69 70 enum class TileMask : unsigned { 71 // clang-format off 72 kZA0B = 0xffff, // 1111 1111 1111 1111 73 74 kZA0H = 0xaaaa, // 1010 1010 1010 1010 75 kZA1H = 0x5555, // 0101 0101 0101 0101 76 77 kZA0S = 0x8888, // 1000 1000 1000 1000 78 kZA1S = 0x4444, // 0100 0100 0100 0100 79 kZA2S = 0x2222, // 0010 0010 0010 0010 80 kZA3S = 0x1111, // 0001 0001 0001 0001 81 82 kZA0D = 0x8080, // 1000 0000 1000 0000 83 kZA1D = 0x4040, // 0100 0000 0100 0000 84 kZA2D = 0x2020, // 0010 0000 0010 0000 85 kZA3D = 0x1010, // 0001 0000 0001 0000 86 kZA4D = 0x808, // 0000 1000 0000 1000 87 kZA5D = 0x404, // 0000 0100 0000 0100 88 kZA6D = 0x202, // 0000 0010 0000 0010 89 kZA7D = 0x101, // 0000 0001 0000 0001 90 91 kZA0Q = 0x8000, // 1000 0000 0000 0000 92 kZA1Q = 0x4000, // 0100 0000 0000 0000 93 kZA2Q = 0x2000, // 0010 0000 0000 0000 94 kZA3Q = 0x1000, // 0001 0000 0000 0000 95 kZA4Q = 0x800, // 0000 1000 0000 0000 96 kZA5Q = 0x400, // 0000 0100 0000 0000 97 kZA6Q = 0x200, // 0000 0010 0000 0000 98 kZA7Q = 0x100, // 0000 0001 0000 0000 99 kZA8Q = 0x80, // 0000 0000 1000 0000 100 kZA9Q = 0x40, // 0000 0000 0100 0000 101 kZA10Q = 0x20, // 0000 0000 0010 0000 102 kZA11Q = 0x10, // 0000 0000 0001 0000 103 kZA12Q = 0x8, // 0000 0000 0000 1000 104 kZA13Q = 0x4, // 0000 0000 0000 0100 105 kZA14Q = 0x2, // 0000 0000 0000 0010 106 kZA15Q = 0x1, // 0000 0000 0000 0001 107 108 kNone = 0x0, // 0000 0000 0000 0000 109 // clang-format on 110 111 LLVM_MARK_AS_BITMASK_ENUM(kZA0B) 112 }; 113 114 /// Returns the set of masks relevant for the given type. 115 static ArrayRef<TileMask> getMasks(ArmSMETileType type) { 116 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B}; 117 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H}; 118 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S, 119 TileMask::kZA2S, TileMask::kZA3S}; 120 static constexpr std::array ZA_D_MASKS = { 121 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D, 122 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D}; 123 static constexpr std::array ZA_Q_MASKS = { 124 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q, 125 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q, 126 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q, 127 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q}; 128 switch (type) { 129 case ArmSMETileType::ZAB: 130 return ZA_B_MASKS; 131 case ArmSMETileType::ZAH: 132 return ZA_H_MASKS; 133 case ArmSMETileType::ZAS: 134 return ZA_S_MASKS; 135 case ArmSMETileType::ZAD: 136 return ZA_D_MASKS; 137 case ArmSMETileType::ZAQ: 138 return ZA_Q_MASKS; 139 } 140 llvm_unreachable("unknown type in getMasks"); 141 } 142 143 class TileAllocator { 144 public: 145 /// Allocates and returns a tile ID. Fails if there are no tiles left. 146 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) { 147 auto masks = getMasks(tileType); 148 for (auto [tileId, tileMask] : llvm::enumerate(masks)) { 149 if ((tilesInUse & tileMask) == TileMask::kNone) { 150 tilesInUse |= tileMask; 151 return tileId; 152 } 153 } 154 return failure(); 155 } 156 157 /// Acquires a specific tile ID. Asserts the tile is initially free. 158 void acquireTileId(ArmSMETileType tileType, unsigned tileId) { 159 TileMask tileMask = getMasks(tileType)[tileId]; 160 assert((tilesInUse & tileMask) == TileMask::kNone && 161 "cannot acquire allocated tile!"); 162 tilesInUse |= tileMask; 163 } 164 165 /// Releases a previously allocated tile ID. 166 void releaseTileId(ArmSMETileType tileType, unsigned tileId) { 167 TileMask tileMask = getMasks(tileType)[tileId]; 168 assert((tilesInUse & tileMask) == tileMask && 169 "cannot release unallocated tile!"); 170 tilesInUse ^= tileMask; 171 } 172 173 /// Allocates an in-memory tile ID. 174 unsigned allocateInMemoryTileId() { 175 // Note: We never release in-memory tile IDs. We could, which may allow 176 // reusing an allocation, but as we _never_ want to spill an SME tile this 177 // is not optimized. 178 return nextInMemoryTileId++; 179 } 180 181 private: 182 TileMask tilesInUse = TileMask::kNone; 183 unsigned nextInMemoryTileId = kInMemoryTileIdBase; 184 }; 185 186 /// Add new intermediate blocks for the true and false destinations of 187 /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness 188 /// overlaps due to copies at branches. 189 /// 190 /// BEFORE: 191 /// ```mlir 192 /// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2 193 /// ``` 194 /// 195 /// AFTER: 196 /// ```mlir 197 /// cf.cond_br %cond, ^bb1_copy, ^bb2_copy 198 /// ^bb1_copy: 199 /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) 200 /// ^bb2_copy: 201 /// cf.br ^bb2 202 /// ``` 203 void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { 204 SmallVector<cf::CondBranchOp> worklist; 205 function.walk([&](cf::CondBranchOp condBranch) { 206 if (llvm::any_of(condBranch->getOperands(), [&](Value value) { 207 return isValidSMETileVectorType(value.getType()); 208 })) { 209 worklist.push_back(condBranch); 210 } 211 }); 212 213 auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) { 214 rewriter.setInsertionPointToEnd(source); 215 rewriter.create<cf::BranchOp>(loc, dest, args); 216 }; 217 218 for (auto condBranch : worklist) { 219 auto loc = condBranch.getLoc(); 220 Block *block = condBranch->getBlock(); 221 auto newTrueBranch = rewriter.splitBlock(block, block->end()); 222 auto newFalseBranch = rewriter.splitBlock(block, block->end()); 223 insertJump(loc, newTrueBranch, condBranch.getTrueDest(), 224 condBranch.getTrueDestOperands()); 225 insertJump(loc, newFalseBranch, condBranch.getFalseDest(), 226 condBranch.getFalseDestOperands()); 227 rewriter.modifyOpInPlace(condBranch, [&] { 228 condBranch.getFalseDestOperandsMutable().clear(); 229 condBranch.getTrueDestOperandsMutable().clear(); 230 condBranch.setSuccessor(newTrueBranch, 0); 231 condBranch.setSuccessor(newFalseBranch, 1); 232 }); 233 } 234 } 235 236 /// Inserts tile copies at `cf.br` operations. 237 /// 238 /// BEFORE: 239 /// ```mlir 240 /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) 241 /// ``` 242 /// 243 /// AFTER: 244 /// ```mlir 245 /// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> 246 /// cf.br ^bb1(%copy: vector<[4]x[4]xf32>) 247 /// ``` 248 void insertCopiesAtBranches(IRRewriter &rewriter, 249 FunctionOpInterface function) { 250 for (Block &block : function.getBlocks()) { 251 Operation *terminator = block.getTerminator(); 252 if (!isa<cf::BranchOp>(terminator)) 253 continue; 254 rewriter.setInsertionPoint(terminator); 255 for (OpOperand &operand : terminator->getOpOperands()) { 256 if (isValidSMETileVectorType(operand.get().getType())) { 257 auto copy = 258 rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get()); 259 rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); }); 260 } 261 } 262 } 263 } 264 265 /// Prepares the IR for tile allocation. It does this by first 'splitting' 266 /// conditional branches (see `splitCondBranches`), then inserting tile copies 267 /// at branch operations. The conditional branches are split to prevent the 268 /// copies needed for them overlapping between the true and false paths of the 269 /// branch (see `tile-allocation-copies.mlir` and 270 /// `tile-allocation-liveness.mlir` for examples). The copies break up live 271 /// ranges and ensure when moving out of SSA the semantics of the program are 272 /// preserved. 273 void preprocessForTileAllocation(IRRewriter &rewriter, 274 FunctionOpInterface function) { 275 splitCondBranches(rewriter, function); 276 insertCopiesAtBranches(rewriter, function); 277 } 278 279 /// A live range for a (collection of) tile values. A live range is built up of 280 /// non-overlapping intervals [start, end) which represent parts of the program 281 /// where a value in the range needs to be live (i.e. in an SME virtual tile). 282 /// Note that as the intervals are non-overlapping all values within a live 283 /// range can be allocated to the same SME virtual tile. 284 struct LiveRange { 285 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16, 286 llvm::IntervalMapHalfOpenInfo<unsigned>>; 287 using Allocator = RangeSet::Allocator; 288 // Dummy value for the IntervalMap. Only the keys matter (the intervals). 289 static constexpr uint8_t kValidLiveRange = 0xff; 290 291 LiveRange(Allocator &allocator) 292 : ranges(std::make_unique<RangeSet>(allocator)) {} 293 294 /// Returns true if this range overlaps with `otherRange`. 295 bool overlaps(LiveRange const &otherRange) const { 296 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges, 297 *otherRange.ranges) 298 .valid(); 299 } 300 301 /// Returns true if this range is active at `point` in the program. 302 bool overlaps(uint64_t point) const { 303 return ranges->lookup(point) == kValidLiveRange; 304 } 305 306 /// Unions this live range with `otherRange`, aborts if the ranges overlap. 307 void unionWith(LiveRange const &otherRange) { 308 for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end(); 309 ++it) 310 ranges->insert(it.start(), it.stop(), kValidLiveRange); 311 values.set_union(otherRange.values); 312 } 313 314 /// Inserts an interval [start, end) for `value` into this range. 315 void insert(Value value, unsigned start, unsigned end) { 316 values.insert(value); 317 if (start != end) 318 ranges->insert(start, end, kValidLiveRange); 319 } 320 321 bool empty() const { return ranges->empty(); } 322 unsigned start() const { return ranges->start(); } 323 unsigned end() const { return ranges->stop(); } 324 bool operator<(LiveRange const &other) const { 325 return start() < other.start(); 326 } 327 328 ArmSMETileType getTileType() const { 329 return *getSMETileType(cast<VectorType>(values[0].getType())); 330 } 331 332 /// The values contained in this live range. 333 SetVector<Value> values; 334 335 /// A set of (non-overlapping) intervals that mark where any value in `values` 336 /// is live. 337 std::unique_ptr<RangeSet> ranges; 338 339 /// The tile ID (or none) assigned to this live range. 340 std::optional<unsigned> tileId; 341 }; 342 343 /// Number operations within a function to allow computing live ranges. 344 /// Operations are numbered consecutively wihin blocks, and the blocks are 345 /// topologically sorted (using forward edges). This function is only correct if 346 /// all ArmSME have been converted to CF (which is asserted). 347 DenseMap<Operation *, unsigned> 348 generateOperationNumbering(FunctionOpInterface function) { 349 unsigned index = 0; 350 SetVector<Block *> blocks = 351 getBlocksSortedByDominance(function.getFunctionBody()); 352 DenseMap<Operation *, unsigned> operationToIndexMap; 353 for (Block *block : blocks) { 354 index++; // We want block args to have their own number. 355 for (Operation &op : block->getOperations()) { 356 #ifndef NDEBUG 357 op.walk([&](ArmSMETileOpInterface nestedOp) { 358 assert(&op == nestedOp.getOperation() && 359 "ArmSME tile allocation does not support nested regions"); 360 }); 361 #endif 362 operationToIndexMap.try_emplace(&op, index++); 363 } 364 } 365 return operationToIndexMap; 366 } 367 368 /// Gather live ranges for SME tiles from the MLIR liveness analysis. 369 DenseMap<Value, LiveRange> 370 gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, 371 LiveRange::Allocator &liveRangeAllocator, 372 Liveness &liveness, FunctionOpInterface function) { 373 assert(!operationToIndexMap.empty() && "expected operation numbering"); 374 DenseMap<Value, LiveRange> liveRanges; 375 /// Defines or updates a live range for an SME tile value. Live-ins may update 376 /// an existing live range (rather than define a new one). Note: If 377 /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in 378 /// the block. 379 auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef, 380 LivenessBlockInfo const &livenessInfo, 381 bool liveAtBlockEntry = false) { 382 if (!isValidSMETileVectorType(value.getType())) 383 return; 384 // Find or create a live range for `value`. 385 auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator); 386 LiveRange &valueLiveRange = it->second; 387 auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); 388 // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. 389 unsigned startOpIdx = 390 operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); 391 unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock); 392 valueLiveRange.insert(value, startOpIdx, endOpIdx); 393 }; 394 395 for (Block &block : function.getBlocks()) { 396 LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block); 397 // Handle block arguments: 398 for (Value argument : block.getArguments()) 399 defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo, 400 /*liveAtBlockEntry=*/true); 401 // Handle live-ins: 402 for (Value liveIn : livenessInfo->in()) 403 defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo, 404 /*liveAtBlockEntry=*/true); 405 // Handle new definitions: 406 for (Operation &op : block) { 407 for (Value result : op.getResults()) 408 defineOrUpdateValueLiveRange(result, &op, *livenessInfo); 409 } 410 } 411 412 return liveRanges; 413 } 414 415 /// Iterate over all predecessor tile values to a (tile) block argument. 416 static void forEachPredecessorTileValue(BlockArgument blockArg, 417 function_ref<void(Value)> callback) { 418 Block *block = blockArg.getOwner(); 419 unsigned argNumber = blockArg.getArgNumber(); 420 for (Block *pred : block->getPredecessors()) { 421 TypeSwitch<Operation *>(pred->getTerminator()) 422 .Case<cf::BranchOp>([&](auto branch) { 423 Value predecessorOperand = branch.getDestOperands()[argNumber]; 424 callback(predecessorOperand); 425 }) 426 .Case<cf::CondBranchOp>([&](auto condBranch) { 427 if (condBranch.getFalseDest() == block) { 428 Value predecessorOperand = 429 condBranch.getFalseDestOperands()[argNumber]; 430 callback(predecessorOperand); 431 } 432 if (condBranch.getTrueDest() == block) { 433 Value predecessorOperand = 434 condBranch.getTrueDestOperands()[argNumber]; 435 callback(predecessorOperand); 436 } 437 }); 438 } 439 } 440 441 /// Coalesce live ranges where it would prevent unnecessary tile moves. 442 SmallVector<LiveRange *> 443 coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) { 444 DenseMap<Value, LiveRange *> liveRanges; 445 for (auto &[value, liveRange] : initialLiveRanges) { 446 liveRanges.insert({value, &liveRange}); 447 } 448 449 // Merge the live ranges of values `a` and `b` into one (if they do not 450 // overlap). After this, the values `a` and `b` will both point to the same 451 // live range (which will contain multiple values). 452 auto mergeValuesIfNonOverlapping = [&](Value a, Value b) { 453 LiveRange *aLiveRange = liveRanges.at(a); 454 LiveRange *bLiveRange = liveRanges.at(b); 455 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) { 456 aLiveRange->unionWith(*bLiveRange); 457 for (Value value : bLiveRange->values) 458 liveRanges[value] = aLiveRange; 459 } 460 }; 461 462 // Merge the live ranges of new definitions with their tile operands. 463 auto unifyDefinitionsWithOperands = [&](Value value) { 464 auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>(); 465 if (!armSMEOp) 466 return; 467 for (auto operand : armSMEOp->getOperands()) { 468 if (isValidSMETileVectorType(operand.getType())) 469 mergeValuesIfNonOverlapping(value, operand); 470 } 471 }; 472 473 // Merge the live ranges of block arguments with their predecessors. 474 auto unifyBlockArgumentsWithPredecessors = [&](Value value) { 475 auto blockArg = dyn_cast<BlockArgument>(value); 476 if (!blockArg) 477 return; 478 forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) { 479 mergeValuesIfNonOverlapping(blockArg, predecessorTile); 480 }); 481 }; 482 483 auto applyRule = [&](auto rule) { 484 llvm::for_each(llvm::make_first_range(initialLiveRanges), rule); 485 }; 486 487 // Unify as many live ranges as we can. This prevents unnecessary moves. 488 applyRule(unifyBlockArgumentsWithPredecessors); 489 applyRule(unifyDefinitionsWithOperands); 490 491 // Remove duplicate live range entries. 492 SetVector<LiveRange *> uniqueLiveRanges; 493 for (auto [_, liveRange] : liveRanges) { 494 if (!liveRange->empty()) 495 uniqueLiveRanges.insert(liveRange); 496 } 497 498 // Sort the new live ranges by starting point (ready for tile allocation). 499 auto coalescedLiveRanges = uniqueLiveRanges.takeVector(); 500 std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(), 501 [](LiveRange *a, LiveRange *b) { return *a < *b; }); 502 return std::move(coalescedLiveRanges); 503 } 504 505 /// Choose a live range to spill (via some heuristics). This picks either a live 506 /// range from `overlappingRanges`, or the new live range `newRange`. 507 template <typename OverlappingRangesIterator> 508 LiveRange * 509 chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges, 510 LiveRange *newRange) { 511 // Heuristic: Spill trivially copyable operations (usually free). 512 auto isTrivialSpill = [&](LiveRange &allocatedRange) { 513 return isTileTypeGreaterOrEqual(allocatedRange.getTileType(), 514 newRange->getTileType()) && 515 allocatedRange.values.size() == 1 && 516 isTriviallyCloneableTileOp( 517 allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>()); 518 }; 519 if (isTrivialSpill(*newRange)) 520 return newRange; 521 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill); 522 if (trivialSpill != overlappingRanges.end()) 523 return &*trivialSpill; 524 525 // Heuristic: Spill the range that ends last (with a compatible tile type). 526 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) { 527 return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) || 528 a.end() < b.end(); 529 }; 530 LiveRange &latestEndingLiveRange = 531 *std::max_element(overlappingRanges.begin(), overlappingRanges.end(), 532 isSmallerTileTypeOrEndsEarlier); 533 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange)) 534 return &latestEndingLiveRange; 535 return newRange; 536 } 537 538 /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics. 539 void allocateTilesToLiveRanges( 540 ArrayRef<LiveRange *> liveRangesSortedByStartPoint) { 541 TileAllocator tileAllocator; 542 // `activeRanges` = Live ranges that need to be in a tile at the 543 // `currentPoint` in the program. 544 SetVector<LiveRange *> activeRanges; 545 // `inactiveRanges` = Live ranges that _do not_ need to be in a tile 546 // at the `currentPoint` in the program but could become active again later. 547 // An inactive section of a live range can be seen as a 'hole' in the live 548 // range, where it is possible to reuse the live range's tile ID _before_ it 549 // has ended. By identifying 'holes', the allocator can reuse tiles more 550 // often, which helps avoid costly tile spills. 551 SetVector<LiveRange *> inactiveRanges; 552 for (LiveRange *nextRange : liveRangesSortedByStartPoint) { 553 auto currentPoint = nextRange->start(); 554 // 1. Update the `activeRanges` at `currentPoint`. 555 activeRanges.remove_if([&](LiveRange *activeRange) { 556 // Check for live ranges that have expired. 557 if (activeRange->end() <= currentPoint) { 558 tileAllocator.releaseTileId(activeRange->getTileType(), 559 *activeRange->tileId); 560 return true; 561 } 562 // Check for live ranges that have become inactive. 563 if (!activeRange->overlaps(currentPoint)) { 564 tileAllocator.releaseTileId(activeRange->getTileType(), 565 *activeRange->tileId); 566 inactiveRanges.insert(activeRange); 567 return true; 568 } 569 return false; 570 }); 571 // 2. Update the `inactiveRanges` at `currentPoint`. 572 inactiveRanges.remove_if([&](LiveRange *inactiveRange) { 573 // Check for live ranges that have expired. 574 if (inactiveRange->end() <= currentPoint) { 575 return true; 576 } 577 // Check for live ranges that have become active. 578 if (inactiveRange->overlaps(currentPoint)) { 579 tileAllocator.acquireTileId(inactiveRange->getTileType(), 580 *inactiveRange->tileId); 581 activeRanges.insert(inactiveRange); 582 return true; 583 } 584 return false; 585 }); 586 587 // 3. Collect inactive live ranges that overlap with the new live range. 588 // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint` 589 // whereas this checks if there is an overlap at any future point too. 590 SmallVector<LiveRange *> overlappingInactiveRanges; 591 for (LiveRange *inactiveRange : inactiveRanges) { 592 if (inactiveRange->overlaps(*nextRange)) { 593 // We need to reserve the tile IDs of overlapping inactive ranges to 594 // prevent two (overlapping) live ranges from getting the same tile ID. 595 tileAllocator.acquireTileId(inactiveRange->getTileType(), 596 *inactiveRange->tileId); 597 overlappingInactiveRanges.push_back(inactiveRange); 598 } 599 } 600 601 // 4. Allocate a tile ID to `nextRange`. 602 auto rangeTileType = nextRange->getTileType(); 603 auto tileId = tileAllocator.allocateTileId(rangeTileType); 604 if (succeeded(tileId)) { 605 nextRange->tileId = *tileId; 606 } else { 607 // Create an iterator over all overlapping live ranges. 608 auto allOverlappingRanges = llvm::concat<LiveRange>( 609 llvm::make_pointee_range(activeRanges.getArrayRef()), 610 llvm::make_pointee_range(overlappingInactiveRanges)); 611 // Choose an overlapping live range to spill. 612 LiveRange *rangeToSpill = 613 chooseSpillUsingHeuristics(allOverlappingRanges, nextRange); 614 if (rangeToSpill != nextRange) { 615 // Spill an (in)active live range (so release its tile ID first). 616 tileAllocator.releaseTileId(rangeToSpill->getTileType(), 617 *rangeToSpill->tileId); 618 // This will always succeed after a spill (of an active live range). 619 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType); 620 // Remove the live range from the active/inactive sets. 621 if (!activeRanges.remove(rangeToSpill)) { 622 bool removed = inactiveRanges.remove(rangeToSpill); 623 assert(removed && "expected a range to be removed!"); 624 (void)removed; 625 } 626 } 627 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId(); 628 } 629 630 // 5. Insert the live range into the active ranges. 631 if (nextRange->tileId < kInMemoryTileIdBase) 632 activeRanges.insert(nextRange); 633 634 // 6. Release tiles reserved for inactive live ranges (in step 3). 635 for (LiveRange *range : overlappingInactiveRanges) { 636 if (*range->tileId < kInMemoryTileIdBase) 637 tileAllocator.releaseTileId(range->getTileType(), *range->tileId); 638 } 639 } 640 } 641 642 /// Assigns a tile ID to an MLIR value. 643 void assignTileIdToValue(IRRewriter &rewriter, Value value, 644 IntegerAttr tileIdAttr) { 645 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) 646 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); 647 for (Operation *user : value.getUsers()) { 648 if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) { 649 // Ensure ArmSME ops that don't produce a value still get a tile ID. 650 if (!hasTileResult(tileOp)) 651 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); 652 } 653 } 654 } 655 656 /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts. 657 LogicalResult assignTileIdsAndResolveTrivialConflicts( 658 IRRewriter &rewriter, FunctionOpInterface function, 659 ArrayRef<LiveRange *> allocatedLiveRanges) { 660 for (LiveRange const *liveRange : allocatedLiveRanges) { 661 auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId); 662 auto isAllocatedToSameTile = [&](Value value) { 663 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>(); 664 tileOp && tileOp.getTileId() == tileIdAttr) 665 return true; 666 return liveRange->values.contains(value); 667 }; 668 669 /// Eliminates copies where the operand has the same tile ID. 670 auto foldRedundantCopies = [&](Value value) -> LogicalResult { 671 auto copyOp = value.getDefiningOp<CopyTileOp>(); 672 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile())) 673 return failure(); 674 rewriter.replaceAllUsesWith(copyOp, copyOp.getTile()); 675 return success(); 676 }; 677 678 /// Validates each predecessor to a tile block argument has been assigned 679 /// the same tile ID. 680 auto validateBlockArguments = [&](Value value) { 681 auto blockArg = dyn_cast<BlockArgument>(value); 682 if (!blockArg) { 683 // Not a block argument (nothing to validate). 684 return success(); 685 } 686 bool tileMismatch = false; 687 forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) { 688 if (tileMismatch) 689 return; 690 if (!isAllocatedToSameTile(predecessorTile)) { 691 blockArg.getOwner()->getParentOp()->emitOpError( 692 "block argument not allocated to the same SME virtial tile as " 693 "predecessors"); 694 tileMismatch = true; 695 } 696 }); 697 return success(/*isSuccess=*/!tileMismatch); 698 }; 699 700 /// Attempts to resolve (trivial) tile ID conflicts. 701 auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult { 702 auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>(); 703 OpOperand *tileOperand = getTileOpOperand(tileOp); 704 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) { 705 // Operand already allocated to the correct tile. 706 // No conflict to resolve. 707 return success(); 708 } 709 auto operandTileOp = 710 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>(); 711 if (!isTriviallyCloneableTileOp(operandTileOp)) { 712 auto error = 713 tileOp.emitOpError("tile operand allocated to different SME " 714 "virtial tile (move required)"); 715 error.attachNote(tileOperand->get().getLoc()) 716 << "tile operand is: " << tileOperand->get(); 717 return error; 718 } 719 // Cloning prevents a move/spill (though may require recomputation). 720 rewriter.setInsertionPoint(tileOp); 721 auto clonedOp = operandTileOp.clone(); 722 rewriter.modifyOpInPlace(clonedOp, 723 [&] { clonedOp.setTileId(tileOp.getTileId()); }); 724 rewriter.insert(clonedOp); 725 if (isa<CopyTileOp>(tileOp)) { 726 rewriter.replaceAllUsesWith(tileOp->getResult(0), 727 clonedOp->getResult(0)); 728 } else { 729 rewriter.modifyOpInPlace( 730 tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); }); 731 } 732 return success(); 733 }; 734 735 for (Value value : liveRange->values) { 736 // 1. Assign the tile ID to the value. 737 assignTileIdToValue(rewriter, value, tileIdAttr); 738 739 // 2. Attempt to eliminate redundant tile copies. 740 if (succeeded(foldRedundantCopies(value))) 741 continue; 742 743 // 3. Validate tile block arguments. 744 if (failed(validateBlockArguments(value))) 745 return failure(); 746 747 // 4. Attempt to resolve (trivial) tile ID conflicts. 748 if (failed(resolveTrivialTileConflicts(value))) 749 return failure(); 750 } 751 } 752 return success(); 753 } 754 755 /// Prints live ranges alongside operation names for debugging. 756 void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, 757 ArrayRef<LiveRange const *> liveRanges, 758 FunctionOpInterface function) { 759 llvm::errs() << "SME Tile Liveness: @" << function.getName() 760 << "\nKey:\nS - Start\nE - End\n| - Live\n"; 761 for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) { 762 llvm::errs() << "^bb" << blockIdx << ":\n"; 763 for (Operation &op : block.getOperations()) { 764 unsigned operationIndex = operationToIndexMap.at(&op); 765 for (LiveRange const *range : liveRanges) { 766 char liveness = ' '; 767 for (auto it = range->ranges->begin(); it != range->ranges->end(); 768 ++it) { 769 if (it.start() == operationIndex) 770 liveness = (liveness == 'E' ? '|' : 'S'); 771 else if (it.stop() == operationIndex) 772 liveness = (liveness == 'S' ? '|' : 'E'); 773 else if (operationIndex >= it.start() && operationIndex < it.stop()) 774 liveness = '|'; 775 } 776 llvm::errs() << liveness; 777 } 778 llvm::errs() << ' ' << op.getName() << '\n'; 779 } 780 } 781 llvm::errs() << "==========\n"; 782 } 783 784 struct TestTileAllocationPass 785 : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> { 786 using TestTileAllocationBase::TestTileAllocationBase; 787 void runOnOperation() override { 788 FunctionOpInterface function = getOperation(); 789 if (preprocessOnly) { 790 IRRewriter rewriter(function); 791 return preprocessForTileAllocation(rewriter, function); 792 } 793 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) 794 signalPassFailure(); 795 } 796 }; 797 } // namespace 798 799 LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function, 800 bool dumpRanges) { 801 if (function.empty()) { 802 // TODO: Also return early if the function contains no ArmSME ops? 803 return success(); 804 } 805 806 LiveRange::Allocator liveRangeAllocator; 807 IRRewriter rewriter(function.getContext()); 808 809 // 1. Preprocess the IR for tile allocation. 810 preprocessForTileAllocation(rewriter, function); 811 812 // 2. Gather live ranges for each ArmSME tile within the function. 813 Liveness liveness(function); 814 auto operationToIndexMap = generateOperationNumbering(function); 815 auto initialLiveRanges = gatherTileLiveRanges( 816 operationToIndexMap, liveRangeAllocator, liveness, function); 817 if (initialLiveRanges.empty()) 818 return success(); 819 820 if (dumpRanges) { 821 // Wrangle initial live ranges into a form suitable for printing. 822 auto nonEmpty = llvm::make_filter_range( 823 llvm::make_second_range(initialLiveRanges), 824 [&](LiveRange const &liveRange) { return !liveRange.empty(); }); 825 auto initialRanges = llvm::to_vector(llvm::map_range( 826 nonEmpty, [](LiveRange const &liveRange) { return &liveRange; })); 827 std::sort(initialRanges.begin(), initialRanges.end(), 828 [](LiveRange const *a, LiveRange const *b) { return *a < *b; }); 829 llvm::errs() << "\n========== Initial Live Ranges:\n"; 830 dumpLiveRanges(operationToIndexMap, initialRanges, function); 831 } 832 833 // 3. Coalesce (non-overlapping) live ranges where it would be beneficial 834 // for tile allocation. E.g. Unify the result of an operation with its 835 // operands. 836 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges); 837 838 if (dumpRanges) { 839 llvm::errs() << "\n========== Coalesced Live Ranges:\n"; 840 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function); 841 } 842 843 // 4. Allocate tile IDs to live ranges. 844 allocateTilesToLiveRanges(coalescedLiveRanges); 845 846 // 5. Assign the tile IDs back to the ArmSME operations. 847 if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function, 848 coalescedLiveRanges))) { 849 return failure(); 850 } 851 852 // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no 853 // users). This prevents the LLVM conversion needlessly inserting spills. 854 eraseTriviallyDeadTileOps(rewriter, function); 855 return success(); 856 } 857