1// RUN: mlir-opt %s -split-input-file -verify-diagnostics 2 3// Verify that ops with broadcastable trait verifies operand and result type 4// combinations and emits an error for invalid combinations. 5 6func.func @broadcast_scalar_scalar_scalar(tensor<i32>, tensor<i32>) -> tensor<i32> { 7^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): 8 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> 9 return %0 : tensor<i32> 10} 11 12// ----- 13 14func.func @broadcast_tensor_scalar_tensor(tensor<4xi32>, tensor<i32>) -> tensor<4xi32> { 15^bb0(%arg0: tensor<4xi32>, %arg1: tensor<i32>): 16 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32> 17 return %0 : tensor<4xi32> 18} 19 20// ----- 21 22// Check only one dimension has size 1 23func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32> { 24^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>): 25 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32> 26 return %0 : tensor<4x3x2xi32> 27} 28 29// ----- 30 31// Check multiple dimensions have size 1 32func.func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> { 33^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>): 34 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> 35 return %0 : tensor<8x7x6x5xi32> 36} 37 38// ----- 39 40// Check leading unknown dimension 41func.func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> { 42^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>): 43 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> 44 return %0 : tensor<?x7x6x5xi32> 45} 46 47// ----- 48 49// Check unknown dimension in the middle 50func.func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32> { 51^bb0(%arg0: tensor<8x1x?x1xi32>, %arg1: tensor<7x1x5xi32>): 52 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32> 53 return %0 : tensor<8x7x?x5xi32> 54} 55 56// ----- 57 58// Check incompatible vector and tensor result type 59func.func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> { 60^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): 61 // expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}} 62 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> 63 return %0 : vector<4xf32> 64} 65 66// ----- 67 68// Check incompatible operand types with known dimension 69func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32> { 70^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x3xi32>): 71 // expected-error @+1 {{operands don't have broadcast-compatible shapes}} 72 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32> 73 return %0 : tensor<4x3x2xi32> 74} 75 76// ----- 77 78// Check incompatible result type with known dimension 79func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> { 80^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>): 81 // expected-error @+1 {{op result type '4x3x3' not broadcast compatible with broadcasted operands's shapes '4x3x2'}} 82 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> 83 return %0 : tensor<4x3x3xi32> 84} 85 86// ----- 87 88// Check incompatible result type with known dimension 89func.func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> { 90^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>): 91 // expected-error @+1 {{op result type '8x7x6x1' not broadcast compatible with broadcasted operands's shapes '8x7x6x5'}} 92 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> 93 return %0 : tensor<8x7x6x1xi32> 94} 95 96// ----- 97 98func.func @broadcast_tensor_tensor_tensor(tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32> { 99^bb0(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>): 100 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32> 101 return %0 : tensor<*xi32> 102} 103 104// ----- 105 106func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32> { 107^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<?xi32>): 108 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32> 109 return %0 : tensor<4x3x2xi32> 110} 111 112// ----- 113 114// It is alright to have an implicit dynamic-to-static cast in a dimension size 115// as long as the runtime result size is consistent with the result tensor's 116// static dimension. 117func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<2xi32> { 118 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<2xi32> 119 return %0 : tensor<2xi32> 120} 121 122// ----- 123 124func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x6x1xi32>, %arg1: tensor<*xi32>) -> tensor<?x6x?xi32> { 125 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x6x1xi32>, tensor<*xi32>) -> tensor<?x6x?xi32> 126 return %0 : tensor<?x6x?xi32> 127} 128 129// ----- 130 131// Unranked operands but ranked result 132func.func @broadcast_tensor_tensor_tensor(tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32> { 133^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>): 134 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32> 135 return %0 : tensor<2xi32> 136} 137 138// ----- 139 140// Unranked operand and compatible ranked result 141func.func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> { 142^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>): 143 %0 = "test.broadcastable"(%arg0, %arg0, %arg1) : (tensor<3x2xi32>, tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> 144 return %0 : tensor<4x3x2xi32> 145} 146 147// ----- 148 149func.func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> { 150^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>): 151 // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '3x2'}} 152 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> 153 return %0 : tensor<2xi32> 154} 155 156// ----- 157 158// Correct use of broadcast semantics for input dimensions 159func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> { 160 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> 161 return %0 : tensor<?x7x6x5xi32> 162} 163 164// ----- 165 166// Incorrect attempt to use broadcast semantics for result 167func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<5xi32> { 168 // expected-error @+1 {{op result type '5' not broadcast compatible with broadcasted operands's shapes '1'}} 169 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<5xi32> 170 return %0 : tensor<5xi32> 171} 172 173// ----- 174 175func.func @broadcastDifferentResultType(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> { 176^bb0(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>): 177 %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> 178 return %0 : tensor<4xi1> 179} 180