1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s 2 3// CHECK-LABEL: func @outerproduct_noacc 4// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, 5// CHECK-SAME: %[[B:.*1]]: vector<3xf32> 6// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 7// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> 8// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> 9// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> 10// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> 11// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> 12// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> 13// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> 14// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> 15// CHECK: return %[[T7]] : vector<2x3xf32> 16 17func.func @outerproduct_noacc(%arg0: vector<2xf32>, 18 %arg1: vector<3xf32>) -> vector<2x3xf32> { 19 %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> 20 return %0: vector<2x3xf32> 21} 22 23// CHECK-LABEL: func @outerproduct_acc 24// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, 25// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, 26// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> 27// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 28// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> 29// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> 30// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32> 31// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> 32// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> 33// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> 34// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> 35// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32> 36// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> 37// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> 38// CHECK: return %[[T9]] : vector<2x3xf32> 39 40func.func @outerproduct_acc(%arg0: vector<2xf32>, 41 %arg1: vector<3xf32>, 42 %arg2: vector<2x3xf32>) -> vector<2x3xf32> { 43 %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> 44 return %0: vector<2x3xf32> 45} 46 47// CHECK-LABEL: func @outerproduct_noacc_int 48// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, 49// CHECK-SAME: %[[B:.*1]]: vector<3xi32> 50// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> 51// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> 52// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> 53// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> 54// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> 55// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> 56// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> 57// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> 58// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> 59// CHECK: return %[[T7]] : vector<2x3xi32> 60func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, 61 %arg1: vector<3xi32>) -> vector<2x3xi32> { 62 %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32> 63 return %0: vector<2x3xi32> 64} 65 66// CHECK-LABEL: func @outerproduct_acc_int 67// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, 68// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, 69// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> 70// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> 71// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> 72// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> 73// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32> 74// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> 75// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> 76// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> 77// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> 78// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> 79// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32> 80// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> 81// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> 82// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> 83// CHECK: return %[[T11]] : vector<2x3xi32> 84func.func @outerproduct_acc_int(%arg0: vector<2xi32>, 85 %arg1: vector<3xi32>, 86 %arg2: vector<2x3xi32>) -> vector<2x3xi32> { 87 %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32> 88 return %0: vector<2x3xi32> 89} 90 91// CHECK-LABEL: func @axpy_fp( 92// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, 93// CHECK-SAME: %[[B:.*1]]: f32) 94// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> 95// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> 96// CHECK: return %[[T1]] : vector<16xf32> 97func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { 98 %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32 99 return %0: vector<16xf32> 100} 101 102// CHECK-LABEL: func @axpy_fp_add( 103// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, 104// CHECK-SAME: %[[B:.*1]]: f32, 105// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) 106// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> 107// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> 108// CHECK: return %[[T1]] : vector<16xf32> 109func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { 110 %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32 111 return %0: vector<16xf32> 112} 113 114// CHECK-LABEL: func @axpy_int( 115// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, 116// CHECK-SAME: %[[B:.*1]]: i32) 117// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> 118// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> 119// CHECK: return %[[T1]] : vector<16xi32> 120func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { 121 %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32 122 return %0: vector<16xi32> 123} 124 125// CHECK-LABEL: func @axpy_int_add( 126// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, 127// CHECK-SAME: %[[B:.*1]]: i32, 128// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) 129// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> 130// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> 131// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> 132// CHECK: return %[[T2]] : vector<16xi32> 133func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> { 134 %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32 135 return %0: vector<16xi32> 136} 137 138module attributes {transform.with_named_sequence} { 139 transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { 140 %f = transform.structured.match ops{["func.func"]} in %module_op 141 : (!transform.any_op) -> !transform.any_op 142 143 transform.apply_patterns to %f { 144 transform.apply_patterns.vector.lower_outerproduct 145 } : !transform.any_op 146 147 transform.apply_patterns to %f { 148 transform.apply_patterns.vector.lower_broadcast 149 } : !transform.any_op 150 transform.yield 151 } 152} 153