1// DEFINE: %{entry_point} = entry 2// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm 3// DEFINE: %{run} = %mcr_aarch64_cmd \ 4// DEFINE: -march=aarch64 -mattr=+sve,+sme \ 5// DEFINE: -e %{entry_point} -entry-point-result=i32 \ 6// DEFINE: -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_arm_sme_abi_shlib 7 8// RUN: %{compile} | %{run} | FileCheck %s 9 10func.func @entry() -> i32 { 11 %c0 = arith.constant 0 : index 12 %c1_i8 = arith.constant 1 : i8 13 %c1_index = arith.constant 1 : index 14 15 // "svl" refers to the Streaming Vector Length and "svl_b" the number of 16 // 8-bit elements in a vector of SVL bits. 17 %svl_b = arm_sme.streaming_vl <byte> 18 19 // Allocate memory and fill with ones. 20 // 21 // TODO: type conversion of rank > 1 vector types generates array(s) of 22 // vectors. This is invalid for scalable vectors since LLVM doesn't support 23 // arrays of scalable vectors. This prevents initializing 2-d vectors with 24 // 'vector.store' or 'vector.transfer_write' ops until this is resolved or 25 // there's a custom lowering path. 26 %za_b = memref.alloca(%svl_b, %svl_b) : memref<?x?xi8> 27 scf.for %i = %c0 to %svl_b step %c1_index { 28 scf.for %j = %c0 to %svl_b step %c1_index { 29 memref.store %c1_i8, %za_b[%i, %j] : memref<?x?xi8> 30 } 31 } 32 33 // Verify memory is ones by doing a mul reduction with initial value of one. 34 %init_1 = arith.constant 1 : i64 35 %mul_reduce = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_1) -> (i64) { 36 %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> 37 38 %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { 39 %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> 40 %t_i64 = arith.extui %t : i8 to i64 41 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 42 scf.yield %inner_mul_reduce_next : i64 43 } 44 45 %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64 46 scf.yield %mul_reduce_next : i64 47 } 48 49 // CHECK: 1 50 vector.print %mul_reduce : i64 51 52 // Verify the mul reduction works as expected. 53 // 54 // TODO: ZA currently isn't re-enabled after calls and is therefore disable 55 // by the callee on return. Once this is resolved this can be moved to a 56 // function. 57 %c3 = arith.constant 3 : index 58 %c4 = arith.constant 4 : i8 59 %c7 = arith.constant 7 : index 60 %c15 = arith.constant 15 : i8 61 memref.store %c4, %za_b[%c3, %c7] : memref<?x?xi8> 62 memref.store %c15, %za_b[%c7, %c3] : memref<?x?xi8> 63 %mul_reduce2 = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_1) -> (i64) { 64 %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> 65 66 %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { 67 %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> 68 %t_i64 = arith.extui %t : i8 to i64 69 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 70 scf.yield %inner_mul_reduce_next : i64 71 } 72 73 %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64 74 scf.yield %mul_reduce_next : i64 75 } 76 77 // 15*4=60 78 // CHECK: 60 79 vector.print %mul_reduce2 : i64 80 81 // Fill memory with zeroes. 82 // 83 // This will get lowered to: 84 // 85 // zero {za} 86 // for vnum = 0; vnum < SVLb; ++vnum; 87 // str za[vnum], [ptr] 88 // ... 89 // 90 %cst_0 = arith.constant dense<0> : vector<[16]x[16]xi8> 91 vector.transfer_write %cst_0, %za_b[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> 92 93 // Verify memory is zeroed by doing an add reduction with initial value of 94 // zero. 95 %init_0 = arith.constant 0 : i8 96 %add_reduce = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_0) -> (i8) { 97 %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> 98 %row_sum = vector.reduction <add>, %row : vector<[16]xi8> into i8 99 %add_reduce_next = arith.addi %iter, %row_sum : i8 100 scf.yield %add_reduce_next : i8 101 } 102 103 // CHECK-NEXT: 0 104 vector.print %add_reduce : i8 105 106 // Verify the add reduction works as expected. 107 // 108 // TODO: ZA currently isn't re-enabled after calls and is therefore disable 109 // by the callee on return. Once this is resolved this can be moved to a 110 // function. 111 memref.store %c4, %za_b[%c3, %c7] : memref<?x?xi8> 112 memref.store %c15, %za_b[%c7, %c3] : memref<?x?xi8> 113 %add_reduce2 = scf.for %vnum = %c0 to %svl_b step %c1_index iter_args(%iter = %init_0) -> (i8) { 114 %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> 115 %row_sum = vector.reduction <add>, %row : vector<[16]xi8> into i8 116 %add_reduce_next = arith.addi %iter, %row_sum : i8 117 scf.yield %add_reduce_next : i8 118 } 119 120 // 15+4=19 121 // CHECK-NEXT: 19 122 vector.print %add_reduce2 : i8 123 124 %c0_i32 = arith.constant 0 : i32 125 return %c0_i32 : i32 126} 127