xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero-block.mlir (revision 435114f9fe2139bec770e5a95799f4eab20639e7)
1// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
2// RUN: mlir-translate -mlir-to-llvmir | \
3// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
4// RUN: FileCheck %s
5
6// Note: To run this test, your CPU must support AMX.
7
8func.func @print(%arg0: memref<4x32xf32>) {
9  %fu = arith.constant -1.0: f32
10  %c0 = arith.constant 0: index
11  %c1 = arith.constant 1: index
12  %c4 = arith.constant 4: index
13  scf.for %i = %c0 to %c4 step %c1 {
14    %0 = vector.transfer_read %arg0[%i, %c0], %fu: memref<4x32xf32>, vector<32xf32>
15    vector.print %0 : vector<32xf32>
16  }
17  return
18}
19
20func.func @kernel(%arg0: memref<4x32xf32>) {
21  %c0  = arith.constant 0: index
22  %c2  = arith.constant 2 : index
23  %c4  = arith.constant 4 : index
24  %c16 = arith.constant 16 : index
25  %c32 = arith.constant 32 : index
26  scf.for %i = %c0 to %c4 step %c2 {
27    scf.for %j = %c0 to %c32 step %c16 {
28      %0 = amx.tile_zero : vector<2x16xf32>
29      amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32>
30      func.call @print(%arg0) : (memref<4x32xf32>) -> ()
31    }
32  }
33  return
34}
35
36func.func @entry() -> i32 {
37  %f1  = arith.constant 1.0: f32
38  %c0  = arith.constant 0: index
39  %c1  = arith.constant 1: index
40  %c4  = arith.constant 4 : index
41  %c32 = arith.constant 32 : index
42
43  // Set up memory.
44  %a = memref.alloc() : memref<4x32xf32>
45  scf.for %i = %c0 to %c4 step %c1 {
46    scf.for %j = %c0 to %c32 step %c1 {
47      memref.store %f1, %a[%i, %j] : memref<4x32xf32>
48    }
49  }
50
51  // Call kernel.
52  func.call @kernel(%a) : (memref<4x32xf32>) -> ()
53
54  // Verify progress of blocked tilezero.
55  //
56  // CHECK:      ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
57  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
58  // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
59  // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
60  //
61  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
62  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
63  // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
64  // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
65  //
66  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
67  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
68  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
69  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
70  //
71  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
72  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
73  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
74  // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
75  //
76
77  // Release resources.
78  memref.dealloc %a : memref<4x32xf32>
79
80  %i0 = arith.constant 0 : i32
81  return %i0 : i32
82}
83