xref: /llvm-project/mlir/test/Dialect/Linalg/affine.mlir (revision 52556c8e3561e7f3fa620e9d0c8f60cd4736b10f)
1// RUN: mlir-opt %s -convert-linalg-to-affine-loops | FileCheck %s
2
3// Test that we can lower all the way to LLVM without crashing, don't check results here.
4// RUN: mlir-opt %s -convert-linalg-to-affine-loops -test-lower-to-llvm -o=/dev/null 2>&1
5
6func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
7  %c0 = arith.constant 0 : index
8  %c1 = arith.constant 1 : index
9  %A = memref.view %arg0[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
10  %B = memref.view %arg0[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
11  %C = memref.view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
12  linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
13               outs(%C: memref<?x?xf32>)
14  return
15}
16
17//----------------------------------------------------------------------------//
18// Named ops to loops.
19//----------------------------------------------------------------------------//
20func.func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
21  linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
22                     outs(%C : memref<?x?x?xf32>)
23  return
24}
25// CHECK-LABEL: @named_batch_matmul
26//  CHECK-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
27//  CHECK-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
28//  CHECK-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
29//       CHECK: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
30//       CHECK: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
31//       CHECK: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
32//       CHECK: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
33//       CHECK: affine.for %[[b:.*]] = {{.*}}0 to %[[B]] {
34//       CHECK:   affine.for %[[m:.*]] = {{.*}}0 to %[[M]] {
35//       CHECK:     affine.for %[[n:.*]] = {{.*}}0 to %[[N]] {
36//       CHECK:       affine.for %[[k:.*]] = {{.*}}0 to %[[K]] {
37//       CHECK:       %[[va:.*]] = affine.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
38//       CHECK:       %[[vb:.*]] = affine.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
39//       CHECK:       %[[vc:.*]] = affine.load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
40//       CHECK:       %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
41//       CHECK:       %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
42//       CHECK:       affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
43