1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s 2 3// CHECK-LABEL: func @nop_shape_cast 4// CHECK-SAME: %[[A:.*]]: vector<16xf32> 5// CHECK: return %[[A]] : vector<16xf32> 6func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { 7 %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> 8 return %0 : vector<16xf32> 9} 10 11// CHECK-LABEL: func @cancel_shape_cast 12// CHECK-SAME: %[[A:.*]]: vector<16xf32> 13// CHECK: return %[[A]] : vector<16xf32> 14 15func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { 16 %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> 17 %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> 18 return %1 : vector<16xf32> 19} 20 21// Shape up and downcasts for 2-D vectors, for supporting conversion to 22// llvm.matrix operations 23// CHECK-LABEL: func @shape_casts 24func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { 25 // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> 26 // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> 27 // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32> 28 // 29 // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] 30 // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 31 // 32 // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32> 33 // 34 // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] 35 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> 36 // 37 %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> 38 // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32> 39 %r0 = arith.addf %0, %0: vector<4xf32> 40 // 41 // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]] 42 // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : 43 // CHECK-SAME: vector<4xf32> to vector<2xf32> 44 // 45 // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : 46 // CHECK-SAME: vector<2xf32> into vector<2x2xf32> 47 // 48 // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]] 49 // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : 50 // CHECK-SAME: vector<4xf32> to vector<2xf32> 51 // 52 // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : 53 // CHECK-SAME: vector<2xf32> into vector<2x2xf32> 54 // 55 %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> 56 // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> 57 return %r0, %1 : vector<4xf32>, vector<2x2xf32> 58} 59 60// CHECK-LABEL: func @shape_cast_2d2d 61// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> 62// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 63// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32> 64// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32> 65// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32> 66// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> 67// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32> 68// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> 69// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32> 70// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> 71// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32> 72// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> 73// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32> 74// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> 75// CHECK: return %[[T11]] : vector<2x3xf32> 76 77func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { 78 %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> 79 return %s : vector<2x3xf32> 80} 81 82// CHECK-LABEL: func @shape_cast_3d1d 83// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> 84// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32> 85// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32> 86// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]] 87// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32> 88// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32> 89// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]] 90// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32> 91// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32> 92// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]] 93// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32> 94// CHECK: return %[[T5]] : vector<6xf32> 95 96func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { 97 %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> 98 return %s : vector<6xf32> 99} 100 101// CHECK-LABEL: func @shape_cast_1d3d 102// CHECK-SAME: %[[A:.*]]: vector<6xf32> 103// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32> 104// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]] 105// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> 106// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32> 107// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]] 108// CHECK: {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> 109// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32> 110// CHECK: return %[[T3]] : vector<2x1x3xf32> 111 112func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { 113 %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> 114 return %s : vector<2x1x3xf32> 115} 116 117// CHECK-LABEL: func.func @shape_cast_0d1d( 118// CHECK-SAME: %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> { 119// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> 120// CHECK: %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32> 121// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32> 122// CHECK: return %[[VAL_3]] : vector<1xf32> 123// CHECK: } 124 125func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> { 126 %s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32> 127 return %s : vector<1xf32> 128} 129 130// CHECK-LABEL: func.func @shape_cast_1d0d( 131// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> { 132// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32> 133// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<1xf32> 134// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32> 135// CHECK: return %[[VAL_3]] : vector<f32> 136// CHECK: } 137 138func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> { 139 %s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32> 140 return %s : vector<f32> 141} 142 143module attributes {transform.with_named_sequence} { 144 transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { 145 %f = transform.structured.match ops{["func.func"]} in %module_op 146 : (!transform.any_op) -> !transform.any_op 147 148 transform.apply_patterns to %f { 149 transform.apply_patterns.vector.lower_shape_cast 150 } : !transform.any_op 151 transform.yield 152 } 153} 154