xref: /llvm-project/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir (revision fe55c34d19628304e0ca6a0e14a0b786b93d0e02)
1// DEFINE: %{compile} =  mlir-opt %s \
2// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
3// DEFINE:    -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
4// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
5// DEFINE: %{entry_point} = matmul_f32
6// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
7// DEFINE:    -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils
8
9// RUN: %{compile}
10
11// RUN: %{run} | FileCheck %s --check-prefix=F32
12
13// REDEFINE: %{entry_point} = matmul_mixed_ty
14// RUN: %{run} | FileCheck %s --check-prefix=MIXED
15
16func.func @matmul_f32() {
17  // Matrix dimensions
18  %K = arith.constant 3 : index
19  %M = arith.constant 5 : index
20  %N = arith.constant 15 : index
21  %c0_f32 = arith.constant 0.0 : f32
22
23  // Allocate the matrices
24  %A_alloc = bufferization.alloc_tensor(%M, %K) : tensor<?x?xf32>
25  %B_alloc = bufferization.alloc_tensor(%K, %N) : tensor<?x?xf32>
26  %C_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xf32>
27
28  // Initialise the matrices
29  %pi = arith.constant  3.14 : f32
30  %A = linalg.fill ins(%pi : f32) outs(%A_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
31  %B = linalg.fill ins(%pi : f32) outs(%B_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
32  %C_in = linalg.fill ins(%c0_f32 : f32) outs(%C_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
33
34  // Matmul
35  %C_out = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>) outs(%C_in: tensor<?x?xf32>) -> tensor<?x?xf32>
36
37  // Print and verify the output
38  // F32-LABEL: SVE: START OF TEST OUTPUT
39  vector.print str "SVE: START OF TEST OUTPUT\n"
40
41  // F32-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
42  // F32-COUNT-5: [29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788]
43  %xf = tensor.cast %C_out : tensor<?x?xf32> to tensor<*xf32>
44  call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
45
46  // F32-NEXT: SVE: END OF TEST OUTPUT
47  vector.print str "SVE: END OF TEST OUTPUT\n"
48
49  return
50}
51
52func.func @matmul_mixed_ty() {
53  // Matrix dimensions
54  %K = arith.constant 3 : index
55  %M = arith.constant 5 : index
56  %N = arith.constant 15 : index
57  %c0_i8 = arith.constant 0 : i8
58  %c0_i32 = arith.constant 0 : i32
59
60  // Allocate the matrices
61  %A_alloc = bufferization.alloc_tensor(%M, %K) : tensor<?x?xi8>
62  %B_alloc = bufferization.alloc_tensor(%K, %N) : tensor<?x?xi8>
63  %C_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xi32>
64
65  // Initialise the matrices
66  %pi = arith.constant  123 : i8
67  %A = linalg.fill ins(%pi : i8) outs(%A_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
68  %B = linalg.fill ins(%pi : i8) outs(%B_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
69  %C_in = linalg.fill ins(%c0_i32 : i32) outs(%C_alloc : tensor<?x?xi32>) -> tensor<?x?xi32>
70
71  // Matmul
72  %C_out = linalg.matmul ins(%A, %B: tensor<?x?xi8>, tensor<?x?xi8>) outs(%C_in: tensor<?x?xi32>) -> tensor<?x?xi32>
73
74  // Print and verify the output
75  // MIXED-LABEL: SVE: START OF TEST OUTPUT
76  vector.print str "SVE: START OF TEST OUTPUT\n"
77
78  // MIXED-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
79  // MIXED-COUNT-5: [45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387]
80  %xf = tensor.cast %C_out : tensor<?x?xi32> to tensor<*xi32>
81  call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
82
83  // MIXED-NEXT: SVE: END OF TEST OUTPUT
84  vector.print str "SVE: END OF TEST OUTPUT\n"
85
86  return
87}
88
89module attributes {transform.with_named_sequence} {
90  // A sequence that will tile and vectorise a Matmul Op
91  transform.named_sequence @tile_and_vectorize_matmul(%func
92    : !transform.op<"func.func"> {transform.readonly}) {
93
94    // Step 0: Get a handle to the matmul Op
95    %matmul = transform.structured.match ops{["linalg.matmul"]} in %func
96      : (!transform.op<"func.func">) -> !transform.any_op
97
98    // Step 1: Tile
99    %tiled_matmul, %loops:3 = transform.structured.tile_using_for %matmul tile_sizes [2, [4], 1]
100      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
101
102    // Step 2: Vectorize
103    transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
104
105    // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
106    transform.apply_patterns to %func {
107      transform.apply_patterns.vector.reduction_to_contract
108      transform.apply_patterns.vector.transfer_permutation_patterns
109      transform.apply_patterns.vector.lower_masked_transfers
110    } : !transform.op<"func.func">
111
112    // Step 4: Lower vector.contract to vector.fma
113    transform.apply_patterns to %func {
114      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
115      transform.apply_patterns.vector.lower_outerproduct
116    } : !transform.op<"func.func">
117
118    transform.yield
119  }
120
121  // A sequence that goes over all functions in tis module and applies
122  // "tile_and_vectorize_matmul"
123  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
124    %funcs = transform.structured.match ops{["func.func"]} in %module
125        : (!transform.any_op) -> !transform.op<"func.func">
126
127    transform.foreach %funcs : !transform.op<"func.func"> {
128      ^bb2(%func : !transform.op<"func.func">):
129        transform.include @tile_and_vectorize_matmul failures(propagate)
130        (%func) : (!transform.op<"func.func">) -> ()
131    }
132    transform.yield
133  }
134}
135
136func.func private @printMemrefF32(%ptr : tensor<*xf32>)
137func.func private @printMemrefI32(%ptr : tensor<*xi32>)
138