xref: /llvm-project/mlir/test/Dialect/MemRef/expand-ops.mlir (revision 889b67c9d30e3024a1317431d66c22599f6c2011)
1// RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s
2
3// CHECK-LABEL: func @atomic_rmw_to_generic
4// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
5func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
6  %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
7  %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
8  %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
9  %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
10  return %a : f32
11}
12// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
13// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
14// CHECK:   [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
15// CHECK:   memref.atomic_yield [[MAXIMUM]] : f32
16// CHECK: }
17// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
18// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
19// CHECK:   [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
20// CHECK:   memref.atomic_yield [[MINIMUM]] : f32
21// CHECK: }
22// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
23// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
24// CHECK:   [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
25// CHECK:   memref.atomic_yield [[MAXNUM]] : f32
26// CHECK: }
27// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
28// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
29// CHECK:   [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
30// CHECK:   memref.atomic_yield [[MINNUM]] : f32
31// CHECK: }
32// CHECK: return [[RESULT]] : f32
33
34// -----
35
36// CHECK-LABEL: func @atomic_rmw_no_conversion
37func.func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
38  %x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
39  return %x : f32
40}
41// CHECK-NOT: generic_atomic_rmw
42
43// -----
44
45// CHECK-LABEL: func @memref_reshape(
46func.func @memref_reshape(%input: memref<*xf32>,
47                     %shape: memref<3xi32>) -> memref<?x?x8xf32> {
48  %result = memref.reshape %input(%shape)
49               : (memref<*xf32>, memref<3xi32>) -> memref<?x?x8xf32>
50  return %result : memref<?x?x8xf32>
51}
52// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
53// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {
54
55// CHECK: [[C8:%.*]] = arith.constant 8 : index
56// CHECK: [[C1:%.*]] = arith.constant 1 : index
57// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32>
58// CHECK: [[SIZE_1:%.*]] = arith.index_cast [[DIM_1]] : i32 to index
59
60// CHECK: [[C8_:%.*]] = arith.constant 8 : index
61// CHECK: [[STRIDE_0:%.*]] = arith.muli [[C8_]], [[SIZE_1]] : index
62
63// CHECK: [[C0:%.*]] = arith.constant 0 : index
64// CHECK: [[DIM_0:%.*]] = memref.load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32>
65// CHECK: [[SIZE_0:%.*]] = arith.index_cast [[DIM_0]] : i32 to index
66
67// CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
68// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
69// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], 8, 1]
70// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
71