xref: /llvm-project/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (revision d5746d73cedcf7a593dc4b4f2ce2465e2d45750b)
1 //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transform allocates SME tiles at the 'func.func' op level for ArmSME
10 // operations. It roughly implements a linear scan register allocator, similar
11 // to the one outlined in [1], but with simplifications and assumptions made for
12 // our use case. Note that this is a greedy allocator (so it may not always find
13 // the most optimal allocation of tiles).
14 //
15 // The allocator operates at the CF dialect level. It is the responsibility of
16 // users to ensure the IR has been lowered to CF before invoking the tile
17 // allocator.
18 //
19 // The 128-bit tiles overlap with other element tiles as follows (see section
20 // B2.3.2 of SME spec [2]):
21 //
22 //   Tile    Overlaps
23 //   ---------------------------------------------------------------------------
24 //   ZA0.B   ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25 //           ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26 //   ZA0.H   ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27 //   ZA1.H   ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28 //   ZA0.S   ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29 //   ZA1.S   ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30 //   ZA2.S   ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31 //   ZA3.S   ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32 //   ZA0.D   ZA0.Q, ZA8.Q
33 //   ZA1.D   ZA1.Q, ZA9.Q
34 //   ZA2.D   ZA2.Q, ZA10.Q
35 //   ZA3.D   ZA3.Q, ZA11.Q
36 //   ZA4.D   ZA4.Q, ZA12.Q
37 //   ZA5.D   ZA5.Q, ZA13.Q
38 //   ZA6.D   ZA6.Q, ZA14.Q
39 //   ZA7.D   ZA7.Q, ZA15.Q
40 //
41 // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42 //      Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43 //     https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44 // [2] https://developer.arm.com/documentation/ddi0616/aa
45 //
46 //===----------------------------------------------------------------------===//
47 
48 #include "mlir/Analysis/Liveness.h"
49 #include "mlir/Analysis/TopologicalSortUtils.h"
50 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
51 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
52 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
53 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
54 #include "mlir/Dialect/Func/IR/FuncOps.h"
55 #include "mlir/Transforms/RegionUtils.h"
56 #include "llvm/ADT/IntervalMap.h"
57 #include "llvm/ADT/TypeSwitch.h"
58 #include <algorithm>
59 
60 namespace mlir::arm_sme {
61 #define GEN_PASS_DEF_TESTTILEALLOCATION
62 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
63 } // namespace mlir::arm_sme
64 
65 using namespace mlir;
66 using namespace mlir::arm_sme;
67 
68 namespace {
69 
70 enum class TileMask : unsigned {
71   // clang-format off
72   kZA0B  = 0xffff, // 1111 1111 1111 1111
73 
74   kZA0H  = 0xaaaa, // 1010 1010 1010 1010
75   kZA1H  = 0x5555, // 0101 0101 0101 0101
76 
77   kZA0S  = 0x8888, // 1000 1000 1000 1000
78   kZA1S  = 0x4444, // 0100 0100 0100 0100
79   kZA2S  = 0x2222, // 0010 0010 0010 0010
80   kZA3S  = 0x1111, // 0001 0001 0001 0001
81 
82   kZA0D  = 0x8080, // 1000 0000 1000 0000
83   kZA1D  = 0x4040, // 0100 0000 0100 0000
84   kZA2D  = 0x2020, // 0010 0000 0010 0000
85   kZA3D  = 0x1010, // 0001 0000 0001 0000
86   kZA4D  = 0x808,  // 0000 1000 0000 1000
87   kZA5D  = 0x404,  // 0000 0100 0000 0100
88   kZA6D  = 0x202,  // 0000 0010 0000 0010
89   kZA7D  = 0x101,  // 0000 0001 0000 0001
90 
91   kZA0Q  = 0x8000, // 1000 0000 0000 0000
92   kZA1Q  = 0x4000, // 0100 0000 0000 0000
93   kZA2Q  = 0x2000, // 0010 0000 0000 0000
94   kZA3Q  = 0x1000, // 0001 0000 0000 0000
95   kZA4Q  = 0x800,  // 0000 1000 0000 0000
96   kZA5Q  = 0x400,  // 0000 0100 0000 0000
97   kZA6Q  = 0x200,  // 0000 0010 0000 0000
98   kZA7Q  = 0x100,  // 0000 0001 0000 0000
99   kZA8Q  = 0x80,   // 0000 0000 1000 0000
100   kZA9Q  = 0x40,   // 0000 0000 0100 0000
101   kZA10Q = 0x20,   // 0000 0000 0010 0000
102   kZA11Q = 0x10,   // 0000 0000 0001 0000
103   kZA12Q = 0x8,    // 0000 0000 0000 1000
104   kZA13Q = 0x4,    // 0000 0000 0000 0100
105   kZA14Q = 0x2,    // 0000 0000 0000 0010
106   kZA15Q = 0x1,    // 0000 0000 0000 0001
107 
108   kNone = 0x0,     // 0000 0000 0000 0000
109   // clang-format on
110 
111   LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
112 };
113 
114 /// Returns the set of masks relevant for the given type.
115 static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
116   static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
117   static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
118   static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
119                                             TileMask::kZA2S, TileMask::kZA3S};
120   static constexpr std::array ZA_D_MASKS = {
121       TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
122       TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
123   static constexpr std::array ZA_Q_MASKS = {
124       TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
125       TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
126       TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
127       TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
128   switch (type) {
129   case ArmSMETileType::ZAB:
130     return ZA_B_MASKS;
131   case ArmSMETileType::ZAH:
132     return ZA_H_MASKS;
133   case ArmSMETileType::ZAS:
134     return ZA_S_MASKS;
135   case ArmSMETileType::ZAD:
136     return ZA_D_MASKS;
137   case ArmSMETileType::ZAQ:
138     return ZA_Q_MASKS;
139   }
140   llvm_unreachable("unknown type in getMasks");
141 }
142 
143 class TileAllocator {
144 public:
145   /// Allocates and returns a tile ID. Fails if there are no tiles left.
146   FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
147     auto masks = getMasks(tileType);
148     for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
149       if ((tilesInUse & tileMask) == TileMask::kNone) {
150         tilesInUse |= tileMask;
151         return tileId;
152       }
153     }
154     return failure();
155   }
156 
157   /// Acquires a specific tile ID. Asserts the tile is initially free.
158   void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
159     TileMask tileMask = getMasks(tileType)[tileId];
160     assert((tilesInUse & tileMask) == TileMask::kNone &&
161            "cannot acquire allocated tile!");
162     tilesInUse |= tileMask;
163   }
164 
165   /// Releases a previously allocated tile ID.
166   void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
167     TileMask tileMask = getMasks(tileType)[tileId];
168     assert((tilesInUse & tileMask) == tileMask &&
169            "cannot release unallocated tile!");
170     tilesInUse ^= tileMask;
171   }
172 
173   /// Allocates an in-memory tile ID.
174   unsigned allocateInMemoryTileId() {
175     // Note: We never release in-memory tile IDs. We could, which may allow
176     // reusing an allocation, but as we _never_ want to spill an SME tile this
177     // is not optimized.
178     return nextInMemoryTileId++;
179   }
180 
181 private:
182   TileMask tilesInUse = TileMask::kNone;
183   unsigned nextInMemoryTileId = kInMemoryTileIdBase;
184 };
185 
186 /// Add new intermediate blocks for the true and false destinations of
187 /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
188 /// overlaps due to copies at branches.
189 ///
190 ///  BEFORE:
191 ///  ```mlir
192 ///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
193 ///  ```
194 ///
195 ///  AFTER:
196 ///  ```mlir
197 ///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
198 ///  ^bb1_copy:
199 ///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
200 ///  ^bb2_copy:
201 ///    cf.br ^bb2
202 ///  ```
203 void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
204   SmallVector<cf::CondBranchOp> worklist;
205   function.walk([&](cf::CondBranchOp condBranch) {
206     if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
207           return isValidSMETileVectorType(value.getType());
208         })) {
209       worklist.push_back(condBranch);
210     }
211   });
212 
213   auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
214     rewriter.setInsertionPointToEnd(source);
215     rewriter.create<cf::BranchOp>(loc, dest, args);
216   };
217 
218   for (auto condBranch : worklist) {
219     auto loc = condBranch.getLoc();
220     Block *block = condBranch->getBlock();
221     auto newTrueBranch = rewriter.splitBlock(block, block->end());
222     auto newFalseBranch = rewriter.splitBlock(block, block->end());
223     insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
224                condBranch.getTrueDestOperands());
225     insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
226                condBranch.getFalseDestOperands());
227     rewriter.modifyOpInPlace(condBranch, [&] {
228       condBranch.getFalseDestOperandsMutable().clear();
229       condBranch.getTrueDestOperandsMutable().clear();
230       condBranch.setSuccessor(newTrueBranch, 0);
231       condBranch.setSuccessor(newFalseBranch, 1);
232     });
233   }
234 }
235 
236 /// Inserts tile copies at `cf.br` operations.
237 ///
238 ///  BEFORE:
239 ///  ```mlir
240 ///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
241 ///  ```
242 ///
243 ///  AFTER:
244 ///  ```mlir
245 ///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
246 ///  cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
247 ///  ```
248 void insertCopiesAtBranches(IRRewriter &rewriter,
249                             FunctionOpInterface function) {
250   for (Block &block : function.getBlocks()) {
251     Operation *terminator = block.getTerminator();
252     if (!isa<cf::BranchOp>(terminator))
253       continue;
254     rewriter.setInsertionPoint(terminator);
255     for (OpOperand &operand : terminator->getOpOperands()) {
256       if (isValidSMETileVectorType(operand.get().getType())) {
257         auto copy =
258             rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
259         rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
260       }
261     }
262   }
263 }
264 
265 /// Prepares the IR for tile allocation. It does this by first 'splitting'
266 /// conditional branches (see `splitCondBranches`), then inserting tile copies
267 /// at branch operations. The conditional branches are split to prevent the
268 /// copies needed for them overlapping between the true and false paths of the
269 /// branch (see `tile-allocation-copies.mlir` and
270 /// `tile-allocation-liveness.mlir` for examples). The copies break up live
271 /// ranges and ensure when moving out of SSA the semantics of the program are
272 /// preserved.
273 void preprocessForTileAllocation(IRRewriter &rewriter,
274                                  FunctionOpInterface function) {
275   splitCondBranches(rewriter, function);
276   insertCopiesAtBranches(rewriter, function);
277 }
278 
279 /// A live range for a (collection of) tile values. A live range is built up of
280 /// non-overlapping intervals [start, end) which represent parts of the program
281 /// where a value in the range needs to be live (i.e. in an SME virtual tile).
282 /// Note that as the intervals are non-overlapping all values within a live
283 /// range can be allocated to the same SME virtual tile.
284 struct LiveRange {
285   using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
286                                      llvm::IntervalMapHalfOpenInfo<unsigned>>;
287   using Allocator = RangeSet::Allocator;
288   // Dummy value for the IntervalMap. Only the keys matter (the intervals).
289   static constexpr uint8_t kValidLiveRange = 0xff;
290 
291   LiveRange(Allocator &allocator)
292       : ranges(std::make_unique<RangeSet>(allocator)) {}
293 
294   /// Returns true if this range overlaps with `otherRange`.
295   bool overlaps(LiveRange const &otherRange) const {
296     return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
297                                                          *otherRange.ranges)
298         .valid();
299   }
300 
301   /// Returns true if this range is active at `point` in the program.
302   bool overlaps(uint64_t point) const {
303     return ranges->lookup(point) == kValidLiveRange;
304   }
305 
306   /// Unions this live range with `otherRange`, aborts if the ranges overlap.
307   void unionWith(LiveRange const &otherRange) {
308     for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
309          ++it)
310       ranges->insert(it.start(), it.stop(), kValidLiveRange);
311     values.set_union(otherRange.values);
312   }
313 
314   /// Inserts an interval [start, end) for `value` into this range.
315   void insert(Value value, unsigned start, unsigned end) {
316     values.insert(value);
317     if (start != end)
318       ranges->insert(start, end, kValidLiveRange);
319   }
320 
321   bool empty() const { return ranges->empty(); }
322   unsigned start() const { return ranges->start(); }
323   unsigned end() const { return ranges->stop(); }
324   bool operator<(LiveRange const &other) const {
325     return start() < other.start();
326   }
327 
328   ArmSMETileType getTileType() const {
329     return *getSMETileType(cast<VectorType>(values[0].getType()));
330   }
331 
332   /// The values contained in this live range.
333   SetVector<Value> values;
334 
335   /// A set of (non-overlapping) intervals that mark where any value in `values`
336   /// is live.
337   std::unique_ptr<RangeSet> ranges;
338 
339   /// The tile ID (or none) assigned to this live range.
340   std::optional<unsigned> tileId;
341 };
342 
343 /// Number operations within a function to allow computing live ranges.
344 /// Operations are numbered consecutively wihin blocks, and the blocks are
345 /// topologically sorted (using forward edges). This function is only correct if
346 /// all ArmSME have been converted to CF (which is asserted).
347 DenseMap<Operation *, unsigned>
348 generateOperationNumbering(FunctionOpInterface function) {
349   unsigned index = 0;
350   SetVector<Block *> blocks =
351       getBlocksSortedByDominance(function.getFunctionBody());
352   DenseMap<Operation *, unsigned> operationToIndexMap;
353   for (Block *block : blocks) {
354     index++; // We want block args to have their own number.
355     for (Operation &op : block->getOperations()) {
356 #ifndef NDEBUG
357       op.walk([&](ArmSMETileOpInterface nestedOp) {
358         assert(&op == nestedOp.getOperation() &&
359                "ArmSME tile allocation does not support nested regions");
360       });
361 #endif
362       operationToIndexMap.try_emplace(&op, index++);
363     }
364   }
365   return operationToIndexMap;
366 }
367 
368 /// Gather live ranges for SME tiles from the MLIR liveness analysis.
369 DenseMap<Value, LiveRange>
370 gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
371                      LiveRange::Allocator &liveRangeAllocator,
372                      Liveness &liveness, FunctionOpInterface function) {
373   assert(!operationToIndexMap.empty() && "expected operation numbering");
374   DenseMap<Value, LiveRange> liveRanges;
375   /// Defines or updates a live range for an SME tile value. Live-ins may update
376   /// an existing live range (rather than define a new one). Note: If
377   /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
378   /// the block.
379   auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
380                                           LivenessBlockInfo const &livenessInfo,
381                                           bool liveAtBlockEntry = false) {
382     if (!isValidSMETileVectorType(value.getType()))
383       return;
384     // Find or create a live range for `value`.
385     auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
386     LiveRange &valueLiveRange = it->second;
387     auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
388     // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
389     unsigned startOpIdx =
390         operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
391     unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
392     valueLiveRange.insert(value, startOpIdx, endOpIdx);
393   };
394 
395   for (Block &block : function.getBlocks()) {
396     LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
397     // Handle block arguments:
398     for (Value argument : block.getArguments())
399       defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
400                                    /*liveAtBlockEntry=*/true);
401     // Handle live-ins:
402     for (Value liveIn : livenessInfo->in())
403       defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
404                                    /*liveAtBlockEntry=*/true);
405     // Handle new definitions:
406     for (Operation &op : block) {
407       for (Value result : op.getResults())
408         defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
409     }
410   }
411 
412   return liveRanges;
413 }
414 
415 /// Iterate over all predecessor tile values to a (tile) block argument.
416 static void forEachPredecessorTileValue(BlockArgument blockArg,
417                                         function_ref<void(Value)> callback) {
418   Block *block = blockArg.getOwner();
419   unsigned argNumber = blockArg.getArgNumber();
420   for (Block *pred : block->getPredecessors()) {
421     TypeSwitch<Operation *>(pred->getTerminator())
422         .Case<cf::BranchOp>([&](auto branch) {
423           Value predecessorOperand = branch.getDestOperands()[argNumber];
424           callback(predecessorOperand);
425         })
426         .Case<cf::CondBranchOp>([&](auto condBranch) {
427           if (condBranch.getFalseDest() == block) {
428             Value predecessorOperand =
429                 condBranch.getFalseDestOperands()[argNumber];
430             callback(predecessorOperand);
431           }
432           if (condBranch.getTrueDest() == block) {
433             Value predecessorOperand =
434                 condBranch.getTrueDestOperands()[argNumber];
435             callback(predecessorOperand);
436           }
437         });
438   }
439 }
440 
441 /// Coalesce live ranges where it would prevent unnecessary tile moves.
442 SmallVector<LiveRange *>
443 coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
444   DenseMap<Value, LiveRange *> liveRanges;
445   for (auto &[value, liveRange] : initialLiveRanges) {
446     liveRanges.insert({value, &liveRange});
447   }
448 
449   // Merge the live ranges of values `a` and `b` into one (if they do not
450   // overlap). After this, the values `a` and `b` will both point to the same
451   // live range (which will contain multiple values).
452   auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
453     LiveRange *aLiveRange = liveRanges.at(a);
454     LiveRange *bLiveRange = liveRanges.at(b);
455     if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
456       aLiveRange->unionWith(*bLiveRange);
457       for (Value value : bLiveRange->values)
458         liveRanges[value] = aLiveRange;
459     }
460   };
461 
462   // Merge the live ranges of new definitions with their tile operands.
463   auto unifyDefinitionsWithOperands = [&](Value value) {
464     auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
465     if (!armSMEOp)
466       return;
467     for (auto operand : armSMEOp->getOperands()) {
468       if (isValidSMETileVectorType(operand.getType()))
469         mergeValuesIfNonOverlapping(value, operand);
470     }
471   };
472 
473   // Merge the live ranges of block arguments with their predecessors.
474   auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
475     auto blockArg = dyn_cast<BlockArgument>(value);
476     if (!blockArg)
477       return;
478     forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
479       mergeValuesIfNonOverlapping(blockArg, predecessorTile);
480     });
481   };
482 
483   auto applyRule = [&](auto rule) {
484     llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
485   };
486 
487   // Unify as many live ranges as we can. This prevents unnecessary moves.
488   applyRule(unifyBlockArgumentsWithPredecessors);
489   applyRule(unifyDefinitionsWithOperands);
490 
491   // Remove duplicate live range entries.
492   SetVector<LiveRange *> uniqueLiveRanges;
493   for (auto [_, liveRange] : liveRanges) {
494     if (!liveRange->empty())
495       uniqueLiveRanges.insert(liveRange);
496   }
497 
498   // Sort the new live ranges by starting point (ready for tile allocation).
499   auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
500   std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
501             [](LiveRange *a, LiveRange *b) { return *a < *b; });
502   return std::move(coalescedLiveRanges);
503 }
504 
505 /// Choose a live range to spill (via some heuristics). This picks either a live
506 /// range from `overlappingRanges`, or the new live range `newRange`.
507 template <typename OverlappingRangesIterator>
508 LiveRange *
509 chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
510                            LiveRange *newRange) {
511   // Heuristic: Spill trivially copyable operations (usually free).
512   auto isTrivialSpill = [&](LiveRange &allocatedRange) {
513     return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
514                                     newRange->getTileType()) &&
515            allocatedRange.values.size() == 1 &&
516            isTriviallyCloneableTileOp(
517                allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
518   };
519   if (isTrivialSpill(*newRange))
520     return newRange;
521   auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
522   if (trivialSpill != overlappingRanges.end())
523     return &*trivialSpill;
524 
525   // Heuristic: Spill the range that ends last (with a compatible tile type).
526   auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
527     return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
528            a.end() < b.end();
529   };
530   LiveRange &latestEndingLiveRange =
531       *std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
532                         isSmallerTileTypeOrEndsEarlier);
533   if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
534     return &latestEndingLiveRange;
535   return newRange;
536 }
537 
538 /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
539 void allocateTilesToLiveRanges(
540     ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
541   TileAllocator tileAllocator;
542   // `activeRanges` = Live ranges that need to be in a tile at the
543   // `currentPoint` in the program.
544   SetVector<LiveRange *> activeRanges;
545   // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
546   // at the `currentPoint` in the program but could become active again later.
547   // An inactive section of a live range can be seen as a 'hole' in the live
548   // range, where it is possible to reuse the live range's tile ID _before_ it
549   // has ended. By identifying 'holes', the allocator can reuse tiles more
550   // often, which helps avoid costly tile spills.
551   SetVector<LiveRange *> inactiveRanges;
552   for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
553     auto currentPoint = nextRange->start();
554     // 1. Update the `activeRanges` at `currentPoint`.
555     activeRanges.remove_if([&](LiveRange *activeRange) {
556       // Check for live ranges that have expired.
557       if (activeRange->end() <= currentPoint) {
558         tileAllocator.releaseTileId(activeRange->getTileType(),
559                                     *activeRange->tileId);
560         return true;
561       }
562       // Check for live ranges that have become inactive.
563       if (!activeRange->overlaps(currentPoint)) {
564         tileAllocator.releaseTileId(activeRange->getTileType(),
565                                     *activeRange->tileId);
566         inactiveRanges.insert(activeRange);
567         return true;
568       }
569       return false;
570     });
571     // 2. Update the `inactiveRanges` at `currentPoint`.
572     inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
573       // Check for live ranges that have expired.
574       if (inactiveRange->end() <= currentPoint) {
575         return true;
576       }
577       // Check for live ranges that have become active.
578       if (inactiveRange->overlaps(currentPoint)) {
579         tileAllocator.acquireTileId(inactiveRange->getTileType(),
580                                     *inactiveRange->tileId);
581         activeRanges.insert(inactiveRange);
582         return true;
583       }
584       return false;
585     });
586 
587     // 3. Collect inactive live ranges that overlap with the new live range.
588     // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
589     // whereas this checks if there is an overlap at any future point too.
590     SmallVector<LiveRange *> overlappingInactiveRanges;
591     for (LiveRange *inactiveRange : inactiveRanges) {
592       if (inactiveRange->overlaps(*nextRange)) {
593         // We need to reserve the tile IDs of overlapping inactive ranges to
594         // prevent two (overlapping) live ranges from getting the same tile ID.
595         tileAllocator.acquireTileId(inactiveRange->getTileType(),
596                                     *inactiveRange->tileId);
597         overlappingInactiveRanges.push_back(inactiveRange);
598       }
599     }
600 
601     // 4. Allocate a tile ID to `nextRange`.
602     auto rangeTileType = nextRange->getTileType();
603     auto tileId = tileAllocator.allocateTileId(rangeTileType);
604     if (succeeded(tileId)) {
605       nextRange->tileId = *tileId;
606     } else {
607       // Create an iterator over all overlapping live ranges.
608       auto allOverlappingRanges = llvm::concat<LiveRange>(
609           llvm::make_pointee_range(activeRanges.getArrayRef()),
610           llvm::make_pointee_range(overlappingInactiveRanges));
611       // Choose an overlapping live range to spill.
612       LiveRange *rangeToSpill =
613           chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
614       if (rangeToSpill != nextRange) {
615         // Spill an (in)active live range (so release its tile ID first).
616         tileAllocator.releaseTileId(rangeToSpill->getTileType(),
617                                     *rangeToSpill->tileId);
618         // This will always succeed after a spill (of an active live range).
619         nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
620         // Remove the live range from the active/inactive sets.
621         if (!activeRanges.remove(rangeToSpill)) {
622           bool removed = inactiveRanges.remove(rangeToSpill);
623           assert(removed && "expected a range to be removed!");
624           (void)removed;
625         }
626       }
627       rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
628     }
629 
630     // 5. Insert the live range into the active ranges.
631     if (nextRange->tileId < kInMemoryTileIdBase)
632       activeRanges.insert(nextRange);
633 
634     // 6. Release tiles reserved for inactive live ranges (in step 3).
635     for (LiveRange *range : overlappingInactiveRanges) {
636       if (*range->tileId < kInMemoryTileIdBase)
637         tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
638     }
639   }
640 }
641 
642 /// Assigns a tile ID to an MLIR value.
643 void assignTileIdToValue(IRRewriter &rewriter, Value value,
644                          IntegerAttr tileIdAttr) {
645   if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
646     rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
647   for (Operation *user : value.getUsers()) {
648     if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
649       // Ensure ArmSME ops that don't produce a value still get a tile ID.
650       if (!hasTileResult(tileOp))
651         rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
652     }
653   }
654 }
655 
656 /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
657 LogicalResult assignTileIdsAndResolveTrivialConflicts(
658     IRRewriter &rewriter, FunctionOpInterface function,
659     ArrayRef<LiveRange *> allocatedLiveRanges) {
660   for (LiveRange const *liveRange : allocatedLiveRanges) {
661     auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
662     auto isAllocatedToSameTile = [&](Value value) {
663       if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
664           tileOp && tileOp.getTileId() == tileIdAttr)
665         return true;
666       return liveRange->values.contains(value);
667     };
668 
669     /// Eliminates copies where the operand has the same tile ID.
670     auto foldRedundantCopies = [&](Value value) -> LogicalResult {
671       auto copyOp = value.getDefiningOp<CopyTileOp>();
672       if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
673         return failure();
674       rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
675       return success();
676     };
677 
678     /// Validates each predecessor to a tile block argument has been assigned
679     /// the same tile ID.
680     auto validateBlockArguments = [&](Value value) {
681       auto blockArg = dyn_cast<BlockArgument>(value);
682       if (!blockArg) {
683         // Not a block argument (nothing to validate).
684         return success();
685       }
686       bool tileMismatch = false;
687       forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
688         if (tileMismatch)
689           return;
690         if (!isAllocatedToSameTile(predecessorTile)) {
691           blockArg.getOwner()->getParentOp()->emitOpError(
692               "block argument not allocated to the same SME virtial tile as "
693               "predecessors");
694           tileMismatch = true;
695         }
696       });
697       return success(/*isSuccess=*/!tileMismatch);
698     };
699 
700     /// Attempts to resolve (trivial) tile ID conflicts.
701     auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
702       auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
703       OpOperand *tileOperand = getTileOpOperand(tileOp);
704       if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
705         // Operand already allocated to the correct tile.
706         // No conflict to resolve.
707         return success();
708       }
709       auto operandTileOp =
710           tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
711       if (!isTriviallyCloneableTileOp(operandTileOp)) {
712         auto error =
713             tileOp.emitOpError("tile operand allocated to different SME "
714                                "virtial tile (move required)");
715         error.attachNote(tileOperand->get().getLoc())
716             << "tile operand is: " << tileOperand->get();
717         return error;
718       }
719       // Cloning prevents a move/spill (though may require recomputation).
720       rewriter.setInsertionPoint(tileOp);
721       auto clonedOp = operandTileOp.clone();
722       rewriter.modifyOpInPlace(clonedOp,
723                                [&] { clonedOp.setTileId(tileOp.getTileId()); });
724       rewriter.insert(clonedOp);
725       if (isa<CopyTileOp>(tileOp)) {
726         rewriter.replaceAllUsesWith(tileOp->getResult(0),
727                                     clonedOp->getResult(0));
728       } else {
729         rewriter.modifyOpInPlace(
730             tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
731       }
732       return success();
733     };
734 
735     for (Value value : liveRange->values) {
736       // 1. Assign the tile ID to the value.
737       assignTileIdToValue(rewriter, value, tileIdAttr);
738 
739       // 2. Attempt to eliminate redundant tile copies.
740       if (succeeded(foldRedundantCopies(value)))
741         continue;
742 
743       // 3. Validate tile block arguments.
744       if (failed(validateBlockArguments(value)))
745         return failure();
746 
747       // 4. Attempt to resolve (trivial) tile ID conflicts.
748       if (failed(resolveTrivialTileConflicts(value)))
749         return failure();
750     }
751   }
752   return success();
753 }
754 
755 /// Prints live ranges alongside operation names for debugging.
756 void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
757                     ArrayRef<LiveRange const *> liveRanges,
758                     FunctionOpInterface function) {
759   llvm::errs() << "SME Tile Liveness: @" << function.getName()
760                << "\nKey:\nS - Start\nE - End\n| - Live\n";
761   for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
762     llvm::errs() << "^bb" << blockIdx << ":\n";
763     for (Operation &op : block.getOperations()) {
764       unsigned operationIndex = operationToIndexMap.at(&op);
765       for (LiveRange const *range : liveRanges) {
766         char liveness = ' ';
767         for (auto it = range->ranges->begin(); it != range->ranges->end();
768              ++it) {
769           if (it.start() == operationIndex)
770             liveness = (liveness == 'E' ? '|' : 'S');
771           else if (it.stop() == operationIndex)
772             liveness = (liveness == 'S' ? '|' : 'E');
773           else if (operationIndex >= it.start() && operationIndex < it.stop())
774             liveness = '|';
775         }
776         llvm::errs() << liveness;
777       }
778       llvm::errs() << ' ' << op.getName() << '\n';
779     }
780   }
781   llvm::errs() << "==========\n";
782 }
783 
784 struct TestTileAllocationPass
785     : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
786   using TestTileAllocationBase::TestTileAllocationBase;
787   void runOnOperation() override {
788     FunctionOpInterface function = getOperation();
789     if (preprocessOnly) {
790       IRRewriter rewriter(function);
791       return preprocessForTileAllocation(rewriter, function);
792     }
793     if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
794       signalPassFailure();
795   }
796 };
797 } // namespace
798 
799 LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
800                                               bool dumpRanges) {
801   if (function.empty()) {
802     // TODO: Also return early if the function contains no ArmSME ops?
803     return success();
804   }
805 
806   LiveRange::Allocator liveRangeAllocator;
807   IRRewriter rewriter(function.getContext());
808 
809   // 1. Preprocess the IR for tile allocation.
810   preprocessForTileAllocation(rewriter, function);
811 
812   // 2. Gather live ranges for each ArmSME tile within the function.
813   Liveness liveness(function);
814   auto operationToIndexMap = generateOperationNumbering(function);
815   auto initialLiveRanges = gatherTileLiveRanges(
816       operationToIndexMap, liveRangeAllocator, liveness, function);
817   if (initialLiveRanges.empty())
818     return success();
819 
820   if (dumpRanges) {
821     // Wrangle initial live ranges into a form suitable for printing.
822     auto nonEmpty = llvm::make_filter_range(
823         llvm::make_second_range(initialLiveRanges),
824         [&](LiveRange const &liveRange) { return !liveRange.empty(); });
825     auto initialRanges = llvm::to_vector(llvm::map_range(
826         nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
827     std::sort(initialRanges.begin(), initialRanges.end(),
828               [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
829     llvm::errs() << "\n========== Initial Live Ranges:\n";
830     dumpLiveRanges(operationToIndexMap, initialRanges, function);
831   }
832 
833   // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
834   // for tile allocation. E.g. Unify the result of an operation with its
835   // operands.
836   auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
837 
838   if (dumpRanges) {
839     llvm::errs() << "\n========== Coalesced Live Ranges:\n";
840     dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
841   }
842 
843   // 4. Allocate tile IDs to live ranges.
844   allocateTilesToLiveRanges(coalescedLiveRanges);
845 
846   // 5. Assign the tile IDs back to the ArmSME operations.
847   if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
848                                                      coalescedLiveRanges))) {
849     return failure();
850   }
851 
852   // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
853   // users). This prevents the LLVM conversion needlessly inserting spills.
854   eraseTriviallyDeadTileOps(rewriter, function);
855   return success();
856 }
857