1// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | \ 2// RUN: FileCheck %s --check-prefix=AFTER-TILE-ALLOC 3// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" \ 4// RUN: -split-input-file -verify-diagnostics | \ 5// RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING 6 7/// Checks tile spill/reloads are inserted around in-memory tiles (i.e. tiles 8/// that were not assigned a physical SME tile). 9/// 10/// These spills are currently very naive and will spill/reload entire tiles 11/// around ArmSME ops. 12/// 13/// The general pattern is: 14/// 15/// During tile allocation if there's not a physical tile ID available an op 16/// will be assigned an in-memory tile ID (which is a tile ID >= 16). 17/// 18/// Example: 19/// 20/// arm_sme.zero : vector<[8]x[8]xi16> 21/// 22/// Becomes: 23/// 24/// arm_sme.zero { tile_id = 16 } : vector<[8]x[8]xi16> 25/// 26/// This works like normal until the final lowering to LLVM, where spills and 27/// reloads will be inserted around uses of in-memory tiles. 28/// 29/// So the above example becomes: 30/// 31/// // Placed at the top of the function: 32/// %tileAlloca = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16> 33/// 34/// Then around the op: 35/// 36/// // Swap contents of %tileAlloca and tile 0 37/// scf.for %sliceIdx ... 38/// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0} 39/// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0} 40/// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0] 41/// // Execute the op using tile 0 42/// arm_sme.intr.zero 43/// // Swap contents of %tileAlloca and tile 0 44/// scf.for %sliceIdx ... 45/// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0} 46/// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0} 47/// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0] 48/// 49 50// ----- 51 52/// Note: In this example loads into ZA are inserted before the zero instruction. 53/// These are obviously redundant, but there's no checks to avoid this. 54func.func @use_too_many_tiles() { 55 %0 = arm_sme.zero : vector<[4]x[4]xi32> 56 "test.prevent_zero_merge"() : () -> () 57 %1 = arm_sme.zero : vector<[4]x[4]xi32> 58 // expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}} 59 %2 = arm_sme.zero : vector<[8]x[8]xi16> 60 "test.some_use"(%0) : (vector<[4]x[4]xi32>) -> () 61 "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> () 62 "test.some_use"(%2) : (vector<[8]x[8]xi16>) -> () 63 return 64} 65// AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles 66// AFTER-TILE-ALLOC: arm_sme.zero 67// AFTER-TILE-ALLOC-SAME: tile_id = 0 68// AFTER-TILE-ALLOC: arm_sme.zero 69// AFTER-TILE-ALLOC-SAME: tile_id = 1 70// AFTER-TILE-ALLOC: arm_sme.zero 71// AFTER-TILE-ALLOC-SAME: tile_id = 16 72 73// AFTER-LLVM-LOWERING-LABEL: @use_too_many_tiles 74// AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index 75// AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index 76// AFTER-LLVM-LOWERING-DAG: %[[C8:.*]] = arith.constant 8 : index 77// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale 78// AFTER-LLVM-LOWERING-DAG: %[[SVL_H:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index 79 80/// 0. Create an in-memory-tile 81/// Note: 16 is an in-memory tile ID, that is a tile ID >= 16 82 83// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]]) 84// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xi16> 85// 86// AFTER-LLVM-LOWERING-NOT: scf.for 87 88/// 1. The following instruciton corresponds to %0 after tile allocation 89/// Note: 17 is the mask for the 32-bit tile 0. 90 91// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> 92// 93// AFTER-LLVM-LOWERING-NOT: scf.for 94 95/// 2. The following instruciton corresponds to %1 after tile allocation 96/// Note: 34 is the mask for the 32-bit tile 1. 97 98// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 34 : i32}> 99 100/// 3. swap(<in-memory-tile>, tile 0). 101/// This can be interpreted as spilling %0 (the 32-bit tile 0), so that 102/// %2 can be allocated a tile (16 bit tile 0). Note that this is 103/// swapping vector<[8]x[8]xi16> rather than vector<[4]x[4]xi32>. 104 105// AFTER-LLVM-LOWERING: scf.for 106// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { 107// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] 108// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] 109// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] 110// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> 111// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> 112// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] 113// AFTER-LLVM-LOWERING-NEXT: } 114 115/// 4. The following instruciton corresponds to %3 after tile allocation 116/// Note: 85 is the mask for the 16-bit tile 0. 117 118// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> 119 120/// 5. swap(<inMemoryTile>, tile 0) 121/// This can be interpreted as restoring %0. 122 123// AFTER-LLVM-LOWERING: scf.for 124// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { 125// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] 126// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] 127// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] 128// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> 129// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> 130// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] 131// AFTER-LLVM-LOWERING-NEXT: } 132 133// ----- 134 135/// Note: In this example an entire tile swap is inserted before/after the 136/// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a 137/// single tile slice (and can omit the initial load, like in the previous example). 138func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: memref<?x?xf32>) -> vector<[4]x[4]xf32> { 139 %c0 = arith.constant 0 : index 140 %tile = arm_sme.get_tile : vector<[4]x[4]xf32> 141 %mask = vector.constant_mask [4] : vector<[4]xi1> 142 // expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}} 143 %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> 144 "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> () 145 "test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> () 146} 147// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills 148// AFTER-TILE-ALLOC: arm_sme.load_tile_slice 149// AFTER-TILE-ALLOC-SAME: tile_id = 16 150 151// AFTER-LLVM-LOWERING-LABEL: @very_excessive_spills 152// AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index 153// AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index 154// AFTER-LLVM-LOWERING-DAG: %[[C4:.*]] = arith.constant 4 : index 155// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale 156// AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index 157// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]]) 158// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xf32> 159// 160 161/// 1. Swap %useAllTiles and %tile - note that this will only swap one 32-bit 162/// tile (vector<[4]x[4]xf32>) 163 164// AFTER-LLVM-LOWERING: scf.for 165// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] { 166// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] 167// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] 168// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] 169// Read ZA tile slice -> vector 170// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> 171/// Load vector from memory -> ZA tile 172// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> 173/// Store ZA tile slice in memory 174// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] 175// AFTER-LLVM-LOWERING-NEXT: } 176 177/// 2. Load into %tile 178// AFTER-LLVM-LOWERING: "arm_sme.intr.ld1w.horiz"{{.*}} <{tile_id = 0 : i32}> 179 180/// 3. Swap %useAllTiles and %tile - note that this will only swap one 32-bit 181/// tile (vector<[4]x[4]xf32>) 182 183// AFTER-LLVM-LOWERING: scf.for 184// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] { 185// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] 186// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] 187// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] 188/// Read ZA tile slice -> vector 189// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> 190/// Load vector from memory -> ZA tile 191// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> 192/// Store ZA tile slice in memory 193// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] 194// AFTER-LLVM-LOWERING-NEXT: } 195