xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (revision f20b8e35b3bc276d09a6911746f9d44cbb5de297)
17d246e84SHsiangkai Wang //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
27d246e84SHsiangkai Wang //
37d246e84SHsiangkai Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47d246e84SHsiangkai Wang // See https://llvm.org/LICENSE.txt for license information.
57d246e84SHsiangkai Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67d246e84SHsiangkai Wang //
77d246e84SHsiangkai Wang //===----------------------------------------------------------------------===//
87d246e84SHsiangkai Wang //
97d246e84SHsiangkai Wang // Implement Winograd Conv2D algorithm. The implementation is based on the
107d246e84SHsiangkai Wang // paper: Fast Algorithms for Convolutional Neural Networks
117d246e84SHsiangkai Wang // (https://arxiv.org/abs/1509.09308)
127d246e84SHsiangkai Wang //
137d246e84SHsiangkai Wang //===----------------------------------------------------------------------===//
147d246e84SHsiangkai Wang 
1527ee33d1SHsiangkai Wang #include "mlir/Dialect/Affine/IR/AffineOps.h"
1627ee33d1SHsiangkai Wang #include "mlir/Dialect/Arith/IR/Arith.h"
177d246e84SHsiangkai Wang #include "mlir/Dialect/Linalg/IR/Linalg.h"
187d246e84SHsiangkai Wang #include "mlir/Dialect/Linalg/Utils/Utils.h"
197d246e84SHsiangkai Wang #include "mlir/Dialect/Tensor/IR/Tensor.h"
207d246e84SHsiangkai Wang #include "mlir/Dialect/Utils/StaticValueUtils.h"
2127ee33d1SHsiangkai Wang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
227d246e84SHsiangkai Wang #include "llvm/Support/MathExtras.h"
237d246e84SHsiangkai Wang 
247d246e84SHsiangkai Wang namespace mlir {
257d246e84SHsiangkai Wang namespace linalg {
267d246e84SHsiangkai Wang 
277d246e84SHsiangkai Wang namespace {
287d246e84SHsiangkai Wang 
2927ee33d1SHsiangkai Wang // clang-format off
3027ee33d1SHsiangkai Wang /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
3127ee33d1SHsiangkai Wang /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
3227ee33d1SHsiangkai Wang /// m is the output dimension and r is the filter dimension, is
3327ee33d1SHsiangkai Wang ///
3427ee33d1SHsiangkai Wang /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
3527ee33d1SHsiangkai Wang ///
3627ee33d1SHsiangkai Wang /// g is filter and d is input data. We need to prepare 6 constant
3727ee33d1SHsiangkai Wang /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
3827ee33d1SHsiangkai Wang ///
3927ee33d1SHsiangkai Wang /// The following tables define these constant transformation matrices for
4027ee33d1SHsiangkai Wang /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
4127ee33d1SHsiangkai Wang constexpr float G_2x2_3x3[] = {
4227ee33d1SHsiangkai Wang    -1,     0,   0,
4327ee33d1SHsiangkai Wang  1./2, -1./2, 1./2,
4427ee33d1SHsiangkai Wang  1./2,  1./2, 1./2,
4527ee33d1SHsiangkai Wang     0,     0,    1
4627ee33d1SHsiangkai Wang };
4727ee33d1SHsiangkai Wang 
4827ee33d1SHsiangkai Wang constexpr float GT_2x2_3x3[] = {
4927ee33d1SHsiangkai Wang    -1,  1./2, 1./2, 0,
5027ee33d1SHsiangkai Wang     0, -1./2, 1./2, 0,
5127ee33d1SHsiangkai Wang     0,  1./2, 1./2, 1
5227ee33d1SHsiangkai Wang };
5327ee33d1SHsiangkai Wang 
5427ee33d1SHsiangkai Wang constexpr float BT_2x2_3x3[] = {
5527ee33d1SHsiangkai Wang    -1,    0,   1,   0,
5627ee33d1SHsiangkai Wang     0,   -1,   1,   0,
5727ee33d1SHsiangkai Wang     0,    1,   1,   0,
5827ee33d1SHsiangkai Wang     0,   -1,   0,   1
5927ee33d1SHsiangkai Wang };
6027ee33d1SHsiangkai Wang 
6127ee33d1SHsiangkai Wang constexpr float B_2x2_3x3[] = {
6227ee33d1SHsiangkai Wang    -1,    0,   0,   0,
6327ee33d1SHsiangkai Wang     0,   -1,   1,  -1,
6427ee33d1SHsiangkai Wang     1,    1,   1,   0,
6527ee33d1SHsiangkai Wang     0,    0,   0,   1
6627ee33d1SHsiangkai Wang };
6727ee33d1SHsiangkai Wang 
6827ee33d1SHsiangkai Wang constexpr float AT_2x2_3x3[] = {
6927ee33d1SHsiangkai Wang     1,    1,   1,   0,
7027ee33d1SHsiangkai Wang     0,   -1,   1,   1
7127ee33d1SHsiangkai Wang };
7227ee33d1SHsiangkai Wang 
7327ee33d1SHsiangkai Wang constexpr float A_2x2_3x3[] = {
7427ee33d1SHsiangkai Wang     1,    0,
7527ee33d1SHsiangkai Wang     1,   -1,
7627ee33d1SHsiangkai Wang     1,    1,
7727ee33d1SHsiangkai Wang     0,    1
7827ee33d1SHsiangkai Wang };
7927ee33d1SHsiangkai Wang 
8027ee33d1SHsiangkai Wang constexpr float G_4x4_3x3[] = {
8127ee33d1SHsiangkai Wang      1,     0,     0,
8227ee33d1SHsiangkai Wang  -1./3,  1./3, -1./3,
8327ee33d1SHsiangkai Wang  -1./3, -1./3, -1./3,
8427ee33d1SHsiangkai Wang  1./12, -1./6,  1./3,
8527ee33d1SHsiangkai Wang  1./12,  1./6,  1./3,
8627ee33d1SHsiangkai Wang      0,     0,     1
8727ee33d1SHsiangkai Wang };
8827ee33d1SHsiangkai Wang 
8927ee33d1SHsiangkai Wang constexpr float GT_4x4_3x3[] = {
9027ee33d1SHsiangkai Wang  1,  -1./3, -1./3, 1./12, 1./12, 0,
9127ee33d1SHsiangkai Wang  0,   1./3, -1./3, -1./6,  1./6, 0,
9227ee33d1SHsiangkai Wang  0,  -1./3, -1./3,  1./3,  1./3, 1
9327ee33d1SHsiangkai Wang };
9427ee33d1SHsiangkai Wang 
9527ee33d1SHsiangkai Wang constexpr float BT_4x4_3x3[] = {
9627ee33d1SHsiangkai Wang  1./4,     0, -5./16,      0, 1./16,     0,
9727ee33d1SHsiangkai Wang     0,  1./4,  -1./4, -1./16, 1./16,     0,
9827ee33d1SHsiangkai Wang     0, -1./4,  -1./4,  1./16, 1./16,     0,
9927ee33d1SHsiangkai Wang     0,  1./4,  -1./8,  -1./4,  1./8,     0,
10027ee33d1SHsiangkai Wang     0, -1./4,  -1./8,   1./4,  1./8,     0,
10127ee33d1SHsiangkai Wang     0,  1./4,      0, -5./16,     0, 1./16
10227ee33d1SHsiangkai Wang };
10327ee33d1SHsiangkai Wang 
10427ee33d1SHsiangkai Wang constexpr float B_4x4_3x3[] = {
10527ee33d1SHsiangkai Wang    1./4,      0,     0,     0,     0,      0,
10627ee33d1SHsiangkai Wang       0,   1./4, -1./4,  1./4, -1./4,   1./4,
10727ee33d1SHsiangkai Wang  -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
10827ee33d1SHsiangkai Wang       0, -1./16, 1./16, -1./4,  1./4, -5./16,
10927ee33d1SHsiangkai Wang   1./16,  1./16, 1./16,  1./8,  1./8,      0,
11027ee33d1SHsiangkai Wang       0,      0,     0,     0,     0,  1./16
11127ee33d1SHsiangkai Wang };
11227ee33d1SHsiangkai Wang 
11327ee33d1SHsiangkai Wang constexpr float AT_4x4_3x3[] = {
11427ee33d1SHsiangkai Wang  1./8,  1./4, 1./4,  1./8, 1./8,    0,
11527ee33d1SHsiangkai Wang     0, -1./4, 1./4, -1./4, 1./4,    0,
11627ee33d1SHsiangkai Wang     0,  1./4, 1./4,  1./2, 1./2,    0,
11727ee33d1SHsiangkai Wang     0, -1./4, 1./4,    -1,    1, 1./2
11827ee33d1SHsiangkai Wang };
11927ee33d1SHsiangkai Wang 
12027ee33d1SHsiangkai Wang constexpr float A_4x4_3x3[] = {
12127ee33d1SHsiangkai Wang   1./8,     0,    0,     0,
12227ee33d1SHsiangkai Wang   1./4, -1./4, 1./4, -1./4,
12327ee33d1SHsiangkai Wang   1./4,  1./4, 1./4,  1./4,
12427ee33d1SHsiangkai Wang   1./8, -1./4, 1./2,    -1,
12527ee33d1SHsiangkai Wang   1./8,  1./4, 1./2,     1,
12627ee33d1SHsiangkai Wang      0,     0,    0,  1./2
12727ee33d1SHsiangkai Wang };
12827ee33d1SHsiangkai Wang 
12927ee33d1SHsiangkai Wang constexpr float G_2x2_5x5[] = {
13027ee33d1SHsiangkai Wang      1,     0,      0,      0,      0,
13127ee33d1SHsiangkai Wang   1./6, -1./6,   1./6,  -1./6,   1./6,
13227ee33d1SHsiangkai Wang  -1./6, -1./6,  -1./6,  -1./6,  -1./6,
13327ee33d1SHsiangkai Wang -4./15, 2./15, -1./15,  1./30, -1./60,
13427ee33d1SHsiangkai Wang  1./60, 1./30,  1./15,  2./15,  4./15,
13527ee33d1SHsiangkai Wang      0,     0,      0,      0,      1
13627ee33d1SHsiangkai Wang };
13727ee33d1SHsiangkai Wang 
13827ee33d1SHsiangkai Wang constexpr float GT_2x2_5x5[] = {
13927ee33d1SHsiangkai Wang    1,  1./6, -1./6, -4./15, 1./60, 0,
14027ee33d1SHsiangkai Wang    0, -1./6, -1./6,  2./15, 1./30, 0,
14127ee33d1SHsiangkai Wang    0,  1./6, -1./6, -1./15, 1./15, 0,
14227ee33d1SHsiangkai Wang    0, -1./6, -1./6,  1./30, 2./15, 0,
14327ee33d1SHsiangkai Wang    0,  1./6, -1./6, -1./60, 4./15, 1
14427ee33d1SHsiangkai Wang };
14527ee33d1SHsiangkai Wang 
14627ee33d1SHsiangkai Wang constexpr float BT_2x2_5x5[] = {
14727ee33d1SHsiangkai Wang  1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
14827ee33d1SHsiangkai Wang     0,   1./8,  1./16,  -5./16,   1./8,    0,
14927ee33d1SHsiangkai Wang     0,  -1./8, -5./16,  -1./16,   1./8,    0,
15027ee33d1SHsiangkai Wang     0,   1./4,  -1./8,   -1./4,   1./8,    0,
15127ee33d1SHsiangkai Wang     0,  -1./8,  -1./4,    1./8,   1./4,    0,
15227ee33d1SHsiangkai Wang     0,   1./8,  3./16,   -1./4, -3./16, 1./8
15327ee33d1SHsiangkai Wang };
15427ee33d1SHsiangkai Wang 
15527ee33d1SHsiangkai Wang constexpr float B_2x2_5x5[] = {
15627ee33d1SHsiangkai Wang    1./8,      0,      0,     0,     0,      0,
15727ee33d1SHsiangkai Wang   3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
15827ee33d1SHsiangkai Wang   -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
15927ee33d1SHsiangkai Wang  -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
16027ee33d1SHsiangkai Wang    1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
16127ee33d1SHsiangkai Wang       0,      0,      0,     0,     0,   1./8
16227ee33d1SHsiangkai Wang };
16327ee33d1SHsiangkai Wang 
16427ee33d1SHsiangkai Wang constexpr float AT_2x2_5x5[] = {
16527ee33d1SHsiangkai Wang   1./2,  1, 1,  2, 1,    0,
16627ee33d1SHsiangkai Wang      0, -1, 1, -1, 2, 1./2
16727ee33d1SHsiangkai Wang };
16827ee33d1SHsiangkai Wang 
16927ee33d1SHsiangkai Wang constexpr float A_2x2_5x5[] = {
17027ee33d1SHsiangkai Wang  1./2,    0,
17127ee33d1SHsiangkai Wang     1,   -1,
17227ee33d1SHsiangkai Wang     1,    1,
17327ee33d1SHsiangkai Wang     2,   -1,
17427ee33d1SHsiangkai Wang     1,    2,
17527ee33d1SHsiangkai Wang     0, 1./2
17627ee33d1SHsiangkai Wang };
17727ee33d1SHsiangkai Wang // clang-format on
17827ee33d1SHsiangkai Wang 
1797d246e84SHsiangkai Wang using TransformMapKeyTy = std::pair<int, int>;
1807d246e84SHsiangkai Wang 
1817d246e84SHsiangkai Wang /// We use F(m, r) to define the size of minimal filtering algorithms.
1827d246e84SHsiangkai Wang /// m is the output dimension and r is the filter dimension. We can get
1837d246e84SHsiangkai Wang /// the input dimension, alpha, from the formula, alpha = m + r - 1.
1847d246e84SHsiangkai Wang ///
1857d246e84SHsiangkai Wang /// For example, when m = 2 and r = 3, we know its input size is 4.
1867d246e84SHsiangkai Wang /// The Conv2D will operate on 4x4 input data with 3x3 filter and get
1877d246e84SHsiangkai Wang /// 2x2 output result.
1887d246e84SHsiangkai Wang constexpr TransformMapKeyTy F_2_3{2, 3};
1897d246e84SHsiangkai Wang constexpr TransformMapKeyTy F_4_3{4, 3};
1907d246e84SHsiangkai Wang constexpr TransformMapKeyTy F_2_5{2, 5};
1917d246e84SHsiangkai Wang 
19227ee33d1SHsiangkai Wang /// Structure to keep information of constant transform matrices.
19327ee33d1SHsiangkai Wang struct TransformMatrix {
19427ee33d1SHsiangkai Wang   TransformMatrix(const float *table, int64_t rows, int64_t cols,
19527ee33d1SHsiangkai Wang                   int64_t scalarFactor = 1)
19627ee33d1SHsiangkai Wang       : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
19727ee33d1SHsiangkai Wang 
19827ee33d1SHsiangkai Wang   const float *table;
19927ee33d1SHsiangkai Wang   int64_t rows;
20027ee33d1SHsiangkai Wang   int64_t cols;
20127ee33d1SHsiangkai Wang   int64_t scalarFactor;
20227ee33d1SHsiangkai Wang };
20327ee33d1SHsiangkai Wang 
20427ee33d1SHsiangkai Wang /// Utility function to convert constant array to arith.constant Value.
20527ee33d1SHsiangkai Wang Value create2DTransformMatrix(OpBuilder &builder, Location loc,
20627ee33d1SHsiangkai Wang                               TransformMatrix transform, Type type) {
20727ee33d1SHsiangkai Wang   ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
20827ee33d1SHsiangkai Wang 
20927ee33d1SHsiangkai Wang   return builder.create<arith::ConstantOp>(
21027ee33d1SHsiangkai Wang       loc, DenseFPElementsAttr::get(
21127ee33d1SHsiangkai Wang                RankedTensorType::get(
21227ee33d1SHsiangkai Wang                    SmallVector<int64_t>{transform.rows, transform.cols}, type),
21327ee33d1SHsiangkai Wang                constVec));
21427ee33d1SHsiangkai Wang }
21527ee33d1SHsiangkai Wang 
21627ee33d1SHsiangkai Wang /// Extract height x width data from 4D tensors.
21727ee33d1SHsiangkai Wang Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
21827ee33d1SHsiangkai Wang                           Value loopNorFIndex, Value loopCorFIndex,
21927ee33d1SHsiangkai Wang                           Value heightOffset, Value widthOffset,
22027ee33d1SHsiangkai Wang                           int64_t extractHeight, int64_t extractWidth,
22127ee33d1SHsiangkai Wang                           int64_t loopNorFIdx, int64_t loopCorFIdx,
22227ee33d1SHsiangkai Wang                           int64_t heightIdx, int64_t widthIdx) {
22327ee33d1SHsiangkai Wang   auto sourceType = cast<ShapedType>(source.getType());
22427ee33d1SHsiangkai Wang   Type elementType = sourceType.getElementType();
22527ee33d1SHsiangkai Wang   int64_t srcSize = sourceType.getRank();
22627ee33d1SHsiangkai Wang 
22727ee33d1SHsiangkai Wang   auto oneIndex = builder.getIndexAttr(1);
22827ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> offsets;
22927ee33d1SHsiangkai Wang   offsets.resize(srcSize);
23027ee33d1SHsiangkai Wang   offsets[loopNorFIdx] = loopNorFIndex;
23127ee33d1SHsiangkai Wang   offsets[loopCorFIdx] = loopCorFIndex;
23227ee33d1SHsiangkai Wang   offsets[heightIdx] = heightOffset;
23327ee33d1SHsiangkai Wang   offsets[widthIdx] = widthOffset;
23427ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
23527ee33d1SHsiangkai Wang   sizes[heightIdx] = builder.getIndexAttr(extractHeight);
23627ee33d1SHsiangkai Wang   sizes[widthIdx] = builder.getIndexAttr(extractWidth);
23727ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> strides(srcSize, oneIndex);
23827ee33d1SHsiangkai Wang 
23927ee33d1SHsiangkai Wang   auto extractFilterType =
24027ee33d1SHsiangkai Wang       RankedTensorType::get({extractHeight, extractWidth}, elementType);
24127ee33d1SHsiangkai Wang   auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
24227ee33d1SHsiangkai Wang       loc, extractFilterType, source, offsets, sizes, strides);
24327ee33d1SHsiangkai Wang 
24427ee33d1SHsiangkai Wang   return extractFilterOp;
24527ee33d1SHsiangkai Wang }
24627ee33d1SHsiangkai Wang 
24727ee33d1SHsiangkai Wang /// Extract height x width data from 6D tensors.
24827ee33d1SHsiangkai Wang Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
24927ee33d1SHsiangkai Wang                           Value tileHIndex, Value tileWIndex,
25027ee33d1SHsiangkai Wang                           Value loopNorFIndex, Value loopCorFIndex,
25127ee33d1SHsiangkai Wang                           int64_t tileHIdx, int64_t tileWIdx,
25227ee33d1SHsiangkai Wang                           int64_t loopNorFIdx, int64_t loopCorFIdx,
25327ee33d1SHsiangkai Wang                           int64_t heightIdx, int64_t widthIdx) {
25427ee33d1SHsiangkai Wang   auto sourceType = cast<ShapedType>(source.getType());
25527ee33d1SHsiangkai Wang   Type elementType = sourceType.getElementType();
25627ee33d1SHsiangkai Wang   auto sourceShape = sourceType.getShape();
25727ee33d1SHsiangkai Wang   int64_t srcSize = sourceType.getRank();
25827ee33d1SHsiangkai Wang   int64_t height = sourceShape[heightIdx];
25927ee33d1SHsiangkai Wang   int64_t width = sourceShape[widthIdx];
26027ee33d1SHsiangkai Wang 
26127ee33d1SHsiangkai Wang   auto zeroIndex = builder.getIndexAttr(0);
26227ee33d1SHsiangkai Wang   auto oneIndex = builder.getIndexAttr(1);
26327ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
26427ee33d1SHsiangkai Wang   offsets.resize(srcSize);
26527ee33d1SHsiangkai Wang   offsets[tileHIdx] = tileHIndex;
26627ee33d1SHsiangkai Wang   offsets[tileWIdx] = tileWIndex;
26727ee33d1SHsiangkai Wang   offsets[loopNorFIdx] = loopNorFIndex;
26827ee33d1SHsiangkai Wang   offsets[loopCorFIdx] = loopCorFIndex;
26927ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
27027ee33d1SHsiangkai Wang   sizes[heightIdx] = builder.getIndexAttr(height);
27127ee33d1SHsiangkai Wang   sizes[widthIdx] = builder.getIndexAttr(width);
27227ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> strides(srcSize, oneIndex);
27327ee33d1SHsiangkai Wang 
27427ee33d1SHsiangkai Wang   auto extractFilterType = RankedTensorType::get({height, width}, elementType);
27527ee33d1SHsiangkai Wang   auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
27627ee33d1SHsiangkai Wang       loc, extractFilterType, source, offsets, sizes, strides);
27727ee33d1SHsiangkai Wang 
27827ee33d1SHsiangkai Wang   return extractFilterOp;
27927ee33d1SHsiangkai Wang }
28027ee33d1SHsiangkai Wang 
28127ee33d1SHsiangkai Wang /// Insert transformed height x width data to 4D tensors which it is
28227ee33d1SHsiangkai Wang /// extracted from.
28327ee33d1SHsiangkai Wang Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
28427ee33d1SHsiangkai Wang                        Value dest, Value loopNorFIndex, Value loopCorFIndex,
28527ee33d1SHsiangkai Wang                        Value heightOffset, Value widthOffset, int64_t height,
28627ee33d1SHsiangkai Wang                        int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
28727ee33d1SHsiangkai Wang                        int64_t heightIdx, int64_t widthIdx) {
28827ee33d1SHsiangkai Wang   int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
28927ee33d1SHsiangkai Wang   auto oneIndex = builder.getIndexAttr(1);
29027ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> retOffsets;
29127ee33d1SHsiangkai Wang   retOffsets.resize(destSize);
29227ee33d1SHsiangkai Wang   retOffsets[loopNorFIdx] = loopNorFIndex;
29327ee33d1SHsiangkai Wang   retOffsets[loopCorFIdx] = loopCorFIndex;
29427ee33d1SHsiangkai Wang   retOffsets[heightIdx] = heightOffset;
29527ee33d1SHsiangkai Wang   retOffsets[widthIdx] = widthOffset;
29627ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
29727ee33d1SHsiangkai Wang   retSizes[heightIdx] = builder.getIndexAttr(height);
29827ee33d1SHsiangkai Wang   retSizes[widthIdx] = builder.getIndexAttr(width);
29927ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> strides(destSize, oneIndex);
30027ee33d1SHsiangkai Wang 
30127ee33d1SHsiangkai Wang   auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
30227ee33d1SHsiangkai Wang       loc, source, dest, retOffsets, retSizes, strides);
30327ee33d1SHsiangkai Wang 
30427ee33d1SHsiangkai Wang   return insertSliceOp;
30527ee33d1SHsiangkai Wang }
30627ee33d1SHsiangkai Wang 
30727ee33d1SHsiangkai Wang /// Insert transformed height x width data to 6D tensors which it is
30827ee33d1SHsiangkai Wang /// extracted from.
30927ee33d1SHsiangkai Wang Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
31027ee33d1SHsiangkai Wang                        Value dest, Value tileHIndex, Value tileWIndex,
31127ee33d1SHsiangkai Wang                        Value loopNorFIndex, Value loopCorFIndex, int64_t height,
31227ee33d1SHsiangkai Wang                        int64_t width, int64_t tileHIdx, int64_t tileWIdx,
31327ee33d1SHsiangkai Wang                        int64_t loopNorFIdx, int64_t loopCorFIdx,
31427ee33d1SHsiangkai Wang                        int64_t heightIdx, int64_t widthIdx) {
31527ee33d1SHsiangkai Wang   int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
31627ee33d1SHsiangkai Wang   auto zeroIndex = builder.getIndexAttr(0);
31727ee33d1SHsiangkai Wang   auto oneIndex = builder.getIndexAttr(1);
31827ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
31927ee33d1SHsiangkai Wang   retOffsets.resize(destSize);
32027ee33d1SHsiangkai Wang   retOffsets[tileHIdx] = tileHIndex;
32127ee33d1SHsiangkai Wang   retOffsets[tileWIdx] = tileWIndex;
32227ee33d1SHsiangkai Wang   retOffsets[loopNorFIdx] = loopNorFIndex;
32327ee33d1SHsiangkai Wang   retOffsets[loopCorFIdx] = loopCorFIndex;
32427ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
32527ee33d1SHsiangkai Wang   retSizes[heightIdx] = builder.getIndexAttr(height);
32627ee33d1SHsiangkai Wang   retSizes[widthIdx] = builder.getIndexAttr(width);
32727ee33d1SHsiangkai Wang   SmallVector<OpFoldResult> strides(destSize, oneIndex);
32827ee33d1SHsiangkai Wang 
32927ee33d1SHsiangkai Wang   auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
33027ee33d1SHsiangkai Wang       loc, source, dest, retOffsets, retSizes, strides);
33127ee33d1SHsiangkai Wang 
33227ee33d1SHsiangkai Wang   return insertSliceOp;
33327ee33d1SHsiangkai Wang }
33427ee33d1SHsiangkai Wang 
33527ee33d1SHsiangkai Wang /// This function transforms the filter. The data layout of the filter is FHWC.
33627ee33d1SHsiangkai Wang /// The transformation matrix is 2-dimension. We need to extract H x W from
33727ee33d1SHsiangkai Wang /// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
33827ee33d1SHsiangkai Wang /// After the transformation, we get
33927ee33d1SHsiangkai Wang ///
34027ee33d1SHsiangkai Wang /// scf.for %f = lo_f to hi_f step 1
34127ee33d1SHsiangkai Wang ///   scf.for %c = lo_c to hi_c step 1
34227ee33d1SHsiangkai Wang ///     %extracted = extract filter<h x w> from filter<f x h x w x c>
34327ee33d1SHsiangkai Wang ///     %ret = linalg.matmul G, %extracted
34427ee33d1SHsiangkai Wang ///     %ret = linalg.matmul %ret, GT
34527ee33d1SHsiangkai Wang ///     %inserted = insert %ret into filter<h x w x c x f>
34627ee33d1SHsiangkai Wang Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
34727ee33d1SHsiangkai Wang                       Value retValue, int64_t m, int64_t r,
34827ee33d1SHsiangkai Wang                       bool leftTransform = true, bool rightTransform = true) {
34927ee33d1SHsiangkai Wang   // Map from (m, r) to G transform matrix.
35027ee33d1SHsiangkai Wang   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
35127ee33d1SHsiangkai Wang       GMatrices = {
35227ee33d1SHsiangkai Wang           {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
35327ee33d1SHsiangkai Wang           {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
35427ee33d1SHsiangkai Wang           {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
35527ee33d1SHsiangkai Wang       };
35627ee33d1SHsiangkai Wang 
35727ee33d1SHsiangkai Wang   // Map from (m, r) to GT transform matrix.
35827ee33d1SHsiangkai Wang   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
35927ee33d1SHsiangkai Wang       GTMatrices = {
36027ee33d1SHsiangkai Wang           {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
36127ee33d1SHsiangkai Wang           {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
36227ee33d1SHsiangkai Wang           {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
36327ee33d1SHsiangkai Wang       };
36427ee33d1SHsiangkai Wang 
36527ee33d1SHsiangkai Wang   auto filterType = cast<ShapedType>(filter.getType());
36627ee33d1SHsiangkai Wang   Type elementType = filterType.getElementType();
36727ee33d1SHsiangkai Wang   auto filterShape = filterType.getShape(); // F, H, W, C
36827ee33d1SHsiangkai Wang   int64_t filterF = filterShape[0];
36927ee33d1SHsiangkai Wang   int64_t filterH = filterShape[1];
37027ee33d1SHsiangkai Wang   int64_t filterW = filterShape[2];
37127ee33d1SHsiangkai Wang   int64_t filterC = filterShape[3];
37227ee33d1SHsiangkai Wang 
37327ee33d1SHsiangkai Wang   if (filterH != r && filterH != 1)
37427ee33d1SHsiangkai Wang     return Value();
37527ee33d1SHsiangkai Wang   if (filterW != r && filterW != 1)
37627ee33d1SHsiangkai Wang     return Value();
37727ee33d1SHsiangkai Wang 
37827ee33d1SHsiangkai Wang   Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
37927ee33d1SHsiangkai Wang   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
38027ee33d1SHsiangkai Wang                        ValueRange args) -> scf::ValueVector {
38127ee33d1SHsiangkai Wang     Value FIter = ivs[0];
38227ee33d1SHsiangkai Wang     Value CIter = ivs[1];
38327ee33d1SHsiangkai Wang 
38427ee33d1SHsiangkai Wang     // Extract (H, W) from (F, H, W, C).
38527ee33d1SHsiangkai Wang     auto extractFilter =
38627ee33d1SHsiangkai Wang         extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
38727ee33d1SHsiangkai Wang                             zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
38827ee33d1SHsiangkai Wang                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
38927ee33d1SHsiangkai Wang 
39027ee33d1SHsiangkai Wang     TransformMapKeyTy key = {m, r};
39127ee33d1SHsiangkai Wang     int64_t retRows = 1;
39227ee33d1SHsiangkai Wang     Value matmulRetValue = extractFilter;
393326287fdSThomas Preud'homme     Value zero = builder.create<arith::ConstantOp>(
394326287fdSThomas Preud'homme         loc, rewriter.getZeroAttr(elementType));
39527ee33d1SHsiangkai Wang     if (leftTransform) {
39627ee33d1SHsiangkai Wang       // Get constant transform matrix G.
39727ee33d1SHsiangkai Wang       auto it = GMatrices.find(key);
39827ee33d1SHsiangkai Wang       if (it == GMatrices.end())
39927ee33d1SHsiangkai Wang         return {};
40027ee33d1SHsiangkai Wang       const TransformMatrix &GMatrix = it->second;
40127ee33d1SHsiangkai Wang 
40227ee33d1SHsiangkai Wang       retRows = GMatrix.rows;
40327ee33d1SHsiangkai Wang       auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
404326287fdSThomas Preud'homme       auto empty =
405326287fdSThomas Preud'homme           builder
406326287fdSThomas Preud'homme               .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
407326287fdSThomas Preud'homme               .getResult();
408326287fdSThomas Preud'homme       auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
40927ee33d1SHsiangkai Wang 
41027ee33d1SHsiangkai Wang       Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
41127ee33d1SHsiangkai Wang       // Multiply G x g.
41227ee33d1SHsiangkai Wang       auto matmulOp = builder.create<linalg::MatmulOp>(
41327ee33d1SHsiangkai Wang           loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
41427ee33d1SHsiangkai Wang       matmulRetValue = matmulOp.getResult(0);
41527ee33d1SHsiangkai Wang     }
41627ee33d1SHsiangkai Wang 
41727ee33d1SHsiangkai Wang     if (rightTransform) {
41827ee33d1SHsiangkai Wang       // Get constant transform matrix GT.
41927ee33d1SHsiangkai Wang       auto it = GTMatrices.find(key);
42027ee33d1SHsiangkai Wang       if (it == GTMatrices.end())
42127ee33d1SHsiangkai Wang         return {};
42227ee33d1SHsiangkai Wang       const TransformMatrix &GTMatrix = it->second;
42327ee33d1SHsiangkai Wang 
42427ee33d1SHsiangkai Wang       auto matmulType =
42527ee33d1SHsiangkai Wang           RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
426326287fdSThomas Preud'homme       auto empty =
427326287fdSThomas Preud'homme           builder
428326287fdSThomas Preud'homme               .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
429326287fdSThomas Preud'homme               .getResult();
430326287fdSThomas Preud'homme       auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
43127ee33d1SHsiangkai Wang 
43227ee33d1SHsiangkai Wang       Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
43327ee33d1SHsiangkai Wang       // Multiply u = (G x g) x GT.
43427ee33d1SHsiangkai Wang       auto matmulOp = builder.create<linalg::MatmulOp>(
43527ee33d1SHsiangkai Wang           loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
43627ee33d1SHsiangkai Wang       matmulRetValue = matmulOp.getResult(0);
43727ee33d1SHsiangkai Wang     }
43827ee33d1SHsiangkai Wang 
43927ee33d1SHsiangkai Wang     // Insert (H, W) to (H, W, C, F).
44027ee33d1SHsiangkai Wang     int64_t retHeight = leftTransform ? m + r - 1 : 1;
44127ee33d1SHsiangkai Wang     int64_t retWidth = rightTransform ? m + r - 1 : 1;
44227ee33d1SHsiangkai Wang 
44327ee33d1SHsiangkai Wang     auto insertSliceOp =
44427ee33d1SHsiangkai Wang         insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
44527ee33d1SHsiangkai Wang                          zeroIdx, zeroIdx, retHeight, retWidth,
44627ee33d1SHsiangkai Wang                          /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
44727ee33d1SHsiangkai Wang                          /*heightIdx=*/0, /*widthIdx=*/1);
44827ee33d1SHsiangkai Wang 
44927ee33d1SHsiangkai Wang     return {insertSliceOp};
45027ee33d1SHsiangkai Wang   };
45127ee33d1SHsiangkai Wang 
45227ee33d1SHsiangkai Wang   auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
45327ee33d1SHsiangkai Wang   auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
45427ee33d1SHsiangkai Wang   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
45527ee33d1SHsiangkai Wang   scf::LoopNest loops = scf::buildLoopNest(
45627ee33d1SHsiangkai Wang       rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
45727ee33d1SHsiangkai Wang       {oneStep, oneStep}, {retValue}, buildBody);
45827ee33d1SHsiangkai Wang   return loops.results[0];
45927ee33d1SHsiangkai Wang }
46027ee33d1SHsiangkai Wang 
46127ee33d1SHsiangkai Wang /// This function transforms the input. The data layout of the input is NHWC.
46227ee33d1SHsiangkai Wang /// The transformation matrix is 2-dimension. We need to extract H x W from
46327ee33d1SHsiangkai Wang /// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
46427ee33d1SHsiangkai Wang /// After the transformation, we get
46527ee33d1SHsiangkai Wang ///
46627ee33d1SHsiangkai Wang /// scf.for %h = 0 to tileH step 1
46727ee33d1SHsiangkai Wang ///   scf.for %w = 0 to tileW step 1
46827ee33d1SHsiangkai Wang ///     scf.for %n = 0 to N step 1
46927ee33d1SHsiangkai Wang ///       scf.for %c = 0 to C step 1
47027ee33d1SHsiangkai Wang ///         %extracted = extract %extracted<alphaH x alphaW> from
47127ee33d1SHsiangkai Wang ///                              %input<N x H x W x C>
47227ee33d1SHsiangkai Wang ///                              at [%n, (%h x m), (%w x m), %c]
47327ee33d1SHsiangkai Wang ///         %ret = linalg.matmul BT, %extracted
47427ee33d1SHsiangkai Wang ///         %ret = linalg.matmul %ret, B
47527ee33d1SHsiangkai Wang ///         %inserted = insert %ret<alphaH x alphaW> into
47627ee33d1SHsiangkai Wang ///                            %output<alphaH x alphaW x tileH x tileW x N x C>
47727ee33d1SHsiangkai Wang ///                            at [0, 0, %h, %w, %n, %c]
47827ee33d1SHsiangkai Wang Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
47927ee33d1SHsiangkai Wang                      Value retValue, int64_t m, int64_t r,
48027ee33d1SHsiangkai Wang                      bool leftTransform = true, bool rightTransform = true) {
48127ee33d1SHsiangkai Wang   // Map from (m, r) to BT transform matrix.
48227ee33d1SHsiangkai Wang   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
48327ee33d1SHsiangkai Wang       BTMatrices = {
48427ee33d1SHsiangkai Wang           {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
48527ee33d1SHsiangkai Wang           {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
48627ee33d1SHsiangkai Wang           {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
48727ee33d1SHsiangkai Wang       };
48827ee33d1SHsiangkai Wang 
48927ee33d1SHsiangkai Wang   // Map from (m, r) to B transform matrix.
49027ee33d1SHsiangkai Wang   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
49127ee33d1SHsiangkai Wang       BMatrices = {
49227ee33d1SHsiangkai Wang           {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
49327ee33d1SHsiangkai Wang           {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
49427ee33d1SHsiangkai Wang           {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
49527ee33d1SHsiangkai Wang       };
49627ee33d1SHsiangkai Wang 
49727ee33d1SHsiangkai Wang   auto inputType = cast<ShapedType>(input.getType());
49827ee33d1SHsiangkai Wang   Type elementType = inputType.getElementType();
49927ee33d1SHsiangkai Wang   auto inputShape = inputType.getShape(); // N, H, W, C
50027ee33d1SHsiangkai Wang   int64_t inputN = inputShape[0];
50127ee33d1SHsiangkai Wang   int64_t inputC = inputShape[3];
50227ee33d1SHsiangkai Wang   auto valueType = cast<ShapedType>(retValue.getType());
50327ee33d1SHsiangkai Wang   auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
50427ee33d1SHsiangkai Wang   int64_t tileH = valueShape[2];
50527ee33d1SHsiangkai Wang   int64_t tileW = valueShape[3];
50627ee33d1SHsiangkai Wang   int64_t alphaH = leftTransform ? m + r - 1 : 1;
50727ee33d1SHsiangkai Wang   int64_t alphaW = rightTransform ? m + r - 1 : 1;
50827ee33d1SHsiangkai Wang 
50927ee33d1SHsiangkai Wang   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
51027ee33d1SHsiangkai Wang                        ValueRange args) -> scf::ValueVector {
51127ee33d1SHsiangkai Wang     Value tileHIter = ivs[0];
51227ee33d1SHsiangkai Wang     Value tileWIter = ivs[1];
51327ee33d1SHsiangkai Wang     Value NIter = ivs[2];
51427ee33d1SHsiangkai Wang     Value CIter = ivs[3];
51527ee33d1SHsiangkai Wang 
51627ee33d1SHsiangkai Wang     auto context = builder.getContext();
517*f20b8e35SDmitriy Smirnov 
518*f20b8e35SDmitriy Smirnov     auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
51927ee33d1SHsiangkai Wang     auto affineMap =
52027ee33d1SHsiangkai Wang         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
521*f20b8e35SDmitriy Smirnov     Value heightOffset = builder.create<affine::AffineApplyOp>(
522*f20b8e35SDmitriy Smirnov         loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
523*f20b8e35SDmitriy Smirnov     Value widthOffset = builder.create<affine::AffineApplyOp>(
524*f20b8e35SDmitriy Smirnov         loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
52527ee33d1SHsiangkai Wang 
52627ee33d1SHsiangkai Wang     // Extract (H, W) from (N, H, W, C).
52727ee33d1SHsiangkai Wang     auto extractInput =
52827ee33d1SHsiangkai Wang         extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
52927ee33d1SHsiangkai Wang                             widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
53027ee33d1SHsiangkai Wang                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
53127ee33d1SHsiangkai Wang 
53227ee33d1SHsiangkai Wang     TransformMapKeyTy key = {m, r};
53327ee33d1SHsiangkai Wang     int64_t retRows = 1;
53427ee33d1SHsiangkai Wang     int64_t retCols = 1;
53527ee33d1SHsiangkai Wang     Value matmulRetValue = extractInput;
536326287fdSThomas Preud'homme     Value zero = builder.create<arith::ConstantOp>(
537326287fdSThomas Preud'homme         loc, rewriter.getZeroAttr(elementType));
53827ee33d1SHsiangkai Wang     if (leftTransform) {
53927ee33d1SHsiangkai Wang       // Get constant transform matrix BT.
54027ee33d1SHsiangkai Wang       auto it = BTMatrices.find(key);
54127ee33d1SHsiangkai Wang       if (it == BTMatrices.end())
54227ee33d1SHsiangkai Wang         return {};
54327ee33d1SHsiangkai Wang       const TransformMatrix &BTMatrix = it->second;
54427ee33d1SHsiangkai Wang 
54527ee33d1SHsiangkai Wang       retRows = BTMatrix.rows;
54627ee33d1SHsiangkai Wang       auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
547326287fdSThomas Preud'homme       auto empty =
548326287fdSThomas Preud'homme           builder
549326287fdSThomas Preud'homme               .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
550326287fdSThomas Preud'homme               .getResult();
551326287fdSThomas Preud'homme       auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
55227ee33d1SHsiangkai Wang 
55327ee33d1SHsiangkai Wang       Value BT =
55427ee33d1SHsiangkai Wang           create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
55527ee33d1SHsiangkai Wang       // Multiply BT x d.
55627ee33d1SHsiangkai Wang       auto matmulOp = builder.create<linalg::MatmulOp>(
55727ee33d1SHsiangkai Wang           loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
55827ee33d1SHsiangkai Wang       matmulRetValue = matmulOp.getResult(0);
55927ee33d1SHsiangkai Wang     }
56027ee33d1SHsiangkai Wang 
56127ee33d1SHsiangkai Wang     if (rightTransform) {
56227ee33d1SHsiangkai Wang       // Get constant transform matrix B.
56327ee33d1SHsiangkai Wang       auto it = BMatrices.find(key);
56427ee33d1SHsiangkai Wang       if (it == BMatrices.end())
56527ee33d1SHsiangkai Wang         return {};
56627ee33d1SHsiangkai Wang       const TransformMatrix &BMatrix = it->second;
56727ee33d1SHsiangkai Wang 
56827ee33d1SHsiangkai Wang       retCols = BMatrix.cols;
56927ee33d1SHsiangkai Wang       auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
570326287fdSThomas Preud'homme       auto empty =
571326287fdSThomas Preud'homme           builder
572326287fdSThomas Preud'homme               .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
573326287fdSThomas Preud'homme               .getResult();
574326287fdSThomas Preud'homme       auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
57527ee33d1SHsiangkai Wang       Value B =
57627ee33d1SHsiangkai Wang           create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
57727ee33d1SHsiangkai Wang       // Multiply v = (BT x d) x B.
57827ee33d1SHsiangkai Wang       auto matmulOp = builder.create<linalg::MatmulOp>(
57927ee33d1SHsiangkai Wang           loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
58027ee33d1SHsiangkai Wang       matmulRetValue = matmulOp.getResult(0);
58127ee33d1SHsiangkai Wang     }
58227ee33d1SHsiangkai Wang 
58327ee33d1SHsiangkai Wang     // Insert (H, W) to (H, W, tileH, tileW, N, C).
58427ee33d1SHsiangkai Wang     auto combinedVal = insert2DDataTo6D(
58527ee33d1SHsiangkai Wang         builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
58627ee33d1SHsiangkai Wang         CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
58727ee33d1SHsiangkai Wang         /*heightIdx=*/0, /*widthIdx=*/1);
58827ee33d1SHsiangkai Wang 
58927ee33d1SHsiangkai Wang     return {combinedVal};
59027ee33d1SHsiangkai Wang   };
59127ee33d1SHsiangkai Wang 
59227ee33d1SHsiangkai Wang   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
59327ee33d1SHsiangkai Wang   auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
59427ee33d1SHsiangkai Wang   auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
59527ee33d1SHsiangkai Wang   auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
59627ee33d1SHsiangkai Wang   auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
59727ee33d1SHsiangkai Wang   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
59827ee33d1SHsiangkai Wang   scf::LoopNest loops = scf::buildLoopNest(
59927ee33d1SHsiangkai Wang       rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
60027ee33d1SHsiangkai Wang       {tileHBound, tileWBound, nUpperBound, cUpperBound},
60127ee33d1SHsiangkai Wang       {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
60227ee33d1SHsiangkai Wang   return loops.results[0];
60327ee33d1SHsiangkai Wang }
60427ee33d1SHsiangkai Wang 
6057d246e84SHsiangkai Wang /// This function generates linalg.batch_matmul to multiply input with filter.
6067d246e84SHsiangkai Wang /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
6077d246e84SHsiangkai Wang /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
6087d246e84SHsiangkai Wang /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
6097d246e84SHsiangkai Wang /// way, we can convert 6-dimensional inputs to 3-dimensional representation
6107d246e84SHsiangkai Wang /// that is suitable for linalg.batch_matmul.
6117d246e84SHsiangkai Wang ///
6127d246e84SHsiangkai Wang /// Batched matmul will do the matrix multiply with the reduction on channel.
6137d246e84SHsiangkai Wang ///
6147d246e84SHsiangkai Wang /// We get
6157d246e84SHsiangkai Wang ///
6167d246e84SHsiangkai Wang /// %collapsed_input = tensor.collapse_shape %input
6177d246e84SHsiangkai Wang /// %collapsed_filter = tensor.collapse_shape %filter
6187d246e84SHsiangkai Wang /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
6197d246e84SHsiangkai Wang /// %expanded_ret = tensor.expand_shape %ret
6207d246e84SHsiangkai Wang ///
6217d246e84SHsiangkai Wang /// After this function, we get return value with data layout
6227d246e84SHsiangkai Wang /// (tileH, tileW, H, W, N, F).
6237d246e84SHsiangkai Wang static Value matrixMultiply(RewriterBase &rewriter, Location loc,
6247d246e84SHsiangkai Wang                             Value transformedFilter, Value transformedInput,
6257d246e84SHsiangkai Wang                             Type outputElementType) {
6267d246e84SHsiangkai Wang   // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
6277d246e84SHsiangkai Wang   auto filterType = cast<ShapedType>(transformedFilter.getType());
6287d246e84SHsiangkai Wang   assert(filterType.hasStaticShape() && "only support static shapes.");
6297d246e84SHsiangkai Wang   ArrayRef<int64_t> filterShape = filterType.getShape();
6307d246e84SHsiangkai Wang   Type filterElementType = filterType.getElementType();
6317d246e84SHsiangkai Wang   auto filterReassocType = RankedTensorType::get(
6327d246e84SHsiangkai Wang       {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
6337d246e84SHsiangkai Wang       filterElementType);
6347d246e84SHsiangkai Wang   SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
6357d246e84SHsiangkai Wang   Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
6367d246e84SHsiangkai Wang       loc, filterReassocType, transformedFilter, filterReassoc);
6377d246e84SHsiangkai Wang 
6387d246e84SHsiangkai Wang   // Convert (alphaH, alphaW, tileH, tileW, N, C) to
6397d246e84SHsiangkai Wang   // (alphaH x alphaW, tileH x tileW x N, C) for input.
6407d246e84SHsiangkai Wang   auto inputType = cast<ShapedType>(transformedInput.getType());
6417d246e84SHsiangkai Wang   assert(inputType.hasStaticShape() && "only support static shapes.");
6427d246e84SHsiangkai Wang   ArrayRef<int64_t> inputShape = inputType.getShape();
6437d246e84SHsiangkai Wang   Type inputElementType = inputType.getElementType();
6447d246e84SHsiangkai Wang   auto inputReassocType = RankedTensorType::get(
6457d246e84SHsiangkai Wang       {inputShape[0] * inputShape[1],
6467d246e84SHsiangkai Wang        inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
6477d246e84SHsiangkai Wang       inputElementType);
6487d246e84SHsiangkai Wang   SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
6497d246e84SHsiangkai Wang   Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
6507d246e84SHsiangkai Wang       loc, inputReassocType, transformedInput, inputReassoc);
6517d246e84SHsiangkai Wang 
6527d246e84SHsiangkai Wang   // Batched matrix multiply.
6537d246e84SHsiangkai Wang   auto matmulType = RankedTensorType::get(
6547d246e84SHsiangkai Wang       {inputShape[0] * inputShape[1],
6557d246e84SHsiangkai Wang        inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
6567d246e84SHsiangkai Wang       outputElementType);
657326287fdSThomas Preud'homme   Value empty = rewriter
658326287fdSThomas Preud'homme                     .create<tensor::EmptyOp>(loc, matmulType.getShape(),
659326287fdSThomas Preud'homme                                              outputElementType)
660326287fdSThomas Preud'homme                     .getResult();
661326287fdSThomas Preud'homme   Value zero = rewriter.create<arith::ConstantOp>(
662326287fdSThomas Preud'homme       loc, rewriter.getZeroAttr(outputElementType));
663326287fdSThomas Preud'homme   Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
6647d246e84SHsiangkai Wang 
6657d246e84SHsiangkai Wang   auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
6667d246e84SHsiangkai Wang       loc, matmulType, ValueRange({collapseInput, collapseFilter}),
6677d246e84SHsiangkai Wang       ValueRange{init});
6687d246e84SHsiangkai Wang 
6697d246e84SHsiangkai Wang   // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
6707d246e84SHsiangkai Wang   // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
6717d246e84SHsiangkai Wang   SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
6727d246e84SHsiangkai Wang   auto outputReassocType =
6737d246e84SHsiangkai Wang       RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
6747d246e84SHsiangkai Wang                              inputShape[3], inputShape[4], filterShape[3]},
6757d246e84SHsiangkai Wang                             outputElementType);
6767d246e84SHsiangkai Wang   auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
6777d246e84SHsiangkai Wang       loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
6787d246e84SHsiangkai Wang   return expandOutput;
6797d246e84SHsiangkai Wang }
6807d246e84SHsiangkai Wang 
68127ee33d1SHsiangkai Wang /// This function transforms the output. The data layout of the output is HWNF.
68227ee33d1SHsiangkai Wang /// The transformation matrix is 2-dimension. We need to extract H x W from
68327ee33d1SHsiangkai Wang /// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
68427ee33d1SHsiangkai Wang /// After the transformation, we get
68527ee33d1SHsiangkai Wang ///
68627ee33d1SHsiangkai Wang /// scf.for %h = 0 to tileH step 1
68727ee33d1SHsiangkai Wang ///   scf.for %w = 0 to tileW step 1
68827ee33d1SHsiangkai Wang ///     scf.for %n = 0 to N step 1
68927ee33d1SHsiangkai Wang ///       scf.for %f = 0 to F step 1
69027ee33d1SHsiangkai Wang ///         %extracted = extract %extracted<alphaH x alphaW> from
69127ee33d1SHsiangkai Wang ///                              %input<alphaH x alphaW x tileH x tileW x N x F>
69227ee33d1SHsiangkai Wang ///                              at [0, 0, %h, %w, %n, %f]
69327ee33d1SHsiangkai Wang ///         %ret = linalg.matmul AT, %extracted
69427ee33d1SHsiangkai Wang ///         %ret = linalg.matmul %ret, A
69527ee33d1SHsiangkai Wang ///         %inserted = insert %ret<alphaH x alphaW> into
69627ee33d1SHsiangkai Wang ///                            output<N x H x W x F>
69727ee33d1SHsiangkai Wang ///                            at [%n, (%h x m), (%w x m), %f]
69827ee33d1SHsiangkai Wang Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
69927ee33d1SHsiangkai Wang                       Value output, int64_t m, int64_t r,
70027ee33d1SHsiangkai Wang                       bool leftTransform = true, bool rightTransform = true) {
70127ee33d1SHsiangkai Wang   // Map from (m, r) to AT transform matrix.
70227ee33d1SHsiangkai Wang   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
70327ee33d1SHsiangkai Wang       ATMatrices = {
70427ee33d1SHsiangkai Wang           {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
70527ee33d1SHsiangkai Wang           {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
70627ee33d1SHsiangkai Wang           {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
70727ee33d1SHsiangkai Wang       };
70827ee33d1SHsiangkai Wang 
70927ee33d1SHsiangkai Wang   // Map from (m, r) to A transform matrix.
71027ee33d1SHsiangkai Wang   static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
71127ee33d1SHsiangkai Wang       AMatrices = {
71227ee33d1SHsiangkai Wang           {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
71327ee33d1SHsiangkai Wang           {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
71427ee33d1SHsiangkai Wang           {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
71527ee33d1SHsiangkai Wang       };
71627ee33d1SHsiangkai Wang 
71727ee33d1SHsiangkai Wang   auto valueType = cast<ShapedType>(value.getType());
71827ee33d1SHsiangkai Wang   Type elementType = valueType.getElementType();
71927ee33d1SHsiangkai Wang   auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
72027ee33d1SHsiangkai Wang   int64_t valueH = valueShape[0];
72127ee33d1SHsiangkai Wang   int64_t valueW = valueShape[1];
72227ee33d1SHsiangkai Wang   int64_t valueN = valueShape[4];
72327ee33d1SHsiangkai Wang   int64_t valueF = valueShape[5];
72427ee33d1SHsiangkai Wang   int64_t alphaH = leftTransform ? m + r - 1 : 1;
72527ee33d1SHsiangkai Wang   int64_t alphaW = rightTransform ? m + r - 1 : 1;
72627ee33d1SHsiangkai Wang 
72727ee33d1SHsiangkai Wang   if (valueH != alphaH && valueH != 1)
72827ee33d1SHsiangkai Wang     return Value();
72927ee33d1SHsiangkai Wang   if (valueW != alphaW && valueW != 1)
73027ee33d1SHsiangkai Wang     return Value();
73127ee33d1SHsiangkai Wang 
73227ee33d1SHsiangkai Wang   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
73327ee33d1SHsiangkai Wang                        ValueRange args) -> scf::ValueVector {
734bb4696ceSDmitriy Smirnov     auto context = builder.getContext();
73527ee33d1SHsiangkai Wang     Value tileHIter = ivs[0];
73627ee33d1SHsiangkai Wang     Value tileWIter = ivs[1];
73727ee33d1SHsiangkai Wang     Value NIter = ivs[2];
73827ee33d1SHsiangkai Wang     Value FIter = ivs[3];
73927ee33d1SHsiangkai Wang 
74027ee33d1SHsiangkai Wang     // Extract (H, W) from (H, W, tileH, tileW, N, F).
74127ee33d1SHsiangkai Wang     auto extractValue =
74227ee33d1SHsiangkai Wang         extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
74327ee33d1SHsiangkai Wang                             FIter, 2, 3, /*loopNorFIdx=*/4,
74427ee33d1SHsiangkai Wang                             /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
74527ee33d1SHsiangkai Wang 
746bb4696ceSDmitriy Smirnov     const TransformMapKeyTy key = {m, r};
747bb4696ceSDmitriy Smirnov     const TransformMatrix &AMatrix = AMatrices.at(key);
748bb4696ceSDmitriy Smirnov     const TransformMatrix &ATMatrix = ATMatrices.at(key);
749bb4696ceSDmitriy Smirnov     int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
750bb4696ceSDmitriy Smirnov                            (leftTransform ? ATMatrix.scalarFactor : 1);
751bb4696ceSDmitriy Smirnov     int64_t retCols = rightTransform ? AMatrix.cols : 1;
752bb4696ceSDmitriy Smirnov     int64_t retRows = leftTransform ? ATMatrix.rows : 1;
753bb4696ceSDmitriy Smirnov 
75427ee33d1SHsiangkai Wang     Value matmulRetValue = extractValue;
755326287fdSThomas Preud'homme     Value zero = builder.create<arith::ConstantOp>(
756326287fdSThomas Preud'homme         loc, rewriter.getZeroAttr(elementType));
75727ee33d1SHsiangkai Wang 
758*f20b8e35SDmitriy Smirnov     auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
759bb4696ceSDmitriy Smirnov     auto affineMap =
760bb4696ceSDmitriy Smirnov         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
761*f20b8e35SDmitriy Smirnov     Value heightOffset = builder.create<affine::AffineApplyOp>(
762*f20b8e35SDmitriy Smirnov         loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
763*f20b8e35SDmitriy Smirnov     Value widthOffset = builder.create<affine::AffineApplyOp>(
764*f20b8e35SDmitriy Smirnov         loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
765bb4696ceSDmitriy Smirnov 
766bb4696ceSDmitriy Smirnov     Value outInitVal =
767bb4696ceSDmitriy Smirnov         extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
768bb4696ceSDmitriy Smirnov                             widthOffset, retRows, retCols,
769bb4696ceSDmitriy Smirnov                             /*loopNorFIdx=*/0,
770bb4696ceSDmitriy Smirnov                             /*loopCorFIdx=*/3, /*heightIdx=*/1,
771bb4696ceSDmitriy Smirnov                             /*widthIdx=*/2);
772bb4696ceSDmitriy Smirnov     if (leftTransform) {
77327ee33d1SHsiangkai Wang       auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
774bb4696ceSDmitriy Smirnov       Value init = outInitVal;
775bb4696ceSDmitriy Smirnov       if (rightTransform || scalarFactor != 1) {
776bb4696ceSDmitriy Smirnov         auto empty = builder
777bb4696ceSDmitriy Smirnov                          .create<tensor::EmptyOp>(loc, matmulType.getShape(),
778bb4696ceSDmitriy Smirnov                                                   elementType)
779326287fdSThomas Preud'homme                          .getResult();
780bb4696ceSDmitriy Smirnov         init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
781bb4696ceSDmitriy Smirnov       }
78227ee33d1SHsiangkai Wang 
78327ee33d1SHsiangkai Wang       Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
78427ee33d1SHsiangkai Wang       // Multiply AT x m.
78527ee33d1SHsiangkai Wang       auto matmulOp = builder.create<linalg::MatmulOp>(
78627ee33d1SHsiangkai Wang           loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
78727ee33d1SHsiangkai Wang       matmulRetValue = matmulOp.getResult(0);
78827ee33d1SHsiangkai Wang     }
78927ee33d1SHsiangkai Wang 
79027ee33d1SHsiangkai Wang     if (rightTransform) {
79127ee33d1SHsiangkai Wang       auto matmulType =
79227ee33d1SHsiangkai Wang           RankedTensorType::get({retRows, AMatrix.cols}, elementType);
793bb4696ceSDmitriy Smirnov       Value init = outInitVal;
794bb4696ceSDmitriy Smirnov       if (scalarFactor != 1) {
795bb4696ceSDmitriy Smirnov         auto empty = builder
796bb4696ceSDmitriy Smirnov                          .create<tensor::EmptyOp>(loc, matmulType.getShape(),
797bb4696ceSDmitriy Smirnov                                                   elementType)
798326287fdSThomas Preud'homme                          .getResult();
799bb4696ceSDmitriy Smirnov         init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
800bb4696ceSDmitriy Smirnov       }
80127ee33d1SHsiangkai Wang 
80227ee33d1SHsiangkai Wang       Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
80327ee33d1SHsiangkai Wang       // Multiply y = (AT x m) x A.
80427ee33d1SHsiangkai Wang       auto matmulOp = builder.create<linalg::MatmulOp>(
80527ee33d1SHsiangkai Wang           loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
80627ee33d1SHsiangkai Wang       matmulRetValue = matmulOp.getResult(0);
80727ee33d1SHsiangkai Wang     }
80827ee33d1SHsiangkai Wang 
809bb4696ceSDmitriy Smirnov     if (scalarFactor != 1) {
810bb4696ceSDmitriy Smirnov       // Multiply by scalar factor and add outInitVal.
811bb4696ceSDmitriy Smirnov       Value scalarFactorValue = builder.create<arith::ConstantOp>(
812bb4696ceSDmitriy Smirnov           loc, FloatAttr::get(elementType, scalarFactor));
81327ee33d1SHsiangkai Wang       auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
81427ee33d1SHsiangkai Wang       auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
81527ee33d1SHsiangkai Wang       SmallVector<AffineMap> affineMaps = {
816bb4696ceSDmitriy Smirnov           AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
817bb4696ceSDmitriy Smirnov 
818bb4696ceSDmitriy Smirnov       matmulRetValue =
81927ee33d1SHsiangkai Wang           rewriter
82027ee33d1SHsiangkai Wang               .create<linalg::GenericOp>(
821bb4696ceSDmitriy Smirnov                   loc, matmulType,
822bb4696ceSDmitriy Smirnov                   ValueRange{scalarFactorValue, matmulRetValue},
823bb4696ceSDmitriy Smirnov                   ValueRange{outInitVal}, affineMaps,
82427ee33d1SHsiangkai Wang                   llvm::ArrayRef<utils::IteratorType>{
82527ee33d1SHsiangkai Wang                       utils::IteratorType::parallel,
82627ee33d1SHsiangkai Wang                       utils::IteratorType::parallel},
82727ee33d1SHsiangkai Wang                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
82827ee33d1SHsiangkai Wang                       ValueRange args) {
829bb4696ceSDmitriy Smirnov                     auto mulf = nestedBuilder.create<arith::MulFOp>(
830bb4696ceSDmitriy Smirnov                         nestedLoc, args[0], args[1]);
831bb4696ceSDmitriy Smirnov                     auto addf = nestedBuilder.create<arith::AddFOp>(
832bb4696ceSDmitriy Smirnov                         nestedLoc, mulf.getResult(), args[2]);
833bb4696ceSDmitriy Smirnov                     nestedBuilder.create<linalg::YieldOp>(nestedLoc,
834bb4696ceSDmitriy Smirnov                                                           addf.getResult());
83527ee33d1SHsiangkai Wang                   })
83627ee33d1SHsiangkai Wang               .getResult(0);
83727ee33d1SHsiangkai Wang     }
83827ee33d1SHsiangkai Wang 
83927ee33d1SHsiangkai Wang     // Insert (H, W) to (N, H, W, F).
84027ee33d1SHsiangkai Wang     Value combinedVal =
84127ee33d1SHsiangkai Wang         insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
84227ee33d1SHsiangkai Wang                          heightOffset, widthOffset, retRows, retCols,
84327ee33d1SHsiangkai Wang                          /*loopNorFIdx=*/0,
84427ee33d1SHsiangkai Wang                          /*loopCorFIdx=*/3, /*heightIdx=*/1,
84527ee33d1SHsiangkai Wang                          /*widthIdx=*/2);
84627ee33d1SHsiangkai Wang 
84727ee33d1SHsiangkai Wang     return {combinedVal};
84827ee33d1SHsiangkai Wang   };
84927ee33d1SHsiangkai Wang 
85027ee33d1SHsiangkai Wang   int64_t tilwH = valueShape[2];
85127ee33d1SHsiangkai Wang   int64_t tileW = valueShape[3];
85227ee33d1SHsiangkai Wang   auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
85327ee33d1SHsiangkai Wang   auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
85427ee33d1SHsiangkai Wang   auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
85527ee33d1SHsiangkai Wang   auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
85627ee33d1SHsiangkai Wang   auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
85727ee33d1SHsiangkai Wang   auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
85827ee33d1SHsiangkai Wang   scf::LoopNest loops = scf::buildLoopNest(
85927ee33d1SHsiangkai Wang       rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
86027ee33d1SHsiangkai Wang       {tileHBound, tileWBound, nUpperBound, fUpperBound},
86127ee33d1SHsiangkai Wang       {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
86227ee33d1SHsiangkai Wang   return loops.results[0];
86327ee33d1SHsiangkai Wang }
86427ee33d1SHsiangkai Wang 
8657d246e84SHsiangkai Wang /// Create an empty tensor with alignedType and insert the value into the
8667d246e84SHsiangkai Wang /// created empty tensor with aligned size.
8677d246e84SHsiangkai Wang static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
8687d246e84SHsiangkai Wang                                 Value value, ArrayRef<int64_t> alignedShape) {
8697d246e84SHsiangkai Wang   auto valueType = cast<ShapedType>(value.getType());
8707d246e84SHsiangkai Wang   Type elementType = valueType.getElementType();
8717d246e84SHsiangkai Wang   auto alignedType = RankedTensorType::get(alignedShape, elementType);
8727d246e84SHsiangkai Wang   Value padValue = rewriter.create<arith::ConstantOp>(
8737d246e84SHsiangkai Wang       loc, elementType, rewriter.getZeroAttr(elementType));
8747d246e84SHsiangkai Wang 
8757d246e84SHsiangkai Wang   return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
8767d246e84SHsiangkai Wang                                        padValue, false);
8777d246e84SHsiangkai Wang }
8787d246e84SHsiangkai Wang 
8797d246e84SHsiangkai Wang /// Extract sub-tensor with extractedType from value.
8807d246e84SHsiangkai Wang static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
8817d246e84SHsiangkai Wang                                       Value value,
8827d246e84SHsiangkai Wang                                       RankedTensorType extractedType) {
8837d246e84SHsiangkai Wang   OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
8847d246e84SHsiangkai Wang   OpFoldResult oneIndex = rewriter.getIndexAttr(1);
8857d246e84SHsiangkai Wang   SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
8867d246e84SHsiangkai Wang   SmallVector<OpFoldResult, 4> strides(4, oneIndex);
8877d246e84SHsiangkai Wang 
8887d246e84SHsiangkai Wang   ArrayRef<int64_t> extractedShape = extractedType.getShape();
8897d246e84SHsiangkai Wang   SmallVector<OpFoldResult> sizes =
8907d246e84SHsiangkai Wang       getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
8917d246e84SHsiangkai Wang 
8927d246e84SHsiangkai Wang   return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
8937d246e84SHsiangkai Wang                                                  offsets, sizes, strides);
8947d246e84SHsiangkai Wang }
8957d246e84SHsiangkai Wang 
8967d246e84SHsiangkai Wang /// Utility function to check all values in the attribute are 1.
8977d246e84SHsiangkai Wang static bool hasAllOneValues(DenseIntElementsAttr attr) {
8987d246e84SHsiangkai Wang   return llvm::all_of(
8997d246e84SHsiangkai Wang       attr, [](const APInt &element) { return element.getSExtValue() == 1; });
9007d246e84SHsiangkai Wang }
9017d246e84SHsiangkai Wang 
9027d246e84SHsiangkai Wang /// A helper function to convert linalg.conv_2d_nhwc_fhwc to
9037d246e84SHsiangkai Wang /// linalg.winograd_*_transform ops.
9047d246e84SHsiangkai Wang static FailureOr<Operation *>
9057d246e84SHsiangkai Wang winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
9067d246e84SHsiangkai Wang                      int64_t m, int64_t r) {
9077d246e84SHsiangkai Wang   Value input = convOp.getInputs()[0];
9087d246e84SHsiangkai Wang   Value filter = convOp.getInputs()[1];
9097d246e84SHsiangkai Wang   Value output = convOp.getOutputs()[0];
9107d246e84SHsiangkai Wang   auto inputType = cast<ShapedType>(input.getType());
9117d246e84SHsiangkai Wang   auto filterType = cast<ShapedType>(filter.getType());
9127d246e84SHsiangkai Wang   auto outputType = cast<ShapedType>(output.getType());
9137d246e84SHsiangkai Wang 
9147d246e84SHsiangkai Wang   if (!inputType.hasStaticShape())
9157d246e84SHsiangkai Wang     return rewriter.notifyMatchFailure(convOp,
9167d246e84SHsiangkai Wang                                        "expected a static shape for the input");
9177d246e84SHsiangkai Wang 
9187d246e84SHsiangkai Wang   if (!filterType.hasStaticShape())
9197d246e84SHsiangkai Wang     return rewriter.notifyMatchFailure(
9207d246e84SHsiangkai Wang         convOp, "expected a static shape for the filter");
9217d246e84SHsiangkai Wang 
9227d246e84SHsiangkai Wang   if (!hasAllOneValues(convOp.getDilations()))
9237d246e84SHsiangkai Wang     return rewriter.notifyMatchFailure(convOp,
9247d246e84SHsiangkai Wang                                        "expected all ones for dilations");
9257d246e84SHsiangkai Wang 
9267d246e84SHsiangkai Wang   if (!hasAllOneValues(convOp.getStrides()))
9277d246e84SHsiangkai Wang     return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
9287d246e84SHsiangkai Wang 
9297d246e84SHsiangkai Wang   ArrayRef<int64_t> filterShape = filterType.getShape();
9307d246e84SHsiangkai Wang   int64_t filterF = filterShape[0];
9317d246e84SHsiangkai Wang   int64_t filterH = filterShape[1];
9327d246e84SHsiangkai Wang   int64_t filterW = filterShape[2];
9337d246e84SHsiangkai Wang   int64_t filterC = filterShape[3];
9347d246e84SHsiangkai Wang   ArrayRef<int64_t> inputShape = inputType.getShape();
9357d246e84SHsiangkai Wang   int64_t inputN = inputShape[0];
9367d246e84SHsiangkai Wang   int64_t inputH = inputShape[1];
9377d246e84SHsiangkai Wang   int64_t inputW = inputShape[2];
9387d246e84SHsiangkai Wang   int64_t inputC = inputShape[3];
9397d246e84SHsiangkai Wang   ArrayRef<int64_t> outputShape = outputType.getShape();
9407d246e84SHsiangkai Wang   int64_t outputN = outputShape[0];
9417d246e84SHsiangkai Wang   int64_t outputH = outputShape[1];
9427d246e84SHsiangkai Wang   int64_t outputW = outputShape[2];
9437d246e84SHsiangkai Wang   int64_t outputF = outputShape[3];
9447d246e84SHsiangkai Wang 
9457d246e84SHsiangkai Wang   // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
9467d246e84SHsiangkai Wang   bool isSupportedFilter = false;
9477d246e84SHsiangkai Wang   if (filterH == filterW && filterH == r)
9487d246e84SHsiangkai Wang     isSupportedFilter = true;
9497d246e84SHsiangkai Wang   if (filterH == r && filterW == 1)
9507d246e84SHsiangkai Wang     isSupportedFilter = true;
9517d246e84SHsiangkai Wang   if (filterH == 1 && filterW == r)
9527d246e84SHsiangkai Wang     isSupportedFilter = true;
9537d246e84SHsiangkai Wang 
9547d246e84SHsiangkai Wang   if (!isSupportedFilter)
9557d246e84SHsiangkai Wang     return rewriter.notifyMatchFailure(
9567d246e84SHsiangkai Wang         convOp, "only support filter (r x r), (r x 1) or (1 x r)");
9577d246e84SHsiangkai Wang 
9587d246e84SHsiangkai Wang   // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
9597d246e84SHsiangkai Wang   static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
9607d246e84SHsiangkai Wang       F_2_3, F_4_3, F_2_5};
9617d246e84SHsiangkai Wang 
9627d246e84SHsiangkai Wang   TransformMapKeyTy key = {m, r};
9637d246e84SHsiangkai Wang   auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
9647d246e84SHsiangkai Wang   // If we cannot find the constant transformation matrix, it means we do
9657d246e84SHsiangkai Wang   // not support this configuration yet.
9667d246e84SHsiangkai Wang   if (it == validConfigs.end())
9677d246e84SHsiangkai Wang     return failure();
9687d246e84SHsiangkai Wang 
9697d246e84SHsiangkai Wang   // All the criterias are satisfied. We can do Winograd Conv2D.
9707d246e84SHsiangkai Wang   Location loc = convOp.getLoc();
9717d246e84SHsiangkai Wang 
9727d246e84SHsiangkai Wang   // For F(m x 1, r x 1), we only need to do left side transform.
9737d246e84SHsiangkai Wang   bool leftTransform = filterH != 1;
9747d246e84SHsiangkai Wang   // For F(1 x m, 1 x r), we only need to do right side transform.
9757d246e84SHsiangkai Wang   bool rightTransform = filterW != 1;
9767d246e84SHsiangkai Wang   int64_t heightM = leftTransform ? m : 1;
9777d246e84SHsiangkai Wang   int64_t widthM = rightTransform ? m : 1;
9787d246e84SHsiangkai Wang   int64_t heightR = leftTransform ? r : 1;
9797d246e84SHsiangkai Wang   int64_t widthR = rightTransform ? r : 1;
9807d246e84SHsiangkai Wang 
9817d246e84SHsiangkai Wang   // --- Create operation for filter transform ---
9827d246e84SHsiangkai Wang   Type filterElementType = filterType.getElementType();
9837d246e84SHsiangkai Wang   int64_t alphaH = heightM + heightR - 1;
9847d246e84SHsiangkai Wang   int64_t alphaW = widthM + widthR - 1;
9857d246e84SHsiangkai Wang   int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
9867d246e84SHsiangkai Wang   int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
9877d246e84SHsiangkai Wang   auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
9887d246e84SHsiangkai Wang                                        filterElementType);
9897d246e84SHsiangkai Wang   Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
9907d246e84SHsiangkai Wang                                                     filterElementType);
9917d246e84SHsiangkai Wang   auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
9927d246e84SHsiangkai Wang       loc, retType, filter, retValue, m, r);
9937d246e84SHsiangkai Wang 
9947d246e84SHsiangkai Wang   // --- Create operation for input transform ---
9957d246e84SHsiangkai Wang 
9967d246e84SHsiangkai Wang   // When input size - (r - 1) is not aligned with output tile size, we need to
9977d246e84SHsiangkai Wang   // pad the input data to create the full tiles as tiling.
9987d246e84SHsiangkai Wang   Type inputElementType = inputType.getElementType();
9997d246e84SHsiangkai Wang   int64_t alignedInputH = tileH * heightM + (heightR - 1);
10007d246e84SHsiangkai Wang   int64_t alignedInputW = tileW * widthM + (widthR - 1);
10017d246e84SHsiangkai Wang   if (alignedInputH != inputH || alignedInputW != inputW) {
10027d246e84SHsiangkai Wang     input = padToAlignedTensor(rewriter, loc, input,
10037d246e84SHsiangkai Wang                                {inputN, alignedInputH, alignedInputW, inputC});
10047d246e84SHsiangkai Wang   }
10057d246e84SHsiangkai Wang 
10067d246e84SHsiangkai Wang   retType = RankedTensorType::get(
10077d246e84SHsiangkai Wang       {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
10087d246e84SHsiangkai Wang   retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
10097d246e84SHsiangkai Wang                                               inputElementType);
10107d246e84SHsiangkai Wang   auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
10117d246e84SHsiangkai Wang       loc, retType, input, retValue, m, r);
10127d246e84SHsiangkai Wang 
10137d246e84SHsiangkai Wang   Type outputElementType = outputType.getElementType();
10147d246e84SHsiangkai Wang   Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
10157d246e84SHsiangkai Wang                                    transformedInput, outputElementType);
10167d246e84SHsiangkai Wang 
10177d246e84SHsiangkai Wang   // --- Create operation for output transform ---
10187d246e84SHsiangkai Wang 
10197d246e84SHsiangkai Wang   // When output size is not aligned with output tile size, we need to pad the
10207d246e84SHsiangkai Wang   // output buffer to insert the full tiles after tiling.
10217d246e84SHsiangkai Wang   int64_t alignedOutputH = tileH * heightM;
10227d246e84SHsiangkai Wang   int64_t alignedOutputW = tileW * widthM;
10237d246e84SHsiangkai Wang   bool isOutputUnaligned =
10247d246e84SHsiangkai Wang       ((alignedOutputH != outputH) || (alignedOutputW != outputW));
10257d246e84SHsiangkai Wang   if (isOutputUnaligned) {
10267d246e84SHsiangkai Wang     auto alignedOutputType = RankedTensorType::get(
10277d246e84SHsiangkai Wang         {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
10287d246e84SHsiangkai Wang     output =
10297d246e84SHsiangkai Wang         padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
10307d246e84SHsiangkai Wang     outputType = alignedOutputType;
10317d246e84SHsiangkai Wang   }
10327d246e84SHsiangkai Wang 
10337d246e84SHsiangkai Wang   Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
10347d246e84SHsiangkai Wang       loc, outputType, matmulRet, output, m, r);
10357d246e84SHsiangkai Wang 
10367d246e84SHsiangkai Wang   // When output size is not aligned with output tile size, extract the
10377d246e84SHsiangkai Wang   // value from the padded buffer.
10387d246e84SHsiangkai Wang   if (isOutputUnaligned) {
10397d246e84SHsiangkai Wang     transformedOutput = extractFromAlignedTensor(
10407d246e84SHsiangkai Wang         rewriter, loc, transformedOutput,
10417d246e84SHsiangkai Wang         RankedTensorType::get({outputN, outputH, outputW, outputF},
10427d246e84SHsiangkai Wang                               outputElementType));
10437d246e84SHsiangkai Wang   }
10447d246e84SHsiangkai Wang 
10457d246e84SHsiangkai Wang   rewriter.replaceOp(convOp, transformedOutput);
10467d246e84SHsiangkai Wang 
10477d246e84SHsiangkai Wang   return transformedOutput.getDefiningOp();
10487d246e84SHsiangkai Wang }
10497d246e84SHsiangkai Wang 
105027ee33d1SHsiangkai Wang /// A helper function to decompose linalg.winograd_filter_transform.
105127ee33d1SHsiangkai Wang FailureOr<Operation *>
105227ee33d1SHsiangkai Wang decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
105327ee33d1SHsiangkai Wang                                        linalg::WinogradFilterTransformOp op) {
105427ee33d1SHsiangkai Wang   Location loc = op.getLoc();
105527ee33d1SHsiangkai Wang   Value filter = op.getFilter();
105627ee33d1SHsiangkai Wang   auto filterType = cast<ShapedType>(filter.getType());
105727ee33d1SHsiangkai Wang   auto filterShape = filterType.getShape();
105827ee33d1SHsiangkai Wang   int64_t filterH = filterShape[1];
105927ee33d1SHsiangkai Wang   int64_t filterW = filterShape[2];
106027ee33d1SHsiangkai Wang 
106127ee33d1SHsiangkai Wang   // For F(m x 1, r x 1), we only need to do left side transform.
106227ee33d1SHsiangkai Wang   bool leftTransform = filterH != 1;
106327ee33d1SHsiangkai Wang   // For F(1 x m, 1 x r), we only need to do right side transform.
106427ee33d1SHsiangkai Wang   bool rightTransform = filterW != 1;
106527ee33d1SHsiangkai Wang   Value transformedFilter =
106627ee33d1SHsiangkai Wang       filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
106727ee33d1SHsiangkai Wang                       op.getR(), leftTransform, rightTransform);
106827ee33d1SHsiangkai Wang   if (!transformedFilter)
106927ee33d1SHsiangkai Wang     return failure();
107027ee33d1SHsiangkai Wang 
107127ee33d1SHsiangkai Wang   rewriter.replaceOp(op, transformedFilter);
107227ee33d1SHsiangkai Wang 
107327ee33d1SHsiangkai Wang   return transformedFilter.getDefiningOp();
107427ee33d1SHsiangkai Wang }
107527ee33d1SHsiangkai Wang 
107627ee33d1SHsiangkai Wang /// A helper function to decompose linalg.winograd_input_transform.
107727ee33d1SHsiangkai Wang FailureOr<Operation *>
107827ee33d1SHsiangkai Wang decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
107927ee33d1SHsiangkai Wang                                       linalg::WinogradInputTransformOp op) {
108027ee33d1SHsiangkai Wang   Location loc = op.getLoc();
1081*f20b8e35SDmitriy Smirnov   Value output = op.getOutput();
1082*f20b8e35SDmitriy Smirnov   auto outputType = cast<ShapedType>(output.getType());
1083*f20b8e35SDmitriy Smirnov   auto outputShape = outputType.getShape();
1084*f20b8e35SDmitriy Smirnov 
1085*f20b8e35SDmitriy Smirnov   int64_t outputH = outputShape[0];
1086*f20b8e35SDmitriy Smirnov   int64_t outputW = outputShape[1];
108727ee33d1SHsiangkai Wang 
108827ee33d1SHsiangkai Wang   // For F(m x 1, r x 1), we only need to do left side transform.
1089*f20b8e35SDmitriy Smirnov   bool leftTransform = outputH != 1;
109027ee33d1SHsiangkai Wang   // For F(1 x m, 1 x r), we only need to do right side transform.
1091*f20b8e35SDmitriy Smirnov   bool rightTransform = outputW != 1;
109227ee33d1SHsiangkai Wang   Value transformedInput =
109327ee33d1SHsiangkai Wang       inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
109427ee33d1SHsiangkai Wang                      op.getR(), leftTransform, rightTransform);
109527ee33d1SHsiangkai Wang   if (!transformedInput)
109627ee33d1SHsiangkai Wang     return failure();
109727ee33d1SHsiangkai Wang 
109827ee33d1SHsiangkai Wang   rewriter.replaceOp(op, transformedInput);
109927ee33d1SHsiangkai Wang 
110027ee33d1SHsiangkai Wang   return transformedInput.getDefiningOp();
110127ee33d1SHsiangkai Wang }
110227ee33d1SHsiangkai Wang 
110327ee33d1SHsiangkai Wang /// A helper function to decompose linalg.winograd_output_transform.
110427ee33d1SHsiangkai Wang FailureOr<Operation *>
110527ee33d1SHsiangkai Wang decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
110627ee33d1SHsiangkai Wang                                        linalg::WinogradOutputTransformOp op) {
110727ee33d1SHsiangkai Wang   Location loc = op.getLoc();
110827ee33d1SHsiangkai Wang   Value value = op.getValue();
110927ee33d1SHsiangkai Wang   auto valueType = cast<ShapedType>(value.getType());
111027ee33d1SHsiangkai Wang   auto valueShape = valueType.getShape();
111127ee33d1SHsiangkai Wang   int64_t valueH = valueShape[0];
111227ee33d1SHsiangkai Wang   int64_t valueW = valueShape[1];
111327ee33d1SHsiangkai Wang 
111427ee33d1SHsiangkai Wang   // For F(m x 1, r x 1), we only need to do left side transform.
111527ee33d1SHsiangkai Wang   bool leftTransform = valueH != 1;
111627ee33d1SHsiangkai Wang   // For F(1 x m, 1 x r), we only need to do right side transform.
111727ee33d1SHsiangkai Wang   bool rightTransform = valueW != 1;
111827ee33d1SHsiangkai Wang   Value transformedOutput =
111927ee33d1SHsiangkai Wang       outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
112027ee33d1SHsiangkai Wang                       op.getR(), leftTransform, rightTransform);
112127ee33d1SHsiangkai Wang   if (!transformedOutput)
112227ee33d1SHsiangkai Wang     return failure();
112327ee33d1SHsiangkai Wang 
112427ee33d1SHsiangkai Wang   rewriter.replaceOp(op, transformedOutput);
112527ee33d1SHsiangkai Wang 
112627ee33d1SHsiangkai Wang   return transformedOutput.getDefiningOp();
112727ee33d1SHsiangkai Wang }
112827ee33d1SHsiangkai Wang 
112927ee33d1SHsiangkai Wang /// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
113027ee33d1SHsiangkai Wang class DecomposeWinogradFilterTransform final
113127ee33d1SHsiangkai Wang     : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
113227ee33d1SHsiangkai Wang public:
113327ee33d1SHsiangkai Wang   using OpRewritePattern::OpRewritePattern;
113427ee33d1SHsiangkai Wang 
113527ee33d1SHsiangkai Wang   LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
113627ee33d1SHsiangkai Wang                                 PatternRewriter &rewriter) const override {
113727ee33d1SHsiangkai Wang     return decomposeWinogradFilterTransformHelper(rewriter, op);
113827ee33d1SHsiangkai Wang   }
113927ee33d1SHsiangkai Wang };
114027ee33d1SHsiangkai Wang 
114127ee33d1SHsiangkai Wang /// A rewrite pattern to decompose linalg.winograd_input_transform operations.
114227ee33d1SHsiangkai Wang class DecomposeWinogradInputTransform final
114327ee33d1SHsiangkai Wang     : public OpRewritePattern<linalg::WinogradInputTransformOp> {
114427ee33d1SHsiangkai Wang public:
114527ee33d1SHsiangkai Wang   using OpRewritePattern::OpRewritePattern;
114627ee33d1SHsiangkai Wang 
114727ee33d1SHsiangkai Wang   LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
114827ee33d1SHsiangkai Wang                                 PatternRewriter &rewriter) const override {
114927ee33d1SHsiangkai Wang     return decomposeWinogradInputTransformHelper(rewriter, op);
115027ee33d1SHsiangkai Wang   }
115127ee33d1SHsiangkai Wang };
115227ee33d1SHsiangkai Wang 
115327ee33d1SHsiangkai Wang /// A rewrite pattern to decompose linalg.winograd_output_transform operations.
115427ee33d1SHsiangkai Wang class DecomposeWinogradOutputTransform final
115527ee33d1SHsiangkai Wang     : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
115627ee33d1SHsiangkai Wang public:
115727ee33d1SHsiangkai Wang   using OpRewritePattern::OpRewritePattern;
115827ee33d1SHsiangkai Wang 
115927ee33d1SHsiangkai Wang   LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
116027ee33d1SHsiangkai Wang                                 PatternRewriter &rewriter) const override {
116127ee33d1SHsiangkai Wang     return decomposeWinogradOutputTransformHelper(rewriter, op);
116227ee33d1SHsiangkai Wang   }
116327ee33d1SHsiangkai Wang };
116427ee33d1SHsiangkai Wang 
11657d246e84SHsiangkai Wang /// A rewrite pattern for Winograd Conv2D algorithm.
11667d246e84SHsiangkai Wang class WinogradConv2DNhwcFhwc final
11677d246e84SHsiangkai Wang     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
11687d246e84SHsiangkai Wang public:
11697d246e84SHsiangkai Wang   using OpRewritePattern::OpRewritePattern;
11707d246e84SHsiangkai Wang   WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
11717d246e84SHsiangkai Wang       : OpRewritePattern(context), m(m), r(r) {}
11727d246e84SHsiangkai Wang 
11737d246e84SHsiangkai Wang   LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
11747d246e84SHsiangkai Wang                                 PatternRewriter &rewriter) const override {
11757d246e84SHsiangkai Wang     if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
11767d246e84SHsiangkai Wang       return failure();
11777d246e84SHsiangkai Wang 
11787d246e84SHsiangkai Wang     return success();
11797d246e84SHsiangkai Wang   }
11807d246e84SHsiangkai Wang 
11817d246e84SHsiangkai Wang private:
11827d246e84SHsiangkai Wang   int64_t m;
11837d246e84SHsiangkai Wang   int64_t r;
11847d246e84SHsiangkai Wang };
11857d246e84SHsiangkai Wang } // end anonymous namespace
11867d246e84SHsiangkai Wang 
11877d246e84SHsiangkai Wang //===----------------------------------------------------------------------===//
1188d9c26b9dSHsiangkai Wang FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1189d9c26b9dSHsiangkai Wang                                       linalg::Conv2DNhwcFhwcOp op, int64_t m,
1190d9c26b9dSHsiangkai Wang                                       int64_t r) {
1191d9c26b9dSHsiangkai Wang   return winogradConv2DHelper(rewriter, op, m, r);
1192d9c26b9dSHsiangkai Wang }
1193d9c26b9dSHsiangkai Wang 
1194c4bf9491SHsiangkai Wang FailureOr<Operation *>
1195c4bf9491SHsiangkai Wang decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1196c4bf9491SHsiangkai Wang                                    linalg::WinogradFilterTransformOp op) {
1197c4bf9491SHsiangkai Wang   return decomposeWinogradFilterTransformHelper(rewriter, op);
1198c4bf9491SHsiangkai Wang }
1199c4bf9491SHsiangkai Wang 
1200c4bf9491SHsiangkai Wang FailureOr<Operation *>
1201c4bf9491SHsiangkai Wang decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1202c4bf9491SHsiangkai Wang                                   linalg::WinogradInputTransformOp op) {
1203c4bf9491SHsiangkai Wang   return decomposeWinogradInputTransformHelper(rewriter, op);
1204c4bf9491SHsiangkai Wang }
1205c4bf9491SHsiangkai Wang 
1206c4bf9491SHsiangkai Wang FailureOr<Operation *>
1207c4bf9491SHsiangkai Wang decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1208c4bf9491SHsiangkai Wang                                    linalg::WinogradOutputTransformOp op) {
1209c4bf9491SHsiangkai Wang   return decomposeWinogradOutputTransformHelper(rewriter, op);
1210c4bf9491SHsiangkai Wang }
1211c4bf9491SHsiangkai Wang 
12127d246e84SHsiangkai Wang void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
12137d246e84SHsiangkai Wang                                     int64_t r) {
12147d246e84SHsiangkai Wang   MLIRContext *context = patterns.getContext();
12157d246e84SHsiangkai Wang   // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
12167d246e84SHsiangkai Wang   patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
12177d246e84SHsiangkai Wang }
12187d246e84SHsiangkai Wang 
121927ee33d1SHsiangkai Wang void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
122027ee33d1SHsiangkai Wang   MLIRContext *context = patterns.getContext();
122127ee33d1SHsiangkai Wang   patterns
122227ee33d1SHsiangkai Wang       .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
122327ee33d1SHsiangkai Wang               DecomposeWinogradOutputTransform>(context);
122427ee33d1SHsiangkai Wang }
122527ee33d1SHsiangkai Wang 
12267d246e84SHsiangkai Wang } // end namespace linalg
12277d246e84SHsiangkai Wang } // end namespace mlir
1228