xref: /llvm-project/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir (revision e37d6d2a74d76fdc95f5c5d625e282ce600aad55)
1// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | \
2// RUN: FileCheck %s  --check-prefix=AFTER-TILE-ALLOC
3// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" \
4// RUN:   -split-input-file -verify-diagnostics | \
5// RUN: FileCheck %s  --check-prefix=AFTER-LLVM-LOWERING
6
7/// Checks tile spill/reloads are inserted around in-memory tiles (i.e. tiles
8/// that were not assigned a physical SME tile).
9///
10/// These spills are currently very naive and will spill/reload entire tiles
11/// around ArmSME ops.
12///
13/// The general pattern is:
14///
15/// During tile allocation if there's not a physical tile ID available an op
16/// will be assigned an in-memory tile ID (which is a tile ID >= 16).
17///
18/// Example:
19///
20///   arm_sme.zero : vector<[8]x[8]xi16>
21///
22/// Becomes:
23///
24///   arm_sme.zero { tile_id = 16 } : vector<[8]x[8]xi16>
25///
26/// This works like normal until the final lowering to LLVM, where spills and
27/// reloads will be inserted around uses of in-memory tiles.
28///
29/// So the above example becomes:
30///
31/// // Placed at the top of the function:
32/// %tileAlloca = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16>
33///
34/// Then around the op:
35///
36/// // Swap contents of %tileAlloca and tile 0
37/// scf.for %sliceIdx ...
38///   %currentSlice = arm_sme.intr.read.horiz {tile_id = 0}
39///   arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0}
40///   vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0]
41/// // Execute the op using tile 0
42/// arm_sme.intr.zero
43/// // Swap contents of %tileAlloca and tile 0
44/// scf.for %sliceIdx ...
45///   %currentSlice = arm_sme.intr.read.horiz {tile_id = 0}
46///   arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0}
47///   vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0]
48///
49
50// -----
51
52/// Note: In this example loads into ZA are inserted before the zero instruction.
53/// These are obviously redundant, but there's no checks to avoid this.
54func.func @use_too_many_tiles() {
55  %0 = arm_sme.zero : vector<[4]x[4]xi32>
56  "test.prevent_zero_merge"() : () -> ()
57  %1 = arm_sme.zero : vector<[4]x[4]xi32>
58  // expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}}
59  %2 = arm_sme.zero : vector<[8]x[8]xi16>
60  "test.some_use"(%0) : (vector<[4]x[4]xi32>) -> ()
61  "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
62  "test.some_use"(%2) : (vector<[8]x[8]xi16>) -> ()
63  return
64}
65// AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles
66//      AFTER-TILE-ALLOC: arm_sme.zero
67// AFTER-TILE-ALLOC-SAME:   tile_id = 0
68//      AFTER-TILE-ALLOC: arm_sme.zero
69// AFTER-TILE-ALLOC-SAME:   tile_id = 1
70//      AFTER-TILE-ALLOC: arm_sme.zero
71// AFTER-TILE-ALLOC-SAME:   tile_id = 16
72
73// AFTER-LLVM-LOWERING-LABEL: @use_too_many_tiles
74//  AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index
75//  AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index
76//  AFTER-LLVM-LOWERING-DAG: %[[C8:.*]] = arith.constant 8 : index
77//  AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale
78//  AFTER-LLVM-LOWERING-DAG: %[[SVL_H:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
79
80///     0. Create an in-memory-tile
81///        Note: 16 is an in-memory tile ID, that is a tile ID >= 16
82
83//  AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]])
84// AFTER-LLVM-LOWERING-SAME:   {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xi16>
85//
86//  AFTER-LLVM-LOWERING-NOT: scf.for
87
88///     1. The following instruciton corresponds to %0 after tile allocation
89///        Note: 17 is the mask for the 32-bit tile 0.
90
91//      AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 17 : i32}>
92//
93//  AFTER-LLVM-LOWERING-NOT: scf.for
94
95///     2. The following instruciton corresponds to %1 after tile allocation
96///        Note: 34 is the mask for the 32-bit tile 1.
97
98//      AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 34 : i32}>
99
100///     3. swap(<in-memory-tile>, tile 0).
101///        This can be interpreted as spilling %0 (the 32-bit tile 0), so that
102///        %2 can be allocated a tile (16 bit tile 0). Note that this is
103///        swapping vector<[8]x[8]xi16> rather than vector<[4]x[4]xi32>.
104
105//      AFTER-LLVM-LOWERING: scf.for
106// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
107//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
108//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
109//      AFTER-LLVM-LOWERING:   %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
110//      AFTER-LLVM-LOWERING:   %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
111// AFTER-LLVM-LOWERING-NEXT:   "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
112// AFTER-LLVM-LOWERING-NEXT:   vector.store %[[SLICE]], %[[TILE_ALLOCA]]
113// AFTER-LLVM-LOWERING-NEXT: }
114
115///     4. The following instruciton corresponds to %3 after tile allocation
116///        Note: 85 is the mask for the 16-bit tile 0.
117
118//      AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}>
119
120///     5.  swap(<inMemoryTile>, tile 0)
121///         This can be interpreted as restoring %0.
122
123//      AFTER-LLVM-LOWERING: scf.for
124// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
125//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
126//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
127//      AFTER-LLVM-LOWERING:   %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
128//      AFTER-LLVM-LOWERING:   %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
129// AFTER-LLVM-LOWERING-NEXT:   "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
130// AFTER-LLVM-LOWERING-NEXT:   vector.store %[[SLICE]], %[[TILE_ALLOCA]]
131// AFTER-LLVM-LOWERING-NEXT: }
132
133// -----
134
135/// Note: In this example an entire tile swap is inserted before/after the
136/// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a
137/// single tile slice (and can omit the initial load, like in the previous example).
138func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: memref<?x?xf32>) -> vector<[4]x[4]xf32> {
139  %c0 = arith.constant 0 : index
140  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
141  %mask = vector.constant_mask [4] : vector<[4]xi1>
142  // expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}}
143  %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
144  "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> ()
145  "test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> ()
146}
147// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
148//      AFTER-TILE-ALLOC: arm_sme.load_tile_slice
149// AFTER-TILE-ALLOC-SAME:   tile_id = 16
150
151// AFTER-LLVM-LOWERING-LABEL: @very_excessive_spills
152//  AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index
153//  AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index
154//  AFTER-LLVM-LOWERING-DAG: %[[C4:.*]] = arith.constant 4 : index
155//  AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale
156//  AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
157//  AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]])
158// AFTER-LLVM-LOWERING-SAME:   {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xf32>
159//
160
161/// 1. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
162///    tile (vector<[4]x[4]xf32>)
163
164//      AFTER-LLVM-LOWERING: scf.for
165// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
166//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
167//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
168//      AFTER-LLVM-LOWERING:   %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
169// Read ZA tile slice -> vector
170//      AFTER-LLVM-LOWERING:   %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
171/// Load vector from memory -> ZA tile
172// AFTER-LLVM-LOWERING-NEXT:   "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
173/// Store ZA tile slice in memory
174// AFTER-LLVM-LOWERING-NEXT:   vector.store %[[SLICE]], %[[TILE_ALLOCA]]
175// AFTER-LLVM-LOWERING-NEXT: }
176
177/// 2. Load into %tile
178//      AFTER-LLVM-LOWERING: "arm_sme.intr.ld1w.horiz"{{.*}} <{tile_id = 0 : i32}>
179
180/// 3. Swap %useAllTiles and %tile - note that this will only swap one 32-bit
181///    tile (vector<[4]x[4]xf32>)
182
183//      AFTER-LLVM-LOWERING: scf.for
184// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
185//      AFTER-LLVM-LOWERING:   %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
186//      AFTER-LLVM-LOWERING:   %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
187//      AFTER-LLVM-LOWERING:   %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
188/// Read ZA tile slice -> vector
189//      AFTER-LLVM-LOWERING:   %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
190/// Load vector from memory -> ZA tile
191// AFTER-LLVM-LOWERING-NEXT:   "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
192/// Store ZA tile slice in memory
193// AFTER-LLVM-LOWERING-NEXT:   vector.store %[[SLICE]], %[[TILE_ALLOCA]]
194// AFTER-LLVM-LOWERING-NEXT: }
195