xref: /llvm-project/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (revision d5746d73cedcf7a593dc4b4f2ce2465e2d45750b)
1fb54fec7SCullen Rhodes //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2fb54fec7SCullen Rhodes //
3fb54fec7SCullen Rhodes // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fb54fec7SCullen Rhodes // See https://llvm.org/LICENSE.txt for license information.
5fb54fec7SCullen Rhodes // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fb54fec7SCullen Rhodes //
7fb54fec7SCullen Rhodes //===----------------------------------------------------------------------===//
8fb54fec7SCullen Rhodes //
9041baf2fSBenjamin Maxwell // This transform allocates SME tiles at the 'func.func' op level for ArmSME
10041baf2fSBenjamin Maxwell // operations. It roughly implements a linear scan register allocator, similar
11041baf2fSBenjamin Maxwell // to the one outlined in [1], but with simplifications and assumptions made for
12041baf2fSBenjamin Maxwell // our use case. Note that this is a greedy allocator (so it may not always find
13041baf2fSBenjamin Maxwell // the most optimal allocation of tiles).
14041baf2fSBenjamin Maxwell //
15041baf2fSBenjamin Maxwell // The allocator operates at the CF dialect level. It is the responsibility of
16041baf2fSBenjamin Maxwell // users to ensure the IR has been lowered to CF before invoking the tile
17041baf2fSBenjamin Maxwell // allocator.
18fb54fec7SCullen Rhodes //
19fb54fec7SCullen Rhodes // The 128-bit tiles overlap with other element tiles as follows (see section
20041baf2fSBenjamin Maxwell // B2.3.2 of SME spec [2]):
21fb54fec7SCullen Rhodes //
22fb54fec7SCullen Rhodes //   Tile    Overlaps
23fb54fec7SCullen Rhodes //   ---------------------------------------------------------------------------
24fb54fec7SCullen Rhodes //   ZA0.B   ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25fb54fec7SCullen Rhodes //           ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26fb54fec7SCullen Rhodes //   ZA0.H   ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27fb54fec7SCullen Rhodes //   ZA1.H   ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28fb54fec7SCullen Rhodes //   ZA0.S   ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29fb54fec7SCullen Rhodes //   ZA1.S   ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30fb54fec7SCullen Rhodes //   ZA2.S   ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31fb54fec7SCullen Rhodes //   ZA3.S   ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32fb54fec7SCullen Rhodes //   ZA0.D   ZA0.Q, ZA8.Q
33fb54fec7SCullen Rhodes //   ZA1.D   ZA1.Q, ZA9.Q
34fb54fec7SCullen Rhodes //   ZA2.D   ZA2.Q, ZA10.Q
35fb54fec7SCullen Rhodes //   ZA3.D   ZA3.Q, ZA11.Q
36fb54fec7SCullen Rhodes //   ZA4.D   ZA4.Q, ZA12.Q
37fb54fec7SCullen Rhodes //   ZA5.D   ZA5.Q, ZA13.Q
38fb54fec7SCullen Rhodes //   ZA6.D   ZA6.Q, ZA14.Q
39fb54fec7SCullen Rhodes //   ZA7.D   ZA7.Q, ZA15.Q
40fb54fec7SCullen Rhodes //
41041baf2fSBenjamin Maxwell // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42041baf2fSBenjamin Maxwell //      Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43041baf2fSBenjamin Maxwell //     https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44041baf2fSBenjamin Maxwell // [2] https://developer.arm.com/documentation/ddi0616/aa
45fb54fec7SCullen Rhodes //
46fb54fec7SCullen Rhodes //===----------------------------------------------------------------------===//
47fb54fec7SCullen Rhodes 
48041baf2fSBenjamin Maxwell #include "mlir/Analysis/Liveness.h"
49b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
50fb54fec7SCullen Rhodes #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
51fb54fec7SCullen Rhodes #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
52041baf2fSBenjamin Maxwell #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
53eaff02f2SBenjamin Maxwell #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
54fb54fec7SCullen Rhodes #include "mlir/Dialect/Func/IR/FuncOps.h"
55041baf2fSBenjamin Maxwell #include "mlir/Transforms/RegionUtils.h"
56041baf2fSBenjamin Maxwell #include "llvm/ADT/IntervalMap.h"
57eaff02f2SBenjamin Maxwell #include "llvm/ADT/TypeSwitch.h"
58041baf2fSBenjamin Maxwell #include <algorithm>
59fb54fec7SCullen Rhodes 
60041baf2fSBenjamin Maxwell namespace mlir::arm_sme {
61041baf2fSBenjamin Maxwell #define GEN_PASS_DEF_TESTTILEALLOCATION
62fb54fec7SCullen Rhodes #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
63041baf2fSBenjamin Maxwell } // namespace mlir::arm_sme
64fb54fec7SCullen Rhodes 
65fb54fec7SCullen Rhodes using namespace mlir;
66fb54fec7SCullen Rhodes using namespace mlir::arm_sme;
67fb54fec7SCullen Rhodes 
68fb54fec7SCullen Rhodes namespace {
69fb54fec7SCullen Rhodes 
70fb54fec7SCullen Rhodes enum class TileMask : unsigned {
71fb54fec7SCullen Rhodes   // clang-format off
72fb54fec7SCullen Rhodes   kZA0B  = 0xffff, // 1111 1111 1111 1111
73fb54fec7SCullen Rhodes 
74fb54fec7SCullen Rhodes   kZA0H  = 0xaaaa, // 1010 1010 1010 1010
75fb54fec7SCullen Rhodes   kZA1H  = 0x5555, // 0101 0101 0101 0101
76fb54fec7SCullen Rhodes 
77fb54fec7SCullen Rhodes   kZA0S  = 0x8888, // 1000 1000 1000 1000
78fb54fec7SCullen Rhodes   kZA1S  = 0x4444, // 0100 0100 0100 0100
79fb54fec7SCullen Rhodes   kZA2S  = 0x2222, // 0010 0010 0010 0010
80fb54fec7SCullen Rhodes   kZA3S  = 0x1111, // 0001 0001 0001 0001
81fb54fec7SCullen Rhodes 
82fb54fec7SCullen Rhodes   kZA0D  = 0x8080, // 1000 0000 1000 0000
83fb54fec7SCullen Rhodes   kZA1D  = 0x4040, // 0100 0000 0100 0000
84fb54fec7SCullen Rhodes   kZA2D  = 0x2020, // 0010 0000 0010 0000
85fb54fec7SCullen Rhodes   kZA3D  = 0x1010, // 0001 0000 0001 0000
86fb54fec7SCullen Rhodes   kZA4D  = 0x808,  // 0000 1000 0000 1000
87fb54fec7SCullen Rhodes   kZA5D  = 0x404,  // 0000 0100 0000 0100
88fb54fec7SCullen Rhodes   kZA6D  = 0x202,  // 0000 0010 0000 0010
89fb54fec7SCullen Rhodes   kZA7D  = 0x101,  // 0000 0001 0000 0001
90fb54fec7SCullen Rhodes 
91fb54fec7SCullen Rhodes   kZA0Q  = 0x8000, // 1000 0000 0000 0000
92fb54fec7SCullen Rhodes   kZA1Q  = 0x4000, // 0100 0000 0000 0000
93fb54fec7SCullen Rhodes   kZA2Q  = 0x2000, // 0010 0000 0000 0000
94fb54fec7SCullen Rhodes   kZA3Q  = 0x1000, // 0001 0000 0000 0000
95fb54fec7SCullen Rhodes   kZA4Q  = 0x800,  // 0000 1000 0000 0000
96fb54fec7SCullen Rhodes   kZA5Q  = 0x400,  // 0000 0100 0000 0000
97fb54fec7SCullen Rhodes   kZA6Q  = 0x200,  // 0000 0010 0000 0000
98fb54fec7SCullen Rhodes   kZA7Q  = 0x100,  // 0000 0001 0000 0000
99fb54fec7SCullen Rhodes   kZA8Q  = 0x80,   // 0000 0000 1000 0000
100fb54fec7SCullen Rhodes   kZA9Q  = 0x40,   // 0000 0000 0100 0000
101fb54fec7SCullen Rhodes   kZA10Q = 0x20,   // 0000 0000 0010 0000
102fb54fec7SCullen Rhodes   kZA11Q = 0x10,   // 0000 0000 0001 0000
103fb54fec7SCullen Rhodes   kZA12Q = 0x8,    // 0000 0000 0000 1000
104fb54fec7SCullen Rhodes   kZA13Q = 0x4,    // 0000 0000 0000 0100
105fb54fec7SCullen Rhodes   kZA14Q = 0x2,    // 0000 0000 0000 0010
106fb54fec7SCullen Rhodes   kZA15Q = 0x1,    // 0000 0000 0000 0001
107fb54fec7SCullen Rhodes 
108fb54fec7SCullen Rhodes   kNone = 0x0,     // 0000 0000 0000 0000
109fb54fec7SCullen Rhodes   // clang-format on
110fb54fec7SCullen Rhodes 
111fb54fec7SCullen Rhodes   LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
112fb54fec7SCullen Rhodes };
113fb54fec7SCullen Rhodes 
114fb54fec7SCullen Rhodes /// Returns the set of masks relevant for the given type.
115eaff02f2SBenjamin Maxwell static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
116eaff02f2SBenjamin Maxwell   static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
117eaff02f2SBenjamin Maxwell   static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
118eaff02f2SBenjamin Maxwell   static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
119eaff02f2SBenjamin Maxwell                                             TileMask::kZA2S, TileMask::kZA3S};
120eaff02f2SBenjamin Maxwell   static constexpr std::array ZA_D_MASKS = {
121fb54fec7SCullen Rhodes       TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
122fb54fec7SCullen Rhodes       TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
123eaff02f2SBenjamin Maxwell   static constexpr std::array ZA_Q_MASKS = {
124fb54fec7SCullen Rhodes       TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
125fb54fec7SCullen Rhodes       TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
126fb54fec7SCullen Rhodes       TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
127fb54fec7SCullen Rhodes       TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
128eaff02f2SBenjamin Maxwell   switch (type) {
129eaff02f2SBenjamin Maxwell   case ArmSMETileType::ZAB:
130fb54fec7SCullen Rhodes     return ZA_B_MASKS;
131eaff02f2SBenjamin Maxwell   case ArmSMETileType::ZAH:
132fb54fec7SCullen Rhodes     return ZA_H_MASKS;
133eaff02f2SBenjamin Maxwell   case ArmSMETileType::ZAS:
134fb54fec7SCullen Rhodes     return ZA_S_MASKS;
135eaff02f2SBenjamin Maxwell   case ArmSMETileType::ZAD:
136fb54fec7SCullen Rhodes     return ZA_D_MASKS;
137eaff02f2SBenjamin Maxwell   case ArmSMETileType::ZAQ:
138fb54fec7SCullen Rhodes     return ZA_Q_MASKS;
139fb54fec7SCullen Rhodes   }
140*d5746d73SFrank Schlimbach   llvm_unreachable("unknown type in getMasks");
141fb54fec7SCullen Rhodes }
142fb54fec7SCullen Rhodes 
143041baf2fSBenjamin Maxwell class TileAllocator {
144041baf2fSBenjamin Maxwell public:
145041baf2fSBenjamin Maxwell   /// Allocates and returns a tile ID. Fails if there are no tiles left.
146041baf2fSBenjamin Maxwell   FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
147eaff02f2SBenjamin Maxwell     auto masks = getMasks(tileType);
148eaff02f2SBenjamin Maxwell     for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
149fb54fec7SCullen Rhodes       if ((tilesInUse & tileMask) == TileMask::kNone) {
150fb54fec7SCullen Rhodes         tilesInUse |= tileMask;
151eaff02f2SBenjamin Maxwell         return tileId;
152fb54fec7SCullen Rhodes       }
153fb54fec7SCullen Rhodes     }
154eaff02f2SBenjamin Maxwell     return failure();
155fb54fec7SCullen Rhodes   }
156fb54fec7SCullen Rhodes 
157eed72d43SBenjamin Maxwell   /// Acquires a specific tile ID. Asserts the tile is initially free.
158eed72d43SBenjamin Maxwell   void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
159eed72d43SBenjamin Maxwell     TileMask tileMask = getMasks(tileType)[tileId];
160eed72d43SBenjamin Maxwell     assert((tilesInUse & tileMask) == TileMask::kNone &&
161eed72d43SBenjamin Maxwell            "cannot acquire allocated tile!");
162eed72d43SBenjamin Maxwell     tilesInUse |= tileMask;
163eed72d43SBenjamin Maxwell   }
164eed72d43SBenjamin Maxwell 
165041baf2fSBenjamin Maxwell   /// Releases a previously allocated tile ID.
166041baf2fSBenjamin Maxwell   void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
167041baf2fSBenjamin Maxwell     TileMask tileMask = getMasks(tileType)[tileId];
168eed72d43SBenjamin Maxwell     assert((tilesInUse & tileMask) == tileMask &&
169041baf2fSBenjamin Maxwell            "cannot release unallocated tile!");
170041baf2fSBenjamin Maxwell     tilesInUse ^= tileMask;
171eaff02f2SBenjamin Maxwell   }
172041baf2fSBenjamin Maxwell 
173041baf2fSBenjamin Maxwell   /// Allocates an in-memory tile ID.
174041baf2fSBenjamin Maxwell   unsigned allocateInMemoryTileId() {
175041baf2fSBenjamin Maxwell     // Note: We never release in-memory tile IDs. We could, which may allow
176041baf2fSBenjamin Maxwell     // reusing an allocation, but as we _never_ want to spill an SME tile this
177041baf2fSBenjamin Maxwell     // is not optimized.
178041baf2fSBenjamin Maxwell     return nextInMemoryTileId++;
179041baf2fSBenjamin Maxwell   }
180041baf2fSBenjamin Maxwell 
181041baf2fSBenjamin Maxwell private:
182041baf2fSBenjamin Maxwell   TileMask tilesInUse = TileMask::kNone;
183041baf2fSBenjamin Maxwell   unsigned nextInMemoryTileId = kInMemoryTileIdBase;
184eaff02f2SBenjamin Maxwell };
185041baf2fSBenjamin Maxwell 
186041baf2fSBenjamin Maxwell /// Add new intermediate blocks for the true and false destinations of
187041baf2fSBenjamin Maxwell /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
188041baf2fSBenjamin Maxwell /// overlaps due to copies at branches.
189041baf2fSBenjamin Maxwell ///
190041baf2fSBenjamin Maxwell ///  BEFORE:
191041baf2fSBenjamin Maxwell ///  ```mlir
192041baf2fSBenjamin Maxwell ///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
193041baf2fSBenjamin Maxwell ///  ```
194041baf2fSBenjamin Maxwell ///
195041baf2fSBenjamin Maxwell ///  AFTER:
196041baf2fSBenjamin Maxwell ///  ```mlir
197041baf2fSBenjamin Maxwell ///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
198041baf2fSBenjamin Maxwell ///  ^bb1_copy:
199041baf2fSBenjamin Maxwell ///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
200041baf2fSBenjamin Maxwell ///  ^bb2_copy:
201041baf2fSBenjamin Maxwell ///    cf.br ^bb2
202041baf2fSBenjamin Maxwell ///  ```
203041baf2fSBenjamin Maxwell void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
204041baf2fSBenjamin Maxwell   SmallVector<cf::CondBranchOp> worklist;
205041baf2fSBenjamin Maxwell   function.walk([&](cf::CondBranchOp condBranch) {
206041baf2fSBenjamin Maxwell     if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
207041baf2fSBenjamin Maxwell           return isValidSMETileVectorType(value.getType());
208041baf2fSBenjamin Maxwell         })) {
209041baf2fSBenjamin Maxwell       worklist.push_back(condBranch);
210041baf2fSBenjamin Maxwell     }
211041baf2fSBenjamin Maxwell   });
212041baf2fSBenjamin Maxwell 
213041baf2fSBenjamin Maxwell   auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
214041baf2fSBenjamin Maxwell     rewriter.setInsertionPointToEnd(source);
215041baf2fSBenjamin Maxwell     rewriter.create<cf::BranchOp>(loc, dest, args);
216041baf2fSBenjamin Maxwell   };
217041baf2fSBenjamin Maxwell 
218041baf2fSBenjamin Maxwell   for (auto condBranch : worklist) {
219041baf2fSBenjamin Maxwell     auto loc = condBranch.getLoc();
220041baf2fSBenjamin Maxwell     Block *block = condBranch->getBlock();
221041baf2fSBenjamin Maxwell     auto newTrueBranch = rewriter.splitBlock(block, block->end());
222041baf2fSBenjamin Maxwell     auto newFalseBranch = rewriter.splitBlock(block, block->end());
223041baf2fSBenjamin Maxwell     insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
224041baf2fSBenjamin Maxwell                condBranch.getTrueDestOperands());
225041baf2fSBenjamin Maxwell     insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
226041baf2fSBenjamin Maxwell                condBranch.getFalseDestOperands());
227041baf2fSBenjamin Maxwell     rewriter.modifyOpInPlace(condBranch, [&] {
228041baf2fSBenjamin Maxwell       condBranch.getFalseDestOperandsMutable().clear();
229041baf2fSBenjamin Maxwell       condBranch.getTrueDestOperandsMutable().clear();
230041baf2fSBenjamin Maxwell       condBranch.setSuccessor(newTrueBranch, 0);
231041baf2fSBenjamin Maxwell       condBranch.setSuccessor(newFalseBranch, 1);
232041baf2fSBenjamin Maxwell     });
233041baf2fSBenjamin Maxwell   }
234041baf2fSBenjamin Maxwell }
235041baf2fSBenjamin Maxwell 
236041baf2fSBenjamin Maxwell /// Inserts tile copies at `cf.br` operations.
237041baf2fSBenjamin Maxwell ///
238041baf2fSBenjamin Maxwell ///  BEFORE:
239041baf2fSBenjamin Maxwell ///  ```mlir
240041baf2fSBenjamin Maxwell ///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
241041baf2fSBenjamin Maxwell ///  ```
242041baf2fSBenjamin Maxwell ///
243041baf2fSBenjamin Maxwell ///  AFTER:
244041baf2fSBenjamin Maxwell ///  ```mlir
245041baf2fSBenjamin Maxwell ///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
246041baf2fSBenjamin Maxwell ///  cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
247041baf2fSBenjamin Maxwell ///  ```
248041baf2fSBenjamin Maxwell void insertCopiesAtBranches(IRRewriter &rewriter,
249041baf2fSBenjamin Maxwell                             FunctionOpInterface function) {
250041baf2fSBenjamin Maxwell   for (Block &block : function.getBlocks()) {
251041baf2fSBenjamin Maxwell     Operation *terminator = block.getTerminator();
252041baf2fSBenjamin Maxwell     if (!isa<cf::BranchOp>(terminator))
253eaff02f2SBenjamin Maxwell       continue;
254041baf2fSBenjamin Maxwell     rewriter.setInsertionPoint(terminator);
255041baf2fSBenjamin Maxwell     for (OpOperand &operand : terminator->getOpOperands()) {
256041baf2fSBenjamin Maxwell       if (isValidSMETileVectorType(operand.get().getType())) {
257041baf2fSBenjamin Maxwell         auto copy =
258041baf2fSBenjamin Maxwell             rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
259041baf2fSBenjamin Maxwell         rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
260041baf2fSBenjamin Maxwell       }
261041baf2fSBenjamin Maxwell     }
262041baf2fSBenjamin Maxwell   }
263041baf2fSBenjamin Maxwell }
264041baf2fSBenjamin Maxwell 
265041baf2fSBenjamin Maxwell /// Prepares the IR for tile allocation. It does this by first 'splitting'
266041baf2fSBenjamin Maxwell /// conditional branches (see `splitCondBranches`), then inserting tile copies
267041baf2fSBenjamin Maxwell /// at branch operations. The conditional branches are split to prevent the
268041baf2fSBenjamin Maxwell /// copies needed for them overlapping between the true and false paths of the
269041baf2fSBenjamin Maxwell /// branch (see `tile-allocation-copies.mlir` and
270041baf2fSBenjamin Maxwell /// `tile-allocation-liveness.mlir` for examples). The copies break up live
271041baf2fSBenjamin Maxwell /// ranges and ensure when moving out of SSA the semantics of the program are
272041baf2fSBenjamin Maxwell /// preserved.
273041baf2fSBenjamin Maxwell void preprocessForTileAllocation(IRRewriter &rewriter,
274041baf2fSBenjamin Maxwell                                  FunctionOpInterface function) {
275041baf2fSBenjamin Maxwell   splitCondBranches(rewriter, function);
276041baf2fSBenjamin Maxwell   insertCopiesAtBranches(rewriter, function);
277041baf2fSBenjamin Maxwell }
278041baf2fSBenjamin Maxwell 
279041baf2fSBenjamin Maxwell /// A live range for a (collection of) tile values. A live range is built up of
280041baf2fSBenjamin Maxwell /// non-overlapping intervals [start, end) which represent parts of the program
281041baf2fSBenjamin Maxwell /// where a value in the range needs to be live (i.e. in an SME virtual tile).
282041baf2fSBenjamin Maxwell /// Note that as the intervals are non-overlapping all values within a live
283041baf2fSBenjamin Maxwell /// range can be allocated to the same SME virtual tile.
284041baf2fSBenjamin Maxwell struct LiveRange {
285041baf2fSBenjamin Maxwell   using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
286041baf2fSBenjamin Maxwell                                      llvm::IntervalMapHalfOpenInfo<unsigned>>;
287041baf2fSBenjamin Maxwell   using Allocator = RangeSet::Allocator;
288041baf2fSBenjamin Maxwell   // Dummy value for the IntervalMap. Only the keys matter (the intervals).
289041baf2fSBenjamin Maxwell   static constexpr uint8_t kValidLiveRange = 0xff;
290041baf2fSBenjamin Maxwell 
291041baf2fSBenjamin Maxwell   LiveRange(Allocator &allocator)
292041baf2fSBenjamin Maxwell       : ranges(std::make_unique<RangeSet>(allocator)) {}
293041baf2fSBenjamin Maxwell 
294041baf2fSBenjamin Maxwell   /// Returns true if this range overlaps with `otherRange`.
295041baf2fSBenjamin Maxwell   bool overlaps(LiveRange const &otherRange) const {
296041baf2fSBenjamin Maxwell     return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
297041baf2fSBenjamin Maxwell                                                          *otherRange.ranges)
298041baf2fSBenjamin Maxwell         .valid();
299041baf2fSBenjamin Maxwell   }
300041baf2fSBenjamin Maxwell 
301eed72d43SBenjamin Maxwell   /// Returns true if this range is active at `point` in the program.
302eed72d43SBenjamin Maxwell   bool overlaps(uint64_t point) const {
303eed72d43SBenjamin Maxwell     return ranges->lookup(point) == kValidLiveRange;
304eed72d43SBenjamin Maxwell   }
305eed72d43SBenjamin Maxwell 
306041baf2fSBenjamin Maxwell   /// Unions this live range with `otherRange`, aborts if the ranges overlap.
307041baf2fSBenjamin Maxwell   void unionWith(LiveRange const &otherRange) {
308041baf2fSBenjamin Maxwell     for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
309041baf2fSBenjamin Maxwell          ++it)
310041baf2fSBenjamin Maxwell       ranges->insert(it.start(), it.stop(), kValidLiveRange);
311041baf2fSBenjamin Maxwell     values.set_union(otherRange.values);
312041baf2fSBenjamin Maxwell   }
313041baf2fSBenjamin Maxwell 
314041baf2fSBenjamin Maxwell   /// Inserts an interval [start, end) for `value` into this range.
315041baf2fSBenjamin Maxwell   void insert(Value value, unsigned start, unsigned end) {
316041baf2fSBenjamin Maxwell     values.insert(value);
317041baf2fSBenjamin Maxwell     if (start != end)
318041baf2fSBenjamin Maxwell       ranges->insert(start, end, kValidLiveRange);
319041baf2fSBenjamin Maxwell   }
320041baf2fSBenjamin Maxwell 
321041baf2fSBenjamin Maxwell   bool empty() const { return ranges->empty(); }
322041baf2fSBenjamin Maxwell   unsigned start() const { return ranges->start(); }
323041baf2fSBenjamin Maxwell   unsigned end() const { return ranges->stop(); }
324041baf2fSBenjamin Maxwell   bool operator<(LiveRange const &other) const {
325041baf2fSBenjamin Maxwell     return start() < other.start();
326041baf2fSBenjamin Maxwell   }
327041baf2fSBenjamin Maxwell 
328041baf2fSBenjamin Maxwell   ArmSMETileType getTileType() const {
329041baf2fSBenjamin Maxwell     return *getSMETileType(cast<VectorType>(values[0].getType()));
330041baf2fSBenjamin Maxwell   }
331041baf2fSBenjamin Maxwell 
332041baf2fSBenjamin Maxwell   /// The values contained in this live range.
333041baf2fSBenjamin Maxwell   SetVector<Value> values;
334041baf2fSBenjamin Maxwell 
335041baf2fSBenjamin Maxwell   /// A set of (non-overlapping) intervals that mark where any value in `values`
336041baf2fSBenjamin Maxwell   /// is live.
337041baf2fSBenjamin Maxwell   std::unique_ptr<RangeSet> ranges;
338041baf2fSBenjamin Maxwell 
339041baf2fSBenjamin Maxwell   /// The tile ID (or none) assigned to this live range.
340041baf2fSBenjamin Maxwell   std::optional<unsigned> tileId;
341041baf2fSBenjamin Maxwell };
342041baf2fSBenjamin Maxwell 
343041baf2fSBenjamin Maxwell /// Number operations within a function to allow computing live ranges.
344041baf2fSBenjamin Maxwell /// Operations are numbered consecutively wihin blocks, and the blocks are
345041baf2fSBenjamin Maxwell /// topologically sorted (using forward edges). This function is only correct if
346041baf2fSBenjamin Maxwell /// all ArmSME have been converted to CF (which is asserted).
347041baf2fSBenjamin Maxwell DenseMap<Operation *, unsigned>
348041baf2fSBenjamin Maxwell generateOperationNumbering(FunctionOpInterface function) {
349041baf2fSBenjamin Maxwell   unsigned index = 0;
350041baf2fSBenjamin Maxwell   SetVector<Block *> blocks =
3519e39a0c7SChristian Ulmann       getBlocksSortedByDominance(function.getFunctionBody());
352041baf2fSBenjamin Maxwell   DenseMap<Operation *, unsigned> operationToIndexMap;
353041baf2fSBenjamin Maxwell   for (Block *block : blocks) {
354041baf2fSBenjamin Maxwell     index++; // We want block args to have their own number.
355041baf2fSBenjamin Maxwell     for (Operation &op : block->getOperations()) {
356041baf2fSBenjamin Maxwell #ifndef NDEBUG
357041baf2fSBenjamin Maxwell       op.walk([&](ArmSMETileOpInterface nestedOp) {
358041baf2fSBenjamin Maxwell         assert(&op == nestedOp.getOperation() &&
359041baf2fSBenjamin Maxwell                "ArmSME tile allocation does not support nested regions");
360041baf2fSBenjamin Maxwell       });
361041baf2fSBenjamin Maxwell #endif
362041baf2fSBenjamin Maxwell       operationToIndexMap.try_emplace(&op, index++);
363041baf2fSBenjamin Maxwell     }
364041baf2fSBenjamin Maxwell   }
365041baf2fSBenjamin Maxwell   return operationToIndexMap;
366041baf2fSBenjamin Maxwell }
367041baf2fSBenjamin Maxwell 
368041baf2fSBenjamin Maxwell /// Gather live ranges for SME tiles from the MLIR liveness analysis.
369041baf2fSBenjamin Maxwell DenseMap<Value, LiveRange>
370041baf2fSBenjamin Maxwell gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
371041baf2fSBenjamin Maxwell                      LiveRange::Allocator &liveRangeAllocator,
372041baf2fSBenjamin Maxwell                      Liveness &liveness, FunctionOpInterface function) {
373041baf2fSBenjamin Maxwell   assert(!operationToIndexMap.empty() && "expected operation numbering");
374041baf2fSBenjamin Maxwell   DenseMap<Value, LiveRange> liveRanges;
375041baf2fSBenjamin Maxwell   /// Defines or updates a live range for an SME tile value. Live-ins may update
376041baf2fSBenjamin Maxwell   /// an existing live range (rather than define a new one). Note: If
377041baf2fSBenjamin Maxwell   /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
378041baf2fSBenjamin Maxwell   /// the block.
379041baf2fSBenjamin Maxwell   auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
380041baf2fSBenjamin Maxwell                                           LivenessBlockInfo const &livenessInfo,
381041baf2fSBenjamin Maxwell                                           bool liveAtBlockEntry = false) {
382041baf2fSBenjamin Maxwell     if (!isValidSMETileVectorType(value.getType()))
383041baf2fSBenjamin Maxwell       return;
384041baf2fSBenjamin Maxwell     // Find or create a live range for `value`.
385041baf2fSBenjamin Maxwell     auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
386041baf2fSBenjamin Maxwell     LiveRange &valueLiveRange = it->second;
387041baf2fSBenjamin Maxwell     auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
388041baf2fSBenjamin Maxwell     // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
389041baf2fSBenjamin Maxwell     unsigned startOpIdx =
390041baf2fSBenjamin Maxwell         operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
391041baf2fSBenjamin Maxwell     unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
392041baf2fSBenjamin Maxwell     valueLiveRange.insert(value, startOpIdx, endOpIdx);
393041baf2fSBenjamin Maxwell   };
394041baf2fSBenjamin Maxwell 
395041baf2fSBenjamin Maxwell   for (Block &block : function.getBlocks()) {
396041baf2fSBenjamin Maxwell     LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
397041baf2fSBenjamin Maxwell     // Handle block arguments:
398041baf2fSBenjamin Maxwell     for (Value argument : block.getArguments())
399041baf2fSBenjamin Maxwell       defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
400041baf2fSBenjamin Maxwell                                    /*liveAtBlockEntry=*/true);
401041baf2fSBenjamin Maxwell     // Handle live-ins:
402041baf2fSBenjamin Maxwell     for (Value liveIn : livenessInfo->in())
403041baf2fSBenjamin Maxwell       defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
404041baf2fSBenjamin Maxwell                                    /*liveAtBlockEntry=*/true);
405041baf2fSBenjamin Maxwell     // Handle new definitions:
406041baf2fSBenjamin Maxwell     for (Operation &op : block) {
407041baf2fSBenjamin Maxwell       for (Value result : op.getResults())
408041baf2fSBenjamin Maxwell         defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
409041baf2fSBenjamin Maxwell     }
410041baf2fSBenjamin Maxwell   }
411041baf2fSBenjamin Maxwell 
412041baf2fSBenjamin Maxwell   return liveRanges;
413041baf2fSBenjamin Maxwell }
414041baf2fSBenjamin Maxwell 
415041baf2fSBenjamin Maxwell /// Iterate over all predecessor tile values to a (tile) block argument.
416041baf2fSBenjamin Maxwell static void forEachPredecessorTileValue(BlockArgument blockArg,
417041baf2fSBenjamin Maxwell                                         function_ref<void(Value)> callback) {
418041baf2fSBenjamin Maxwell   Block *block = blockArg.getOwner();
419041baf2fSBenjamin Maxwell   unsigned argNumber = blockArg.getArgNumber();
420041baf2fSBenjamin Maxwell   for (Block *pred : block->getPredecessors()) {
421041baf2fSBenjamin Maxwell     TypeSwitch<Operation *>(pred->getTerminator())
422041baf2fSBenjamin Maxwell         .Case<cf::BranchOp>([&](auto branch) {
423041baf2fSBenjamin Maxwell           Value predecessorOperand = branch.getDestOperands()[argNumber];
424041baf2fSBenjamin Maxwell           callback(predecessorOperand);
425eaff02f2SBenjamin Maxwell         })
426041baf2fSBenjamin Maxwell         .Case<cf::CondBranchOp>([&](auto condBranch) {
427041baf2fSBenjamin Maxwell           if (condBranch.getFalseDest() == block) {
428041baf2fSBenjamin Maxwell             Value predecessorOperand =
429041baf2fSBenjamin Maxwell                 condBranch.getFalseDestOperands()[argNumber];
430041baf2fSBenjamin Maxwell             callback(predecessorOperand);
431041baf2fSBenjamin Maxwell           }
432041baf2fSBenjamin Maxwell           if (condBranch.getTrueDest() == block) {
433041baf2fSBenjamin Maxwell             Value predecessorOperand =
434041baf2fSBenjamin Maxwell                 condBranch.getTrueDestOperands()[argNumber];
435041baf2fSBenjamin Maxwell             callback(predecessorOperand);
436041baf2fSBenjamin Maxwell           }
437eaff02f2SBenjamin Maxwell         });
438eaff02f2SBenjamin Maxwell   }
439eaff02f2SBenjamin Maxwell }
440041baf2fSBenjamin Maxwell 
441041baf2fSBenjamin Maxwell /// Coalesce live ranges where it would prevent unnecessary tile moves.
442041baf2fSBenjamin Maxwell SmallVector<LiveRange *>
443041baf2fSBenjamin Maxwell coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
444041baf2fSBenjamin Maxwell   DenseMap<Value, LiveRange *> liveRanges;
445041baf2fSBenjamin Maxwell   for (auto &[value, liveRange] : initialLiveRanges) {
446041baf2fSBenjamin Maxwell     liveRanges.insert({value, &liveRange});
447041baf2fSBenjamin Maxwell   }
448041baf2fSBenjamin Maxwell 
449041baf2fSBenjamin Maxwell   // Merge the live ranges of values `a` and `b` into one (if they do not
450041baf2fSBenjamin Maxwell   // overlap). After this, the values `a` and `b` will both point to the same
451041baf2fSBenjamin Maxwell   // live range (which will contain multiple values).
452041baf2fSBenjamin Maxwell   auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
453041baf2fSBenjamin Maxwell     LiveRange *aLiveRange = liveRanges.at(a);
454041baf2fSBenjamin Maxwell     LiveRange *bLiveRange = liveRanges.at(b);
455041baf2fSBenjamin Maxwell     if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
456041baf2fSBenjamin Maxwell       aLiveRange->unionWith(*bLiveRange);
457041baf2fSBenjamin Maxwell       for (Value value : bLiveRange->values)
458041baf2fSBenjamin Maxwell         liveRanges[value] = aLiveRange;
459041baf2fSBenjamin Maxwell     }
460041baf2fSBenjamin Maxwell   };
461041baf2fSBenjamin Maxwell 
462041baf2fSBenjamin Maxwell   // Merge the live ranges of new definitions with their tile operands.
463041baf2fSBenjamin Maxwell   auto unifyDefinitionsWithOperands = [&](Value value) {
464041baf2fSBenjamin Maxwell     auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
465041baf2fSBenjamin Maxwell     if (!armSMEOp)
466041baf2fSBenjamin Maxwell       return;
467041baf2fSBenjamin Maxwell     for (auto operand : armSMEOp->getOperands()) {
468041baf2fSBenjamin Maxwell       if (isValidSMETileVectorType(operand.getType()))
469041baf2fSBenjamin Maxwell         mergeValuesIfNonOverlapping(value, operand);
470041baf2fSBenjamin Maxwell     }
471041baf2fSBenjamin Maxwell   };
472041baf2fSBenjamin Maxwell 
473041baf2fSBenjamin Maxwell   // Merge the live ranges of block arguments with their predecessors.
474041baf2fSBenjamin Maxwell   auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
475041baf2fSBenjamin Maxwell     auto blockArg = dyn_cast<BlockArgument>(value);
476041baf2fSBenjamin Maxwell     if (!blockArg)
477041baf2fSBenjamin Maxwell       return;
478041baf2fSBenjamin Maxwell     forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
479041baf2fSBenjamin Maxwell       mergeValuesIfNonOverlapping(blockArg, predecessorTile);
480041baf2fSBenjamin Maxwell     });
481041baf2fSBenjamin Maxwell   };
482041baf2fSBenjamin Maxwell 
483041baf2fSBenjamin Maxwell   auto applyRule = [&](auto rule) {
484041baf2fSBenjamin Maxwell     llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
485041baf2fSBenjamin Maxwell   };
486041baf2fSBenjamin Maxwell 
487041baf2fSBenjamin Maxwell   // Unify as many live ranges as we can. This prevents unnecessary moves.
488041baf2fSBenjamin Maxwell   applyRule(unifyBlockArgumentsWithPredecessors);
489041baf2fSBenjamin Maxwell   applyRule(unifyDefinitionsWithOperands);
490041baf2fSBenjamin Maxwell 
491041baf2fSBenjamin Maxwell   // Remove duplicate live range entries.
492041baf2fSBenjamin Maxwell   SetVector<LiveRange *> uniqueLiveRanges;
493041baf2fSBenjamin Maxwell   for (auto [_, liveRange] : liveRanges) {
494041baf2fSBenjamin Maxwell     if (!liveRange->empty())
495041baf2fSBenjamin Maxwell       uniqueLiveRanges.insert(liveRange);
496041baf2fSBenjamin Maxwell   }
497041baf2fSBenjamin Maxwell 
498041baf2fSBenjamin Maxwell   // Sort the new live ranges by starting point (ready for tile allocation).
499041baf2fSBenjamin Maxwell   auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
500041baf2fSBenjamin Maxwell   std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
501041baf2fSBenjamin Maxwell             [](LiveRange *a, LiveRange *b) { return *a < *b; });
502041baf2fSBenjamin Maxwell   return std::move(coalescedLiveRanges);
503041baf2fSBenjamin Maxwell }
504041baf2fSBenjamin Maxwell 
505eed72d43SBenjamin Maxwell /// Choose a live range to spill (via some heuristics). This picks either a live
506eed72d43SBenjamin Maxwell /// range from `overlappingRanges`, or the new live range `newRange`.
507eed72d43SBenjamin Maxwell template <typename OverlappingRangesIterator>
508eed72d43SBenjamin Maxwell LiveRange *
509eed72d43SBenjamin Maxwell chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
510041baf2fSBenjamin Maxwell                            LiveRange *newRange) {
511041baf2fSBenjamin Maxwell   // Heuristic: Spill trivially copyable operations (usually free).
512eed72d43SBenjamin Maxwell   auto isTrivialSpill = [&](LiveRange &allocatedRange) {
513eed72d43SBenjamin Maxwell     return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
514041baf2fSBenjamin Maxwell                                     newRange->getTileType()) &&
515eed72d43SBenjamin Maxwell            allocatedRange.values.size() == 1 &&
516041baf2fSBenjamin Maxwell            isTriviallyCloneableTileOp(
517eed72d43SBenjamin Maxwell                allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
518041baf2fSBenjamin Maxwell   };
519eed72d43SBenjamin Maxwell   if (isTrivialSpill(*newRange))
520041baf2fSBenjamin Maxwell     return newRange;
521eed72d43SBenjamin Maxwell   auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
522eed72d43SBenjamin Maxwell   if (trivialSpill != overlappingRanges.end())
523eed72d43SBenjamin Maxwell     return &*trivialSpill;
524041baf2fSBenjamin Maxwell 
525041baf2fSBenjamin Maxwell   // Heuristic: Spill the range that ends last (with a compatible tile type).
526eed72d43SBenjamin Maxwell   auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
527eed72d43SBenjamin Maxwell     return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
528eed72d43SBenjamin Maxwell            a.end() < b.end();
529041baf2fSBenjamin Maxwell   };
530eed72d43SBenjamin Maxwell   LiveRange &latestEndingLiveRange =
531eed72d43SBenjamin Maxwell       *std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
532eed72d43SBenjamin Maxwell                         isSmallerTileTypeOrEndsEarlier);
533eed72d43SBenjamin Maxwell   if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
534eed72d43SBenjamin Maxwell     return &latestEndingLiveRange;
535041baf2fSBenjamin Maxwell   return newRange;
536041baf2fSBenjamin Maxwell }
537041baf2fSBenjamin Maxwell 
538041baf2fSBenjamin Maxwell /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
539041baf2fSBenjamin Maxwell void allocateTilesToLiveRanges(
540041baf2fSBenjamin Maxwell     ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
541041baf2fSBenjamin Maxwell   TileAllocator tileAllocator;
542eed72d43SBenjamin Maxwell   // `activeRanges` = Live ranges that need to be in a tile at the
543eed72d43SBenjamin Maxwell   // `currentPoint` in the program.
544041baf2fSBenjamin Maxwell   SetVector<LiveRange *> activeRanges;
545eed72d43SBenjamin Maxwell   // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
546eed72d43SBenjamin Maxwell   // at the `currentPoint` in the program but could become active again later.
547eed72d43SBenjamin Maxwell   // An inactive section of a live range can be seen as a 'hole' in the live
548eed72d43SBenjamin Maxwell   // range, where it is possible to reuse the live range's tile ID _before_ it
549eed72d43SBenjamin Maxwell   // has ended. By identifying 'holes', the allocator can reuse tiles more
550eed72d43SBenjamin Maxwell   // often, which helps avoid costly tile spills.
551eed72d43SBenjamin Maxwell   SetVector<LiveRange *> inactiveRanges;
552041baf2fSBenjamin Maxwell   for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
553eed72d43SBenjamin Maxwell     auto currentPoint = nextRange->start();
554eed72d43SBenjamin Maxwell     // 1. Update the `activeRanges` at `currentPoint`.
555041baf2fSBenjamin Maxwell     activeRanges.remove_if([&](LiveRange *activeRange) {
556eed72d43SBenjamin Maxwell       // Check for live ranges that have expired.
557eed72d43SBenjamin Maxwell       if (activeRange->end() <= currentPoint) {
558041baf2fSBenjamin Maxwell         tileAllocator.releaseTileId(activeRange->getTileType(),
559041baf2fSBenjamin Maxwell                                     *activeRange->tileId);
560041baf2fSBenjamin Maxwell         return true;
561041baf2fSBenjamin Maxwell       }
562eed72d43SBenjamin Maxwell       // Check for live ranges that have become inactive.
563eed72d43SBenjamin Maxwell       if (!activeRange->overlaps(currentPoint)) {
564eed72d43SBenjamin Maxwell         tileAllocator.releaseTileId(activeRange->getTileType(),
565eed72d43SBenjamin Maxwell                                     *activeRange->tileId);
566eed72d43SBenjamin Maxwell         inactiveRanges.insert(activeRange);
567eed72d43SBenjamin Maxwell         return true;
568eed72d43SBenjamin Maxwell       }
569eed72d43SBenjamin Maxwell       return false;
570eed72d43SBenjamin Maxwell     });
571eed72d43SBenjamin Maxwell     // 2. Update the `inactiveRanges` at `currentPoint`.
572eed72d43SBenjamin Maxwell     inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
573eed72d43SBenjamin Maxwell       // Check for live ranges that have expired.
574eed72d43SBenjamin Maxwell       if (inactiveRange->end() <= currentPoint) {
575eed72d43SBenjamin Maxwell         return true;
576eed72d43SBenjamin Maxwell       }
577eed72d43SBenjamin Maxwell       // Check for live ranges that have become active.
578eed72d43SBenjamin Maxwell       if (inactiveRange->overlaps(currentPoint)) {
579eed72d43SBenjamin Maxwell         tileAllocator.acquireTileId(inactiveRange->getTileType(),
580eed72d43SBenjamin Maxwell                                     *inactiveRange->tileId);
581eed72d43SBenjamin Maxwell         activeRanges.insert(inactiveRange);
582eed72d43SBenjamin Maxwell         return true;
583eed72d43SBenjamin Maxwell       }
584041baf2fSBenjamin Maxwell       return false;
585041baf2fSBenjamin Maxwell     });
586041baf2fSBenjamin Maxwell 
587eed72d43SBenjamin Maxwell     // 3. Collect inactive live ranges that overlap with the new live range.
588eed72d43SBenjamin Maxwell     // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
589eed72d43SBenjamin Maxwell     // whereas this checks if there is an overlap at any future point too.
590eed72d43SBenjamin Maxwell     SmallVector<LiveRange *> overlappingInactiveRanges;
591eed72d43SBenjamin Maxwell     for (LiveRange *inactiveRange : inactiveRanges) {
592eed72d43SBenjamin Maxwell       if (inactiveRange->overlaps(*nextRange)) {
593eed72d43SBenjamin Maxwell         // We need to reserve the tile IDs of overlapping inactive ranges to
594eed72d43SBenjamin Maxwell         // prevent two (overlapping) live ranges from getting the same tile ID.
595eed72d43SBenjamin Maxwell         tileAllocator.acquireTileId(inactiveRange->getTileType(),
596eed72d43SBenjamin Maxwell                                     *inactiveRange->tileId);
597eed72d43SBenjamin Maxwell         overlappingInactiveRanges.push_back(inactiveRange);
598eed72d43SBenjamin Maxwell       }
599eed72d43SBenjamin Maxwell     }
600eed72d43SBenjamin Maxwell 
601eed72d43SBenjamin Maxwell     // 4. Allocate a tile ID to `nextRange`.
602041baf2fSBenjamin Maxwell     auto rangeTileType = nextRange->getTileType();
603041baf2fSBenjamin Maxwell     auto tileId = tileAllocator.allocateTileId(rangeTileType);
604041baf2fSBenjamin Maxwell     if (succeeded(tileId)) {
605041baf2fSBenjamin Maxwell       nextRange->tileId = *tileId;
606041baf2fSBenjamin Maxwell     } else {
607eed72d43SBenjamin Maxwell       // Create an iterator over all overlapping live ranges.
608eed72d43SBenjamin Maxwell       auto allOverlappingRanges = llvm::concat<LiveRange>(
609eed72d43SBenjamin Maxwell           llvm::make_pointee_range(activeRanges.getArrayRef()),
610eed72d43SBenjamin Maxwell           llvm::make_pointee_range(overlappingInactiveRanges));
611eed72d43SBenjamin Maxwell       // Choose an overlapping live range to spill.
612041baf2fSBenjamin Maxwell       LiveRange *rangeToSpill =
613eed72d43SBenjamin Maxwell           chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
614041baf2fSBenjamin Maxwell       if (rangeToSpill != nextRange) {
615eed72d43SBenjamin Maxwell         // Spill an (in)active live range (so release its tile ID first).
616041baf2fSBenjamin Maxwell         tileAllocator.releaseTileId(rangeToSpill->getTileType(),
617041baf2fSBenjamin Maxwell                                     *rangeToSpill->tileId);
618041baf2fSBenjamin Maxwell         // This will always succeed after a spill (of an active live range).
619041baf2fSBenjamin Maxwell         nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
620eed72d43SBenjamin Maxwell         // Remove the live range from the active/inactive sets.
621eed72d43SBenjamin Maxwell         if (!activeRanges.remove(rangeToSpill)) {
622eed72d43SBenjamin Maxwell           bool removed = inactiveRanges.remove(rangeToSpill);
623eed72d43SBenjamin Maxwell           assert(removed && "expected a range to be removed!");
62499faa038SKazu Hirata           (void)removed;
625eed72d43SBenjamin Maxwell         }
626041baf2fSBenjamin Maxwell       }
627041baf2fSBenjamin Maxwell       rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
628041baf2fSBenjamin Maxwell     }
629041baf2fSBenjamin Maxwell 
630eed72d43SBenjamin Maxwell     // 5. Insert the live range into the active ranges.
631041baf2fSBenjamin Maxwell     if (nextRange->tileId < kInMemoryTileIdBase)
632041baf2fSBenjamin Maxwell       activeRanges.insert(nextRange);
633eed72d43SBenjamin Maxwell 
634eed72d43SBenjamin Maxwell     // 6. Release tiles reserved for inactive live ranges (in step 3).
635eed72d43SBenjamin Maxwell     for (LiveRange *range : overlappingInactiveRanges) {
636eed72d43SBenjamin Maxwell       if (*range->tileId < kInMemoryTileIdBase)
637eed72d43SBenjamin Maxwell         tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
638eed72d43SBenjamin Maxwell     }
639041baf2fSBenjamin Maxwell   }
640041baf2fSBenjamin Maxwell }
641041baf2fSBenjamin Maxwell 
642041baf2fSBenjamin Maxwell /// Assigns a tile ID to an MLIR value.
643041baf2fSBenjamin Maxwell void assignTileIdToValue(IRRewriter &rewriter, Value value,
644041baf2fSBenjamin Maxwell                          IntegerAttr tileIdAttr) {
645041baf2fSBenjamin Maxwell   if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
646041baf2fSBenjamin Maxwell     rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
647041baf2fSBenjamin Maxwell   for (Operation *user : value.getUsers()) {
648041baf2fSBenjamin Maxwell     if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
649041baf2fSBenjamin Maxwell       // Ensure ArmSME ops that don't produce a value still get a tile ID.
650041baf2fSBenjamin Maxwell       if (!hasTileResult(tileOp))
651041baf2fSBenjamin Maxwell         rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
652041baf2fSBenjamin Maxwell     }
653041baf2fSBenjamin Maxwell   }
654041baf2fSBenjamin Maxwell }
655041baf2fSBenjamin Maxwell 
656041baf2fSBenjamin Maxwell /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
657041baf2fSBenjamin Maxwell LogicalResult assignTileIdsAndResolveTrivialConflicts(
658041baf2fSBenjamin Maxwell     IRRewriter &rewriter, FunctionOpInterface function,
659041baf2fSBenjamin Maxwell     ArrayRef<LiveRange *> allocatedLiveRanges) {
660041baf2fSBenjamin Maxwell   for (LiveRange const *liveRange : allocatedLiveRanges) {
661041baf2fSBenjamin Maxwell     auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
662041baf2fSBenjamin Maxwell     auto isAllocatedToSameTile = [&](Value value) {
663041baf2fSBenjamin Maxwell       if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
664041baf2fSBenjamin Maxwell           tileOp && tileOp.getTileId() == tileIdAttr)
665041baf2fSBenjamin Maxwell         return true;
666041baf2fSBenjamin Maxwell       return liveRange->values.contains(value);
667041baf2fSBenjamin Maxwell     };
668041baf2fSBenjamin Maxwell 
669041baf2fSBenjamin Maxwell     /// Eliminates copies where the operand has the same tile ID.
670041baf2fSBenjamin Maxwell     auto foldRedundantCopies = [&](Value value) -> LogicalResult {
671041baf2fSBenjamin Maxwell       auto copyOp = value.getDefiningOp<CopyTileOp>();
672041baf2fSBenjamin Maxwell       if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
673fb54fec7SCullen Rhodes         return failure();
674041baf2fSBenjamin Maxwell       rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
675041baf2fSBenjamin Maxwell       return success();
6765417a5feSBenjamin Maxwell     };
6775417a5feSBenjamin Maxwell 
678041baf2fSBenjamin Maxwell     /// Validates each predecessor to a tile block argument has been assigned
679041baf2fSBenjamin Maxwell     /// the same tile ID.
680041baf2fSBenjamin Maxwell     auto validateBlockArguments = [&](Value value) {
681041baf2fSBenjamin Maxwell       auto blockArg = dyn_cast<BlockArgument>(value);
682041baf2fSBenjamin Maxwell       if (!blockArg) {
683041baf2fSBenjamin Maxwell         // Not a block argument (nothing to validate).
684fb54fec7SCullen Rhodes         return success();
685fb54fec7SCullen Rhodes       }
686041baf2fSBenjamin Maxwell       bool tileMismatch = false;
687041baf2fSBenjamin Maxwell       forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
688041baf2fSBenjamin Maxwell         if (tileMismatch)
689041baf2fSBenjamin Maxwell           return;
690041baf2fSBenjamin Maxwell         if (!isAllocatedToSameTile(predecessorTile)) {
691041baf2fSBenjamin Maxwell           blockArg.getOwner()->getParentOp()->emitOpError(
692041baf2fSBenjamin Maxwell               "block argument not allocated to the same SME virtial tile as "
693041baf2fSBenjamin Maxwell               "predecessors");
694041baf2fSBenjamin Maxwell           tileMismatch = true;
695041baf2fSBenjamin Maxwell         }
696041baf2fSBenjamin Maxwell       });
697041baf2fSBenjamin Maxwell       return success(/*isSuccess=*/!tileMismatch);
698fb54fec7SCullen Rhodes     };
699fb54fec7SCullen Rhodes 
700041baf2fSBenjamin Maxwell     /// Attempts to resolve (trivial) tile ID conflicts.
701041baf2fSBenjamin Maxwell     auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
702041baf2fSBenjamin Maxwell       auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
703041baf2fSBenjamin Maxwell       OpOperand *tileOperand = getTileOpOperand(tileOp);
704041baf2fSBenjamin Maxwell       if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
705041baf2fSBenjamin Maxwell         // Operand already allocated to the correct tile.
706041baf2fSBenjamin Maxwell         // No conflict to resolve.
707041baf2fSBenjamin Maxwell         return success();
708fb54fec7SCullen Rhodes       }
709041baf2fSBenjamin Maxwell       auto operandTileOp =
710041baf2fSBenjamin Maxwell           tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
711041baf2fSBenjamin Maxwell       if (!isTriviallyCloneableTileOp(operandTileOp)) {
712041baf2fSBenjamin Maxwell         auto error =
713041baf2fSBenjamin Maxwell             tileOp.emitOpError("tile operand allocated to different SME "
714041baf2fSBenjamin Maxwell                                "virtial tile (move required)");
715041baf2fSBenjamin Maxwell         error.attachNote(tileOperand->get().getLoc())
716041baf2fSBenjamin Maxwell             << "tile operand is: " << tileOperand->get();
717041baf2fSBenjamin Maxwell         return error;
718041baf2fSBenjamin Maxwell       }
719041baf2fSBenjamin Maxwell       // Cloning prevents a move/spill (though may require recomputation).
720041baf2fSBenjamin Maxwell       rewriter.setInsertionPoint(tileOp);
721041baf2fSBenjamin Maxwell       auto clonedOp = operandTileOp.clone();
722041baf2fSBenjamin Maxwell       rewriter.modifyOpInPlace(clonedOp,
723041baf2fSBenjamin Maxwell                                [&] { clonedOp.setTileId(tileOp.getTileId()); });
724041baf2fSBenjamin Maxwell       rewriter.insert(clonedOp);
725041baf2fSBenjamin Maxwell       if (isa<CopyTileOp>(tileOp)) {
726041baf2fSBenjamin Maxwell         rewriter.replaceAllUsesWith(tileOp->getResult(0),
727041baf2fSBenjamin Maxwell                                     clonedOp->getResult(0));
728041baf2fSBenjamin Maxwell       } else {
729041baf2fSBenjamin Maxwell         rewriter.modifyOpInPlace(
730041baf2fSBenjamin Maxwell             tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
731041baf2fSBenjamin Maxwell       }
732041baf2fSBenjamin Maxwell       return success();
733041baf2fSBenjamin Maxwell     };
734041baf2fSBenjamin Maxwell 
735041baf2fSBenjamin Maxwell     for (Value value : liveRange->values) {
736041baf2fSBenjamin Maxwell       // 1. Assign the tile ID to the value.
737041baf2fSBenjamin Maxwell       assignTileIdToValue(rewriter, value, tileIdAttr);
738041baf2fSBenjamin Maxwell 
739041baf2fSBenjamin Maxwell       // 2. Attempt to eliminate redundant tile copies.
740041baf2fSBenjamin Maxwell       if (succeeded(foldRedundantCopies(value)))
741041baf2fSBenjamin Maxwell         continue;
742041baf2fSBenjamin Maxwell 
743041baf2fSBenjamin Maxwell       // 3. Validate tile block arguments.
744041baf2fSBenjamin Maxwell       if (failed(validateBlockArguments(value)))
745041baf2fSBenjamin Maxwell         return failure();
746041baf2fSBenjamin Maxwell 
747041baf2fSBenjamin Maxwell       // 4. Attempt to resolve (trivial) tile ID conflicts.
748041baf2fSBenjamin Maxwell       if (failed(resolveTrivialTileConflicts(value)))
749041baf2fSBenjamin Maxwell         return failure();
750041baf2fSBenjamin Maxwell     }
751041baf2fSBenjamin Maxwell   }
752041baf2fSBenjamin Maxwell   return success();
753041baf2fSBenjamin Maxwell }
754041baf2fSBenjamin Maxwell 
755041baf2fSBenjamin Maxwell /// Prints live ranges alongside operation names for debugging.
756041baf2fSBenjamin Maxwell void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
757041baf2fSBenjamin Maxwell                     ArrayRef<LiveRange const *> liveRanges,
758041baf2fSBenjamin Maxwell                     FunctionOpInterface function) {
759041baf2fSBenjamin Maxwell   llvm::errs() << "SME Tile Liveness: @" << function.getName()
760041baf2fSBenjamin Maxwell                << "\nKey:\nS - Start\nE - End\n| - Live\n";
761041baf2fSBenjamin Maxwell   for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
762041baf2fSBenjamin Maxwell     llvm::errs() << "^bb" << blockIdx << ":\n";
763041baf2fSBenjamin Maxwell     for (Operation &op : block.getOperations()) {
764041baf2fSBenjamin Maxwell       unsigned operationIndex = operationToIndexMap.at(&op);
765041baf2fSBenjamin Maxwell       for (LiveRange const *range : liveRanges) {
766041baf2fSBenjamin Maxwell         char liveness = ' ';
767041baf2fSBenjamin Maxwell         for (auto it = range->ranges->begin(); it != range->ranges->end();
768041baf2fSBenjamin Maxwell              ++it) {
769041baf2fSBenjamin Maxwell           if (it.start() == operationIndex)
770041baf2fSBenjamin Maxwell             liveness = (liveness == 'E' ? '|' : 'S');
771041baf2fSBenjamin Maxwell           else if (it.stop() == operationIndex)
772041baf2fSBenjamin Maxwell             liveness = (liveness == 'S' ? '|' : 'E');
773041baf2fSBenjamin Maxwell           else if (operationIndex >= it.start() && operationIndex < it.stop())
774041baf2fSBenjamin Maxwell             liveness = '|';
775041baf2fSBenjamin Maxwell         }
776041baf2fSBenjamin Maxwell         llvm::errs() << liveness;
777041baf2fSBenjamin Maxwell       }
778041baf2fSBenjamin Maxwell       llvm::errs() << ' ' << op.getName() << '\n';
779041baf2fSBenjamin Maxwell     }
780041baf2fSBenjamin Maxwell   }
781041baf2fSBenjamin Maxwell   llvm::errs() << "==========\n";
782041baf2fSBenjamin Maxwell }
783041baf2fSBenjamin Maxwell 
784041baf2fSBenjamin Maxwell struct TestTileAllocationPass
785041baf2fSBenjamin Maxwell     : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
786041baf2fSBenjamin Maxwell   using TestTileAllocationBase::TestTileAllocationBase;
787041baf2fSBenjamin Maxwell   void runOnOperation() override {
788041baf2fSBenjamin Maxwell     FunctionOpInterface function = getOperation();
789041baf2fSBenjamin Maxwell     if (preprocessOnly) {
790041baf2fSBenjamin Maxwell       IRRewriter rewriter(function);
791041baf2fSBenjamin Maxwell       return preprocessForTileAllocation(rewriter, function);
792041baf2fSBenjamin Maxwell     }
793041baf2fSBenjamin Maxwell     if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
794041baf2fSBenjamin Maxwell       signalPassFailure();
795eaff02f2SBenjamin Maxwell   }
796fb54fec7SCullen Rhodes };
797fb54fec7SCullen Rhodes } // namespace
798fb54fec7SCullen Rhodes 
799041baf2fSBenjamin Maxwell LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
800041baf2fSBenjamin Maxwell                                               bool dumpRanges) {
801041baf2fSBenjamin Maxwell   if (function.empty()) {
802041baf2fSBenjamin Maxwell     // TODO: Also return early if the function contains no ArmSME ops?
803041baf2fSBenjamin Maxwell     return success();
804041baf2fSBenjamin Maxwell   }
805041baf2fSBenjamin Maxwell 
806041baf2fSBenjamin Maxwell   LiveRange::Allocator liveRangeAllocator;
807041baf2fSBenjamin Maxwell   IRRewriter rewriter(function.getContext());
808041baf2fSBenjamin Maxwell 
809041baf2fSBenjamin Maxwell   // 1. Preprocess the IR for tile allocation.
810041baf2fSBenjamin Maxwell   preprocessForTileAllocation(rewriter, function);
811041baf2fSBenjamin Maxwell 
812041baf2fSBenjamin Maxwell   // 2. Gather live ranges for each ArmSME tile within the function.
813041baf2fSBenjamin Maxwell   Liveness liveness(function);
814041baf2fSBenjamin Maxwell   auto operationToIndexMap = generateOperationNumbering(function);
815041baf2fSBenjamin Maxwell   auto initialLiveRanges = gatherTileLiveRanges(
816041baf2fSBenjamin Maxwell       operationToIndexMap, liveRangeAllocator, liveness, function);
817041baf2fSBenjamin Maxwell   if (initialLiveRanges.empty())
818041baf2fSBenjamin Maxwell     return success();
819041baf2fSBenjamin Maxwell 
820041baf2fSBenjamin Maxwell   if (dumpRanges) {
821041baf2fSBenjamin Maxwell     // Wrangle initial live ranges into a form suitable for printing.
822041baf2fSBenjamin Maxwell     auto nonEmpty = llvm::make_filter_range(
823041baf2fSBenjamin Maxwell         llvm::make_second_range(initialLiveRanges),
824041baf2fSBenjamin Maxwell         [&](LiveRange const &liveRange) { return !liveRange.empty(); });
825041baf2fSBenjamin Maxwell     auto initialRanges = llvm::to_vector(llvm::map_range(
826041baf2fSBenjamin Maxwell         nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
827041baf2fSBenjamin Maxwell     std::sort(initialRanges.begin(), initialRanges.end(),
828041baf2fSBenjamin Maxwell               [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
829041baf2fSBenjamin Maxwell     llvm::errs() << "\n========== Initial Live Ranges:\n";
830041baf2fSBenjamin Maxwell     dumpLiveRanges(operationToIndexMap, initialRanges, function);
831041baf2fSBenjamin Maxwell   }
832041baf2fSBenjamin Maxwell 
833041baf2fSBenjamin Maxwell   // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
834041baf2fSBenjamin Maxwell   // for tile allocation. E.g. Unify the result of an operation with its
835041baf2fSBenjamin Maxwell   // operands.
836041baf2fSBenjamin Maxwell   auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
837041baf2fSBenjamin Maxwell 
838041baf2fSBenjamin Maxwell   if (dumpRanges) {
839041baf2fSBenjamin Maxwell     llvm::errs() << "\n========== Coalesced Live Ranges:\n";
840041baf2fSBenjamin Maxwell     dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
841041baf2fSBenjamin Maxwell   }
842041baf2fSBenjamin Maxwell 
843041baf2fSBenjamin Maxwell   // 4. Allocate tile IDs to live ranges.
844041baf2fSBenjamin Maxwell   allocateTilesToLiveRanges(coalescedLiveRanges);
845041baf2fSBenjamin Maxwell 
846041baf2fSBenjamin Maxwell   // 5. Assign the tile IDs back to the ArmSME operations.
847041baf2fSBenjamin Maxwell   if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
848041baf2fSBenjamin Maxwell                                                      coalescedLiveRanges))) {
849041baf2fSBenjamin Maxwell     return failure();
850041baf2fSBenjamin Maxwell   }
851041baf2fSBenjamin Maxwell 
852041baf2fSBenjamin Maxwell   // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
853041baf2fSBenjamin Maxwell   // users). This prevents the LLVM conversion needlessly inserting spills.
854041baf2fSBenjamin Maxwell   eraseTriviallyDeadTileOps(rewriter, function);
855041baf2fSBenjamin Maxwell   return success();
856fb54fec7SCullen Rhodes }
857