xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (revision f20b8e35b3bc276d09a6911746f9d44cbb5de297)
1 //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implement Winograd Conv2D algorithm. The implementation is based on the
10 // paper: Fast Algorithms for Convolutional Neural Networks
11 // (https://arxiv.org/abs/1509.09308)
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/Support/MathExtras.h"
23 
24 namespace mlir {
25 namespace linalg {
26 
27 namespace {
28 
29 // clang-format off
30 /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
31 /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
32 /// m is the output dimension and r is the filter dimension, is
33 ///
34 /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
35 ///
36 /// g is filter and d is input data. We need to prepare 6 constant
37 /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
38 ///
39 /// The following tables define these constant transformation matrices for
40 /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
41 constexpr float G_2x2_3x3[] = {
42    -1,     0,   0,
43  1./2, -1./2, 1./2,
44  1./2,  1./2, 1./2,
45     0,     0,    1
46 };
47 
48 constexpr float GT_2x2_3x3[] = {
49    -1,  1./2, 1./2, 0,
50     0, -1./2, 1./2, 0,
51     0,  1./2, 1./2, 1
52 };
53 
54 constexpr float BT_2x2_3x3[] = {
55    -1,    0,   1,   0,
56     0,   -1,   1,   0,
57     0,    1,   1,   0,
58     0,   -1,   0,   1
59 };
60 
61 constexpr float B_2x2_3x3[] = {
62    -1,    0,   0,   0,
63     0,   -1,   1,  -1,
64     1,    1,   1,   0,
65     0,    0,   0,   1
66 };
67 
68 constexpr float AT_2x2_3x3[] = {
69     1,    1,   1,   0,
70     0,   -1,   1,   1
71 };
72 
73 constexpr float A_2x2_3x3[] = {
74     1,    0,
75     1,   -1,
76     1,    1,
77     0,    1
78 };
79 
80 constexpr float G_4x4_3x3[] = {
81      1,     0,     0,
82  -1./3,  1./3, -1./3,
83  -1./3, -1./3, -1./3,
84  1./12, -1./6,  1./3,
85  1./12,  1./6,  1./3,
86      0,     0,     1
87 };
88 
89 constexpr float GT_4x4_3x3[] = {
90  1,  -1./3, -1./3, 1./12, 1./12, 0,
91  0,   1./3, -1./3, -1./6,  1./6, 0,
92  0,  -1./3, -1./3,  1./3,  1./3, 1
93 };
94 
95 constexpr float BT_4x4_3x3[] = {
96  1./4,     0, -5./16,      0, 1./16,     0,
97     0,  1./4,  -1./4, -1./16, 1./16,     0,
98     0, -1./4,  -1./4,  1./16, 1./16,     0,
99     0,  1./4,  -1./8,  -1./4,  1./8,     0,
100     0, -1./4,  -1./8,   1./4,  1./8,     0,
101     0,  1./4,      0, -5./16,     0, 1./16
102 };
103 
104 constexpr float B_4x4_3x3[] = {
105    1./4,      0,     0,     0,     0,      0,
106       0,   1./4, -1./4,  1./4, -1./4,   1./4,
107  -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
108       0, -1./16, 1./16, -1./4,  1./4, -5./16,
109   1./16,  1./16, 1./16,  1./8,  1./8,      0,
110       0,      0,     0,     0,     0,  1./16
111 };
112 
113 constexpr float AT_4x4_3x3[] = {
114  1./8,  1./4, 1./4,  1./8, 1./8,    0,
115     0, -1./4, 1./4, -1./4, 1./4,    0,
116     0,  1./4, 1./4,  1./2, 1./2,    0,
117     0, -1./4, 1./4,    -1,    1, 1./2
118 };
119 
120 constexpr float A_4x4_3x3[] = {
121   1./8,     0,    0,     0,
122   1./4, -1./4, 1./4, -1./4,
123   1./4,  1./4, 1./4,  1./4,
124   1./8, -1./4, 1./2,    -1,
125   1./8,  1./4, 1./2,     1,
126      0,     0,    0,  1./2
127 };
128 
129 constexpr float G_2x2_5x5[] = {
130      1,     0,      0,      0,      0,
131   1./6, -1./6,   1./6,  -1./6,   1./6,
132  -1./6, -1./6,  -1./6,  -1./6,  -1./6,
133 -4./15, 2./15, -1./15,  1./30, -1./60,
134  1./60, 1./30,  1./15,  2./15,  4./15,
135      0,     0,      0,      0,      1
136 };
137 
138 constexpr float GT_2x2_5x5[] = {
139    1,  1./6, -1./6, -4./15, 1./60, 0,
140    0, -1./6, -1./6,  2./15, 1./30, 0,
141    0,  1./6, -1./6, -1./15, 1./15, 0,
142    0, -1./6, -1./6,  1./30, 2./15, 0,
143    0,  1./6, -1./6, -1./60, 4./15, 1
144 };
145 
146 constexpr float BT_2x2_5x5[] = {
147  1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
148     0,   1./8,  1./16,  -5./16,   1./8,    0,
149     0,  -1./8, -5./16,  -1./16,   1./8,    0,
150     0,   1./4,  -1./8,   -1./4,   1./8,    0,
151     0,  -1./8,  -1./4,    1./8,   1./4,    0,
152     0,   1./8,  3./16,   -1./4, -3./16, 1./8
153 };
154 
155 constexpr float B_2x2_5x5[] = {
156    1./8,      0,      0,     0,     0,      0,
157   3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
158   -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
159  -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
160    1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
161       0,      0,      0,     0,     0,   1./8
162 };
163 
164 constexpr float AT_2x2_5x5[] = {
165   1./2,  1, 1,  2, 1,    0,
166      0, -1, 1, -1, 2, 1./2
167 };
168 
169 constexpr float A_2x2_5x5[] = {
170  1./2,    0,
171     1,   -1,
172     1,    1,
173     2,   -1,
174     1,    2,
175     0, 1./2
176 };
177 // clang-format on
178 
179 using TransformMapKeyTy = std::pair<int, int>;
180 
181 /// We use F(m, r) to define the size of minimal filtering algorithms.
182 /// m is the output dimension and r is the filter dimension. We can get
183 /// the input dimension, alpha, from the formula, alpha = m + r - 1.
184 ///
185 /// For example, when m = 2 and r = 3, we know its input size is 4.
186 /// The Conv2D will operate on 4x4 input data with 3x3 filter and get
187 /// 2x2 output result.
188 constexpr TransformMapKeyTy F_2_3{2, 3};
189 constexpr TransformMapKeyTy F_4_3{4, 3};
190 constexpr TransformMapKeyTy F_2_5{2, 5};
191 
192 /// Structure to keep information of constant transform matrices.
193 struct TransformMatrix {
194   TransformMatrix(const float *table, int64_t rows, int64_t cols,
195                   int64_t scalarFactor = 1)
196       : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
197 
198   const float *table;
199   int64_t rows;
200   int64_t cols;
201   int64_t scalarFactor;
202 };
203 
204 /// Utility function to convert constant array to arith.constant Value.
205 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
206                               TransformMatrix transform, Type type) {
207   ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
208 
209   return builder.create<arith::ConstantOp>(
210       loc, DenseFPElementsAttr::get(
211                RankedTensorType::get(
212                    SmallVector<int64_t>{transform.rows, transform.cols}, type),
213                constVec));
214 }
215 
216 /// Extract height x width data from 4D tensors.
217 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
218                           Value loopNorFIndex, Value loopCorFIndex,
219                           Value heightOffset, Value widthOffset,
220                           int64_t extractHeight, int64_t extractWidth,
221                           int64_t loopNorFIdx, int64_t loopCorFIdx,
222                           int64_t heightIdx, int64_t widthIdx) {
223   auto sourceType = cast<ShapedType>(source.getType());
224   Type elementType = sourceType.getElementType();
225   int64_t srcSize = sourceType.getRank();
226 
227   auto oneIndex = builder.getIndexAttr(1);
228   SmallVector<OpFoldResult> offsets;
229   offsets.resize(srcSize);
230   offsets[loopNorFIdx] = loopNorFIndex;
231   offsets[loopCorFIdx] = loopCorFIndex;
232   offsets[heightIdx] = heightOffset;
233   offsets[widthIdx] = widthOffset;
234   SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
235   sizes[heightIdx] = builder.getIndexAttr(extractHeight);
236   sizes[widthIdx] = builder.getIndexAttr(extractWidth);
237   SmallVector<OpFoldResult> strides(srcSize, oneIndex);
238 
239   auto extractFilterType =
240       RankedTensorType::get({extractHeight, extractWidth}, elementType);
241   auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
242       loc, extractFilterType, source, offsets, sizes, strides);
243 
244   return extractFilterOp;
245 }
246 
247 /// Extract height x width data from 6D tensors.
248 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
249                           Value tileHIndex, Value tileWIndex,
250                           Value loopNorFIndex, Value loopCorFIndex,
251                           int64_t tileHIdx, int64_t tileWIdx,
252                           int64_t loopNorFIdx, int64_t loopCorFIdx,
253                           int64_t heightIdx, int64_t widthIdx) {
254   auto sourceType = cast<ShapedType>(source.getType());
255   Type elementType = sourceType.getElementType();
256   auto sourceShape = sourceType.getShape();
257   int64_t srcSize = sourceType.getRank();
258   int64_t height = sourceShape[heightIdx];
259   int64_t width = sourceShape[widthIdx];
260 
261   auto zeroIndex = builder.getIndexAttr(0);
262   auto oneIndex = builder.getIndexAttr(1);
263   SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
264   offsets.resize(srcSize);
265   offsets[tileHIdx] = tileHIndex;
266   offsets[tileWIdx] = tileWIndex;
267   offsets[loopNorFIdx] = loopNorFIndex;
268   offsets[loopCorFIdx] = loopCorFIndex;
269   SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
270   sizes[heightIdx] = builder.getIndexAttr(height);
271   sizes[widthIdx] = builder.getIndexAttr(width);
272   SmallVector<OpFoldResult> strides(srcSize, oneIndex);
273 
274   auto extractFilterType = RankedTensorType::get({height, width}, elementType);
275   auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
276       loc, extractFilterType, source, offsets, sizes, strides);
277 
278   return extractFilterOp;
279 }
280 
281 /// Insert transformed height x width data to 4D tensors which it is
282 /// extracted from.
283 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
284                        Value dest, Value loopNorFIndex, Value loopCorFIndex,
285                        Value heightOffset, Value widthOffset, int64_t height,
286                        int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
287                        int64_t heightIdx, int64_t widthIdx) {
288   int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
289   auto oneIndex = builder.getIndexAttr(1);
290   SmallVector<OpFoldResult> retOffsets;
291   retOffsets.resize(destSize);
292   retOffsets[loopNorFIdx] = loopNorFIndex;
293   retOffsets[loopCorFIdx] = loopCorFIndex;
294   retOffsets[heightIdx] = heightOffset;
295   retOffsets[widthIdx] = widthOffset;
296   SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
297   retSizes[heightIdx] = builder.getIndexAttr(height);
298   retSizes[widthIdx] = builder.getIndexAttr(width);
299   SmallVector<OpFoldResult> strides(destSize, oneIndex);
300 
301   auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
302       loc, source, dest, retOffsets, retSizes, strides);
303 
304   return insertSliceOp;
305 }
306 
307 /// Insert transformed height x width data to 6D tensors which it is
308 /// extracted from.
309 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
310                        Value dest, Value tileHIndex, Value tileWIndex,
311                        Value loopNorFIndex, Value loopCorFIndex, int64_t height,
312                        int64_t width, int64_t tileHIdx, int64_t tileWIdx,
313                        int64_t loopNorFIdx, int64_t loopCorFIdx,
314                        int64_t heightIdx, int64_t widthIdx) {
315   int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
316   auto zeroIndex = builder.getIndexAttr(0);
317   auto oneIndex = builder.getIndexAttr(1);
318   SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
319   retOffsets.resize(destSize);
320   retOffsets[tileHIdx] = tileHIndex;
321   retOffsets[tileWIdx] = tileWIndex;
322   retOffsets[loopNorFIdx] = loopNorFIndex;
323   retOffsets[loopCorFIdx] = loopCorFIndex;
324   SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
325   retSizes[heightIdx] = builder.getIndexAttr(height);
326   retSizes[widthIdx] = builder.getIndexAttr(width);
327   SmallVector<OpFoldResult> strides(destSize, oneIndex);
328 
329   auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
330       loc, source, dest, retOffsets, retSizes, strides);
331 
332   return insertSliceOp;
333 }
334 
335 /// This function transforms the filter. The data layout of the filter is FHWC.
336 /// The transformation matrix is 2-dimension. We need to extract H x W from
337 /// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
338 /// After the transformation, we get
339 ///
340 /// scf.for %f = lo_f to hi_f step 1
341 ///   scf.for %c = lo_c to hi_c step 1
342 ///     %extracted = extract filter<h x w> from filter<f x h x w x c>
343 ///     %ret = linalg.matmul G, %extracted
344 ///     %ret = linalg.matmul %ret, GT
345 ///     %inserted = insert %ret into filter<h x w x c x f>
346 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
347                       Value retValue, int64_t m, int64_t r,
348                       bool leftTransform = true, bool rightTransform = true) {
349   // Map from (m, r) to G transform matrix.
350   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
351       GMatrices = {
352           {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
353           {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
354           {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
355       };
356 
357   // Map from (m, r) to GT transform matrix.
358   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
359       GTMatrices = {
360           {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
361           {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
362           {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
363       };
364 
365   auto filterType = cast<ShapedType>(filter.getType());
366   Type elementType = filterType.getElementType();
367   auto filterShape = filterType.getShape(); // F, H, W, C
368   int64_t filterF = filterShape[0];
369   int64_t filterH = filterShape[1];
370   int64_t filterW = filterShape[2];
371   int64_t filterC = filterShape[3];
372 
373   if (filterH != r && filterH != 1)
374     return Value();
375   if (filterW != r && filterW != 1)
376     return Value();
377 
378   Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
379   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
380                        ValueRange args) -> scf::ValueVector {
381     Value FIter = ivs[0];
382     Value CIter = ivs[1];
383 
384     // Extract (H, W) from (F, H, W, C).
385     auto extractFilter =
386         extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
387                             zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
388                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
389 
390     TransformMapKeyTy key = {m, r};
391     int64_t retRows = 1;
392     Value matmulRetValue = extractFilter;
393     Value zero = builder.create<arith::ConstantOp>(
394         loc, rewriter.getZeroAttr(elementType));
395     if (leftTransform) {
396       // Get constant transform matrix G.
397       auto it = GMatrices.find(key);
398       if (it == GMatrices.end())
399         return {};
400       const TransformMatrix &GMatrix = it->second;
401 
402       retRows = GMatrix.rows;
403       auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
404       auto empty =
405           builder
406               .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
407               .getResult();
408       auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
409 
410       Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
411       // Multiply G x g.
412       auto matmulOp = builder.create<linalg::MatmulOp>(
413           loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
414       matmulRetValue = matmulOp.getResult(0);
415     }
416 
417     if (rightTransform) {
418       // Get constant transform matrix GT.
419       auto it = GTMatrices.find(key);
420       if (it == GTMatrices.end())
421         return {};
422       const TransformMatrix &GTMatrix = it->second;
423 
424       auto matmulType =
425           RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
426       auto empty =
427           builder
428               .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
429               .getResult();
430       auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
431 
432       Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
433       // Multiply u = (G x g) x GT.
434       auto matmulOp = builder.create<linalg::MatmulOp>(
435           loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
436       matmulRetValue = matmulOp.getResult(0);
437     }
438 
439     // Insert (H, W) to (H, W, C, F).
440     int64_t retHeight = leftTransform ? m + r - 1 : 1;
441     int64_t retWidth = rightTransform ? m + r - 1 : 1;
442 
443     auto insertSliceOp =
444         insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
445                          zeroIdx, zeroIdx, retHeight, retWidth,
446                          /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
447                          /*heightIdx=*/0, /*widthIdx=*/1);
448 
449     return {insertSliceOp};
450   };
451 
452   auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
453   auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
454   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
455   scf::LoopNest loops = scf::buildLoopNest(
456       rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
457       {oneStep, oneStep}, {retValue}, buildBody);
458   return loops.results[0];
459 }
460 
461 /// This function transforms the input. The data layout of the input is NHWC.
462 /// The transformation matrix is 2-dimension. We need to extract H x W from
463 /// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
464 /// After the transformation, we get
465 ///
466 /// scf.for %h = 0 to tileH step 1
467 ///   scf.for %w = 0 to tileW step 1
468 ///     scf.for %n = 0 to N step 1
469 ///       scf.for %c = 0 to C step 1
470 ///         %extracted = extract %extracted<alphaH x alphaW> from
471 ///                              %input<N x H x W x C>
472 ///                              at [%n, (%h x m), (%w x m), %c]
473 ///         %ret = linalg.matmul BT, %extracted
474 ///         %ret = linalg.matmul %ret, B
475 ///         %inserted = insert %ret<alphaH x alphaW> into
476 ///                            %output<alphaH x alphaW x tileH x tileW x N x C>
477 ///                            at [0, 0, %h, %w, %n, %c]
478 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
479                      Value retValue, int64_t m, int64_t r,
480                      bool leftTransform = true, bool rightTransform = true) {
481   // Map from (m, r) to BT transform matrix.
482   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
483       BTMatrices = {
484           {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
485           {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
486           {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
487       };
488 
489   // Map from (m, r) to B transform matrix.
490   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
491       BMatrices = {
492           {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
493           {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
494           {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
495       };
496 
497   auto inputType = cast<ShapedType>(input.getType());
498   Type elementType = inputType.getElementType();
499   auto inputShape = inputType.getShape(); // N, H, W, C
500   int64_t inputN = inputShape[0];
501   int64_t inputC = inputShape[3];
502   auto valueType = cast<ShapedType>(retValue.getType());
503   auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
504   int64_t tileH = valueShape[2];
505   int64_t tileW = valueShape[3];
506   int64_t alphaH = leftTransform ? m + r - 1 : 1;
507   int64_t alphaW = rightTransform ? m + r - 1 : 1;
508 
509   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
510                        ValueRange args) -> scf::ValueVector {
511     Value tileHIter = ivs[0];
512     Value tileWIter = ivs[1];
513     Value NIter = ivs[2];
514     Value CIter = ivs[3];
515 
516     auto context = builder.getContext();
517 
518     auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
519     auto affineMap =
520         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
521     Value heightOffset = builder.create<affine::AffineApplyOp>(
522         loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
523     Value widthOffset = builder.create<affine::AffineApplyOp>(
524         loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
525 
526     // Extract (H, W) from (N, H, W, C).
527     auto extractInput =
528         extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
529                             widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
530                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
531 
532     TransformMapKeyTy key = {m, r};
533     int64_t retRows = 1;
534     int64_t retCols = 1;
535     Value matmulRetValue = extractInput;
536     Value zero = builder.create<arith::ConstantOp>(
537         loc, rewriter.getZeroAttr(elementType));
538     if (leftTransform) {
539       // Get constant transform matrix BT.
540       auto it = BTMatrices.find(key);
541       if (it == BTMatrices.end())
542         return {};
543       const TransformMatrix &BTMatrix = it->second;
544 
545       retRows = BTMatrix.rows;
546       auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
547       auto empty =
548           builder
549               .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
550               .getResult();
551       auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
552 
553       Value BT =
554           create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
555       // Multiply BT x d.
556       auto matmulOp = builder.create<linalg::MatmulOp>(
557           loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
558       matmulRetValue = matmulOp.getResult(0);
559     }
560 
561     if (rightTransform) {
562       // Get constant transform matrix B.
563       auto it = BMatrices.find(key);
564       if (it == BMatrices.end())
565         return {};
566       const TransformMatrix &BMatrix = it->second;
567 
568       retCols = BMatrix.cols;
569       auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
570       auto empty =
571           builder
572               .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
573               .getResult();
574       auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
575       Value B =
576           create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
577       // Multiply v = (BT x d) x B.
578       auto matmulOp = builder.create<linalg::MatmulOp>(
579           loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
580       matmulRetValue = matmulOp.getResult(0);
581     }
582 
583     // Insert (H, W) to (H, W, tileH, tileW, N, C).
584     auto combinedVal = insert2DDataTo6D(
585         builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
586         CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
587         /*heightIdx=*/0, /*widthIdx=*/1);
588 
589     return {combinedVal};
590   };
591 
592   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
593   auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
594   auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
595   auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
596   auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
597   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
598   scf::LoopNest loops = scf::buildLoopNest(
599       rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
600       {tileHBound, tileWBound, nUpperBound, cUpperBound},
601       {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
602   return loops.results[0];
603 }
604 
605 /// This function generates linalg.batch_matmul to multiply input with filter.
606 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
607 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
608 /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
609 /// way, we can convert 6-dimensional inputs to 3-dimensional representation
610 /// that is suitable for linalg.batch_matmul.
611 ///
612 /// Batched matmul will do the matrix multiply with the reduction on channel.
613 ///
614 /// We get
615 ///
616 /// %collapsed_input = tensor.collapse_shape %input
617 /// %collapsed_filter = tensor.collapse_shape %filter
618 /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
619 /// %expanded_ret = tensor.expand_shape %ret
620 ///
621 /// After this function, we get return value with data layout
622 /// (tileH, tileW, H, W, N, F).
623 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
624                             Value transformedFilter, Value transformedInput,
625                             Type outputElementType) {
626   // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
627   auto filterType = cast<ShapedType>(transformedFilter.getType());
628   assert(filterType.hasStaticShape() && "only support static shapes.");
629   ArrayRef<int64_t> filterShape = filterType.getShape();
630   Type filterElementType = filterType.getElementType();
631   auto filterReassocType = RankedTensorType::get(
632       {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
633       filterElementType);
634   SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
635   Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
636       loc, filterReassocType, transformedFilter, filterReassoc);
637 
638   // Convert (alphaH, alphaW, tileH, tileW, N, C) to
639   // (alphaH x alphaW, tileH x tileW x N, C) for input.
640   auto inputType = cast<ShapedType>(transformedInput.getType());
641   assert(inputType.hasStaticShape() && "only support static shapes.");
642   ArrayRef<int64_t> inputShape = inputType.getShape();
643   Type inputElementType = inputType.getElementType();
644   auto inputReassocType = RankedTensorType::get(
645       {inputShape[0] * inputShape[1],
646        inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
647       inputElementType);
648   SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
649   Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
650       loc, inputReassocType, transformedInput, inputReassoc);
651 
652   // Batched matrix multiply.
653   auto matmulType = RankedTensorType::get(
654       {inputShape[0] * inputShape[1],
655        inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
656       outputElementType);
657   Value empty = rewriter
658                     .create<tensor::EmptyOp>(loc, matmulType.getShape(),
659                                              outputElementType)
660                     .getResult();
661   Value zero = rewriter.create<arith::ConstantOp>(
662       loc, rewriter.getZeroAttr(outputElementType));
663   Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
664 
665   auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
666       loc, matmulType, ValueRange({collapseInput, collapseFilter}),
667       ValueRange{init});
668 
669   // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
670   // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
671   SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
672   auto outputReassocType =
673       RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
674                              inputShape[3], inputShape[4], filterShape[3]},
675                             outputElementType);
676   auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
677       loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
678   return expandOutput;
679 }
680 
681 /// This function transforms the output. The data layout of the output is HWNF.
682 /// The transformation matrix is 2-dimension. We need to extract H x W from
683 /// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
684 /// After the transformation, we get
685 ///
686 /// scf.for %h = 0 to tileH step 1
687 ///   scf.for %w = 0 to tileW step 1
688 ///     scf.for %n = 0 to N step 1
689 ///       scf.for %f = 0 to F step 1
690 ///         %extracted = extract %extracted<alphaH x alphaW> from
691 ///                              %input<alphaH x alphaW x tileH x tileW x N x F>
692 ///                              at [0, 0, %h, %w, %n, %f]
693 ///         %ret = linalg.matmul AT, %extracted
694 ///         %ret = linalg.matmul %ret, A
695 ///         %inserted = insert %ret<alphaH x alphaW> into
696 ///                            output<N x H x W x F>
697 ///                            at [%n, (%h x m), (%w x m), %f]
698 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
699                       Value output, int64_t m, int64_t r,
700                       bool leftTransform = true, bool rightTransform = true) {
701   // Map from (m, r) to AT transform matrix.
702   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
703       ATMatrices = {
704           {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
705           {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
706           {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
707       };
708 
709   // Map from (m, r) to A transform matrix.
710   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
711       AMatrices = {
712           {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
713           {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
714           {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
715       };
716 
717   auto valueType = cast<ShapedType>(value.getType());
718   Type elementType = valueType.getElementType();
719   auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
720   int64_t valueH = valueShape[0];
721   int64_t valueW = valueShape[1];
722   int64_t valueN = valueShape[4];
723   int64_t valueF = valueShape[5];
724   int64_t alphaH = leftTransform ? m + r - 1 : 1;
725   int64_t alphaW = rightTransform ? m + r - 1 : 1;
726 
727   if (valueH != alphaH && valueH != 1)
728     return Value();
729   if (valueW != alphaW && valueW != 1)
730     return Value();
731 
732   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
733                        ValueRange args) -> scf::ValueVector {
734     auto context = builder.getContext();
735     Value tileHIter = ivs[0];
736     Value tileWIter = ivs[1];
737     Value NIter = ivs[2];
738     Value FIter = ivs[3];
739 
740     // Extract (H, W) from (H, W, tileH, tileW, N, F).
741     auto extractValue =
742         extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
743                             FIter, 2, 3, /*loopNorFIdx=*/4,
744                             /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
745 
746     const TransformMapKeyTy key = {m, r};
747     const TransformMatrix &AMatrix = AMatrices.at(key);
748     const TransformMatrix &ATMatrix = ATMatrices.at(key);
749     int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
750                            (leftTransform ? ATMatrix.scalarFactor : 1);
751     int64_t retCols = rightTransform ? AMatrix.cols : 1;
752     int64_t retRows = leftTransform ? ATMatrix.rows : 1;
753 
754     Value matmulRetValue = extractValue;
755     Value zero = builder.create<arith::ConstantOp>(
756         loc, rewriter.getZeroAttr(elementType));
757 
758     auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
759     auto affineMap =
760         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
761     Value heightOffset = builder.create<affine::AffineApplyOp>(
762         loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
763     Value widthOffset = builder.create<affine::AffineApplyOp>(
764         loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
765 
766     Value outInitVal =
767         extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
768                             widthOffset, retRows, retCols,
769                             /*loopNorFIdx=*/0,
770                             /*loopCorFIdx=*/3, /*heightIdx=*/1,
771                             /*widthIdx=*/2);
772     if (leftTransform) {
773       auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
774       Value init = outInitVal;
775       if (rightTransform || scalarFactor != 1) {
776         auto empty = builder
777                          .create<tensor::EmptyOp>(loc, matmulType.getShape(),
778                                                   elementType)
779                          .getResult();
780         init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
781       }
782 
783       Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
784       // Multiply AT x m.
785       auto matmulOp = builder.create<linalg::MatmulOp>(
786           loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
787       matmulRetValue = matmulOp.getResult(0);
788     }
789 
790     if (rightTransform) {
791       auto matmulType =
792           RankedTensorType::get({retRows, AMatrix.cols}, elementType);
793       Value init = outInitVal;
794       if (scalarFactor != 1) {
795         auto empty = builder
796                          .create<tensor::EmptyOp>(loc, matmulType.getShape(),
797                                                   elementType)
798                          .getResult();
799         init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
800       }
801 
802       Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
803       // Multiply y = (AT x m) x A.
804       auto matmulOp = builder.create<linalg::MatmulOp>(
805           loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
806       matmulRetValue = matmulOp.getResult(0);
807     }
808 
809     if (scalarFactor != 1) {
810       // Multiply by scalar factor and add outInitVal.
811       Value scalarFactorValue = builder.create<arith::ConstantOp>(
812           loc, FloatAttr::get(elementType, scalarFactor));
813       auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
814       auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
815       SmallVector<AffineMap> affineMaps = {
816           AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
817 
818       matmulRetValue =
819           rewriter
820               .create<linalg::GenericOp>(
821                   loc, matmulType,
822                   ValueRange{scalarFactorValue, matmulRetValue},
823                   ValueRange{outInitVal}, affineMaps,
824                   llvm::ArrayRef<utils::IteratorType>{
825                       utils::IteratorType::parallel,
826                       utils::IteratorType::parallel},
827                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
828                       ValueRange args) {
829                     auto mulf = nestedBuilder.create<arith::MulFOp>(
830                         nestedLoc, args[0], args[1]);
831                     auto addf = nestedBuilder.create<arith::AddFOp>(
832                         nestedLoc, mulf.getResult(), args[2]);
833                     nestedBuilder.create<linalg::YieldOp>(nestedLoc,
834                                                           addf.getResult());
835                   })
836               .getResult(0);
837     }
838 
839     // Insert (H, W) to (N, H, W, F).
840     Value combinedVal =
841         insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
842                          heightOffset, widthOffset, retRows, retCols,
843                          /*loopNorFIdx=*/0,
844                          /*loopCorFIdx=*/3, /*heightIdx=*/1,
845                          /*widthIdx=*/2);
846 
847     return {combinedVal};
848   };
849 
850   int64_t tilwH = valueShape[2];
851   int64_t tileW = valueShape[3];
852   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
853   auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
854   auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
855   auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
856   auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
857   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
858   scf::LoopNest loops = scf::buildLoopNest(
859       rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
860       {tileHBound, tileWBound, nUpperBound, fUpperBound},
861       {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
862   return loops.results[0];
863 }
864 
865 /// Create an empty tensor with alignedType and insert the value into the
866 /// created empty tensor with aligned size.
867 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
868                                 Value value, ArrayRef<int64_t> alignedShape) {
869   auto valueType = cast<ShapedType>(value.getType());
870   Type elementType = valueType.getElementType();
871   auto alignedType = RankedTensorType::get(alignedShape, elementType);
872   Value padValue = rewriter.create<arith::ConstantOp>(
873       loc, elementType, rewriter.getZeroAttr(elementType));
874 
875   return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
876                                        padValue, false);
877 }
878 
879 /// Extract sub-tensor with extractedType from value.
880 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
881                                       Value value,
882                                       RankedTensorType extractedType) {
883   OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
884   OpFoldResult oneIndex = rewriter.getIndexAttr(1);
885   SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
886   SmallVector<OpFoldResult, 4> strides(4, oneIndex);
887 
888   ArrayRef<int64_t> extractedShape = extractedType.getShape();
889   SmallVector<OpFoldResult> sizes =
890       getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
891 
892   return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
893                                                  offsets, sizes, strides);
894 }
895 
896 /// Utility function to check all values in the attribute are 1.
897 static bool hasAllOneValues(DenseIntElementsAttr attr) {
898   return llvm::all_of(
899       attr, [](const APInt &element) { return element.getSExtValue() == 1; });
900 }
901 
902 /// A helper function to convert linalg.conv_2d_nhwc_fhwc to
903 /// linalg.winograd_*_transform ops.
904 static FailureOr<Operation *>
905 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
906                      int64_t m, int64_t r) {
907   Value input = convOp.getInputs()[0];
908   Value filter = convOp.getInputs()[1];
909   Value output = convOp.getOutputs()[0];
910   auto inputType = cast<ShapedType>(input.getType());
911   auto filterType = cast<ShapedType>(filter.getType());
912   auto outputType = cast<ShapedType>(output.getType());
913 
914   if (!inputType.hasStaticShape())
915     return rewriter.notifyMatchFailure(convOp,
916                                        "expected a static shape for the input");
917 
918   if (!filterType.hasStaticShape())
919     return rewriter.notifyMatchFailure(
920         convOp, "expected a static shape for the filter");
921 
922   if (!hasAllOneValues(convOp.getDilations()))
923     return rewriter.notifyMatchFailure(convOp,
924                                        "expected all ones for dilations");
925 
926   if (!hasAllOneValues(convOp.getStrides()))
927     return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
928 
929   ArrayRef<int64_t> filterShape = filterType.getShape();
930   int64_t filterF = filterShape[0];
931   int64_t filterH = filterShape[1];
932   int64_t filterW = filterShape[2];
933   int64_t filterC = filterShape[3];
934   ArrayRef<int64_t> inputShape = inputType.getShape();
935   int64_t inputN = inputShape[0];
936   int64_t inputH = inputShape[1];
937   int64_t inputW = inputShape[2];
938   int64_t inputC = inputShape[3];
939   ArrayRef<int64_t> outputShape = outputType.getShape();
940   int64_t outputN = outputShape[0];
941   int64_t outputH = outputShape[1];
942   int64_t outputW = outputShape[2];
943   int64_t outputF = outputShape[3];
944 
945   // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
946   bool isSupportedFilter = false;
947   if (filterH == filterW && filterH == r)
948     isSupportedFilter = true;
949   if (filterH == r && filterW == 1)
950     isSupportedFilter = true;
951   if (filterH == 1 && filterW == r)
952     isSupportedFilter = true;
953 
954   if (!isSupportedFilter)
955     return rewriter.notifyMatchFailure(
956         convOp, "only support filter (r x r), (r x 1) or (1 x r)");
957 
958   // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
959   static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
960       F_2_3, F_4_3, F_2_5};
961 
962   TransformMapKeyTy key = {m, r};
963   auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
964   // If we cannot find the constant transformation matrix, it means we do
965   // not support this configuration yet.
966   if (it == validConfigs.end())
967     return failure();
968 
969   // All the criterias are satisfied. We can do Winograd Conv2D.
970   Location loc = convOp.getLoc();
971 
972   // For F(m x 1, r x 1), we only need to do left side transform.
973   bool leftTransform = filterH != 1;
974   // For F(1 x m, 1 x r), we only need to do right side transform.
975   bool rightTransform = filterW != 1;
976   int64_t heightM = leftTransform ? m : 1;
977   int64_t widthM = rightTransform ? m : 1;
978   int64_t heightR = leftTransform ? r : 1;
979   int64_t widthR = rightTransform ? r : 1;
980 
981   // --- Create operation for filter transform ---
982   Type filterElementType = filterType.getElementType();
983   int64_t alphaH = heightM + heightR - 1;
984   int64_t alphaW = widthM + widthR - 1;
985   int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
986   int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
987   auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
988                                        filterElementType);
989   Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
990                                                     filterElementType);
991   auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
992       loc, retType, filter, retValue, m, r);
993 
994   // --- Create operation for input transform ---
995 
996   // When input size - (r - 1) is not aligned with output tile size, we need to
997   // pad the input data to create the full tiles as tiling.
998   Type inputElementType = inputType.getElementType();
999   int64_t alignedInputH = tileH * heightM + (heightR - 1);
1000   int64_t alignedInputW = tileW * widthM + (widthR - 1);
1001   if (alignedInputH != inputH || alignedInputW != inputW) {
1002     input = padToAlignedTensor(rewriter, loc, input,
1003                                {inputN, alignedInputH, alignedInputW, inputC});
1004   }
1005 
1006   retType = RankedTensorType::get(
1007       {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1008   retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
1009                                               inputElementType);
1010   auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
1011       loc, retType, input, retValue, m, r);
1012 
1013   Type outputElementType = outputType.getElementType();
1014   Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1015                                    transformedInput, outputElementType);
1016 
1017   // --- Create operation for output transform ---
1018 
1019   // When output size is not aligned with output tile size, we need to pad the
1020   // output buffer to insert the full tiles after tiling.
1021   int64_t alignedOutputH = tileH * heightM;
1022   int64_t alignedOutputW = tileW * widthM;
1023   bool isOutputUnaligned =
1024       ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1025   if (isOutputUnaligned) {
1026     auto alignedOutputType = RankedTensorType::get(
1027         {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1028     output =
1029         padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1030     outputType = alignedOutputType;
1031   }
1032 
1033   Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
1034       loc, outputType, matmulRet, output, m, r);
1035 
1036   // When output size is not aligned with output tile size, extract the
1037   // value from the padded buffer.
1038   if (isOutputUnaligned) {
1039     transformedOutput = extractFromAlignedTensor(
1040         rewriter, loc, transformedOutput,
1041         RankedTensorType::get({outputN, outputH, outputW, outputF},
1042                               outputElementType));
1043   }
1044 
1045   rewriter.replaceOp(convOp, transformedOutput);
1046 
1047   return transformedOutput.getDefiningOp();
1048 }
1049 
1050 /// A helper function to decompose linalg.winograd_filter_transform.
1051 FailureOr<Operation *>
1052 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1053                                        linalg::WinogradFilterTransformOp op) {
1054   Location loc = op.getLoc();
1055   Value filter = op.getFilter();
1056   auto filterType = cast<ShapedType>(filter.getType());
1057   auto filterShape = filterType.getShape();
1058   int64_t filterH = filterShape[1];
1059   int64_t filterW = filterShape[2];
1060 
1061   // For F(m x 1, r x 1), we only need to do left side transform.
1062   bool leftTransform = filterH != 1;
1063   // For F(1 x m, 1 x r), we only need to do right side transform.
1064   bool rightTransform = filterW != 1;
1065   Value transformedFilter =
1066       filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
1067                       op.getR(), leftTransform, rightTransform);
1068   if (!transformedFilter)
1069     return failure();
1070 
1071   rewriter.replaceOp(op, transformedFilter);
1072 
1073   return transformedFilter.getDefiningOp();
1074 }
1075 
1076 /// A helper function to decompose linalg.winograd_input_transform.
1077 FailureOr<Operation *>
1078 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1079                                       linalg::WinogradInputTransformOp op) {
1080   Location loc = op.getLoc();
1081   Value output = op.getOutput();
1082   auto outputType = cast<ShapedType>(output.getType());
1083   auto outputShape = outputType.getShape();
1084 
1085   int64_t outputH = outputShape[0];
1086   int64_t outputW = outputShape[1];
1087 
1088   // For F(m x 1, r x 1), we only need to do left side transform.
1089   bool leftTransform = outputH != 1;
1090   // For F(1 x m, 1 x r), we only need to do right side transform.
1091   bool rightTransform = outputW != 1;
1092   Value transformedInput =
1093       inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
1094                      op.getR(), leftTransform, rightTransform);
1095   if (!transformedInput)
1096     return failure();
1097 
1098   rewriter.replaceOp(op, transformedInput);
1099 
1100   return transformedInput.getDefiningOp();
1101 }
1102 
1103 /// A helper function to decompose linalg.winograd_output_transform.
1104 FailureOr<Operation *>
1105 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1106                                        linalg::WinogradOutputTransformOp op) {
1107   Location loc = op.getLoc();
1108   Value value = op.getValue();
1109   auto valueType = cast<ShapedType>(value.getType());
1110   auto valueShape = valueType.getShape();
1111   int64_t valueH = valueShape[0];
1112   int64_t valueW = valueShape[1];
1113 
1114   // For F(m x 1, r x 1), we only need to do left side transform.
1115   bool leftTransform = valueH != 1;
1116   // For F(1 x m, 1 x r), we only need to do right side transform.
1117   bool rightTransform = valueW != 1;
1118   Value transformedOutput =
1119       outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
1120                       op.getR(), leftTransform, rightTransform);
1121   if (!transformedOutput)
1122     return failure();
1123 
1124   rewriter.replaceOp(op, transformedOutput);
1125 
1126   return transformedOutput.getDefiningOp();
1127 }
1128 
1129 /// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
1130 class DecomposeWinogradFilterTransform final
1131     : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1132 public:
1133   using OpRewritePattern::OpRewritePattern;
1134 
1135   LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1136                                 PatternRewriter &rewriter) const override {
1137     return decomposeWinogradFilterTransformHelper(rewriter, op);
1138   }
1139 };
1140 
1141 /// A rewrite pattern to decompose linalg.winograd_input_transform operations.
1142 class DecomposeWinogradInputTransform final
1143     : public OpRewritePattern<linalg::WinogradInputTransformOp> {
1144 public:
1145   using OpRewritePattern::OpRewritePattern;
1146 
1147   LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1148                                 PatternRewriter &rewriter) const override {
1149     return decomposeWinogradInputTransformHelper(rewriter, op);
1150   }
1151 };
1152 
1153 /// A rewrite pattern to decompose linalg.winograd_output_transform operations.
1154 class DecomposeWinogradOutputTransform final
1155     : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1156 public:
1157   using OpRewritePattern::OpRewritePattern;
1158 
1159   LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1160                                 PatternRewriter &rewriter) const override {
1161     return decomposeWinogradOutputTransformHelper(rewriter, op);
1162   }
1163 };
1164 
1165 /// A rewrite pattern for Winograd Conv2D algorithm.
1166 class WinogradConv2DNhwcFhwc final
1167     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1168 public:
1169   using OpRewritePattern::OpRewritePattern;
1170   WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
1171       : OpRewritePattern(context), m(m), r(r) {}
1172 
1173   LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1174                                 PatternRewriter &rewriter) const override {
1175     if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
1176       return failure();
1177 
1178     return success();
1179   }
1180 
1181 private:
1182   int64_t m;
1183   int64_t r;
1184 };
1185 } // end anonymous namespace
1186 
1187 //===----------------------------------------------------------------------===//
1188 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1189                                       linalg::Conv2DNhwcFhwcOp op, int64_t m,
1190                                       int64_t r) {
1191   return winogradConv2DHelper(rewriter, op, m, r);
1192 }
1193 
1194 FailureOr<Operation *>
1195 decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1196                                    linalg::WinogradFilterTransformOp op) {
1197   return decomposeWinogradFilterTransformHelper(rewriter, op);
1198 }
1199 
1200 FailureOr<Operation *>
1201 decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1202                                   linalg::WinogradInputTransformOp op) {
1203   return decomposeWinogradInputTransformHelper(rewriter, op);
1204 }
1205 
1206 FailureOr<Operation *>
1207 decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1208                                    linalg::WinogradOutputTransformOp op) {
1209   return decomposeWinogradOutputTransformHelper(rewriter, op);
1210 }
1211 
1212 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1213                                     int64_t r) {
1214   MLIRContext *context = patterns.getContext();
1215   // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
1216   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
1217 }
1218 
1219 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
1220   MLIRContext *context = patterns.getContext();
1221   patterns
1222       .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1223               DecomposeWinogradOutputTransform>(context);
1224 }
1225 
1226 } // end namespace linalg
1227 } // end namespace mlir
1228