xref: /llvm-project/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (revision f20b8e35b3bc276d09a6911746f9d44cbb5de297)
1// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
2
3func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
4  %0 = tensor.empty() : tensor<6x6x5x2xf32>
5  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
6  %2 = tensor.empty() : tensor<6x6x2x2x2x5xf32>
7  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
8  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
9  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
10  %4 = tensor.empty() : tensor<36x8x2xf32>
11  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
12  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
13  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
14  return %6 : tensor<2x8x8x2xf32>
15}
16
17module attributes {transform.with_named_sequence} {
18  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
19    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
20    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
21    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
22    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
23    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
24    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
25    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
26    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
27    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
28    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
29    transform.yield
30  }
31}
32
33// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
34// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
35// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
36// CHECK-LABEL: func.func @conv2d
37// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
38// CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
39// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<6x4xf32>
40// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
41// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
42// CHECK:  %[[CST_3:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
43// CHECK:  %[[CST_4:.*]] = arith.constant dense<{{.*}}> : tensor<3x6xf32>
44// CHECK:  %[[CST_5:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
45// CHECK:  %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
46// CHECK:  %[[C1:.*]] = arith.constant 1 : index
47// CHECK:  %[[C5:.*]] = arith.constant 5 : index
48// CHECK:  %[[C2:.*]] = arith.constant 2 : index
49// CHECK:  %[[C0:.*]] = arith.constant 0 : index
50// CHECK:  %[[S0:.*]] = tensor.empty()
51// CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
52// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
53// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
54// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
55// CHECK:      %[[S11:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
56// CHECK:      %[[S12:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S11]] : tensor<6x3xf32>) -> tensor<6x3xf32>
57// CHECK:      %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
58// CHECK:      %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
59// CHECK:      %[[S15:.*]] = linalg.matmul ins(%[[S12]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
60// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
61// CHECK:      scf.yield %[[INSERTED_SLICE]]
62// CHECK:    scf.yield %[[S9]]
63// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
64// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
65// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
66// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
67// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
68// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
69// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
70// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
71// CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
72// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
73// CHECK:          %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
74// CHECK:          %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
75// CHECK:          %[[S16:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_8]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
76// CHECK:          %[[S17:.*]] = tensor.empty() : tensor<6x6xf32>
77// CHECK:          %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<6x6xf32>) -> tensor<6x6xf32>
78// CHECK:          %[[S19:.*]] = linalg.matmul ins(%[[S16]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<6x6xf32>) -> tensor<6x6xf32>
79// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
80// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
81// CHECK:        scf.yield %[[S13]]
82// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
83// CHECK:      scf.yield %[[INSERTED_SLICE]]
84// CHECK:    scf.yield %[[S9]]
85// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
86// CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
87// CHECK:  %[[S7:.*]] = tensor.empty()
88// CHECK:  %[[S6:.*]] = linalg.batch_matmul
89// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
90// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
91// CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
92// CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
93// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
94// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
95// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
96// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
97// CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
98// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
99// CHECK:          %[[S25:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
100// CHECK:          %[[S16:.*]] = tensor.empty() : tensor<4x6xf32>
101// CHECK:          %[[S17:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S16]] : tensor<4x6xf32>) -> tensor<4x6xf32>
102// CHECK:          %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_8]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
103// CHECK:          %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
104// CHECK:          %[[S20:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
105// CHECK:          %[[S21:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
106// CHECK:          %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S21]] : f32, tensor<4x4xf32>) outs(%[[S25]] : tensor<4x4xf32>) {
107// CHECK:          ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
108// CHECK:             %[[VAL_90:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
109// CHECK:             %[[VAL_91:.*]] = arith.addf %[[VAL_90]], %[[OUT]] : f32
110/// CHECK:            linalg.yield %[[VAL_91]] : f32
111// CHECK:          } -> tensor<4x4xf32>
112// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S23]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
113// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
114// CHECK:        scf.yield %[[S15]]
115// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
116// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
117// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
118// CHECK:      scf.yield %[[INSERTED_SLICE]]
119// CHECK:    scf.yield %[[S9]]
120
121// -----
122
123func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
124  %cst = arith.constant 0.000000e+00 : f32
125  %0 = tensor.empty() : tensor<6x6x5x2xf32>
126  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
127  %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
128  ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
129    tensor.yield %cst : f32
130  } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
131  %2 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
132  %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
133  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
134  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
135  %4 = tensor.empty() : tensor<36x18x2xf32>
136  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
137  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%5 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
138  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
139  %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
140  ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
141    tensor.yield %cst : f32
142  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
143  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
144  %extracted_slice = tensor.extract_slice %7[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
145  return %extracted_slice : tensor<2x9x9x2xf32>
146}
147
148module attributes {transform.with_named_sequence} {
149  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
150    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
151    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
152    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
153    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
154    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
155    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
156    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
157    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
158    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
159    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
160    transform.yield
161  }
162}
163
164// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
165// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
166// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
167// CHECK-LABEL: func.func @conv2d_unaligned
168// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
169// CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
170// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<6x4xf32>
171// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
172// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
173// CHECK:  %[[CST_3:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
174// CHECK:  %[[C3:.*]] = arith.constant 3 : index
175// CHECK:  %[[CST_4:.*]] = arith.constant dense<{{.*}}> : tensor<3x6xf32>
176// CHECK:  %[[CST_5:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
177// CHECK:  %[[C1:.*]] = arith.constant 1 : index
178// CHECK:  %[[C5:.*]] = arith.constant 5 : index
179// CHECK:  %[[C2:.*]] = arith.constant 2 : index
180// CHECK:  %[[C0:.*]] = arith.constant 0 : index
181// CHECK:  %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
182// CHECK:  %[[S0:.*]] = tensor.empty()
183// CHECK:  %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
184// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
185// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
186// CHECK:      %[[S11:.*]] = tensor.empty() : tensor<6x3xf32>
187// CHECK:      %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<6x3xf32>) -> tensor<6x3xf32>
188// CHECK:      %[[S13:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S12]] : tensor<6x3xf32>) -> tensor<6x3xf32>
189// CHECK:      %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
190// CHECK:      %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
191// CHECK:      %[[S16:.*]] = linalg.matmul ins(%[[S13]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
192// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
193// CHECK:      scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
194// CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
195// CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
196// CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
197// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]])
198// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
199// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
200// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
201// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
202// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
203// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
204// CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
205// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
206// CHECK:          %[[S15:.*]] = tensor.empty() : tensor<6x6xf32>
207// CHECK:          %[[S16:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
208// CHECK:          %[[S17:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_11]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S16]] : tensor<6x6xf32>) -> tensor<6x6xf32>
209// CHECK:          %[[S18:.*]] = tensor.empty() : tensor<6x6xf32>
210// CHECK:          %[[S19:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S18]] : tensor<6x6xf32>) -> tensor<6x6xf32>
211// CHECK:          %[[S20:.*]] = linalg.matmul ins(%[[S17]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32>
212// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S20]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
213// CHECK:          scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
214// CHECK:        scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
215// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
216// CHECK:      scf.yield %[[INSERTED_SLICE]]
217// CHECK:    scf.yield %[[S9]]
218// CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
219// CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
220// CHECK:  %[[S7:.*]] = tensor.empty()
221// CHECK:  %[[S6:.*]] = linalg.batch_matmul
222// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
223// CHECK:  %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
224// CHECK:  %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[PADDED_8]])
225// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
226// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
227// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
228// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
229// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
230// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
231// CHECK:        %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
232// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
233// CHECK:          %[[S26:.*]] = tensor.extract_slice %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
234// CHECK:          %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
235// CHECK:          %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
236// CHECK:          %[[S19:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_11]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<4x6xf32>) -> tensor<4x6xf32>
237// CHECK:          %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
238// CHECK:          %[[S21:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
239// CHECK:          %[[S22:.*]] = linalg.matmul ins(%[[S19]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) -> tensor<4x4xf32>
240// CHECK:          %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S22]] : f32, tensor<4x4xf32>) outs(%[[S26]] : tensor<4x4xf32>) {
241// CHECK:          ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
242// CHECK:             %[[VAL_104:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
243// CHECK:             %[[VAL_105:.*]] = arith.addf %[[VAL_104]], %[[OUT]] : f32
244/// CHECK:            linalg.yield %[[VAL_105]] : f32
245// CHECK:          } -> tensor<4x4xf32>
246// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S24]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
247// CHECK:          scf.yield %[[INSERTED_SLICE_12]]
248// CHECK:        scf.yield %[[S15]] : tensor<2x4x4x2xf32>
249// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
250// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
251// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
252// CHECK:      scf.yield %[[INSERTED_SLICE]]
253// CHECK:    scf.yield %[[S9]]
254// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
255// CHECK:  return %[[EXTRACTED_SLICE]]
256
257// -----
258
259func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
260  %cst = arith.constant 0.000000e+00 : f32
261  %0 = tensor.empty() : tensor<6x1x5x2xf32>
262  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
263  %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
264  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
265  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
266  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
267  %4 = tensor.empty() : tensor<6x2x2xf32>
268  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
269  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
270  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
271  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
272  return %7 : tensor<2x4x1x2xf32>
273}
274
275module attributes {transform.with_named_sequence} {
276  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
277    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
278    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
279    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
280    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
281    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
282    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
283    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
284    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
285    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
286    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
287    transform.yield
288  }
289}
290
291// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
292// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
293// CHECK-LABEL: func.func @conv2d_mx1_rx1
294// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
295// CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
296// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
297// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
298// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
299// CHECK:   %[[C1:.*]] = arith.constant 1 : index
300// CHECK:   %[[C5:.*]] = arith.constant 5 : index
301// CHECK:   %[[C2:.*]] = arith.constant 2 : index
302// CHECK:   %[[C0:.*]] = arith.constant 0 : index
303// CHECK:   %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
304// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
305// CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
306// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
307// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
308// CHECK:       %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
309// CHECK:       %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
310// CHECK:       %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
311// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
312// CHECK:       scf.yield %[[INSERTED_SLICE]]
313// CHECK:     scf.yield %[[S7]]
314// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
315// CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
316// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
317// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
318// CHECK:       %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
319// CHECK:       %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
320// CHECK:       %[[S10:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
321// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
322// CHECK:       scf.yield %[[INSERTED_SLICE]]
323// CHECK:     scf.yield %[[S7]]
324// CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
325// CHECK:   %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
326// CHECK:   %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
327// CHECK:   %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
328// CHECK:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_3]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
329// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
330// CHECK:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
331// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
332// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
333// CHECK:       %[[S15:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
334// CHECK:       %[[S9:.*]] = tensor.empty() : tensor<4x1xf32>
335// CHECK:       %[[S10:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S9]] : tensor<4x1xf32>) -> tensor<4x1xf32>
336// CHECK:       %[[S11:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
337// CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S11]] : f32, tensor<4x1xf32>) outs(%[[S15]] : tensor<4x1xf32>) {
338// CHECK:       ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
339// CHECK:          %[[VAL_57:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
340// CHECK:          %[[VAL_58:.*]] = arith.addf %[[VAL_57]], %[[OUT]] : f32
341/// CHECK:         linalg.yield %[[VAL_58]] : f32
342// CHECK:       } -> tensor<4x1xf32>
343// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
344// CHECK:       scf.yield %[[INSERTED_SLICE]]
345// CHECK:     scf.yield %[[S7]]
346// CHECK:   return %[[S6]]
347
348// -----
349
350func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
351  %cst = arith.constant 0.000000e+00 : f32
352  %0 = tensor.empty() : tensor<6x1x5x2xf32>
353  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
354  %2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
355  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
356  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
357  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
358  %4 = tensor.empty() : tensor<6x4x2xf32>
359  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
360  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
361  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
362  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
363  return %7 : tensor<2x4x2x2xf32>
364}
365
366module attributes {transform.with_named_sequence} {
367  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
368    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
369    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
370    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
371    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
372    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
373    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
374    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
375    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
376    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
377    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
378    transform.yield
379  }
380}
381
382// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> ()>
383// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
384// CHECK-LABEL: func.func @conv2d_mx1_rx1_2
385// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x2x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
386// CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
387// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
388// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
389// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
390// CHECK:   %[[C1:.*]] = arith.constant 1 : index
391// CHECK:   %[[C5:.*]] = arith.constant 5 : index
392// CHECK:   %[[C2:.*]] = arith.constant 2 : index
393// CHECK:   %[[C0:.*]] = arith.constant 0 : index
394// CHECK:   %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
395// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
396// CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
397// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
398// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
399// CHECK:       %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
400// CHECK:       %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
401// CHECK:       %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
402// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
403// CHECK:       scf.yield %[[INSERTED_SLICE]]
404// CHECK:     scf.yield %[[S7]]
405// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x2x2x5xf32>
406// CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
407// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG3]], 0] [2, 6, 1, 5] [1, 1, 1, 1]
408// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
409// CHECK:     %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
410// CHECK:     %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
411// CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 6, 1, 1] [1, 1, 1, 1]
412// CHECK:       %[[S11:.*]] = tensor.empty() : tensor<6x1xf32>
413// CHECK:       %[[S12:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S11]] : tensor<6x1xf32>) -> tensor<6x1xf32>
414// CHECK:       %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_6]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S12]] : tensor<6x1xf32>) -> tensor<6x1xf32>
415// CHECK:       %[[INSERTED_SLICE_7:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
416// CHECK:       scf.yield %[[INSERTED_SLICE_7]]
417// CHECK:     scf.yield %[[S10]]
418// CHECK:    %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
419// CHECK:    scf.yield %[[INSERTED_SLICE]]
420// CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
421// CHECK:   %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
422// CHECK:   %[[S4:.*]] = tensor.empty() : tensor<6x4x2xf32>
423// CHECK:   %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
424// CHECK:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_4]], %[[COLLAPSED]] : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
425// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2]
426// CHECK:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
427// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
428// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, %[[ARG3]], 0] [2, 4, 1, 2] [1, 1, 1, 1]
429// CHECK:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
430// CHECK:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
431// CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
432// CHECK:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG8]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 4, 1, 1] [1, 1, 1, 1]
433// CHECK:       %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
434// CHECK:       %[[S11:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
435// CHECK:       %[[S12:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE_6]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S11]] : tensor<4x1xf32>) -> tensor<4x1xf32>
436// CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S12]] : f32, tensor<4x1xf32>) outs(%[[EXTRACTED_SLICE_7]] : tensor<4x1xf32>) {
437// CHECK:       ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
438// CHECK:          %[[VAL_57:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
439// CHECK:          %[[VAL_58:.*]] = arith.addf %[[VAL_57]], %[[OUT]] : f32
440// CHECK:          linalg.yield %[[VAL_58]] : f32
441// CHECK:       } -> tensor<4x1xf32>
442// CHECK:       %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 4, 1, 1] [1, 1, 1, 1]
443// CHECK:       scf.yield %[[INSERTED_SLICE_8]]
444// CHECK:     scf.yield %[[S9]]
445// CHECK:     %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S8]] into %[[ARG4]][0, 0, %[[ARG3]], 0] [2, 4, 1, 2] [1, 1, 1, 1]
446// CHECK:     scf.yield %[[INSERTED_SLICE]]
447// CHECK:   return %[[S7]]
448