1// RUN: not mlir-opt -split-input-file -verify-diagnostics %s 2>&1 | FileCheck %s 2 3func.func @add_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { 4 // CHECK: op requires the same type for all operands and results 5 linalg.add ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) 6 return 7} 8 9// ----- 10 11func.func @add_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { 12 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 13 linalg.add ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) 14 return 15} 16 17// ----- 18 19func.func @sub_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { 20 // CHECK: op requires the same type for all operands and results 21 linalg.sub ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) 22 return 23} 24 25// ----- 26 27func.func @sub_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { 28 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 29 linalg.sub ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) 30 return 31} 32 33// ----- 34 35func.func @mul_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { 36 // CHECK: op requires the same type for all operands and results 37 linalg.mul ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) 38 return 39} 40 41// ----- 42 43func.func @mul_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { 44 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 45 linalg.mul ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) 46 return 47} 48 49// ----- 50 51func.func @div_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { 52 // CHECK: op requires the same type for all operands and results 53 linalg.div ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) 54 return 55} 56 57// ----- 58 59func.func @div_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { 60 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 61 linalg.div ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) 62 return 63} 64 65// ----- 66 67func.func @divu_type_cast(%arg0: memref<4x8x16xi32>, %arg1: memref<4x8x16xi16>, %arg2: memref<4x8x16xi32>) { 68 // CHECK: op requires the same type for all operands and results 69 linalg.div_unsigned ins(%arg0, %arg1 : memref<4x8x16xi32>, memref<4x8x16xi16>) outs(%arg2: memref<4x8x16xi32>) 70 return 71} 72 73// ----- 74 75func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %arg2: memref<4x8x16xi32>) { 76 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 77 linalg.div_unsigned ins(%arg0, %arg1 : memref<8x16xi32>, memref<4x8x16xi32>) outs(%arg2: memref<4x8x16xi32>) 78 return 79} 80 81// ----- 82 83func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 84 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 85 linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 86 return 87} 88 89// ----- 90 91func.func @exp_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 92 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 93 linalg.exp ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 94 return 95} 96 97// ----- 98 99func.func @log_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 100 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 101 linalg.log ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 102 return 103} 104 105// ----- 106 107func.func @log_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 108 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 109 linalg.log ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 110 return 111} 112 113// ----- 114 115func.func @abs_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 116 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 117 linalg.abs ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 118 return 119} 120 121// ----- 122 123func.func @abs_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 124 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 125 linalg.abs ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 126 return 127} 128 129// ----- 130 131func.func @ceil_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 132 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 133 linalg.ceil ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 134 return 135} 136 137// ----- 138 139func.func @ceil_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 140 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 141 linalg.ceil ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 142 return 143} 144 145// ----- 146 147func.func @floor_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 148 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 149 linalg.floor ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 150 return 151} 152 153// ----- 154 155func.func @floor_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 156 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 157 linalg.floor ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 158 return 159} 160 161// ----- 162 163func.func @negf_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 164 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 165 linalg.negf ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 166 return 167} 168 169// ----- 170 171func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 172 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 173 linalg.negf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 174 return 175} 176 177// ----- 178 179func.func @reciprocal_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 180 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 181 linalg.reciprocal ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 182 return 183} 184 185// ----- 186 187func.func @reciprocal_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 188 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 189 linalg.reciprocal ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 190 return 191} 192 193// ----- 194 195func.func @round_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 196 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 197 linalg.round ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 198 return 199} 200 201// ----- 202 203func.func @round_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 204 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 205 linalg.round ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 206 return 207} 208 209// ----- 210 211func.func @sqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 212 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 213 linalg.sqrt ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 214 return 215} 216 217// ----- 218 219func.func @sqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 220 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 221 linalg.sqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 222 return 223} 224 225// ----- 226 227func.func @rsqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 228 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 229 linalg.rsqrt ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 230 return 231} 232 233// ----- 234 235func.func @rsqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 236 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 237 linalg.rsqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 238 return 239} 240 241// ----- 242 243func.func @square_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 244 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 245 linalg.square ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 246 return 247} 248 249// ----- 250 251func.func @square_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 252 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 253 linalg.square ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 254 return 255} 256 257// ----- 258 259func.func @tanh_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 260 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 261 linalg.tanh ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 262 return 263} 264 265// ----- 266 267func.func @tanh_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 268 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 269 linalg.tanh ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 270 return 271} 272 273// ----- 274 275func.func @erf_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { 276 // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') 277 linalg.erf ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) 278 return 279} 280 281// ----- 282 283func.func @erf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { 284 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 285 linalg.erf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) 286 return 287} 288 289// ----- 290 291func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { 292 // CHECK: op requires the same type for all operands and results 293 linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) 294 return 295} 296 297// ----- 298 299func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { 300 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 301 linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) 302 return 303} 304 305// ----- 306 307func.func @min_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { 308 // CHECK: op requires the same type for all operands and results 309 linalg.min ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) 310 return 311} 312 313// ----- 314 315func.func @min_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { 316 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 317 linalg.min ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) 318 return 319} 320 321// ----- 322 323func.func @powf_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { 324 // CHECK: op requires the same type for all operands and results 325 linalg.powf ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) 326 return 327} 328 329// ----- 330 331func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { 332 // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) 333 linalg.powf ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) 334 return 335} 336 337// ----- 338 339func.func @select_type_cast(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) { 340 // CHECK: op failed to verify that all of {true_value, false_value, result} have same type 341 linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf16>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>) 342 return 343} 344 345// ----- 346 347func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) { 348 // CHECK: op operand #0 must be bool-like, but got 'f32' 349 linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>) 350 return 351} 352 353