xref: /llvm-project/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (revision 0cf844759add057f76ca72a611e692eea191c7b7)
1// RUN: mlir-opt --resolve-ranked-shaped-type-result-dims --split-input-file %s | FileCheck %s
2
3// CHECK-LABEL: func @dim_out_of_bounds(
4//  CHECK-NEXT:   arith.constant
5//  CHECK-NEXT:   memref.dim
6//  CHECK-NEXT:   return
7func.func @dim_out_of_bounds(%m : memref<7x8xf32>) -> index {
8  %idx = arith.constant 7 : index
9  %0 = memref.dim %m, %idx : memref<7x8xf32>
10  return %0 : index
11}
12
13// -----
14
15// CHECK-LABEL: func @dim_out_of_bounds_2(
16//  CHECK-NEXT:   arith.constant
17//  CHECK-NEXT:   arith.constant
18//  CHECK-NEXT:   bufferization.alloc_tensor
19//  CHECK-NEXT:   tensor.dim
20//  CHECK-NEXT:   return
21func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
22  %idx = arith.constant 7 : index
23  %sz = arith.constant 5 : index
24  %alloc = bufferization.alloc_tensor(%sz, %sz) : tensor<?x?xf32>
25  %0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
26  return %0 : index
27}
28
29// -----
30
31// CHECK-LABEL:   func.func @dynamic_dim_of_transpose_op(
32//  CHECK-SAME:                                   %[[arg:.*]]: tensor<1x2x?x8xi8>) -> index {
33//  CHECK-NEXT:           %[[c2:.*]] = arith.constant 2
34//  CHECK-NEXT:           tensor.dim %[[arg]], %[[c2]]
35//  CHECK-NEXT:           return
36func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
37  %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
38  %1 = tosa.transpose %arg0, %0 : (tensor<1x2x?x8xi8>, tensor<4xi32>) -> tensor<1x8x2x?xi8>
39  %c3 = arith.constant 3 : index
40  %dim = tensor.dim %1, %c3 : tensor<1x8x2x?xi8>
41  return %dim : index
42}
43
44// -----
45
46// CHECK-LABEL:   func.func @static_dim_of_transpose_op(
47//  CHECK:           arith.constant 100 : index
48//  CHECK:           return
49func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
50  %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
51  %1 = tosa.transpose %arg0, %0 : (tensor<1x100x?x8xi8>, tensor<4xi32>) -> tensor<1x8x100x?xi8>
52  %c2 = arith.constant 2 : index
53  %dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
54  return %dim : index
55}
56
57// -----
58
59// Test case: Folding of memref.dim(memref.expand_shape)
60// CHECK-LABEL: func @dim_of_memref_expand_shape(
61//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
62//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 0
63//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
64//       CHECK:   return %[[DIM]] : index
65func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
66    -> index {
67  %c0 = arith.constant 0 : index
68  %c1 = arith.constant 1 : index
69  %s = memref.dim %arg0, %c0 : memref<?x8xi32>
70  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, %s, 2, 4]: memref<?x8xi32> into memref<1x?x2x4xi32>
71  %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
72  return %1 : index
73}
74
75// -----
76
77// CHECK-LABEL: @iter_to_init_arg_loop_like
78//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
79//       CHECK:    %[[RESULT:.*]] = scf.forall
80//  CHECK-SAME:                       shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
81//  CHECK-NEXT:       %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
82func.func @iter_to_init_arg_loop_like(
83  %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
84  %c0 = arith.constant 0 : index
85  %c1 = arith.constant 1 : index
86  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
87
88  %result = scf.forall (%i) = (%c0) to (%dim0)
89      step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
90
91    %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
92    %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
93      : tensor<?x?xf32> to tensor<1x?xf32>
94
95    scf.forall.in_parallel {
96      tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
97        : tensor<1x?xf32> into tensor<?x?xf32>
98    }
99  }
100  return %result : tensor<?x?xf32>
101}
102