1// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-linalg-to-loops,convert-scf-to-cf,convert-arith-to-llvm),finalize-memref-to-llvm,convert-func-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" | \ 2// RUN: mlir-runner -e main -entry-point-result=void \ 3// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | FileCheck %s 4 5// CHECK: rank = 2 6// CHECK-SAME: sizes = [10, 3] 7// CHECK-SAME: strides = [3, 1] 8// CHECK-COUNT-10: [10, 10, 10] 9// 10// CHECK: rank = 2 11// CHECK-SAME: sizes = [10, 3] 12// CHECK-SAME: strides = [3, 1] 13// CHECK-COUNT-10: [5, 5, 5] 14// 15// CHECK: rank = 2 16// CHECK-SAME: sizes = [10, 3] 17// CHECK-SAME: strides = [3, 1] 18// CHECK-COUNT-10: [2, 2, 2] 19// 20// CHECK: rank = 0 21// 122 is ASCII for 'z'. 22// CHECK: [z] 23// 24// CHECK: rank = 2 25// CHECK-SAME: sizes = [4, 3] 26// CHECK-SAME: strides = [3, 1] 27// CHECK-COUNT-4: [1, 1, 1] 28// 29// CHECK: rank = 2 30// CHECK-SAME: sizes = [4, 3] 31// CHECK-SAME: strides = [3, 1] 32// CHECK-COUNT-4: [1, 1, 1] 33// 34// CHECK: rank = 2 35// CHECK-SAME: sizes = [4, 3] 36// CHECK-SAME: strides = [3, 1] 37// CHECK-COUNT-4: [1, 1, 1] 38func.func @main() -> () { 39 %A = memref.alloc() : memref<10x3xf32, 0> 40 %f2 = arith.constant 2.00000e+00 : f32 41 %f5 = arith.constant 5.00000e+00 : f32 42 %f10 = arith.constant 10.00000e+00 : f32 43 44 %V = memref.cast %A : memref<10x3xf32, 0> to memref<?x?xf32> 45 linalg.fill ins(%f10 : f32) outs(%V : memref<?x?xf32, 0>) 46 %U = memref.cast %A : memref<10x3xf32, 0> to memref<*xf32> 47 call @printMemrefF32(%U) : (memref<*xf32>) -> () 48 49 %V2 = memref.cast %U : memref<*xf32> to memref<?x?xf32> 50 linalg.fill ins(%f5 : f32) outs(%V2 : memref<?x?xf32, 0>) 51 %U2 = memref.cast %V2 : memref<?x?xf32, 0> to memref<*xf32> 52 call @printMemrefF32(%U2) : (memref<*xf32>) -> () 53 54 %V3 = memref.cast %V2 : memref<?x?xf32> to memref<*xf32> 55 %V4 = memref.cast %V3 : memref<*xf32> to memref<?x?xf32> 56 linalg.fill ins(%f2 : f32) outs(%V4 : memref<?x?xf32, 0>) 57 %U3 = memref.cast %V2 : memref<?x?xf32> to memref<*xf32> 58 call @printMemrefF32(%U3) : (memref<*xf32>) -> () 59 60 // 122 is ASCII for 'z'. 61 %i8_z = arith.constant 122 : i8 62 %I8 = memref.alloc() : memref<i8> 63 memref.store %i8_z, %I8[]: memref<i8> 64 %U4 = memref.cast %I8 : memref<i8> to memref<*xi8> 65 call @printMemrefI8(%U4) : (memref<*xi8>) -> () 66 67 memref.dealloc %U4 : memref<*xi8> 68 memref.dealloc %A : memref<10x3xf32, 0> 69 70 call @return_var_memref_caller() : () -> () 71 call @return_two_var_memref_caller() : () -> () 72 call @dim_op_of_unranked() : () -> () 73 return 74} 75 76func.func private @printMemrefI8(memref<*xi8>) attributes { llvm.emit_c_interface } 77func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } 78 79func.func @return_two_var_memref_caller() { 80 %0 = memref.alloca() : memref<4x3xf32> 81 %c0f32 = arith.constant 1.0 : f32 82 linalg.fill ins(%c0f32 : f32) outs(%0 : memref<4x3xf32>) 83 %1:2 = call @return_two_var_memref(%0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) 84 call @printMemrefF32(%1#0) : (memref<*xf32>) -> () 85 call @printMemrefF32(%1#1) : (memref<*xf32>) -> () 86 return 87 } 88 89 func.func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) { 90 %0 = memref.cast %arg0 : memref<4x3xf32> to memref<*xf32> 91 return %0, %0 : memref<*xf32>, memref<*xf32> 92} 93 94func.func @return_var_memref_caller() { 95 %0 = memref.alloca() : memref<4x3xf32> 96 %c0f32 = arith.constant 1.0 : f32 97 linalg.fill ins(%c0f32 : f32) outs(%0 : memref<4x3xf32>) 98 %1 = call @return_var_memref(%0) : (memref<4x3xf32>) -> memref<*xf32> 99 call @printMemrefF32(%1) : (memref<*xf32>) -> () 100 return 101} 102 103func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> { 104 %0 = memref.cast %arg0: memref<4x3xf32> to memref<*xf32> 105 return %0 : memref<*xf32> 106} 107 108func.func private @printU64(index) -> () 109func.func private @printNewline() -> () 110 111func.func @dim_op_of_unranked() { 112 %ranked = memref.alloca() : memref<4x3xf32> 113 %unranked = memref.cast %ranked: memref<4x3xf32> to memref<*xf32> 114 115 %c0 = arith.constant 0 : index 116 %dim_0 = memref.dim %unranked, %c0 : memref<*xf32> 117 call @printU64(%dim_0) : (index) -> () 118 call @printNewline() : () -> () 119 // CHECK: 4 120 121 %c1 = arith.constant 1 : index 122 %dim_1 = memref.dim %unranked, %c1 : memref<*xf32> 123 call @printU64(%dim_1) : (index) -> () 124 call @printNewline() : () -> () 125 // CHECK: 3 126 127 return 128} 129