1// RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s 2 3// CHECK-LABEL: func @atomic_rmw_to_generic 4// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) 5func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { 6 %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 7 %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 8 %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 9 %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 10 return %a : f32 11} 12// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { 13// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): 14// CHECK: [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32 15// CHECK: memref.atomic_yield [[MAXIMUM]] : f32 16// CHECK: } 17// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { 18// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): 19// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32 20// CHECK: memref.atomic_yield [[MINIMUM]] : f32 21// CHECK: } 22// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { 23// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): 24// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32 25// CHECK: memref.atomic_yield [[MAXNUM]] : f32 26// CHECK: } 27// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { 28// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): 29// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32 30// CHECK: memref.atomic_yield [[MINNUM]] : f32 31// CHECK: } 32// CHECK: return [[RESULT]] : f32 33 34// ----- 35 36// CHECK-LABEL: func @atomic_rmw_no_conversion 37func.func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { 38 %x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32 39 return %x : f32 40} 41// CHECK-NOT: generic_atomic_rmw 42 43// ----- 44 45// CHECK-LABEL: func @memref_reshape( 46func.func @memref_reshape(%input: memref<*xf32>, 47 %shape: memref<3xi32>) -> memref<?x?x8xf32> { 48 %result = memref.reshape %input(%shape) 49 : (memref<*xf32>, memref<3xi32>) -> memref<?x?x8xf32> 50 return %result : memref<?x?x8xf32> 51} 52// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>, 53// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> { 54 55// CHECK: [[C8:%.*]] = arith.constant 8 : index 56// CHECK: [[C1:%.*]] = arith.constant 1 : index 57// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32> 58// CHECK: [[SIZE_1:%.*]] = arith.index_cast [[DIM_1]] : i32 to index 59 60// CHECK: [[C8_:%.*]] = arith.constant 8 : index 61// CHECK: [[STRIDE_0:%.*]] = arith.muli [[C8_]], [[SIZE_1]] : index 62 63// CHECK: [[C0:%.*]] = arith.constant 0 : index 64// CHECK: [[DIM_0:%.*]] = memref.load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32> 65// CHECK: [[SIZE_0:%.*]] = arith.index_cast [[DIM_0]] : i32 to index 66 67// CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]] 68// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8], 69// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], 8, 1] 70// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32> 71