xref: /llvm-project/mlir/test/mlir-runner/memref-reinterpret-cast.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-scf-to-cf),finalize-memref-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" \
2// RUN: | mlir-runner -e main -entry-point-result=void \
3// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils \
4// RUN: | FileCheck %s
5
6func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
7
8func.func @main() -> () {
9  %c0 = arith.constant 0 : index
10  %c1 = arith.constant 1 : index
11
12  // Initialize input.
13  %input = memref.alloc() : memref<2x3xf32>
14  %dim_x = memref.dim %input, %c0 : memref<2x3xf32>
15  %dim_y = memref.dim %input, %c1 : memref<2x3xf32>
16  scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
17    %prod = arith.muli %i,  %dim_y : index
18    %val = arith.addi %prod, %j : index
19    %val_i64 = arith.index_cast %val : index to i64
20    %val_f32 = arith.sitofp %val_i64 : i64 to f32
21    memref.store %val_f32, %input[%i, %j] : memref<2x3xf32>
22  }
23  %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
24  call @printMemrefF32(%unranked_input) : (memref<*xf32>) -> ()
25  // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
26  // CHECK-NEXT: [0,   1,   2]
27  // CHECK-NEXT: [3,   4,   5]
28
29  // Test cases.
30  call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
31  call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
32  call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
33  call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
34  memref.dealloc %input : memref<2x3xf32>
35  return
36}
37
38func.func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) {
39  %output = memref.reinterpret_cast %input to
40           offset: [0], sizes: [6, 1], strides: [1, 1]
41           : memref<2x3xf32> to memref<6x1xf32>
42
43  %unranked_output = memref.cast %output
44      : memref<6x1xf32> to memref<*xf32>
45  call @printMemrefF32(%unranked_output) : (memref<*xf32>) -> ()
46  // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
47  // CHECK-NEXT: [0],
48  // CHECK-NEXT: [1],
49  // CHECK-NEXT: [2],
50  // CHECK-NEXT: [3],
51  // CHECK-NEXT: [4],
52  // CHECK-NEXT: [5]
53  return
54}
55
56func.func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
57  %c0 = arith.constant 0 : index
58  %c1 = arith.constant 1 : index
59  %c6 = arith.constant 6 : index
60  %output = memref.reinterpret_cast %input to
61           offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
62           : memref<2x3xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
63
64  %unranked_output = memref.cast %output
65      : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<*xf32>
66  call @printMemrefF32(%unranked_output) : (memref<*xf32>) -> ()
67  // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
68  // CHECK-NEXT: [0,   1,   2,   3,   4,   5]
69  return
70}
71
72func.func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) {
73  %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
74  %output = memref.reinterpret_cast %unranked_input to
75           offset: [0], sizes: [6, 1], strides: [1, 1]
76           : memref<*xf32> to memref<6x1xf32>
77
78  %unranked_output = memref.cast %output
79      : memref<6x1xf32> to memref<*xf32>
80  call @printMemrefF32(%unranked_output) : (memref<*xf32>) -> ()
81  // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
82  // CHECK-NEXT: [0],
83  // CHECK-NEXT: [1],
84  // CHECK-NEXT: [2],
85  // CHECK-NEXT: [3],
86  // CHECK-NEXT: [4],
87  // CHECK-NEXT: [5]
88  return
89}
90
91func.func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
92  %unranked_input = memref.cast %input : memref<2x3xf32> to memref<*xf32>
93  %c0 = arith.constant 0 : index
94  %c1 = arith.constant 1 : index
95  %c6 = arith.constant 6 : index
96  %output = memref.reinterpret_cast %unranked_input to
97           offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
98           : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
99
100  %unranked_output = memref.cast %output
101      : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<*xf32>
102  call @printMemrefF32(%unranked_output) : (memref<*xf32>) -> ()
103  // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
104  // CHECK-NEXT: [0,   1,   2,   3,   4,   5]
105  return
106}
107