1// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s 2 3/// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass. 4 5// ----- 6 7// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1( 8// CHECK-SAME: %[[MASK:.*]]: vector<[1]xi1>) 9func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> { 10 // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>> 11 %alloca = memref.alloca() : memref<vector<[1]xi1>> 12 // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1> 13 // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>> 14 memref.store %mask, %alloca[] : memref<vector<[1]xi1>> 15 // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>> 16 // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1> 17 %reload = memref.load %alloca[] : memref<vector<[1]xi1>> 18 // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1> 19 return %reload : vector<[1]xi1> 20} 21 22// ----- 23 24// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1( 25// CHECK-SAME: %[[MASK:.*]]: vector<[2]xi1>) 26func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> { 27 // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>> 28 %alloca = memref.alloca() : memref<vector<[2]xi1>> 29 // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1> 30 // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>> 31 memref.store %mask, %alloca[] : memref<vector<[2]xi1>> 32 // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>> 33 // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1> 34 %reload = memref.load %alloca[] : memref<vector<[2]xi1>> 35 // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1> 36 return %reload : vector<[2]xi1> 37} 38 39// ----- 40 41// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1( 42// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>) 43func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> { 44 // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>> 45 %alloca = memref.alloca() : memref<vector<[4]xi1>> 46 // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1> 47 // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>> 48 memref.store %mask, %alloca[] : memref<vector<[4]xi1>> 49 // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>> 50 // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1> 51 %reload = memref.load %alloca[] : memref<vector<[4]xi1>> 52 // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1> 53 return %reload : vector<[4]xi1> 54} 55 56// ----- 57 58// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1( 59// CHECK-SAME: %[[MASK:.*]]: vector<[8]xi1>) 60func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> { 61 // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>> 62 %alloca = memref.alloca() : memref<vector<[8]xi1>> 63 // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1> 64 // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>> 65 memref.store %mask, %alloca[] : memref<vector<[8]xi1>> 66 // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>> 67 // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1> 68 %reload = memref.load %alloca[] : memref<vector<[8]xi1>> 69 // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1> 70 return %reload : vector<[8]xi1> 71} 72 73// ----- 74 75// CHECK-LABEL: @store_and_reload_sve_predicate_nxv16i1( 76// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) 77func.func @store_and_reload_sve_predicate_nxv16i1(%mask: vector<[16]xi1>) -> vector<[16]xi1> { 78 // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>> 79 %alloca = memref.alloca() : memref<vector<[16]xi1>> 80 // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref<vector<[16]xi1>> 81 memref.store %mask, %alloca[] : memref<vector<[16]xi1>> 82 // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>> 83 %reload = memref.load %alloca[] : memref<vector<[16]xi1>> 84 // CHECK-NEXT: return %[[RELOAD]] : vector<[16]xi1> 85 return %reload : vector<[16]xi1> 86} 87 88// ----- 89 90/// This is not a valid SVE mask type, so is ignored by the 91// `-arm-sve-legalize-vector-storage` pass. 92 93// CHECK-LABEL: @store_and_reload_unsupported_type( 94// CHECK-SAME: %[[MASK:.*]]: vector<[7]xi1>) 95func.func @store_and_reload_unsupported_type(%mask: vector<[7]xi1>) -> vector<[7]xi1> { 96 // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[7]xi1>> 97 %alloca = memref.alloca() : memref<vector<[7]xi1>> 98 // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref<vector<[7]xi1>> 99 memref.store %mask, %alloca[] : memref<vector<[7]xi1>> 100 // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[7]xi1>> 101 %reload = memref.load %alloca[] : memref<vector<[7]xi1>> 102 // CHECK-NEXT: return %[[RELOAD]] : vector<[7]xi1> 103 return %reload : vector<[7]xi1> 104} 105 106// ----- 107 108// CHECK-LABEL: @store_2d_mask_and_reload_slice( 109// CHECK-SAME: %[[MASK:.*]]: vector<3x[8]xi1>) 110func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> { 111 // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index 112 %c0 = arith.constant 0 : index 113 // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<3x[16]xi1>> 114 %alloca = memref.alloca() : memref<vector<3x[8]xi1>> 115 // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1> 116 // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>> 117 memref.store %mask, %alloca[] : memref<vector<3x[8]xi1>> 118 // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>> 119 %unpack = vector.type_cast %alloca : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>> 120 // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>> 121 // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1> 122 %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>> 123 // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1> 124 return %slice : vector<[8]xi1> 125} 126 127// ----- 128 129// CHECK-LABEL: @set_sve_alloca_alignment 130func.func @set_sve_alloca_alignment() { 131 /// This checks the alignment of alloca's of scalable vectors will be 132 /// something the backend can handle. Currently, the backend sets the 133 /// alignment of scalable vectors to their base size (i.e. their size at 134 /// vscale = 1). This works for hardware-sized types, which always get a 135 /// 16-byte alignment. The problem is larger types e.g. vector<[8]xf32> end up 136 /// with alignments larger than 16-bytes (e.g. 32-bytes here), which are 137 /// unsupported. The `-arm-sve-legalize-vector-storage` pass avoids this 138 /// issue by explicitly setting the alignment to 16-bytes for all scalable 139 /// vectors. 140 141 // CHECK-COUNT-6: alignment = 16 142 %a1 = memref.alloca() : memref<vector<[32]xi8>> 143 %a2 = memref.alloca() : memref<vector<[16]xi8>> 144 %a3 = memref.alloca() : memref<vector<[8]xi8>> 145 %a4 = memref.alloca() : memref<vector<[4]xi8>> 146 %a5 = memref.alloca() : memref<vector<[2]xi8>> 147 %a6 = memref.alloca() : memref<vector<[1]xi8>> 148 149 // CHECK-COUNT-6: alignment = 16 150 %b1 = memref.alloca() : memref<vector<[32]xi16>> 151 %b2 = memref.alloca() : memref<vector<[16]xi16>> 152 %b3 = memref.alloca() : memref<vector<[8]xi16>> 153 %b4 = memref.alloca() : memref<vector<[4]xi16>> 154 %b5 = memref.alloca() : memref<vector<[2]xi16>> 155 %b6 = memref.alloca() : memref<vector<[1]xi16>> 156 157 // CHECK-COUNT-6: alignment = 16 158 %c1 = memref.alloca() : memref<vector<[32]xi32>> 159 %c2 = memref.alloca() : memref<vector<[16]xi32>> 160 %c3 = memref.alloca() : memref<vector<[8]xi32>> 161 %c4 = memref.alloca() : memref<vector<[4]xi32>> 162 %c5 = memref.alloca() : memref<vector<[2]xi32>> 163 %c6 = memref.alloca() : memref<vector<[1]xi32>> 164 165 // CHECK-COUNT-6: alignment = 16 166 %d1 = memref.alloca() : memref<vector<[32]xi64>> 167 %d2 = memref.alloca() : memref<vector<[16]xi64>> 168 %d3 = memref.alloca() : memref<vector<[8]xi64>> 169 %d4 = memref.alloca() : memref<vector<[4]xi64>> 170 %d5 = memref.alloca() : memref<vector<[2]xi64>> 171 %d6 = memref.alloca() : memref<vector<[1]xi64>> 172 173 // CHECK-COUNT-6: alignment = 16 174 %e1 = memref.alloca() : memref<vector<[32]xf32>> 175 %e2 = memref.alloca() : memref<vector<[16]xf32>> 176 %e3 = memref.alloca() : memref<vector<[8]xf32>> 177 %e4 = memref.alloca() : memref<vector<[4]xf32>> 178 %e5 = memref.alloca() : memref<vector<[2]xf32>> 179 %e6 = memref.alloca() : memref<vector<[1]xf32>> 180 181 // CHECK-COUNT-6: alignment = 16 182 %f1 = memref.alloca() : memref<vector<[32]xf64>> 183 %f2 = memref.alloca() : memref<vector<[16]xf64>> 184 %f3 = memref.alloca() : memref<vector<[8]xf64>> 185 %f4 = memref.alloca() : memref<vector<[4]xf64>> 186 %f5 = memref.alloca() : memref<vector<[2]xf64>> 187 %f6 = memref.alloca() : memref<vector<[1]xf64>> 188 189 "prevent.dce"( 190 %a1, %a2, %a3, %a4, %a5, %a6, 191 %b1, %b2, %b3, %b4, %b5, %b6, 192 %c1, %c2, %c3, %c4, %c5, %c6, 193 %d1, %d2, %d3, %d4, %d5, %d6, 194 %e1, %e2, %e3, %e4, %e5, %e6, 195 %f1, %f2, %f3, %f4, %f5, %f6) 196 : (memref<vector<[32]xi8>>, memref<vector<[16]xi8>>, memref<vector<[8]xi8>>, memref<vector<[4]xi8>>, memref<vector<[2]xi8>>, memref<vector<[1]xi8>>, 197 memref<vector<[32]xi16>>, memref<vector<[16]xi16>>, memref<vector<[8]xi16>>, memref<vector<[4]xi16>>, memref<vector<[2]xi16>>, memref<vector<[1]xi16>>, 198 memref<vector<[32]xi32>>, memref<vector<[16]xi32>>, memref<vector<[8]xi32>>, memref<vector<[4]xi32>>, memref<vector<[2]xi32>>, memref<vector<[1]xi32>>, 199 memref<vector<[32]xi64>>, memref<vector<[16]xi64>>, memref<vector<[8]xi64>>, memref<vector<[4]xi64>>, memref<vector<[2]xi64>>, memref<vector<[1]xi64>>, 200 memref<vector<[32]xf32>>, memref<vector<[16]xf32>>, memref<vector<[8]xf32>>, memref<vector<[4]xf32>>, memref<vector<[2]xf32>>, memref<vector<[1]xf32>>, 201 memref<vector<[32]xf64>>, memref<vector<[16]xf64>>, memref<vector<[8]xf64>>, memref<vector<[4]xf64>>, memref<vector<[2]xf64>>, memref<vector<[1]xf64>>) -> () 202 return 203} 204