1// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ 2// RUN: mlir-translate -mlir-to-llvmir | \ 3// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \ 4// RUN: FileCheck %s 5 6// Note: To run this test, your CPU must support AMX. 7 8func.func @print(%arg0: memref<4x32xf32>) { 9 %fu = arith.constant -1.0: f32 10 %c0 = arith.constant 0: index 11 %c1 = arith.constant 1: index 12 %c4 = arith.constant 4: index 13 scf.for %i = %c0 to %c4 step %c1 { 14 %0 = vector.transfer_read %arg0[%i, %c0], %fu: memref<4x32xf32>, vector<32xf32> 15 vector.print %0 : vector<32xf32> 16 } 17 return 18} 19 20func.func @kernel(%arg0: memref<4x32xf32>) { 21 %c0 = arith.constant 0: index 22 %c2 = arith.constant 2 : index 23 %c4 = arith.constant 4 : index 24 %c16 = arith.constant 16 : index 25 %c32 = arith.constant 32 : index 26 scf.for %i = %c0 to %c4 step %c2 { 27 scf.for %j = %c0 to %c32 step %c16 { 28 %0 = amx.tile_zero : vector<2x16xf32> 29 amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32> 30 func.call @print(%arg0) : (memref<4x32xf32>) -> () 31 } 32 } 33 return 34} 35 36func.func @entry() -> i32 { 37 %f1 = arith.constant 1.0: f32 38 %c0 = arith.constant 0: index 39 %c1 = arith.constant 1: index 40 %c4 = arith.constant 4 : index 41 %c32 = arith.constant 32 : index 42 43 // Set up memory. 44 %a = memref.alloc() : memref<4x32xf32> 45 scf.for %i = %c0 to %c4 step %c1 { 46 scf.for %j = %c0 to %c32 step %c1 { 47 memref.store %f1, %a[%i, %j] : memref<4x32xf32> 48 } 49 } 50 51 // Call kernel. 52 func.call @kernel(%a) : (memref<4x32xf32>) -> () 53 54 // Verify progress of blocked tilezero. 55 // 56 // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) 57 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) 58 // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) 59 // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) 60 // 61 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) 62 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) 63 // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) 64 // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) 65 // 66 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) 67 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) 68 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) 69 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) 70 // 71 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) 72 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) 73 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) 74 // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) 75 // 76 77 // Release resources. 78 memref.dealloc %a : memref<4x32xf32> 79 80 %i0 = arith.constant 0 : i32 81 return %i0 : i32 82} 83