1// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-a.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A 2// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-b.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B 3 4// CHECK-LABEL: func.func @matmul_static( 5// CHECK-SAME: %[[A:.*]]: tensor<16x8xf32>, 6// CHECK-SAME: %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> { 7// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 8// CHECK: %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32> 9// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32> 10// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32> 11// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0] 12// TRANSPOSE-A: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32> 13// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32> 14// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0] 15// TRANSPOSE-B: %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32> 16// CHECK: return %[[C]] : tensor<16x16xf32> 17// CHECK: } 18func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) { 19 %cst = arith.constant 0.0 : f32 20 %init = tensor.empty() : tensor<16x16xf32> 21 %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32> 22 %0 = linalg.matmul ins(%A, %B : tensor<16x8xf32>, tensor<8x16xf32>) outs(%C : tensor<16x16xf32>) -> tensor<16x16xf32> 23 return %0 : tensor<16x16xf32> 24} 25 26//----- 27 28// CHECK-LABEL: func.func @matmul_dynamic( 29// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, 30// CHECK-SAME: %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> { 31// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 32// CHECK: %[[C0:.*]] = arith.constant 0 : index 33// CHECK: %[[C1:.*]] = arith.constant 1 : index 34// CHECK: %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32> 35// CHECK: %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32> 36// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32> 37// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32> 38// TRANSPOSE-A: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32> 39// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32> 40// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0] 41// TRANSPOSE-A: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32> 42// TRANSPOSE-B: %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32> 43// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM1]], %[[B_DIM0]]) : tensor<?x?xf32> 44// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0] 45// TRANSPOSE-B: %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32> 46// CHECK: return %[[C]] : tensor<?x?xf32> 47// CHECK: } 48func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) { 49 %cst = arith.constant 0.0 : f32 50 %c0 = arith.constant 0 : index 51 %c1 = arith.constant 1 : index 52 %d0 = tensor.dim %A, %c0 : tensor<?x?xf32> 53 %d1 = tensor.dim %B, %c1 : tensor<?x?xf32> 54 %init = tensor.empty(%d0, %d1) : tensor<?x?xf32> 55 %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32> 56 %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32> 57 return %0 : tensor<?x?xf32> 58} 59 60//----- 61 62// CHECK-LABEL: func.func @matmul_mixed( 63// CHECK-SAME: %[[A:.*]]: tensor<?x8xf32>, 64// CHECK-SAME: %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> { 65// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 66// CHECK: %[[C0:.*]] = arith.constant 0 : index 67// CHECK: %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32> 68// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32> 69// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32> 70// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32> 71// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0] 72// TRANSPOSE-A: %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32> 73// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32> 74// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0] 75// TRANSPOSE-B: %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32> 76// CHECK: return %[[B0]] : tensor<?x16xf32> 77// CHECK: } 78func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) { 79 %cst = arith.constant 0.0 : f32 80 %c0 = arith.constant 0 : index 81 %c1 = arith.constant 1 : index 82 %d0 = tensor.dim %A, %c0 : tensor<?x8xf32> 83 %init = tensor.empty(%d0) : tensor<?x16xf32> 84 %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x16xf32>) -> tensor<?x16xf32> 85 %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32> 86 return %0 : tensor<?x16xf32> 87} 88 89//----- 90 91// CHECK-LABEL: func.func @batch_matmul_static( 92// CHECK-SAME: %[[A:.*]]: tensor<2x16x8xf32>, 93// CHECK-SAME: %[[B:.*]]: tensor<2x8x16xf32>) -> tensor<2x16x16xf32> { 94// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 95// CHECK: %[[C_INIT:.*]] = tensor.empty() : tensor<2x16x16xf32> 96// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32> 97// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x8x16xf32> 98// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x16xf32>) permutation = [0, 2, 1] 99// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32> 100// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32> 101// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1] 102// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32> 103// CHECK: return %[[C]] : tensor<2x16x16xf32> 104// CHECK: } 105func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x16x16xf32>) { 106 %cst = arith.constant 0.0 : f32 107 %init = tensor.empty() : tensor<2x16x16xf32> 108 %C = linalg.fill ins(%cst : f32) outs(%init : tensor<2x16x16xf32>) -> tensor<2x16x16xf32> 109 %0 = linalg.batch_matmul ins(%A, %B : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%C : tensor<2x16x16xf32>) -> tensor<2x16x16xf32> 110 return %0 : tensor<2x16x16xf32> 111} 112 113//----- 114 115// CHECK-LABEL: func.func @batch_matmul_dynamic( 116// CHECK-SAME: %[[A:.*]]: tensor<?x?x?xf32>, 117// CHECK-SAME: %[[B:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { 118// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 119// CHECK: %[[C0:.*]] = arith.constant 0 : index 120// CHECK: %[[C1:.*]] = arith.constant 1 : index 121// CHECK: %[[C2:.*]] = arith.constant 2 : index 122// CHECK: %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf32> 123// CHECK: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf32> 124// CHECK: %[[B_DIM2:.*]] = tensor.dim %[[B]], %[[C2]] : tensor<?x?x?xf32> 125// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM1]], %[[B_DIM2]]) : tensor<?x?x?xf32> 126// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 127// TRANSPOSE-A: %[[A_DIM2:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf32> 128// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM2]], %[[A_DIM1]]) : tensor<?x?x?xf32> 129// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1] 130// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 131// TRANSPOSE-B: %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?x?xf32> 132// TRANSPOSE-B: %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?x?xf32> 133// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM0]], %[[B_DIM2]], %[[B_DIM1]]) : tensor<?x?x?xf32> 134// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1] 135// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 136// CHECK: return %[[C]] : tensor<?x?x?xf32> 137// CHECK: } 138func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) { 139 %cst = arith.constant 0.0 : f32 140 %c0 = arith.constant 0 : index 141 %c1 = arith.constant 1 : index 142 %c2 = arith.constant 2 : index 143 %d0 = tensor.dim %A, %c0 : tensor<?x?x?xf32> 144 %d1 = tensor.dim %A, %c1 : tensor<?x?x?xf32> 145 %d2 = tensor.dim %B, %c2 : tensor<?x?x?xf32> 146 %init = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32> 147 %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 148 %0 = linalg.batch_matmul ins(%A, %B : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 149 return %0 : tensor<?x?x?xf32> 150} 151 152//----- 153 154// CHECK-LABEL: func.func @batch_matmul_mixed( 155// CHECK-SAME: %[[A:.*]]: tensor<2x?x8xf32>, 156// CHECK-SAME: %[[B:.*]]: tensor<2x8x16xf32>) -> tensor<2x?x16xf32> { 157// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 158// CHECK: %[[C1:.*]] = arith.constant 1 : index 159// CHECK: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<2x?x8xf32> 160// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x?x16xf32> 161// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32> 162// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x8x?xf32> 163// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x?xf32>) permutation = [0, 2, 1] 164// TRANSPOSE-A: %[[B0:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32> 165// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32> 166// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1] 167// TRANSPOSE-B: %[[B0:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32> 168// CHECK: return %[[B0]] : tensor<2x?x16xf32> 169// CHECK: } 170func.func @batch_matmul_mixed(%A: tensor<2x?x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x?x16xf32>) { 171 %cst = arith.constant 0.0 : f32 172 %c0 = arith.constant 0 : index 173 %c1 = arith.constant 1 : index 174 %d1 = tensor.dim %A, %c1 : tensor<2x?x8xf32> 175 %init = tensor.empty(%d1) : tensor<2x?x16xf32> 176 %C = linalg.fill ins(%cst : f32) outs(%init : tensor<2x?x16xf32>) -> tensor<2x?x16xf32> 177 %0 = linalg.batch_matmul ins(%A, %B : tensor<2x?x8xf32>, tensor<2x8x16xf32>) outs(%C : tensor<2x?x16xf32>) -> tensor<2x?x16xf32> 178 return %0 : tensor<2x?x16xf32> 179} 180