xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir (revision fe55c34d19628304e0ca6a0e14a0b786b93d0e02)
1// DEFINE: %{entry_point} = test_load_store_zaq0
2// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
3// DEFINE: %{run} = %mcr_aarch64_cmd \
4// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
5// DEFINE:  -e %{entry_point} -entry-point-result=void \
6// DEFINE:  -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_arm_sme_abi_shlib
7
8// RUN: %{compile} | %{run} | FileCheck %s
9
10func.func @print_i8s(%bytes: memref<?xi8>, %len: index) {
11  %c0 = arith.constant 0 : index
12  %c16 = arith.constant 16 : index
13  scf.for %i = %c0 to %len step %c16 {
14    %v = vector.load %bytes[%i] : memref<?xi8>, vector<16xi8>
15    vector.print %v : vector<16xi8>
16  }
17  return
18}
19
20func.func @vector_copy_i128(%src: memref<?x?xi128>, %dst: memref<?x?xi128>) {
21  %c0 = arith.constant 0 : index
22  %tile = vector.load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
23  vector.store %tile, %dst[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
24  return
25}
26
27func.func @test_load_store_zaq0() {
28  %c0 = arith.constant 0 : index
29  %min_elts_q = arith.constant 1 : index
30  %bytes_per_128_bit = arith.constant 16 : index
31
32  /// Calculate the size of an 128-bit tile, e.g. ZA{n}.q, in bytes:
33  %vscale = vector.vscale
34  %svl_q = arith.muli %min_elts_q, %vscale : index
35  %zaq_size = arith.muli %svl_q, %svl_q : index
36  %zaq_size_bytes = arith.muli %zaq_size, %bytes_per_128_bit : index
37
38  /// Allocate memory for two 128-bit tiles (A and B) and fill them a constant.
39  /// The tiles are allocated as bytes so we can fill and print them, as there's
40  /// very little that can be done with 128-bit types directly.
41  %tile_a_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref<?xi8>
42  %tile_b_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref<?xi8>
43  %fill_a_i8 = arith.constant 7 : i8
44  %fill_b_i8 = arith.constant 64 : i8
45  linalg.fill ins(%fill_a_i8 : i8) outs(%tile_a_bytes : memref<?xi8>)
46  linalg.fill ins(%fill_b_i8 : i8) outs(%tile_b_bytes : memref<?xi8>)
47
48  /// Get an 128-bit view of the memory for tiles A and B:
49  %tile_a = memref.view %tile_a_bytes[%c0][%svl_q, %svl_q] :
50    memref<?xi8> to memref<?x?xi128>
51  %tile_b = memref.view %tile_b_bytes[%c0][%svl_q, %svl_q] :
52    memref<?xi8> to memref<?x?xi128>
53
54  // CHECK-LABEL: INITIAL TILE A:
55  // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
56  vector.print str "INITIAL TILE A:\n"
57  func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
58  vector.print punctuation <newline>
59
60  // CHECK-LABEL: INITIAL TILE B:
61  // CHECK: ( 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 )
62  vector.print str "INITIAL TILE B:\n"
63  func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
64  vector.print punctuation <newline>
65
66  /// Load tile A and store it to tile B:
67  func.call @vector_copy_i128(%tile_a, %tile_b) : (memref<?x?xi128>, memref<?x?xi128>) -> ()
68
69  // CHECK-LABEL: FINAL TILE A:
70  // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
71  vector.print str "FINAL TILE A:\n"
72  func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
73  vector.print punctuation <newline>
74
75  // CHECK-LABEL: FINAL TILE B:
76  // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
77  vector.print str "FINAL TILE B:\n"
78  func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
79
80  return
81}
82