xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf.mlir (revision 435114f9fe2139bec770e5a95799f4eab20639e7)
1// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
2// RUN: mlir-translate -mlir-to-llvmir | \
3// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
4// RUN: FileCheck %s
5
6// Note: To run this test, your CPU must support AMX.
7
8// Multiply into zeroed destination.
9func.func @kernel1(%arg0: memref<2x4xbf16>,
10              %arg1: memref<2x4xbf16>,
11              %arg2: memref<2x2xf32>) {
12  %0 = arith.constant 0 : index
13  %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
14  %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
15  %3 = amx.tile_zero : vector<2x2xf32>
16  %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
17  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
18  return
19}
20
21// Multiply and update into destination.
22func.func @kernel2(%arg0: memref<2x4xbf16>,
23              %arg1: memref<2x4xbf16>,
24              %arg2: memref<2x2xf32>) {
25  %0 = arith.constant 0 : index
26  %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
27  %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
28  %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
29  %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
30  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
31  return
32}
33
34func.func @entry() -> i32 {
35  %f0 = arith.constant 0.0: f32
36  %c0 = arith.constant 0: index
37  %c1 = arith.constant 1: index
38  %c2 = arith.constant 2: index
39
40  // Set up memory.
41  %a = memref.alloc() : memref<2x4xbf16>
42  %b = memref.alloc() : memref<2x4xbf16>
43  %c = memref.alloc() : memref<2x2xf32>
44
45  %0 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0 ],
46                          [5.0, 6.0, 7.0, 8.0 ]]> : vector<2x4xbf16>
47  vector.transfer_write %0, %a[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16>
48  %1 = arith.constant dense<[[ 9.0, 10.0, 11.0, 12.0 ],
49                          [13.0, 14.0, 15.0, 16.0 ]]> : vector<2x4xbf16>
50  vector.transfer_write %1, %b[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16>
51
52  // Call kernel.
53  call @kernel1(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> ()
54
55  // Print and verify.
56  //
57  // CHECK:      ( 124, 144 )
58  // CHECK-NEXT: ( 308, 360 )
59  scf.for %i = %c0 to %c2 step %c1 {
60    %av = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32>
61    vector.print %av : vector<2xf32>
62  }
63
64  // Call kernel.
65  call @kernel2(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> ()
66
67  // Print and verify.
68  //
69  // CHECK-NEXT: ( 248, 288 )
70  // CHECK-NEXT: ( 616, 720 )
71  //
72  scf.for %i = %c0 to %c2 step %c1 {
73    %cv = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32>
74    vector.print %cv : vector<2xf32>
75  }
76
77  // Release resources.
78  memref.dealloc %a : memref<2x4xbf16>
79  memref.dealloc %b : memref<2x4xbf16>
80  memref.dealloc %c : memref<2x2xf32>
81
82  %i0 = arith.constant 0 : i32
83  return %i0 : i32
84}
85