1// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ 2// RUN: mlir-translate -mlir-to-llvmir | \ 3// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \ 4// RUN: FileCheck %s 5 6// Note: To run this test, your CPU must support AMX. 7 8// Multiply into zeroed destination. 9func.func @kernel1(%arg0: memref<2x8xi8>, 10 %arg1: memref<2x8xi8>, 11 %arg2: memref<2x2xi32>) { 12 %0 = arith.constant 0 : index 13 %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> 14 %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> 15 %3 = amx.tile_zero : vector<2x2xi32> 16 %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> 17 amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> 18 return 19} 20 21// Multiply and update into destination. 22func.func @kernel2(%arg0: memref<2x8xi8>, 23 %arg1: memref<2x8xi8>, 24 %arg2: memref<2x2xi32>) { 25 %0 = arith.constant 0 : index 26 %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> 27 %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> 28 %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32> 29 %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> 30 amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> 31 return 32} 33 34func.func @entry() -> i32 { 35 %i0 = arith.constant 0: i32 36 %c0 = arith.constant 0: index 37 %c1 = arith.constant 1: index 38 %c2 = arith.constant 2: index 39 40 // Set up memory. 41 %a = memref.alloc() : memref<2x8xi8> 42 %b = memref.alloc() : memref<2x8xi8> 43 %c = memref.alloc() : memref<2x2xi32> 44 45 %0 = arith.constant dense<[[1 , 2, 3 , 4 , 5, 6, 7, 8], 46 [9, 10, 11, 12, 13, 14, 15, 16]]> : vector<2x8xi8> 47 vector.transfer_write %0, %a[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8> 48 %1 = arith.constant dense<[[17, 18, 19, 20, 21, 22, 23, 24], 49 [25, 26, 27, 28, 29, 30, 31, 32]]> : vector<2x8xi8> 50 vector.transfer_write %1, %b[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8> 51 52 // Call kernel. 53 call @kernel1(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> () 54 55 // Print and verify. 56 // 57 // CHECK: ( 884, 1028 ) 58 // CHECK-NEXT: ( 2324, 2724 ) 59 scf.for %i = %c0 to %c2 step %c1 { 60 %av = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32> 61 vector.print %av : vector<2xi32> 62 } 63 64 // Call kernel. 65 call @kernel2(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> () 66 67 // Print and verify. 68 // 69 // CHECK-NEXT: ( 1768, 2056 ) 70 // CHECK-NEXT: ( 4648, 5448 ) 71 // 72 scf.for %i = %c0 to %c2 step %c1 { 73 %cv = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32> 74 vector.print %cv : vector<2xi32> 75 } 76 77 // Release resources. 78 memref.dealloc %a : memref<2x8xi8> 79 memref.dealloc %b : memref<2x8xi8> 80 memref.dealloc %c : memref<2x2xi32> 81 82 return %i0 : i32 83} 84