1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s 2 3// CHECK-LABEL: func @parallel_contract_lowering 4// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32> 5// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32> 6// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> 7// CHECK: return %[[F]] : vector<4xf32> 8func.func @parallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { 9 %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> 10 return %0 : vector<4xf32> 11} 12 13// CHECK-LABEL: func @parallel_contract_lowering_broadcast 14// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> 15// CHECK: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> 16// CHECK: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<4xf32> from vector<1x1x4xf32> 17// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32> 18// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> 19// CHECK: return %[[F]] : vector<4xf32> 20func.func @parallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { 21 %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> 22 return %0 : vector<4xf32> 23} 24 25// CHECK-LABEL: func @parallel_contract_lowering 26// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> 27// CHECK: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> 28// CHECK: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32> 29// CHECK: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<4xf32> from vector<1x1x4xf32> 30// CHECK: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<4xf32> from vector<1x1x4xf32> 31// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32> 32// CHECK: return %[[F]] : vector<4xf32> 33func.func @parallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { 34 %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> 35 return %0 : vector<4xf32> 36} 37 38// CHECK-LABEL: func @parallel_contract_lowering_scalar 39// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f32 from vector<1x1xf32> 40// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f32 from vector<1x1xf32> 41// CHECK: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32 42// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32 43// CHECK: return %[[A]] : f32 44func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 { 45 %0 = vector.contract { 46 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 47 affine_map<(d0, d1) -> (d0, d1)>, 48 affine_map<(d0, d1) -> ()>], 49 iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>} 50 %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 51 return %0 : f32 52} 53 54module attributes {transform.with_named_sequence} { 55 transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { 56 %f = transform.structured.match ops{["func.func"]} in %module_op 57 : (!transform.any_op) -> !transform.any_op 58 59 transform.apply_patterns to %f { 60 transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith" 61 } : !transform.any_op 62 transform.yield 63 } 64} 65