xref: /llvm-project/mlir/test/Dialect/Linalg/standard.mlir (revision ff94419a287c0b20bf357ab85cf611d4e9bad4c0)
1// RUN: mlir-opt %s -convert-linalg-to-std --split-input-file -verify-diagnostics | FileCheck %s
2
3func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
4          %arg1: memref<?xf32, strided<[1], offset: ?>>,
5          %arg2: memref<f32>) {
6  linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
7                               memref<?xf32, strided<[1], offset: ?>>)
8             outs(%arg2: memref<f32>)
9  return
10}
11// CHECK-LABEL: func @dot(
12//  CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
13//  CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
14//  CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
15//       CHECK:   %[[o0:.*]] = memref.cast %[[arg0]] :
16//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
17//       CHECK:   %[[o1:.*]] = memref.cast %[[arg1]] :
18//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
19//       CHECK:   %[[o2:.*]] = memref.cast %[[arg2]] :
20//  CHECK-SAME:     memref<f32> to memref<f32, strided<[], offset: ?>>
21//       CHECK:   call @linalg_dot_viewsxf32_viewsxf32_viewf32(
22//  CHECK-SAME:     %[[o0]], %[[o1]], %[[o2]]) :
23//  CHECK-SAME:   memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
24
25// -----
26
27#matmul_accesses = [
28  affine_map<(m, n, k) -> (m, k)>,
29  affine_map<(m, n, k) -> (k, n)>,
30  affine_map<(m, n, k) -> (m, n)>
31]
32#matmul_trait = {
33  iterator_types = ["parallel", "parallel", "reduction"],
34  indexing_maps = #matmul_accesses,
35  library_call = "external_outerproduct_matmul"
36}
37
38!vector_type_A = vector<4xf32>
39!vector_type_B = vector<4xf32>
40!vector_type_C = vector<4x4xf32>
41
42!matrix_type_A = memref<?x?x!vector_type_A>
43!matrix_type_B = memref<?x?x!vector_type_B>
44!matrix_type_C = memref<?x?x!vector_type_C>
45
46func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
47  linalg.generic #matmul_trait
48      ins(%A, %B : !matrix_type_A, !matrix_type_B)
49     outs(%C : !matrix_type_C) {
50    ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
51      %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
52      linalg.yield %d: !vector_type_C
53  }
54  return
55}
56// CHECK-LABEL: func @matmul_vec_impl(
57// CHECK:  call @external_outerproduct_matmul(%{{.*}}) :
58
59// -----
60
61#map = affine_map<(d0, d1) -> (d0, d1)>
62#map1 = affine_map<(d0, d1) -> (d0)>
63
64func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>)  {
65  // expected-error @below {{failed to legalize}}
66  %0 = linalg.generic {
67    indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
68  ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
69  ^bb0(%in: f32, %out: f32):
70    linalg.yield %in : f32
71  } -> tensor<?xf32>
72  return
73}
74
75// -----
76
77func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
78  // expected-error @below {{failed to legalize}}
79  %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
80  return %0 : tensor<4x8xf32>
81}
82