1// DEFINE: %{entry_point} = test_load_store_zaq0 2// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm 3// DEFINE: %{run} = %mcr_aarch64_cmd \ 4// DEFINE: -march=aarch64 -mattr=+sve,+sme \ 5// DEFINE: -e %{entry_point} -entry-point-result=void \ 6// DEFINE: -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_arm_sme_abi_shlib 7 8// RUN: %{compile} | %{run} | FileCheck %s 9 10func.func @print_i8s(%bytes: memref<?xi8>, %len: index) { 11 %c0 = arith.constant 0 : index 12 %c16 = arith.constant 16 : index 13 scf.for %i = %c0 to %len step %c16 { 14 %v = vector.load %bytes[%i] : memref<?xi8>, vector<16xi8> 15 vector.print %v : vector<16xi8> 16 } 17 return 18} 19 20func.func @vector_copy_i128(%src: memref<?x?xi128>, %dst: memref<?x?xi128>) { 21 %c0 = arith.constant 0 : index 22 %tile = vector.load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128> 23 vector.store %tile, %dst[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128> 24 return 25} 26 27func.func @test_load_store_zaq0() { 28 %c0 = arith.constant 0 : index 29 %min_elts_q = arith.constant 1 : index 30 %bytes_per_128_bit = arith.constant 16 : index 31 32 /// Calculate the size of an 128-bit tile, e.g. ZA{n}.q, in bytes: 33 %vscale = vector.vscale 34 %svl_q = arith.muli %min_elts_q, %vscale : index 35 %zaq_size = arith.muli %svl_q, %svl_q : index 36 %zaq_size_bytes = arith.muli %zaq_size, %bytes_per_128_bit : index 37 38 /// Allocate memory for two 128-bit tiles (A and B) and fill them a constant. 39 /// The tiles are allocated as bytes so we can fill and print them, as there's 40 /// very little that can be done with 128-bit types directly. 41 %tile_a_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref<?xi8> 42 %tile_b_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref<?xi8> 43 %fill_a_i8 = arith.constant 7 : i8 44 %fill_b_i8 = arith.constant 64 : i8 45 linalg.fill ins(%fill_a_i8 : i8) outs(%tile_a_bytes : memref<?xi8>) 46 linalg.fill ins(%fill_b_i8 : i8) outs(%tile_b_bytes : memref<?xi8>) 47 48 /// Get an 128-bit view of the memory for tiles A and B: 49 %tile_a = memref.view %tile_a_bytes[%c0][%svl_q, %svl_q] : 50 memref<?xi8> to memref<?x?xi128> 51 %tile_b = memref.view %tile_b_bytes[%c0][%svl_q, %svl_q] : 52 memref<?xi8> to memref<?x?xi128> 53 54 // CHECK-LABEL: INITIAL TILE A: 55 // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 ) 56 vector.print str "INITIAL TILE A:\n" 57 func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> () 58 vector.print punctuation <newline> 59 60 // CHECK-LABEL: INITIAL TILE B: 61 // CHECK: ( 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 ) 62 vector.print str "INITIAL TILE B:\n" 63 func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> () 64 vector.print punctuation <newline> 65 66 /// Load tile A and store it to tile B: 67 func.call @vector_copy_i128(%tile_a, %tile_b) : (memref<?x?xi128>, memref<?x?xi128>) -> () 68 69 // CHECK-LABEL: FINAL TILE A: 70 // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 ) 71 vector.print str "FINAL TILE A:\n" 72 func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> () 73 vector.print punctuation <newline> 74 75 // CHECK-LABEL: FINAL TILE B: 76 // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 ) 77 vector.print str "FINAL TILE B:\n" 78 func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> () 79 80 return 81} 82