1// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s 2 3// Verifies that different argument types is legal. 4func.func @generalize_matmul_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 5 %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>) 6 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 7 return %0: tensor<16x32xf32> 8} 9 10// CHECK-LABEL: @generalize_matmul_tensor_f16f64f32 11// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) 12// Verify floating point extension and truncation. 13// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32 14// CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32 15// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 16// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 17// CHECK-NEXT: linalg.yield %[[ADD]] : f32 18// CHECK-NEXT: -> tensor<16x32xf32> 19 20// ----- 21 22// Verifies that different argument types is legal. 23func.func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { 24 %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) 25 outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> 26 return %0: tensor<16x32xi32> 27} 28 29// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32 30// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i16, %[[B_ARG:.+]]: i64, %[[C_ARG:.+]]: i32) 31// Verify signed integer extension and truncation. 32// CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i16 to i32 33// CHECK-NEXT: %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i64 to i32 34// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32 35// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32 36// CHECK-NEXT: linalg.yield %[[ADD]] : i32 37// CHECK-NEXT: -> tensor<16x32xi32> 38 39 40// ----- 41 42// Verifies that cast attributes control the cast operations used. 43func.func @generalize_matmul_tensor_i16i64i32_unsigned(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { 44 %0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>} 45 ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) 46 outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> 47 return %0: tensor<16x32xi32> 48} 49 50// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32_unsigned 51// CHECK: = arith.extui 52 53// ----- 54 55func.func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 56 %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) 57 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 58 return %0: tensor<16x32xf32> 59} 60 61// CHECK-LABEL: @generalize_matmul_tensor_i16i64f32 62// Verify signed integer to floating point cast. 63// CHECK: = arith.sitofp 64// CHECK: = arith.sitofp 65 66// ----- 67 68func.func @generalize_matmul_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { 69 %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>) 70 outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> 71 return %0: tensor<16x32xi32> 72} 73 74// CHECK-LABEL: @generalize_matmul_tensor_f16f64i32 75// Verify floating point to signed integer cast. 76// CHECK: = arith.fptosi 77// CHECK: = arith.fptosi 78 79// ----- 80 81func.func @generalize_matmul_unsigned_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { 82 %0 = linalg.matmul { cast = #linalg.type_fn<cast_unsigned> } 83 ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) 84 outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> 85 return %0: tensor<16x32xi32> 86} 87 88// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64i32 89// Verify unsigned integer extension and truncation. 90// CHECK: = arith.extui 91// CHECK: = arith.trunci 92 93// ----- 94 95func.func @generalize_matmul_unsigned_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 96 %0 = linalg.matmul { cast = #linalg.type_fn<cast_unsigned> } 97 ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) 98 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 99 return %0: tensor<16x32xf32> 100} 101 102// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64f32 103// Verify unsigned integer to floating point cast. 104// CHECK: = arith.uitofp 105// CHECK: = arith.uitofp 106 107// ----- 108 109func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { 110 %0 = linalg.matmul { cast = #linalg.type_fn<cast_unsigned> } 111 ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>) 112 outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> 113 return %0: tensor<16x32xi32> 114} 115 116// CHECK-LABEL: @generalize_matmul_unsigned_tensor_f16f64i32 117// Verify floating point to unsigend integer cast. 118// CHECK: = arith.fptoui 119// CHECK: = arith.fptoui 120 121// ----- 122 123func.func @generalize_matmul_as_contraction_tensor_f16f64f32( 124 %A: tensor<16x8xf16>, 125 %B: tensor<8x32xf64>, 126 %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 127 %0 = linalg.contract 128 indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 129 affine_map<(d0, d1, d2) -> (d2, d1)>, 130 affine_map<(d0, d1, d2) -> (d0, d1)>] 131 ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>) 132 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 133 return %0: tensor<16x32xf32> 134} 135 136// CHECK-LABEL: @generalize_matmul_as_contraction_tensor_f16f64f32 137// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) 138// Verify floating point extension and truncation. 139// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32 140// CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32 141// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 142// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 143// CHECK-NEXT: linalg.yield %[[ADD]] : f32 144// CHECK-NEXT: -> tensor<16x32xf32> 145 146// ----- 147 148func.func @generalize_matmul_as_contract_with_ext_and_trunc( 149 %A: tensor<24x12xf16>, 150 %B: tensor<12x25xf16>, 151 %C: tensor<24x25xf32>) -> tensor<24x25xf16> { 152 %0 = linalg.contract 153 indexing_maps = [affine_map<(m, n, k) -> (m, k)>, 154 affine_map<(m, n, k) -> (k, n)>, 155 affine_map<(m, n, k) -> (m, n)>] 156 ins(%A, %B : tensor<24x12xf16>, tensor<12x25xf16>) 157 outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32> 158 %1 = arith.truncf %0 : tensor<24x25xf32> to tensor<24x25xf16> 159 func.return %1 : tensor<24x25xf16> 160} 161 162// CHECK-LABEL: @generalize_matmul_as_contract_with_ext_and_trunc 163// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) 164// Verify floating point extension and truncation. 165// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32 166// CHECK-NEXT: %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32 167// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 168// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 169// CHECK-NEXT: linalg.yield %[[ADD]] : f32 170// CHECK-NEXT: -> tensor<24x25xf32> 171// CHECK-NEXT: %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16> 172 173// ----- 174 175func.func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { 176 %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 177 ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> 178 return %0: tensor<1x2x4x1xf32> 179} 180 181// CHECK-LABEL: @generalize_pooling_nhwc_max_f32 182// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 183// CHECK-NEXT: %[[MAX:.+]] = arith.maximumf %[[OUT_ARG]], %[[IN_ARG]] : f32 184// CHECK-NEXT: linalg.yield %[[MAX]] : f32 185// CHECK-NEXT: -> tensor<1x2x4x1xf32> 186 187// ----- 188 189func.func @generalize_pooling_nwc_max_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { 190 %0 = linalg.pooling_nwc_max {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} 191 ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32> 192 return %0: tensor<1x4x1xf32> 193} 194 195// CHECK-LABEL: @generalize_pooling_nwc_max_f32 196// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 197// CHECK-NEXT: %[[MAX:.+]] = arith.maximumf %[[OUT_ARG]], %[[IN_ARG]] : f32 198// CHECK-NEXT: linalg.yield %[[MAX]] : f32 199// CHECK-NEXT: -> tensor<1x4x1xf32> 200 201// ----- 202 203func.func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 204 %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 205 ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 206 return %0: tensor<1x2x4x1xi32> 207} 208 209// CHECK-LABEL: @generalize_pooling_nhwc_max_i32 210// Verify signed integer maximum. 211// CHECK: = arith.maxsi 212 213// ----- 214 215func.func @generalize_pooling_nwc_max_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { 216 %0 = linalg.pooling_nwc_max {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} 217 ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> 218 return %0: tensor<1x4x1xi32> 219} 220 221// CHECK-LABEL: @generalize_pooling_nwc_max_i32 222// Verify signed integer maximum. 223// CHECK: = arith.maxsi 224 225// ----- 226 227func.func @generalize_pooling_nhwc_max_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 228 %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 229 ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 230 return %0: tensor<1x2x4x1xi32> 231} 232 233// CHECK-LABEL: @generalize_pooling_nhwc_max_unsigned_i32 234// Verify unsigned integer minimum. 235// CHECK: = arith.maxui 236 237// ----- 238 239func.func @generalize_pooling_nwc_max_unsigned_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { 240 %0 = linalg.pooling_nwc_max_unsigned {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} 241 ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> 242 return %0: tensor<1x4x1xi32> 243} 244 245// CHECK-LABEL: @generalize_pooling_nwc_max_unsigned_i32 246// Verify unsigned integer minimum. 247// CHECK: = arith.maxui 248 249// ----- 250 251func.func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { 252 %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 253 ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> 254 return %0: tensor<1x2x4x1xf32> 255} 256 257// CHECK-LABEL: @generalize_pooling_nhwc_min_f32 258// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 259// CHECK-NEXT: %[[MIN:.+]] = arith.minimumf %[[OUT_ARG]], %[[IN_ARG]] : f32 260// CHECK-NEXT: linalg.yield %[[MIN]] : f32 261// CHECK-NEXT: -> tensor<1x2x4x1xf32> 262 263// ----- 264 265func.func @generalize_pooling_nwc_min_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { 266 %0 = linalg.pooling_nwc_min {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} 267 ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32> 268 return %0: tensor<1x4x1xf32> 269} 270 271// CHECK-LABEL: @generalize_pooling_nwc_min_f32 272// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 273// CHECK-NEXT: %[[MIN:.+]] = arith.minimumf %[[OUT_ARG]], %[[IN_ARG]] : f32 274// CHECK-NEXT: linalg.yield %[[MIN]] : f32 275// CHECK-NEXT: -> tensor<1x4x1xf32> 276 277// ----- 278 279func.func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 280 %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 281 ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 282 return %0: tensor<1x2x4x1xi32> 283} 284 285// CHECK-LABEL: @generalize_pooling_nhwc_min_i32 286// Verify signed integer minimum. 287// CHECK: = arith.minsi 288 289// ----- 290 291func.func @generalize_pooling_nwc_min_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { 292 %0 = linalg.pooling_nwc_min {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} 293 ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> 294 return %0: tensor<1x4x1xi32> 295} 296 297// CHECK-LABEL: @generalize_pooling_nwc_min_i32 298// Verify signed integer minimum. 299// CHECK: = arith.minsi 300 301// ----- 302 303func.func @generalize_pooling_nhwc_min_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 304 %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 305 ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 306 return %0: tensor<1x2x4x1xi32> 307} 308 309// CHECK-LABEL: @generalize_pooling_nhwc_min_unsigned_i32 310// Verify unsigned integer minimum. 311// CHECK: = arith.minui 312 313// ----- 314 315func.func @generalize_pooling_nwc_min_unsigned_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { 316 %0 = linalg.pooling_nwc_min_unsigned {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} 317 ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> 318 return %0: tensor<1x4x1xi32> 319} 320 321// CHECK-LABEL: @generalize_pooling_nwc_min_unsigned_i32 322// Verify unsigned integer minimum. 323// CHECK: = arith.minui 324 325// ----- 326 327func.func @generalize_pooling_nhwc_sum_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { 328 %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 329 ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> 330 return %0: tensor<1x2x4x1xf32> 331} 332 333// CHECK-LABEL: @generalize_pooling_nhwc_sum_f32 334// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 335// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[OUT_ARG]], %[[IN_ARG]] : f32 336// CHECK-NEXT: linalg.yield %[[ADD]] : f32 337// CHECK-NEXT: -> tensor<1x2x4x1xf32> 338 339// ----- 340 341func.func @generalize_pooling_nwc_sum_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { 342 %0 = linalg.pooling_nwc_sum {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} 343 ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32> 344 return %0: tensor<1x4x1xf32> 345} 346 347// CHECK-LABEL: @generalize_pooling_nwc_sum_f32 348// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 349// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[OUT_ARG]], %[[IN_ARG]] : f32 350// CHECK-NEXT: linalg.yield %[[ADD]] : f32 351// CHECK-NEXT: -> tensor<1x4x1xf32> 352 353// ----- 354 355func.func @generalize_pooling_nhwc_sum_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 356 %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 357 ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 358 return %0: tensor<1x2x4x1xi32> 359} 360 361// CHECK-LABEL: @generalize_pooling_nhwc_sum_i32 362// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) 363// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[OUT_ARG]], %[[IN_ARG]] : i32 364// CHECK-NEXT: linalg.yield %[[ADD]] : i32 365// CHECK-NEXT: -> tensor<1x2x4x1xi32> 366 367// ----- 368 369func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { 370 %0 = linalg.pooling_nwc_sum {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} 371 ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> 372 return %0: tensor<1x4x1xi32> 373} 374 375// CHECK-LABEL: @generalize_pooling_nwc_sum_i32 376// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) 377// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[OUT_ARG]], %[[IN_ARG]] : i32 378// CHECK-NEXT: linalg.yield %[[ADD]] : i32 379// CHECK-NEXT: -> tensor<1x4x1xi32> 380 381// ----- 382 383func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> { 384 %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32> 385 return %0: tensor<f32> 386} 387 388// CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> 389 390// CHECK-LABEL: @generalize_fill_0d 391// CHECK: linalg.generic 392// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] 393// CHECK-SAME: iterator_types = [] 394 395// ----- 396 397func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) { 398 linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>) 399 return 400} 401 402// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()> 403// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> 404 405// CHECK-LABEL: @generalize_fill 406// CHECK: linalg.generic 407// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] 408// CHECK-SAME: iterator_types = ["parallel", "parallel"] 409 410// ----- 411 412func.func @generalize_index(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> { 413 %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32> 414 return %0: tensor<16x32xf32> 415} 416 417// CHECK-LABEL: @generalize_index 418// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index 419// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index 420// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 421// CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 422 423// ----- 424 425func.func @generalize_const(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> { 426 %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32> 427 return %0: tensor<16x32xf32> 428} 429 430// CHECK-LABEL: @generalize_const 431// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 432// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 433// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.3283063999999999E-10 : f64 434 435// ----- 436 437// Verifies the default value of the fun attribute is an exp op. 438func.func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 439 %0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 440 return %0: tensor<4x8xf32> 441} 442 443// CHECK-LABEL: @generalize_elemwise_exp 444// CHECK: = math.exp 445 446// ----- 447 448// Verifies the fun attribute controls the unary function used. 449func.func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 450 %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<log>} 451 ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 452 return %0: tensor<4x8xf32> 453} 454 455// CHECK-LABEL: @generalize_elemwise_log 456// CHECK: = math.log 457 458// ----- 459 460// Verifies the fun attribute controls the unary function used. 461func.func @generalize_elemwise_abs(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 462 %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} 463 ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 464 return %0: tensor<4x8xf32> 465} 466 467// CHECK-LABEL: @generalize_elemwise_abs 468// CHECK: = math.absf 469 470// ----- 471 472// Verifies the fun attribute controls the unary function used. 473func.func @generalize_elemwise_ceil(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 474 %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<ceil>} 475 ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 476 return %0: tensor<4x8xf32> 477} 478 479// CHECK-LABEL: @generalize_elemwise_ceil 480// CHECK: = math.ceil 481 482// ----- 483 484// Verifies the fun attribute controls the unary function used. 485func.func @generalize_elemwise_floor(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 486 %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<floor>} 487 ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 488 return %0: tensor<4x8xf32> 489} 490 491// CHECK-LABEL: @generalize_elemwise_floor 492// CHECK: = math.floor 493 494// ----- 495 496// Verifies the fun attribute controls the unary function used. 497func.func @generalize_elemwise_negf(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 498 %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<negf>} 499 ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 500 return %0: tensor<4x8xf32> 501} 502 503// CHECK-LABEL: @generalize_elemwise_negf 504// CHECK: = arith.negf 505 506// ----- 507 508// Verifies the default value of the fun attribute is an add op. 509func.func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 510 %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>) 511 outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 512 return %0: tensor<4x8xf32> 513} 514 515// CHECK-LABEL: @generalize_elemwise_add 516// CHECK: = arith.addf 517 518// ----- 519 520// Verifies the fun attribute controls the binary function used. 521func.func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 522 %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>} 523 ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>) 524 outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 525 return %0: tensor<4x8xf32> 526} 527 528// CHECK-LABEL: @generalize_elemwise_mul 529// CHECK: = arith.mulf 530 531// ----- 532 533// Verifies pointwise ops support rank zero input tensors 534func.func @generalize_elemwise_rank_zero(%lhs : tensor<f32>, %rhs : tensor<f32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 535 %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<sub>} 536 ins(%lhs, %rhs: tensor<f32>, tensor<f32>) 537 outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 538 return %0: tensor<4x8xf32> 539} 540 541// CHECK-LABEL: @generalize_elemwise_rank_zero 542// CHECK: linalg.generic 543// CHECK-SAME: iterator_types = ["parallel", "parallel"] 544// CHECK: = arith.subf 545 546// ----- 547 548// Verifies the fun attribute controls the binary function used. 549func.func @generalize_copy(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { 550 %0 = linalg.copy ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> 551 return %0: tensor<4x8xf32> 552} 553 554// CHECK-LABEL: @generalize_copy 555// CHECK: linalg.generic 556// CHECK-NEXT: ^bb0(%[[I:[0-9a-zA-Z]*]]: f32 557// CHECK-NEXT: linalg.yield %[[I]] 558