1// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s 2 3func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) { 4 linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>) 5 outs(%C: memref<16x32xf32>) 6 return 7} 8 9 10// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> 11// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> 12// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 13 14// CHECK: func @generalize_matmul_buffer 15// CHECK-SAME: %[[A:.+]]: memref<16x8xf32> 16// CHECK-SAME: %[[B:.+]]: memref<8x32xf32> 17// CHECK-SAME: %[[C:.+]]: memref<16x32xf32> 18 19// CHECK: linalg.generic 20// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]] 21// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] 22// CHECK-SAME: ins(%[[A]], %[[B]] 23// CHECK-SAME: outs(%[[C]] 24 25// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) 26// CHECK: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32 27// CHECK: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 28// CHECK: linalg.yield %[[ADD]] : f32 29 30// ----- 31 32func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { 33 linalg.matmul indexing_maps = [ 34 affine_map<(d0, d1, d2) -> (d2)>, 35 affine_map<(d0, d1, d2) -> (d2, d1)>, 36 affine_map<(d0, d1, d2) -> (d0, d1)> 37 ] 38 ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) 39 return 40} 41 42// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> 43// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> 44// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 45// CHECK-LABEL: func.func @matmul_bcast_a( 46// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, 47// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, 48// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { 49// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) { 50// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32): 51// CHECK: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32 52// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32 53// CHECK: linalg.yield %[[VAL_7]] : f32 54// CHECK: } 55// CHECK: return 56// CHECK: } 57 58// ----- 59 60func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 61 %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) 62 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 63 return %0: tensor<16x32xf32> 64} 65 66// CHECK: func @generalize_matmul_tensor 67 68// CHECK: linalg.generic 69// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xf32>, tensor<8x32xf32>) 70// CHECK-SAME: outs(%{{.+}} : tensor<16x32xf32>) 71 72// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) 73// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32 74// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 75// CHECK-NEXT: linalg.yield %[[ADD]] : f32 76// CHECK-NEXT: -> tensor<16x32xf32> 77 78// ----- 79 80func.func @generalize_matmul_tensor_complex(%A : tensor<16x8xcomplex<f32>>, 81 %B: tensor<8x32xcomplex<f32>>, 82 %C: tensor<16x32xcomplex<f32>>) 83 -> tensor<16x32xcomplex<f32>> { 84 %0 = linalg.matmul ins(%A, %B: tensor<16x8xcomplex<f32>>, tensor<8x32xcomplex<f32>>) 85 outs(%C: tensor<16x32xcomplex<f32>>) -> tensor<16x32xcomplex<f32>> 86 return %0: tensor<16x32xcomplex<f32>> 87} 88 89// CHECK: func @generalize_matmul_tensor_complex 90 91// CHECK: linalg.generic 92// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xcomplex<f32>>, tensor<8x32xcomplex<f32>>) 93// CHECK-SAME: outs(%{{.+}} : tensor<16x32xcomplex<f32>>) 94 95// CHECK: ^{{.*}}(%[[A_ARG:.+]]: complex<f32>, %[[B_ARG:.+]]: complex<f32>, %[[C_ARG:.+]]: complex<f32>) 96// CHECK-NEXT: %[[MUL:.+]] = complex.mul %[[A_ARG]], %[[B_ARG]] : complex<f32> 97// CHECK-NEXT: %[[ADD:.+]] = complex.add %[[C_ARG]], %[[MUL]] : complex<f32> 98// CHECK-NEXT: linalg.yield %[[ADD]] : complex<f32> 99// CHECK-NEXT: -> tensor<16x32xcomplex<f32>> 100 101// ----- 102 103func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) { 104 linalg.depthwise_conv_2d_nhwc_hwcm 105 { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 106 ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>) 107 outs(%output : memref<2x3x4x2x3xf32>) 108 return 109} 110 111// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> 112// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> 113// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> 114 115// CHECK: func @depthwise_conv_2d_nhwc_hwcm 116 117// CHECK: linalg.generic 118// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 119// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} 120// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>) 121// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>) 122 123// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 124// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 125// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 126// CHECK-NEXT: linalg.yield %[[ADD]] : f32 127 128// ----- 129 130func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x2x3x2x3xf32>) { 131 linalg.depthwise_conv_2d_nhwc_hwcm 132 { dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 133 ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>) 134 outs(%output : memref<2x2x3x2x3xf32>) 135 return 136} 137 138// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5 * 2, d2 + d6 * 2, d3)> 139// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> 140// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> 141 142// CHECK: func @depthwise_conv_2d_nhwc_hwcm 143 144// CHECK: linalg.generic 145// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 146// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} 147// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>) 148// CHECK-SAME: outs(%{{.+}} : memref<2x2x3x2x3xf32>) 149 150// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 151// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 152// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 153// CHECK-NEXT: linalg.yield %[[ADD]] : f32 154 155// ----- 156 157func.func @depthwise_conv_2d_nhwc_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { 158 linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} 159 ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) 160 outs(%output: memref<1x56x56x96xf32>) 161 return 162} 163 164// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)> 165// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> 166// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> 167 168// CHECK: func @depthwise_conv_2d_nhwc_hwc 169 170// CHECK: linalg.generic 171// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 172// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} 173// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>) 174// CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>) 175 176// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 177// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 178// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 179// CHECK-NEXT: linalg.yield %[[ADD]] : f32 180 181// ----- 182 183func.func @conv_1d_nwc_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) { 184 linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>, 185 strides = dense<1> : tensor<1xi64>} 186 ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>) 187 outs (%output: memref<?x?x?xf32>) 188 return 189} 190// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> 191// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> 192// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> 193 194// CHECK: func @conv_1d_nwc_wcf 195 196// CHECK: linalg.generic 197// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 198// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} 199// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>) 200// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>) 201 202// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 203// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 204// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 205// CHECK-NEXT: linalg.yield %[[ADD]] : f32 206 207// ----- 208 209func.func @conv_1d_ncw_fcw(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) { 210 linalg.conv_1d_ncw_fcw {dilations = dense<1> : tensor<1xi64>, 211 strides = dense<1> : tensor<1xi64>} 212 ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>) 213 outs (%output: memref<?x?x?xf32>) 214 return 215} 216// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> 217// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> 218// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> 219 220// CHECK: func @conv_1d_ncw_fcw 221 222// CHECK: linalg.generic 223// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 224// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} 225// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>) 226// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>) 227 228// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 229// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 230// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 231// CHECK-NEXT: linalg.yield %[[ADD]] : f32 232 233// ----- 234 235func.func @conv_2d_ngchw_gfchw_q(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xi32>) { 236 linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>, 237 strides = dense<1> : tensor<2xi64>} 238 ins (%input, %filter, %inputzp, %filterzp: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32) 239 outs (%output: memref<?x?x?x?x?xi32>) 240 return 241} 242// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> 243// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> 244// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> 245// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> 246 247// CHECK: func @conv_2d_ngchw_gfchw_q 248 249// CHECK: linalg.generic 250// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP2]], #[[MAP3]]] 251// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} 252// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32) 253// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?x?xi32>) 254 255// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i32, %[[BBARG3:.+]]: i32, %[[BBARG4:.+]]: i32) 256// CHECK-NEXT: %[[EXTSI0:.+]] = arith.extsi %[[BBARG0]] : i8 to i32 257// CHECK-NEXT: %[[SUB0:.+]] = arith.subi %[[EXTSI0]], %[[BBARG2]] : i32 258// CHECK-NEXT: %[[EXTSI1:.+]] = arith.extsi %[[BBARG1]] : i8 to i32 259// CHECK-NEXT: %[[SUB1:.+]] = arith.subi %[[EXTSI1]], %[[BBARG3]] : i32 260// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[SUB0]], %[[SUB1]] : i32 261// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[BBARG4]], %[[MUL]] : i32 262// CHECK-NEXT: linalg.yield %[[ADD]] : i32 263 264// ----- 265 266func.func @generalize_fill(%output: memref<?x?xf32>, %value : f32) { 267 linalg.fill ins(%value : f32) outs(%output : memref<?x?xf32>) 268 return 269} 270 271// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()> 272// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> 273 274// CHECK: func @generalize_fill 275// CHECK-SAME: (%[[ARG0:.+]]: memref<?x?xf32>, %[[VAL:.+]]: f32) 276 277// CHECK: linalg.generic 278// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] 279// CHECK-SAME: iterator_types = ["parallel", "parallel"]} 280// CHECK-SAME: ins(%[[VAL]] : f32) 281// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>) 282 283// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 284// CHECK-NEXT: linalg.yield %[[BBARG0]] : f32 285 286// ----- 287 288func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi8>, %out: memref<?x?xf32>) { 289 linalg.batch_matvec ins(%lhs, %rhs: memref<?x?x?xi8>, memref<?x?xi8>) 290 outs(%out: memref<?x?xf32>) 291 return 292} 293// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 294// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> 295// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 296 297// CHECK: @generalize_batch_matm_vec 298 299// CHECK: linalg.generic 300// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 301// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} 302// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xi8>, memref<?x?xi8>) 303// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>) 304// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32) 305// CHECK: %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32 306// CHECK: %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32 307// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]] 308// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] 309// CHECK: linalg.yield %[[ADD]] : f32 310 311// ----- 312 313func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>, %out: memref<?x?xf32>) { 314 linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>) 315 outs(%out: memref<?x?xf32>) 316 return 317} 318// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> 319// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> 320// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 321 322// CHECK: @generalize_batch_vecmat 323 324// CHECK: linalg.generic 325// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 326// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} 327// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>) 328// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>) 329// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32) 330// CHECK: %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32 331// CHECK: %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32 332// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]] 333// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] 334// CHECK: linalg.yield %[[ADD]] : f32 335 336// ----- 337 338func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) { 339 linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>) 340 outs(%out: memref<8x8xf32>) 341 return 342} 343 344// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 345// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> 346// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> 347 348// CHECK: @batch_reduce_gemm 349 350// CHECK: linalg.generic 351// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 352// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]} 353// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xf32>, memref<7x9x8xf32>) 354// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32> 355// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 356// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 357// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 358// CHECK: linalg.yield %[[ADD]] : f32 359 360// ----- 361 362func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: memref<7x9x8xbf16>, %out: memref<8x8xf32>) { 363 linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xbf16>, memref<7x9x8xbf16>) 364 outs(%out: memref<8x8xf32>) 365 return 366} 367 368// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 369// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> 370// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> 371 372// CHECK: @generalize_batch_reduce_gemm_bf16 373 374// CHECK: linalg.generic 375// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] 376// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]} 377// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xbf16>, memref<7x9x8xbf16>) 378// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32> 379// CHECK: ^{{.+}}(%[[BBARG0:.+]]: bf16, %[[BBARG1:.+]]: bf16, %[[BBARG2:.+]]: f32) 380// CHECK: %[[EXTBF16_0:.+]] = arith.extf %[[BBARG0]] : bf16 to f32 381// CHECK: %[[EXTBF16_1:.+]] = arith.extf %[[BBARG1]] : bf16 to f32 382// CHECK: %[[MUL:.+]] = arith.mulf %[[EXTBF16_0]], %[[EXTBF16_1]] : f32 383// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 384// CHECK: linalg.yield %[[ADD]] : f32 385 386 387// ----- 388 389// CHECK-LABEL: generalize_linalg_map 390func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) { 391 %cst = arith.constant 0.000000e+00 : f32 392 // CHECK: linalg.map 393 // CHECK-NOT: linalg.generic 394 linalg.map outs(%arg0 : memref<1x8x8x8xf32>) 395 () { 396 linalg.yield %cst : f32 397 } 398 return 399} 400 401// ----- 402 403func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, 404 %out: memref<7x14x21xf32>) { 405 linalg.add ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) 406 outs(%out : memref<7x14x21xf32>) 407 return 408} 409 410// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 411 412// CHECK: func @generalize_add 413// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, 414// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) 415 416// CHECK: linalg.generic 417// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 418// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 419// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) 420// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) 421 422// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 423// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32 424// CHECK-NEXT: linalg.yield %[[SUM]] : f32 425 426// ----- 427 428func.func @generalize_sub(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, 429 %out: memref<7x14x21xf32>) { 430 linalg.sub ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) 431 outs(%out : memref<7x14x21xf32>) 432 return 433} 434 435// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 436 437// CHECK: func @generalize_sub 438// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, 439// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) 440 441// CHECK: linalg.generic 442// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 443// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 444// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) 445// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) 446 447// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 448// CHECK-NEXT: %[[SUB:.+]] = arith.subf %[[BBARG0]], %[[BBARG1]] : f32 449// CHECK-NEXT: linalg.yield %[[SUB]] : f32 450 451// ----- 452 453func.func @generalize_mul(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, 454 %out: memref<7x14x21xf32>) { 455 linalg.mul ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) 456 outs(%out : memref<7x14x21xf32>) 457 return 458} 459 460// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 461 462// CHECK: func @generalize_mul 463// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, 464// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) 465 466// CHECK: linalg.generic 467// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 468// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 469// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) 470// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) 471 472// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 473// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 474// CHECK-NEXT: linalg.yield %[[MUL]] : f32 475 476// ----- 477 478func.func @generalize_div(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, 479 %out: memref<7x14x21xf32>) { 480 linalg.div ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) 481 outs(%out : memref<7x14x21xf32>) 482 return 483} 484 485// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 486 487// CHECK: func @generalize_div 488// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, 489// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) 490 491// CHECK: linalg.generic 492// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 493// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 494// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) 495// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) 496 497// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 498// CHECK-NEXT: %[[DIV:.+]] = arith.divf %[[BBARG0]], %[[BBARG1]] : f32 499// CHECK-NEXT: linalg.yield %[[DIV]] : f32 500 501// ----- 502 503func.func @generalize_divu(%lhs: memref<7x14x21xi32>, %rhs: memref<7x14x21xi32>, 504 %out: memref<7x14x21xi32>) { 505 linalg.div_unsigned ins(%lhs, %rhs : memref<7x14x21xi32>, memref<7x14x21xi32>) 506 outs(%out : memref<7x14x21xi32>) 507 return 508} 509 510// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 511 512// CHECK: func @generalize_divu 513// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xi32>, %[[RHS:.+]]: memref<7x14x21xi32>, 514// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xi32>) 515 516// CHECK: linalg.generic 517// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 518// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 519// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xi32>, memref<7x14x21xi32>) 520// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xi32>) 521 522// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32, %[[BBARG2:.+]]: i32) 523// CHECK-NEXT: %[[DIVU:.+]] = arith.divui %[[BBARG0]], %[[BBARG1]] : i32 524// CHECK-NEXT: linalg.yield %[[DIVU]] : i32 525 526// ----- 527 528func.func @generalize_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 529 linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 530 return 531} 532 533// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 534 535// CHECK: func @generalize_exp 536// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 537 538// CHECK: linalg.generic 539// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 540// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 541// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 542 543// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 544// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[BBARG0]] : f32 545// CHECK-NEXT: linalg.yield %[[EXP]] : f32 546 547// ----- 548 549func.func @generalize_log(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 550 linalg.log ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 551 return 552} 553 554// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 555 556// CHECK: func @generalize_log 557// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 558 559// CHECK: linalg.generic 560// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 561// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 562// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 563 564// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 565// CHECK-NEXT: %[[log:.+]] = math.log %[[BBARG0]] : f32 566// CHECK-NEXT: linalg.yield %[[log]] : f32 567 568// ----- 569 570func.func @generalize_abs(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 571 linalg.abs ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 572 return 573} 574 575// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 576 577// CHECK: func @generalize_abs 578// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 579 580// CHECK: linalg.generic 581// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 582// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 583// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 584 585// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 586// CHECK-NEXT: %[[abs:.+]] = math.absf %[[BBARG0]] : f32 587// CHECK-NEXT: linalg.yield %[[abs]] : f32 588 589// ----- 590 591func.func @generalize_ceil(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 592 linalg.ceil ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 593 return 594} 595 596// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 597 598// CHECK: func @generalize_ceil 599// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 600 601// CHECK: linalg.generic 602// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 603// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 604// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 605 606// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 607// CHECK-NEXT: %[[ceil:.+]] = math.ceil %[[BBARG0]] : f32 608// CHECK-NEXT: linalg.yield %[[ceil]] : f32 609 610// ----- 611 612func.func @generalize_floor(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 613 linalg.floor ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 614 return 615} 616 617// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 618 619// CHECK: func @generalize_floor 620// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 621 622// CHECK: linalg.generic 623// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 624// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 625// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 626 627// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 628// CHECK-NEXT: %[[floor:.+]] = math.floor %[[BBARG0]] : f32 629// CHECK-NEXT: linalg.yield %[[floor]] : f32 630 631// ----- 632 633func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 634 linalg.negf ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 635 return 636} 637 638// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 639 640// CHECK: func @generalize_negf 641// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 642 643// CHECK: linalg.generic 644// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 645// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 646// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 647 648// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 649// CHECK-NEXT: %[[negf:.+]] = arith.negf %[[BBARG0]] : f32 650// CHECK-NEXT: linalg.yield %[[negf]] : f32 651 652// ----- 653 654func.func @generalize_reciprocal(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 655 linalg.reciprocal ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 656 return 657} 658 659// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 660 661// CHECK: func @generalize_reciprocal 662// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 663 664// CHECK: %[[one:.+]] = arith.constant 1.000000e+00 : f32 665 666// CHECK: linalg.generic 667// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 668// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 669// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 670 671// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 672// CHECK-NEXT: %[[reciprocal:.+]] = arith.divf %[[one]], %[[BBARG0]] : f32 673// CHECK-NEXT: linalg.yield %[[reciprocal]] : f32 674 675// ----- 676 677func.func @generalize_round(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 678 linalg.round ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 679 return 680} 681 682// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 683 684// CHECK: func @generalize_round 685// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 686 687// CHECK: linalg.generic 688// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 689// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 690// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 691 692// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 693// CHECK-NEXT: %[[round:.+]] = math.round %[[BBARG0]] : f32 694// CHECK-NEXT: linalg.yield %[[round]] : f32 695 696// ----- 697 698func.func @generalize_sqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 699 linalg.sqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 700 return 701} 702 703// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 704 705// CHECK: func @generalize_sqrt 706// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 707 708// CHECK: linalg.generic 709// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 710// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 711// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 712 713// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 714// CHECK-NEXT: %[[sqrt:.+]] = math.sqrt %[[BBARG0]] : f32 715// CHECK-NEXT: linalg.yield %[[sqrt]] : f32 716 717// ----- 718 719func.func @generalize_rsqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 720 linalg.rsqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 721 return 722} 723 724// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 725 726// CHECK: func @generalize_rsqrt 727// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 728 729// CHECK: linalg.generic 730// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 731// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 732// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 733 734// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 735// CHECK-NEXT: %[[rsqrt:.+]] = math.rsqrt %[[BBARG0]] : f32 736// CHECK-NEXT: linalg.yield %[[rsqrt]] : f32 737 738// ----- 739 740func.func @generalize_square(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 741 linalg.square ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 742 return 743} 744 745// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 746 747// CHECK: func @generalize_square 748// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 749 750// CHECK: linalg.generic 751// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 752// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 753// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 754 755// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 756// CHECK-NEXT: %[[square:.+]] = arith.mulf %[[BBARG0]], %[[BBARG0]] : f32 757// CHECK-NEXT: linalg.yield %[[square]] : f32 758 759// ----- 760 761func.func @generalize_tanh(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 762 linalg.tanh ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 763 return 764} 765 766// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 767 768// CHECK: func @generalize_tanh 769// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 770 771// CHECK: linalg.generic 772// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 773// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 774// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 775 776// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 777// CHECK-NEXT: %[[tanh:.+]] = math.tanh %[[BBARG0]] : f32 778// CHECK-NEXT: linalg.yield %[[tanh]] : f32 779 780// ----- 781 782func.func @generalize_erf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) { 783 linalg.erf ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>) 784 return 785} 786 787// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 788 789// CHECK: func @generalize_erf 790// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>) 791 792// CHECK: linalg.generic 793// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] 794// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 795// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>) 796 797// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 798// CHECK-NEXT: %[[erf:.+]] = math.erf %[[BBARG0]] : f32 799// CHECK-NEXT: linalg.yield %[[erf]] : f32 800 801// ----- 802 803func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, 804 %out: memref<7x14x21xf32>) { 805 linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) 806 outs(%out : memref<7x14x21xf32>) 807 return 808} 809 810// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 811 812// CHECK: func @generalize_max 813// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, 814// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) 815 816// CHECK: linalg.generic 817// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 818// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 819// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) 820// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) 821 822// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 823// CHECK-NEXT: %[[max:.+]] = arith.maximumf %[[BBARG0]], %[[BBARG1]] : f32 824// CHECK-NEXT: linalg.yield %[[max]] : f32 825 826// ----- 827 828func.func @generalize_min(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, 829 %out: memref<7x14x21xf32>) { 830 linalg.min ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) 831 outs(%out : memref<7x14x21xf32>) 832 return 833} 834 835// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 836 837// CHECK: func @generalize_min 838// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, 839// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) 840 841// CHECK: linalg.generic 842// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 843// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 844// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) 845// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) 846 847// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 848// CHECK-NEXT: %[[min:.+]] = arith.minimumf %[[BBARG0]], %[[BBARG1]] : f32 849// CHECK-NEXT: linalg.yield %[[min]] : f32 850 851 852// ----- 853 854func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, 855 %out: memref<7x14x21xf32>) { 856 linalg.powf ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) 857 outs(%out : memref<7x14x21xf32>) 858 return 859} 860 861// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 862 863// CHECK: func @generalize_powf 864// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, 865// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) 866 867// CHECK: linalg.generic 868// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] 869// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 870// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) 871// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) 872 873// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) 874// CHECK-NEXT: %[[powf:.+]] = math.powf %[[BBARG0]], %[[BBARG1]] : f32 875// CHECK-NEXT: linalg.yield %[[powf]] : f32 876 877 878// ----- 879 880func.func @generalize_select(%cond: memref<7x14x21xi1>, %lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, 881 %out: memref<7x14x21xf32>) { 882 linalg.select ins(%cond, %lhs, %rhs: memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>) 883 outs(%out: memref<7x14x21xf32>) 884 return 885} 886 887// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 888 889// CHECK: func @generalize_select 890// CHECK-SAME: (%[[COND:.+]]: memref<7x14x21xi1>, %[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, 891// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) 892 893// CHECK: linalg.generic 894// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]], #[[MAP]]] 895// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} 896// CHECK-SAME: ins(%[[COND]], %[[LHS]], %[[RHS]] : memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>) 897// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) 898 899// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i1, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32, %[[BBARG3:.+]]: f32) 900// CHECK-NEXT: %[[select:.+]] = arith.select %[[BBARG0]], %[[BBARG1]], %[[BBARG2]] : f32 901// CHECK-NEXT: linalg.yield %[[select]] : f32 902 903 904// ----- 905 906// CHECK-LABEL: func @fill_tensor 907func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) { 908 %e0 = tensor.empty() : tensor<f32> 909 %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32> 910// CHECK: linalg.generic 911// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) 912// CHECK-NEXT: linalg.yield %[[BBARG0]] : f32 913 914 %e1 = tensor.empty() : tensor<vector<2x4xf32>> 915 %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>> 916// CHECK: linalg.generic 917// CHECK: ^{{.+}}(%[[BBARG0:.+]]: vector<2x4xf32>, %[[BBARG1:.+]]: vector<2x4xf32>) 918// CHECK-NEXT: linalg.yield %[[BBARG0]] : vector<2x4xf32> 919 920 return %0, %1: tensor<f32>, tensor<vector<2x4xf32>> 921} 922 923// ----- 924 925// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> 926// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> 927// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 928 929// CHECK-LABEL: func.func @matmul_transpose_a_explicit( 930// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, 931// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, 932// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { 933 934// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} 935// CHECK: arith.mulf 936// CHECK: arith.addf 937 938func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { 939 linalg.matmul indexing_maps = [ 940 affine_map<(d0, d1, d2) -> (d2, d0)>, 941 affine_map<(d0, d1, d2) -> (d2, d1)>, 942 affine_map<(d0, d1, d2) -> (d0, d1)> 943 ] 944 ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) 945 outs(%arg2: memref<3x7xf32>) 946 return 947} 948 949// ----- 950 951// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> 952// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> 953// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 954// CHECK-LABEL: func.func @matmul_transpose_b_explicit( 955// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, 956// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, 957// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { 958 959// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} 960// CHECK: arith.mulf 961// CHECK: arith.addf 962 963func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { 964 linalg.matmul indexing_maps = [ 965 affine_map<(d0, d1, d2) -> (d0, d2)>, 966 affine_map<(d0, d1, d2) -> (d1, d2)>, 967 affine_map<(d0, d1, d2) -> (d0, d1)> 968 ] 969 ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) 970 outs(%arg2: memref<3x7xf32>) 971 return 972} 973 974// ----- 975 976// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> 977// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> 978// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 979 980// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit( 981// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, 982// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, 983// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { 984 985// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} 986// CHECK: arith.mulf 987// CHECK: arith.addf 988 989func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { 990 linalg.matmul indexing_maps = [ 991 affine_map<(d0, d1, d2) -> (d2, d0)>, 992 affine_map<(d0, d1, d2) -> (d1, d2)>, 993 affine_map<(d0, d1, d2) -> (d0, d1)> 994 ] 995 ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) 996 outs(%arg2: memref<3x7xf32>) 997 return 998} 999 1000// ----- 1001 1002// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> 1003// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> 1004// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 1005 1006// CHECK-LABEL: func.func @contract_matmul( 1007// CHECK-SAME: %[[A:.*]]: memref<3x5xf32>, 1008// CHECK-SAME: %[[B:.*]]: memref<5x7xf32>, 1009// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { 1010 1011// CHECK: linalg.generic 1012// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]] 1013// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] 1014// CHECK-NEXT: ^{{.+}}( 1015// CHECK-NEXT: arith.mulf 1016// CHECK-NEXT: arith.addf 1017// CHECK-NEXT: linalg.yield 1018 1019func.func @contract_matmul(%A: memref<3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) { 1020 linalg.contract 1021 indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 1022 affine_map<(d0, d1, d2) -> (d2, d1)>, 1023 affine_map<(d0, d1, d2) -> (d0, d1)>] 1024 ins(%A, %B : memref<3x5xf32>, memref<5x7xf32>) 1025 outs(%C: memref<3x7xf32>) 1026 1027 return 1028} 1029 1030// ----- 1031 1032// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> 1033// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> 1034// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 1035 1036// CHECK-LABEL: func.func @contract_matmul_transpose_a_b( 1037// CHECK-SAME: %[[A:.*]]: memref<5x3xf32>, 1038// CHECK-SAME: %[[B:.*]]: memref<7x5xf32>, 1039// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { 1040 1041// CHECK: linalg.generic 1042// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]] 1043// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] 1044// CHECK-NEXT: ^{{.+}}( 1045// CHECK-NEXT: arith.mulf 1046// CHECK-NEXT: arith.addf 1047// CHECK-NEXT: linalg.yield 1048 1049func.func @contract_matmul_transpose_a_b(%A: memref<5x3xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) { 1050 linalg.contract 1051 indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, 1052 affine_map<(d0, d1, d2) -> (d1, d2)>, 1053 affine_map<(d0, d1, d2) -> (d0, d1)>] 1054 ins(%A, %B : memref<5x3xf32>, memref<7x5xf32>) 1055 outs(%C: memref<3x7xf32>) 1056 return 1057} 1058 1059// ----- 1060 1061// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 1062// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> 1063// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 1064 1065// CHECK-LABEL: func.func @contract_batch_matmul( 1066// CHECK-SAME: %[[A:.*]]: memref<9x3x5xf32>, 1067// CHECK-SAME: %[[B:.*]]: memref<9x5x7xf32>, 1068// CHECK-SAME: %[[C:.*]]: memref<9x3x7xf32>) { 1069 1070// CHECK: linalg.generic 1071// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]] 1072// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] 1073// CHECK-NEXT: ^{{.+}}( 1074// CHECK-NEXT: arith.mulf 1075// CHECK-NEXT: arith.addf 1076// CHECK-NEXT: linalg.yield 1077 1078func.func @contract_batch_matmul(%A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<9x3x7xf32>) { 1079 linalg.contract 1080 indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, 1081 affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, 1082 affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>] 1083 ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>) 1084 outs(%C: memref<9x3x7xf32>) 1085 return 1086} 1087 1088// ----- 1089 1090// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 1091// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> 1092// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> 1093 1094// CHECK-LABEL: func.func @contract_batch_reduce_matmul( 1095// CHECK-SAME: %[[A:.*]]: memref<9x3x5xf32>, 1096// CHECK-SAME: %[[B:.*]]: memref<9x5x7xf32>, 1097// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { 1098 1099// CHECK: linalg.generic 1100// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]] 1101// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"] 1102// CHECK-NEXT: ^{{.+}}( 1103// CHECK-NEXT: arith.mulf 1104// CHECK-NEXT: arith.addf 1105// CHECK-NEXT: linalg.yield 1106 1107#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 1108#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> 1109#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)> 1110func.func @contract_batch_reduce_matmul( 1111 %A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<3x7xf32>) { 1112 linalg.contract 1113 indexing_maps = [#accessA, #accessB, #accessC] 1114 ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>) 1115 outs(%C: memref<3x7xf32>) 1116 return 1117} 1118 1119// ----- 1120 1121// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)> 1122// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> 1123// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> 1124 1125// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n( 1126// CHECK-SAME: %[[A:.*]]: memref<9x5x3xf32>, 1127// CHECK-SAME: %[[B:.*]]: memref<9x7x5xf32>, 1128// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { 1129 1130// CHECK: linalg.generic 1131// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]] 1132// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"] 1133// CHECK-NEXT: ^{{.+}}( 1134// CHECK-NEXT: arith.mulf 1135// CHECK-NEXT: arith.addf 1136// CHECK-NEXT: linalg.yield 1137 1138#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)> 1139#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> 1140#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)> 1141func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n( 1142 %A: memref<9x5x3xf32>, %B: memref<9x7x5xf32>, %C: memref<3x7xf32>) { 1143 linalg.contract 1144 indexing_maps = [#accessA, #accessB, #accessC] 1145 ins(%A, %B : memref<9x5x3xf32>, memref<9x7x5xf32>) 1146 outs(%C: memref<3x7xf32>) 1147 return 1148} 1149 1150// ----- 1151 1152// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0) -> (d0)> 1153// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0) -> ()> 1154 1155// CHECK-LABEL: func.func @contract_dot( 1156// CHECK-SAME: %[[A:.*]]: memref<9xf32>, %[[B:.*]]: memref<9xf32>, 1157// CHECK-SAME: %[[C:.*]]: memref<f32>) { 1158 1159// CHECK: linalg.generic 1160// CHECK-SAME: indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]] 1161// CHECK-SAME: iterator_types = ["reduction"] 1162// CHECK-NEXT: ^{{.+}}( 1163// CHECK-NEXT: arith.mulf 1164// CHECK-NEXT: arith.addf 1165// CHECK-NEXT: linalg.yield 1166 1167#accessAB = affine_map<(d0) -> (d0)> 1168#accessC = affine_map<(d0) -> ()> 1169func.func @contract_dot( 1170 %A: memref<9xf32>, %B: memref<9xf32>, %C: memref<f32>) { 1171 linalg.contract 1172 indexing_maps = [#accessAB, #accessAB, #accessC] 1173 ins(%A, %B : memref<9xf32>, memref<9xf32>) 1174 outs(%C: memref<f32>) 1175 return 1176} 1177 1178// ----- 1179 1180// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0, d1, d2) -> (d2)> 1181// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 1182 1183// CHECK-LABEL: func.func @contract_matmul_bcast_a_b( 1184// CHECK-SAME: %[[A:.*]]: memref<5xf32>, %[[B:.*]]: memref<5xf32>, 1185// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { 1186 1187// CHECK: linalg.generic 1188// CHECK-SAME: indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]] 1189// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] 1190// CHECK-NEXT: ^{{.+}}( 1191// CHECK-NEXT: arith.mulf 1192// CHECK-NEXT: arith.addf 1193// CHECK-NEXT: linalg.yield 1194 1195#accessAB = affine_map<(d0, d1, d2) -> (d2)> 1196#accessC = affine_map<(d0, d1, d2) -> (d0, d1)> 1197func.func @contract_matmul_bcast_a_b( 1198 %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) { 1199 linalg.contract 1200 indexing_maps = [#accessAB, #accessAB, #accessC] 1201 ins(%A, %B : memref<5xf32>, memref<5xf32>) 1202 outs(%C: memref<3x7xf32>) 1203 return 1204} 1205