17d246e84SHsiangkai Wang //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// 27d246e84SHsiangkai Wang // 37d246e84SHsiangkai Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47d246e84SHsiangkai Wang // See https://llvm.org/LICENSE.txt for license information. 57d246e84SHsiangkai Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67d246e84SHsiangkai Wang // 77d246e84SHsiangkai Wang //===----------------------------------------------------------------------===// 87d246e84SHsiangkai Wang // 97d246e84SHsiangkai Wang // Implement Winograd Conv2D algorithm. The implementation is based on the 107d246e84SHsiangkai Wang // paper: Fast Algorithms for Convolutional Neural Networks 117d246e84SHsiangkai Wang // (https://arxiv.org/abs/1509.09308) 127d246e84SHsiangkai Wang // 137d246e84SHsiangkai Wang //===----------------------------------------------------------------------===// 147d246e84SHsiangkai Wang 1527ee33d1SHsiangkai Wang #include "mlir/Dialect/Affine/IR/AffineOps.h" 1627ee33d1SHsiangkai Wang #include "mlir/Dialect/Arith/IR/Arith.h" 177d246e84SHsiangkai Wang #include "mlir/Dialect/Linalg/IR/Linalg.h" 187d246e84SHsiangkai Wang #include "mlir/Dialect/Linalg/Utils/Utils.h" 197d246e84SHsiangkai Wang #include "mlir/Dialect/Tensor/IR/Tensor.h" 207d246e84SHsiangkai Wang #include "mlir/Dialect/Utils/StaticValueUtils.h" 2127ee33d1SHsiangkai Wang #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 227d246e84SHsiangkai Wang #include "llvm/Support/MathExtras.h" 237d246e84SHsiangkai Wang 247d246e84SHsiangkai Wang namespace mlir { 257d246e84SHsiangkai Wang namespace linalg { 267d246e84SHsiangkai Wang 277d246e84SHsiangkai Wang namespace { 287d246e84SHsiangkai Wang 2927ee33d1SHsiangkai Wang // clang-format off 3027ee33d1SHsiangkai Wang /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its 3127ee33d1SHsiangkai Wang /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), 3227ee33d1SHsiangkai Wang /// m is the output dimension and r is the filter dimension, is 3327ee33d1SHsiangkai Wang /// 3427ee33d1SHsiangkai Wang /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A 3527ee33d1SHsiangkai Wang /// 3627ee33d1SHsiangkai Wang /// g is filter and d is input data. We need to prepare 6 constant 3727ee33d1SHsiangkai Wang /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. 3827ee33d1SHsiangkai Wang /// 3927ee33d1SHsiangkai Wang /// The following tables define these constant transformation matrices for 4027ee33d1SHsiangkai Wang /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) 4127ee33d1SHsiangkai Wang constexpr float G_2x2_3x3[] = { 4227ee33d1SHsiangkai Wang -1, 0, 0, 4327ee33d1SHsiangkai Wang 1./2, -1./2, 1./2, 4427ee33d1SHsiangkai Wang 1./2, 1./2, 1./2, 4527ee33d1SHsiangkai Wang 0, 0, 1 4627ee33d1SHsiangkai Wang }; 4727ee33d1SHsiangkai Wang 4827ee33d1SHsiangkai Wang constexpr float GT_2x2_3x3[] = { 4927ee33d1SHsiangkai Wang -1, 1./2, 1./2, 0, 5027ee33d1SHsiangkai Wang 0, -1./2, 1./2, 0, 5127ee33d1SHsiangkai Wang 0, 1./2, 1./2, 1 5227ee33d1SHsiangkai Wang }; 5327ee33d1SHsiangkai Wang 5427ee33d1SHsiangkai Wang constexpr float BT_2x2_3x3[] = { 5527ee33d1SHsiangkai Wang -1, 0, 1, 0, 5627ee33d1SHsiangkai Wang 0, -1, 1, 0, 5727ee33d1SHsiangkai Wang 0, 1, 1, 0, 5827ee33d1SHsiangkai Wang 0, -1, 0, 1 5927ee33d1SHsiangkai Wang }; 6027ee33d1SHsiangkai Wang 6127ee33d1SHsiangkai Wang constexpr float B_2x2_3x3[] = { 6227ee33d1SHsiangkai Wang -1, 0, 0, 0, 6327ee33d1SHsiangkai Wang 0, -1, 1, -1, 6427ee33d1SHsiangkai Wang 1, 1, 1, 0, 6527ee33d1SHsiangkai Wang 0, 0, 0, 1 6627ee33d1SHsiangkai Wang }; 6727ee33d1SHsiangkai Wang 6827ee33d1SHsiangkai Wang constexpr float AT_2x2_3x3[] = { 6927ee33d1SHsiangkai Wang 1, 1, 1, 0, 7027ee33d1SHsiangkai Wang 0, -1, 1, 1 7127ee33d1SHsiangkai Wang }; 7227ee33d1SHsiangkai Wang 7327ee33d1SHsiangkai Wang constexpr float A_2x2_3x3[] = { 7427ee33d1SHsiangkai Wang 1, 0, 7527ee33d1SHsiangkai Wang 1, -1, 7627ee33d1SHsiangkai Wang 1, 1, 7727ee33d1SHsiangkai Wang 0, 1 7827ee33d1SHsiangkai Wang }; 7927ee33d1SHsiangkai Wang 8027ee33d1SHsiangkai Wang constexpr float G_4x4_3x3[] = { 8127ee33d1SHsiangkai Wang 1, 0, 0, 8227ee33d1SHsiangkai Wang -1./3, 1./3, -1./3, 8327ee33d1SHsiangkai Wang -1./3, -1./3, -1./3, 8427ee33d1SHsiangkai Wang 1./12, -1./6, 1./3, 8527ee33d1SHsiangkai Wang 1./12, 1./6, 1./3, 8627ee33d1SHsiangkai Wang 0, 0, 1 8727ee33d1SHsiangkai Wang }; 8827ee33d1SHsiangkai Wang 8927ee33d1SHsiangkai Wang constexpr float GT_4x4_3x3[] = { 9027ee33d1SHsiangkai Wang 1, -1./3, -1./3, 1./12, 1./12, 0, 9127ee33d1SHsiangkai Wang 0, 1./3, -1./3, -1./6, 1./6, 0, 9227ee33d1SHsiangkai Wang 0, -1./3, -1./3, 1./3, 1./3, 1 9327ee33d1SHsiangkai Wang }; 9427ee33d1SHsiangkai Wang 9527ee33d1SHsiangkai Wang constexpr float BT_4x4_3x3[] = { 9627ee33d1SHsiangkai Wang 1./4, 0, -5./16, 0, 1./16, 0, 9727ee33d1SHsiangkai Wang 0, 1./4, -1./4, -1./16, 1./16, 0, 9827ee33d1SHsiangkai Wang 0, -1./4, -1./4, 1./16, 1./16, 0, 9927ee33d1SHsiangkai Wang 0, 1./4, -1./8, -1./4, 1./8, 0, 10027ee33d1SHsiangkai Wang 0, -1./4, -1./8, 1./4, 1./8, 0, 10127ee33d1SHsiangkai Wang 0, 1./4, 0, -5./16, 0, 1./16 10227ee33d1SHsiangkai Wang }; 10327ee33d1SHsiangkai Wang 10427ee33d1SHsiangkai Wang constexpr float B_4x4_3x3[] = { 10527ee33d1SHsiangkai Wang 1./4, 0, 0, 0, 0, 0, 10627ee33d1SHsiangkai Wang 0, 1./4, -1./4, 1./4, -1./4, 1./4, 10727ee33d1SHsiangkai Wang -5./16, -1./4, -1./4, -1./8, -1./8, 0, 10827ee33d1SHsiangkai Wang 0, -1./16, 1./16, -1./4, 1./4, -5./16, 10927ee33d1SHsiangkai Wang 1./16, 1./16, 1./16, 1./8, 1./8, 0, 11027ee33d1SHsiangkai Wang 0, 0, 0, 0, 0, 1./16 11127ee33d1SHsiangkai Wang }; 11227ee33d1SHsiangkai Wang 11327ee33d1SHsiangkai Wang constexpr float AT_4x4_3x3[] = { 11427ee33d1SHsiangkai Wang 1./8, 1./4, 1./4, 1./8, 1./8, 0, 11527ee33d1SHsiangkai Wang 0, -1./4, 1./4, -1./4, 1./4, 0, 11627ee33d1SHsiangkai Wang 0, 1./4, 1./4, 1./2, 1./2, 0, 11727ee33d1SHsiangkai Wang 0, -1./4, 1./4, -1, 1, 1./2 11827ee33d1SHsiangkai Wang }; 11927ee33d1SHsiangkai Wang 12027ee33d1SHsiangkai Wang constexpr float A_4x4_3x3[] = { 12127ee33d1SHsiangkai Wang 1./8, 0, 0, 0, 12227ee33d1SHsiangkai Wang 1./4, -1./4, 1./4, -1./4, 12327ee33d1SHsiangkai Wang 1./4, 1./4, 1./4, 1./4, 12427ee33d1SHsiangkai Wang 1./8, -1./4, 1./2, -1, 12527ee33d1SHsiangkai Wang 1./8, 1./4, 1./2, 1, 12627ee33d1SHsiangkai Wang 0, 0, 0, 1./2 12727ee33d1SHsiangkai Wang }; 12827ee33d1SHsiangkai Wang 12927ee33d1SHsiangkai Wang constexpr float G_2x2_5x5[] = { 13027ee33d1SHsiangkai Wang 1, 0, 0, 0, 0, 13127ee33d1SHsiangkai Wang 1./6, -1./6, 1./6, -1./6, 1./6, 13227ee33d1SHsiangkai Wang -1./6, -1./6, -1./6, -1./6, -1./6, 13327ee33d1SHsiangkai Wang -4./15, 2./15, -1./15, 1./30, -1./60, 13427ee33d1SHsiangkai Wang 1./60, 1./30, 1./15, 2./15, 4./15, 13527ee33d1SHsiangkai Wang 0, 0, 0, 0, 1 13627ee33d1SHsiangkai Wang }; 13727ee33d1SHsiangkai Wang 13827ee33d1SHsiangkai Wang constexpr float GT_2x2_5x5[] = { 13927ee33d1SHsiangkai Wang 1, 1./6, -1./6, -4./15, 1./60, 0, 14027ee33d1SHsiangkai Wang 0, -1./6, -1./6, 2./15, 1./30, 0, 14127ee33d1SHsiangkai Wang 0, 1./6, -1./6, -1./15, 1./15, 0, 14227ee33d1SHsiangkai Wang 0, -1./6, -1./6, 1./30, 2./15, 0, 14327ee33d1SHsiangkai Wang 0, 1./6, -1./6, -1./60, 4./15, 1 14427ee33d1SHsiangkai Wang }; 14527ee33d1SHsiangkai Wang 14627ee33d1SHsiangkai Wang constexpr float BT_2x2_5x5[] = { 14727ee33d1SHsiangkai Wang 1./8, 3./16, -1./4, -3./16, 1./8, 0, 14827ee33d1SHsiangkai Wang 0, 1./8, 1./16, -5./16, 1./8, 0, 14927ee33d1SHsiangkai Wang 0, -1./8, -5./16, -1./16, 1./8, 0, 15027ee33d1SHsiangkai Wang 0, 1./4, -1./8, -1./4, 1./8, 0, 15127ee33d1SHsiangkai Wang 0, -1./8, -1./4, 1./8, 1./4, 0, 15227ee33d1SHsiangkai Wang 0, 1./8, 3./16, -1./4, -3./16, 1./8 15327ee33d1SHsiangkai Wang }; 15427ee33d1SHsiangkai Wang 15527ee33d1SHsiangkai Wang constexpr float B_2x2_5x5[] = { 15627ee33d1SHsiangkai Wang 1./8, 0, 0, 0, 0, 0, 15727ee33d1SHsiangkai Wang 3./16, 1./8, -1./8, 1./4, -1./8, 1./8, 15827ee33d1SHsiangkai Wang -1./4, 1./16, -5./16, -1./8, -1./4, 3./16, 15927ee33d1SHsiangkai Wang -3./16, -5./16, -1./16, -1./4, 1./8, -1./4, 16027ee33d1SHsiangkai Wang 1./8, 1./8, 1./8, 1./8, 1./4, -3./16, 16127ee33d1SHsiangkai Wang 0, 0, 0, 0, 0, 1./8 16227ee33d1SHsiangkai Wang }; 16327ee33d1SHsiangkai Wang 16427ee33d1SHsiangkai Wang constexpr float AT_2x2_5x5[] = { 16527ee33d1SHsiangkai Wang 1./2, 1, 1, 2, 1, 0, 16627ee33d1SHsiangkai Wang 0, -1, 1, -1, 2, 1./2 16727ee33d1SHsiangkai Wang }; 16827ee33d1SHsiangkai Wang 16927ee33d1SHsiangkai Wang constexpr float A_2x2_5x5[] = { 17027ee33d1SHsiangkai Wang 1./2, 0, 17127ee33d1SHsiangkai Wang 1, -1, 17227ee33d1SHsiangkai Wang 1, 1, 17327ee33d1SHsiangkai Wang 2, -1, 17427ee33d1SHsiangkai Wang 1, 2, 17527ee33d1SHsiangkai Wang 0, 1./2 17627ee33d1SHsiangkai Wang }; 17727ee33d1SHsiangkai Wang // clang-format on 17827ee33d1SHsiangkai Wang 1797d246e84SHsiangkai Wang using TransformMapKeyTy = std::pair<int, int>; 1807d246e84SHsiangkai Wang 1817d246e84SHsiangkai Wang /// We use F(m, r) to define the size of minimal filtering algorithms. 1827d246e84SHsiangkai Wang /// m is the output dimension and r is the filter dimension. We can get 1837d246e84SHsiangkai Wang /// the input dimension, alpha, from the formula, alpha = m + r - 1. 1847d246e84SHsiangkai Wang /// 1857d246e84SHsiangkai Wang /// For example, when m = 2 and r = 3, we know its input size is 4. 1867d246e84SHsiangkai Wang /// The Conv2D will operate on 4x4 input data with 3x3 filter and get 1877d246e84SHsiangkai Wang /// 2x2 output result. 1887d246e84SHsiangkai Wang constexpr TransformMapKeyTy F_2_3{2, 3}; 1897d246e84SHsiangkai Wang constexpr TransformMapKeyTy F_4_3{4, 3}; 1907d246e84SHsiangkai Wang constexpr TransformMapKeyTy F_2_5{2, 5}; 1917d246e84SHsiangkai Wang 19227ee33d1SHsiangkai Wang /// Structure to keep information of constant transform matrices. 19327ee33d1SHsiangkai Wang struct TransformMatrix { 19427ee33d1SHsiangkai Wang TransformMatrix(const float *table, int64_t rows, int64_t cols, 19527ee33d1SHsiangkai Wang int64_t scalarFactor = 1) 19627ee33d1SHsiangkai Wang : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} 19727ee33d1SHsiangkai Wang 19827ee33d1SHsiangkai Wang const float *table; 19927ee33d1SHsiangkai Wang int64_t rows; 20027ee33d1SHsiangkai Wang int64_t cols; 20127ee33d1SHsiangkai Wang int64_t scalarFactor; 20227ee33d1SHsiangkai Wang }; 20327ee33d1SHsiangkai Wang 20427ee33d1SHsiangkai Wang /// Utility function to convert constant array to arith.constant Value. 20527ee33d1SHsiangkai Wang Value create2DTransformMatrix(OpBuilder &builder, Location loc, 20627ee33d1SHsiangkai Wang TransformMatrix transform, Type type) { 20727ee33d1SHsiangkai Wang ArrayRef<float> constVec(transform.table, transform.rows * transform.cols); 20827ee33d1SHsiangkai Wang 20927ee33d1SHsiangkai Wang return builder.create<arith::ConstantOp>( 21027ee33d1SHsiangkai Wang loc, DenseFPElementsAttr::get( 21127ee33d1SHsiangkai Wang RankedTensorType::get( 21227ee33d1SHsiangkai Wang SmallVector<int64_t>{transform.rows, transform.cols}, type), 21327ee33d1SHsiangkai Wang constVec)); 21427ee33d1SHsiangkai Wang } 21527ee33d1SHsiangkai Wang 21627ee33d1SHsiangkai Wang /// Extract height x width data from 4D tensors. 21727ee33d1SHsiangkai Wang Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source, 21827ee33d1SHsiangkai Wang Value loopNorFIndex, Value loopCorFIndex, 21927ee33d1SHsiangkai Wang Value heightOffset, Value widthOffset, 22027ee33d1SHsiangkai Wang int64_t extractHeight, int64_t extractWidth, 22127ee33d1SHsiangkai Wang int64_t loopNorFIdx, int64_t loopCorFIdx, 22227ee33d1SHsiangkai Wang int64_t heightIdx, int64_t widthIdx) { 22327ee33d1SHsiangkai Wang auto sourceType = cast<ShapedType>(source.getType()); 22427ee33d1SHsiangkai Wang Type elementType = sourceType.getElementType(); 22527ee33d1SHsiangkai Wang int64_t srcSize = sourceType.getRank(); 22627ee33d1SHsiangkai Wang 22727ee33d1SHsiangkai Wang auto oneIndex = builder.getIndexAttr(1); 22827ee33d1SHsiangkai Wang SmallVector<OpFoldResult> offsets; 22927ee33d1SHsiangkai Wang offsets.resize(srcSize); 23027ee33d1SHsiangkai Wang offsets[loopNorFIdx] = loopNorFIndex; 23127ee33d1SHsiangkai Wang offsets[loopCorFIdx] = loopCorFIndex; 23227ee33d1SHsiangkai Wang offsets[heightIdx] = heightOffset; 23327ee33d1SHsiangkai Wang offsets[widthIdx] = widthOffset; 23427ee33d1SHsiangkai Wang SmallVector<OpFoldResult> sizes(srcSize, oneIndex); 23527ee33d1SHsiangkai Wang sizes[heightIdx] = builder.getIndexAttr(extractHeight); 23627ee33d1SHsiangkai Wang sizes[widthIdx] = builder.getIndexAttr(extractWidth); 23727ee33d1SHsiangkai Wang SmallVector<OpFoldResult> strides(srcSize, oneIndex); 23827ee33d1SHsiangkai Wang 23927ee33d1SHsiangkai Wang auto extractFilterType = 24027ee33d1SHsiangkai Wang RankedTensorType::get({extractHeight, extractWidth}, elementType); 24127ee33d1SHsiangkai Wang auto extractFilterOp = builder.create<tensor::ExtractSliceOp>( 24227ee33d1SHsiangkai Wang loc, extractFilterType, source, offsets, sizes, strides); 24327ee33d1SHsiangkai Wang 24427ee33d1SHsiangkai Wang return extractFilterOp; 24527ee33d1SHsiangkai Wang } 24627ee33d1SHsiangkai Wang 24727ee33d1SHsiangkai Wang /// Extract height x width data from 6D tensors. 24827ee33d1SHsiangkai Wang Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source, 24927ee33d1SHsiangkai Wang Value tileHIndex, Value tileWIndex, 25027ee33d1SHsiangkai Wang Value loopNorFIndex, Value loopCorFIndex, 25127ee33d1SHsiangkai Wang int64_t tileHIdx, int64_t tileWIdx, 25227ee33d1SHsiangkai Wang int64_t loopNorFIdx, int64_t loopCorFIdx, 25327ee33d1SHsiangkai Wang int64_t heightIdx, int64_t widthIdx) { 25427ee33d1SHsiangkai Wang auto sourceType = cast<ShapedType>(source.getType()); 25527ee33d1SHsiangkai Wang Type elementType = sourceType.getElementType(); 25627ee33d1SHsiangkai Wang auto sourceShape = sourceType.getShape(); 25727ee33d1SHsiangkai Wang int64_t srcSize = sourceType.getRank(); 25827ee33d1SHsiangkai Wang int64_t height = sourceShape[heightIdx]; 25927ee33d1SHsiangkai Wang int64_t width = sourceShape[widthIdx]; 26027ee33d1SHsiangkai Wang 26127ee33d1SHsiangkai Wang auto zeroIndex = builder.getIndexAttr(0); 26227ee33d1SHsiangkai Wang auto oneIndex = builder.getIndexAttr(1); 26327ee33d1SHsiangkai Wang SmallVector<OpFoldResult> offsets(srcSize, zeroIndex); 26427ee33d1SHsiangkai Wang offsets.resize(srcSize); 26527ee33d1SHsiangkai Wang offsets[tileHIdx] = tileHIndex; 26627ee33d1SHsiangkai Wang offsets[tileWIdx] = tileWIndex; 26727ee33d1SHsiangkai Wang offsets[loopNorFIdx] = loopNorFIndex; 26827ee33d1SHsiangkai Wang offsets[loopCorFIdx] = loopCorFIndex; 26927ee33d1SHsiangkai Wang SmallVector<OpFoldResult> sizes(srcSize, oneIndex); 27027ee33d1SHsiangkai Wang sizes[heightIdx] = builder.getIndexAttr(height); 27127ee33d1SHsiangkai Wang sizes[widthIdx] = builder.getIndexAttr(width); 27227ee33d1SHsiangkai Wang SmallVector<OpFoldResult> strides(srcSize, oneIndex); 27327ee33d1SHsiangkai Wang 27427ee33d1SHsiangkai Wang auto extractFilterType = RankedTensorType::get({height, width}, elementType); 27527ee33d1SHsiangkai Wang auto extractFilterOp = builder.create<tensor::ExtractSliceOp>( 27627ee33d1SHsiangkai Wang loc, extractFilterType, source, offsets, sizes, strides); 27727ee33d1SHsiangkai Wang 27827ee33d1SHsiangkai Wang return extractFilterOp; 27927ee33d1SHsiangkai Wang } 28027ee33d1SHsiangkai Wang 28127ee33d1SHsiangkai Wang /// Insert transformed height x width data to 4D tensors which it is 28227ee33d1SHsiangkai Wang /// extracted from. 28327ee33d1SHsiangkai Wang Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, 28427ee33d1SHsiangkai Wang Value dest, Value loopNorFIndex, Value loopCorFIndex, 28527ee33d1SHsiangkai Wang Value heightOffset, Value widthOffset, int64_t height, 28627ee33d1SHsiangkai Wang int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx, 28727ee33d1SHsiangkai Wang int64_t heightIdx, int64_t widthIdx) { 28827ee33d1SHsiangkai Wang int64_t destSize = cast<ShapedType>(dest.getType()).getRank(); 28927ee33d1SHsiangkai Wang auto oneIndex = builder.getIndexAttr(1); 29027ee33d1SHsiangkai Wang SmallVector<OpFoldResult> retOffsets; 29127ee33d1SHsiangkai Wang retOffsets.resize(destSize); 29227ee33d1SHsiangkai Wang retOffsets[loopNorFIdx] = loopNorFIndex; 29327ee33d1SHsiangkai Wang retOffsets[loopCorFIdx] = loopCorFIndex; 29427ee33d1SHsiangkai Wang retOffsets[heightIdx] = heightOffset; 29527ee33d1SHsiangkai Wang retOffsets[widthIdx] = widthOffset; 29627ee33d1SHsiangkai Wang SmallVector<OpFoldResult> retSizes(destSize, oneIndex); 29727ee33d1SHsiangkai Wang retSizes[heightIdx] = builder.getIndexAttr(height); 29827ee33d1SHsiangkai Wang retSizes[widthIdx] = builder.getIndexAttr(width); 29927ee33d1SHsiangkai Wang SmallVector<OpFoldResult> strides(destSize, oneIndex); 30027ee33d1SHsiangkai Wang 30127ee33d1SHsiangkai Wang auto insertSliceOp = builder.create<tensor::InsertSliceOp>( 30227ee33d1SHsiangkai Wang loc, source, dest, retOffsets, retSizes, strides); 30327ee33d1SHsiangkai Wang 30427ee33d1SHsiangkai Wang return insertSliceOp; 30527ee33d1SHsiangkai Wang } 30627ee33d1SHsiangkai Wang 30727ee33d1SHsiangkai Wang /// Insert transformed height x width data to 6D tensors which it is 30827ee33d1SHsiangkai Wang /// extracted from. 30927ee33d1SHsiangkai Wang Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, 31027ee33d1SHsiangkai Wang Value dest, Value tileHIndex, Value tileWIndex, 31127ee33d1SHsiangkai Wang Value loopNorFIndex, Value loopCorFIndex, int64_t height, 31227ee33d1SHsiangkai Wang int64_t width, int64_t tileHIdx, int64_t tileWIdx, 31327ee33d1SHsiangkai Wang int64_t loopNorFIdx, int64_t loopCorFIdx, 31427ee33d1SHsiangkai Wang int64_t heightIdx, int64_t widthIdx) { 31527ee33d1SHsiangkai Wang int64_t destSize = cast<ShapedType>(dest.getType()).getRank(); 31627ee33d1SHsiangkai Wang auto zeroIndex = builder.getIndexAttr(0); 31727ee33d1SHsiangkai Wang auto oneIndex = builder.getIndexAttr(1); 31827ee33d1SHsiangkai Wang SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex); 31927ee33d1SHsiangkai Wang retOffsets.resize(destSize); 32027ee33d1SHsiangkai Wang retOffsets[tileHIdx] = tileHIndex; 32127ee33d1SHsiangkai Wang retOffsets[tileWIdx] = tileWIndex; 32227ee33d1SHsiangkai Wang retOffsets[loopNorFIdx] = loopNorFIndex; 32327ee33d1SHsiangkai Wang retOffsets[loopCorFIdx] = loopCorFIndex; 32427ee33d1SHsiangkai Wang SmallVector<OpFoldResult> retSizes(destSize, oneIndex); 32527ee33d1SHsiangkai Wang retSizes[heightIdx] = builder.getIndexAttr(height); 32627ee33d1SHsiangkai Wang retSizes[widthIdx] = builder.getIndexAttr(width); 32727ee33d1SHsiangkai Wang SmallVector<OpFoldResult> strides(destSize, oneIndex); 32827ee33d1SHsiangkai Wang 32927ee33d1SHsiangkai Wang auto insertSliceOp = builder.create<tensor::InsertSliceOp>( 33027ee33d1SHsiangkai Wang loc, source, dest, retOffsets, retSizes, strides); 33127ee33d1SHsiangkai Wang 33227ee33d1SHsiangkai Wang return insertSliceOp; 33327ee33d1SHsiangkai Wang } 33427ee33d1SHsiangkai Wang 33527ee33d1SHsiangkai Wang /// This function transforms the filter. The data layout of the filter is FHWC. 33627ee33d1SHsiangkai Wang /// The transformation matrix is 2-dimension. We need to extract H x W from 33727ee33d1SHsiangkai Wang /// FHWC first. We need to generate 2 levels of loops to iterate on F and C. 33827ee33d1SHsiangkai Wang /// After the transformation, we get 33927ee33d1SHsiangkai Wang /// 34027ee33d1SHsiangkai Wang /// scf.for %f = lo_f to hi_f step 1 34127ee33d1SHsiangkai Wang /// scf.for %c = lo_c to hi_c step 1 34227ee33d1SHsiangkai Wang /// %extracted = extract filter<h x w> from filter<f x h x w x c> 34327ee33d1SHsiangkai Wang /// %ret = linalg.matmul G, %extracted 34427ee33d1SHsiangkai Wang /// %ret = linalg.matmul %ret, GT 34527ee33d1SHsiangkai Wang /// %inserted = insert %ret into filter<h x w x c x f> 34627ee33d1SHsiangkai Wang Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, 34727ee33d1SHsiangkai Wang Value retValue, int64_t m, int64_t r, 34827ee33d1SHsiangkai Wang bool leftTransform = true, bool rightTransform = true) { 34927ee33d1SHsiangkai Wang // Map from (m, r) to G transform matrix. 35027ee33d1SHsiangkai Wang static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 35127ee33d1SHsiangkai Wang GMatrices = { 35227ee33d1SHsiangkai Wang {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, 35327ee33d1SHsiangkai Wang {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, 35427ee33d1SHsiangkai Wang {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, 35527ee33d1SHsiangkai Wang }; 35627ee33d1SHsiangkai Wang 35727ee33d1SHsiangkai Wang // Map from (m, r) to GT transform matrix. 35827ee33d1SHsiangkai Wang static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 35927ee33d1SHsiangkai Wang GTMatrices = { 36027ee33d1SHsiangkai Wang {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, 36127ee33d1SHsiangkai Wang {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, 36227ee33d1SHsiangkai Wang {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, 36327ee33d1SHsiangkai Wang }; 36427ee33d1SHsiangkai Wang 36527ee33d1SHsiangkai Wang auto filterType = cast<ShapedType>(filter.getType()); 36627ee33d1SHsiangkai Wang Type elementType = filterType.getElementType(); 36727ee33d1SHsiangkai Wang auto filterShape = filterType.getShape(); // F, H, W, C 36827ee33d1SHsiangkai Wang int64_t filterF = filterShape[0]; 36927ee33d1SHsiangkai Wang int64_t filterH = filterShape[1]; 37027ee33d1SHsiangkai Wang int64_t filterW = filterShape[2]; 37127ee33d1SHsiangkai Wang int64_t filterC = filterShape[3]; 37227ee33d1SHsiangkai Wang 37327ee33d1SHsiangkai Wang if (filterH != r && filterH != 1) 37427ee33d1SHsiangkai Wang return Value(); 37527ee33d1SHsiangkai Wang if (filterW != r && filterW != 1) 37627ee33d1SHsiangkai Wang return Value(); 37727ee33d1SHsiangkai Wang 37827ee33d1SHsiangkai Wang Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); 37927ee33d1SHsiangkai Wang auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, 38027ee33d1SHsiangkai Wang ValueRange args) -> scf::ValueVector { 38127ee33d1SHsiangkai Wang Value FIter = ivs[0]; 38227ee33d1SHsiangkai Wang Value CIter = ivs[1]; 38327ee33d1SHsiangkai Wang 38427ee33d1SHsiangkai Wang // Extract (H, W) from (F, H, W, C). 38527ee33d1SHsiangkai Wang auto extractFilter = 38627ee33d1SHsiangkai Wang extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx, 38727ee33d1SHsiangkai Wang zeroIdx, filterH, filterW, /*loopNorFIdx=*/0, 38827ee33d1SHsiangkai Wang /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); 38927ee33d1SHsiangkai Wang 39027ee33d1SHsiangkai Wang TransformMapKeyTy key = {m, r}; 39127ee33d1SHsiangkai Wang int64_t retRows = 1; 39227ee33d1SHsiangkai Wang Value matmulRetValue = extractFilter; 393326287fdSThomas Preud'homme Value zero = builder.create<arith::ConstantOp>( 394326287fdSThomas Preud'homme loc, rewriter.getZeroAttr(elementType)); 39527ee33d1SHsiangkai Wang if (leftTransform) { 39627ee33d1SHsiangkai Wang // Get constant transform matrix G. 39727ee33d1SHsiangkai Wang auto it = GMatrices.find(key); 39827ee33d1SHsiangkai Wang if (it == GMatrices.end()) 39927ee33d1SHsiangkai Wang return {}; 40027ee33d1SHsiangkai Wang const TransformMatrix &GMatrix = it->second; 40127ee33d1SHsiangkai Wang 40227ee33d1SHsiangkai Wang retRows = GMatrix.rows; 40327ee33d1SHsiangkai Wang auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); 404326287fdSThomas Preud'homme auto empty = 405326287fdSThomas Preud'homme builder 406326287fdSThomas Preud'homme .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) 407326287fdSThomas Preud'homme .getResult(); 408326287fdSThomas Preud'homme auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 40927ee33d1SHsiangkai Wang 41027ee33d1SHsiangkai Wang Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType); 41127ee33d1SHsiangkai Wang // Multiply G x g. 41227ee33d1SHsiangkai Wang auto matmulOp = builder.create<linalg::MatmulOp>( 41327ee33d1SHsiangkai Wang loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init}); 41427ee33d1SHsiangkai Wang matmulRetValue = matmulOp.getResult(0); 41527ee33d1SHsiangkai Wang } 41627ee33d1SHsiangkai Wang 41727ee33d1SHsiangkai Wang if (rightTransform) { 41827ee33d1SHsiangkai Wang // Get constant transform matrix GT. 41927ee33d1SHsiangkai Wang auto it = GTMatrices.find(key); 42027ee33d1SHsiangkai Wang if (it == GTMatrices.end()) 42127ee33d1SHsiangkai Wang return {}; 42227ee33d1SHsiangkai Wang const TransformMatrix >Matrix = it->second; 42327ee33d1SHsiangkai Wang 42427ee33d1SHsiangkai Wang auto matmulType = 42527ee33d1SHsiangkai Wang RankedTensorType::get({retRows, GTMatrix.cols}, elementType); 426326287fdSThomas Preud'homme auto empty = 427326287fdSThomas Preud'homme builder 428326287fdSThomas Preud'homme .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) 429326287fdSThomas Preud'homme .getResult(); 430326287fdSThomas Preud'homme auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 43127ee33d1SHsiangkai Wang 43227ee33d1SHsiangkai Wang Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType); 43327ee33d1SHsiangkai Wang // Multiply u = (G x g) x GT. 43427ee33d1SHsiangkai Wang auto matmulOp = builder.create<linalg::MatmulOp>( 43527ee33d1SHsiangkai Wang loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init}); 43627ee33d1SHsiangkai Wang matmulRetValue = matmulOp.getResult(0); 43727ee33d1SHsiangkai Wang } 43827ee33d1SHsiangkai Wang 43927ee33d1SHsiangkai Wang // Insert (H, W) to (H, W, C, F). 44027ee33d1SHsiangkai Wang int64_t retHeight = leftTransform ? m + r - 1 : 1; 44127ee33d1SHsiangkai Wang int64_t retWidth = rightTransform ? m + r - 1 : 1; 44227ee33d1SHsiangkai Wang 44327ee33d1SHsiangkai Wang auto insertSliceOp = 44427ee33d1SHsiangkai Wang insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter, 44527ee33d1SHsiangkai Wang zeroIdx, zeroIdx, retHeight, retWidth, 44627ee33d1SHsiangkai Wang /*loopNorFIdx=*/3, /*loopCorFIdx=*/2, 44727ee33d1SHsiangkai Wang /*heightIdx=*/0, /*widthIdx=*/1); 44827ee33d1SHsiangkai Wang 44927ee33d1SHsiangkai Wang return {insertSliceOp}; 45027ee33d1SHsiangkai Wang }; 45127ee33d1SHsiangkai Wang 45227ee33d1SHsiangkai Wang auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF); 45327ee33d1SHsiangkai Wang auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC); 45427ee33d1SHsiangkai Wang auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); 45527ee33d1SHsiangkai Wang scf::LoopNest loops = scf::buildLoopNest( 45627ee33d1SHsiangkai Wang rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound}, 45727ee33d1SHsiangkai Wang {oneStep, oneStep}, {retValue}, buildBody); 45827ee33d1SHsiangkai Wang return loops.results[0]; 45927ee33d1SHsiangkai Wang } 46027ee33d1SHsiangkai Wang 46127ee33d1SHsiangkai Wang /// This function transforms the input. The data layout of the input is NHWC. 46227ee33d1SHsiangkai Wang /// The transformation matrix is 2-dimension. We need to extract H x W from 46327ee33d1SHsiangkai Wang /// NHWC first. We need to generate 2 levels of loops to iterate on N and C. 46427ee33d1SHsiangkai Wang /// After the transformation, we get 46527ee33d1SHsiangkai Wang /// 46627ee33d1SHsiangkai Wang /// scf.for %h = 0 to tileH step 1 46727ee33d1SHsiangkai Wang /// scf.for %w = 0 to tileW step 1 46827ee33d1SHsiangkai Wang /// scf.for %n = 0 to N step 1 46927ee33d1SHsiangkai Wang /// scf.for %c = 0 to C step 1 47027ee33d1SHsiangkai Wang /// %extracted = extract %extracted<alphaH x alphaW> from 47127ee33d1SHsiangkai Wang /// %input<N x H x W x C> 47227ee33d1SHsiangkai Wang /// at [%n, (%h x m), (%w x m), %c] 47327ee33d1SHsiangkai Wang /// %ret = linalg.matmul BT, %extracted 47427ee33d1SHsiangkai Wang /// %ret = linalg.matmul %ret, B 47527ee33d1SHsiangkai Wang /// %inserted = insert %ret<alphaH x alphaW> into 47627ee33d1SHsiangkai Wang /// %output<alphaH x alphaW x tileH x tileW x N x C> 47727ee33d1SHsiangkai Wang /// at [0, 0, %h, %w, %n, %c] 47827ee33d1SHsiangkai Wang Value inputTransform(RewriterBase &rewriter, Location loc, Value input, 47927ee33d1SHsiangkai Wang Value retValue, int64_t m, int64_t r, 48027ee33d1SHsiangkai Wang bool leftTransform = true, bool rightTransform = true) { 48127ee33d1SHsiangkai Wang // Map from (m, r) to BT transform matrix. 48227ee33d1SHsiangkai Wang static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 48327ee33d1SHsiangkai Wang BTMatrices = { 48427ee33d1SHsiangkai Wang {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)}, 48527ee33d1SHsiangkai Wang {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)}, 48627ee33d1SHsiangkai Wang {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)}, 48727ee33d1SHsiangkai Wang }; 48827ee33d1SHsiangkai Wang 48927ee33d1SHsiangkai Wang // Map from (m, r) to B transform matrix. 49027ee33d1SHsiangkai Wang static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 49127ee33d1SHsiangkai Wang BMatrices = { 49227ee33d1SHsiangkai Wang {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)}, 49327ee33d1SHsiangkai Wang {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)}, 49427ee33d1SHsiangkai Wang {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)}, 49527ee33d1SHsiangkai Wang }; 49627ee33d1SHsiangkai Wang 49727ee33d1SHsiangkai Wang auto inputType = cast<ShapedType>(input.getType()); 49827ee33d1SHsiangkai Wang Type elementType = inputType.getElementType(); 49927ee33d1SHsiangkai Wang auto inputShape = inputType.getShape(); // N, H, W, C 50027ee33d1SHsiangkai Wang int64_t inputN = inputShape[0]; 50127ee33d1SHsiangkai Wang int64_t inputC = inputShape[3]; 50227ee33d1SHsiangkai Wang auto valueType = cast<ShapedType>(retValue.getType()); 50327ee33d1SHsiangkai Wang auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C 50427ee33d1SHsiangkai Wang int64_t tileH = valueShape[2]; 50527ee33d1SHsiangkai Wang int64_t tileW = valueShape[3]; 50627ee33d1SHsiangkai Wang int64_t alphaH = leftTransform ? m + r - 1 : 1; 50727ee33d1SHsiangkai Wang int64_t alphaW = rightTransform ? m + r - 1 : 1; 50827ee33d1SHsiangkai Wang 50927ee33d1SHsiangkai Wang auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, 51027ee33d1SHsiangkai Wang ValueRange args) -> scf::ValueVector { 51127ee33d1SHsiangkai Wang Value tileHIter = ivs[0]; 51227ee33d1SHsiangkai Wang Value tileWIter = ivs[1]; 51327ee33d1SHsiangkai Wang Value NIter = ivs[2]; 51427ee33d1SHsiangkai Wang Value CIter = ivs[3]; 51527ee33d1SHsiangkai Wang 51627ee33d1SHsiangkai Wang auto context = builder.getContext(); 517*f20b8e35SDmitriy Smirnov 518*f20b8e35SDmitriy Smirnov auto identityAffineMap = rewriter.getMultiDimIdentityMap(1); 51927ee33d1SHsiangkai Wang auto affineMap = 52027ee33d1SHsiangkai Wang AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); 521*f20b8e35SDmitriy Smirnov Value heightOffset = builder.create<affine::AffineApplyOp>( 522*f20b8e35SDmitriy Smirnov loc, leftTransform ? affineMap : identityAffineMap, tileHIter); 523*f20b8e35SDmitriy Smirnov Value widthOffset = builder.create<affine::AffineApplyOp>( 524*f20b8e35SDmitriy Smirnov loc, rightTransform ? affineMap : identityAffineMap, tileWIter); 52527ee33d1SHsiangkai Wang 52627ee33d1SHsiangkai Wang // Extract (H, W) from (N, H, W, C). 52727ee33d1SHsiangkai Wang auto extractInput = 52827ee33d1SHsiangkai Wang extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset, 52927ee33d1SHsiangkai Wang widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0, 53027ee33d1SHsiangkai Wang /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); 53127ee33d1SHsiangkai Wang 53227ee33d1SHsiangkai Wang TransformMapKeyTy key = {m, r}; 53327ee33d1SHsiangkai Wang int64_t retRows = 1; 53427ee33d1SHsiangkai Wang int64_t retCols = 1; 53527ee33d1SHsiangkai Wang Value matmulRetValue = extractInput; 536326287fdSThomas Preud'homme Value zero = builder.create<arith::ConstantOp>( 537326287fdSThomas Preud'homme loc, rewriter.getZeroAttr(elementType)); 53827ee33d1SHsiangkai Wang if (leftTransform) { 53927ee33d1SHsiangkai Wang // Get constant transform matrix BT. 54027ee33d1SHsiangkai Wang auto it = BTMatrices.find(key); 54127ee33d1SHsiangkai Wang if (it == BTMatrices.end()) 54227ee33d1SHsiangkai Wang return {}; 54327ee33d1SHsiangkai Wang const TransformMatrix &BTMatrix = it->second; 54427ee33d1SHsiangkai Wang 54527ee33d1SHsiangkai Wang retRows = BTMatrix.rows; 54627ee33d1SHsiangkai Wang auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType); 547326287fdSThomas Preud'homme auto empty = 548326287fdSThomas Preud'homme builder 549326287fdSThomas Preud'homme .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) 550326287fdSThomas Preud'homme .getResult(); 551326287fdSThomas Preud'homme auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 55227ee33d1SHsiangkai Wang 55327ee33d1SHsiangkai Wang Value BT = 55427ee33d1SHsiangkai Wang create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); 55527ee33d1SHsiangkai Wang // Multiply BT x d. 55627ee33d1SHsiangkai Wang auto matmulOp = builder.create<linalg::MatmulOp>( 55727ee33d1SHsiangkai Wang loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init}); 55827ee33d1SHsiangkai Wang matmulRetValue = matmulOp.getResult(0); 55927ee33d1SHsiangkai Wang } 56027ee33d1SHsiangkai Wang 56127ee33d1SHsiangkai Wang if (rightTransform) { 56227ee33d1SHsiangkai Wang // Get constant transform matrix B. 56327ee33d1SHsiangkai Wang auto it = BMatrices.find(key); 56427ee33d1SHsiangkai Wang if (it == BMatrices.end()) 56527ee33d1SHsiangkai Wang return {}; 56627ee33d1SHsiangkai Wang const TransformMatrix &BMatrix = it->second; 56727ee33d1SHsiangkai Wang 56827ee33d1SHsiangkai Wang retCols = BMatrix.cols; 56927ee33d1SHsiangkai Wang auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); 570326287fdSThomas Preud'homme auto empty = 571326287fdSThomas Preud'homme builder 572326287fdSThomas Preud'homme .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) 573326287fdSThomas Preud'homme .getResult(); 574326287fdSThomas Preud'homme auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 57527ee33d1SHsiangkai Wang Value B = 57627ee33d1SHsiangkai Wang create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); 57727ee33d1SHsiangkai Wang // Multiply v = (BT x d) x B. 57827ee33d1SHsiangkai Wang auto matmulOp = builder.create<linalg::MatmulOp>( 57927ee33d1SHsiangkai Wang loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init}); 58027ee33d1SHsiangkai Wang matmulRetValue = matmulOp.getResult(0); 58127ee33d1SHsiangkai Wang } 58227ee33d1SHsiangkai Wang 58327ee33d1SHsiangkai Wang // Insert (H, W) to (H, W, tileH, tileW, N, C). 58427ee33d1SHsiangkai Wang auto combinedVal = insert2DDataTo6D( 58527ee33d1SHsiangkai Wang builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter, 58627ee33d1SHsiangkai Wang CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5, 58727ee33d1SHsiangkai Wang /*heightIdx=*/0, /*widthIdx=*/1); 58827ee33d1SHsiangkai Wang 58927ee33d1SHsiangkai Wang return {combinedVal}; 59027ee33d1SHsiangkai Wang }; 59127ee33d1SHsiangkai Wang 59227ee33d1SHsiangkai Wang auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); 59327ee33d1SHsiangkai Wang auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH); 59427ee33d1SHsiangkai Wang auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW); 59527ee33d1SHsiangkai Wang auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN); 59627ee33d1SHsiangkai Wang auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC); 59727ee33d1SHsiangkai Wang auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); 59827ee33d1SHsiangkai Wang scf::LoopNest loops = scf::buildLoopNest( 59927ee33d1SHsiangkai Wang rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, 60027ee33d1SHsiangkai Wang {tileHBound, tileWBound, nUpperBound, cUpperBound}, 60127ee33d1SHsiangkai Wang {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody); 60227ee33d1SHsiangkai Wang return loops.results[0]; 60327ee33d1SHsiangkai Wang } 60427ee33d1SHsiangkai Wang 6057d246e84SHsiangkai Wang /// This function generates linalg.batch_matmul to multiply input with filter. 6067d246e84SHsiangkai Wang /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat 6077d246e84SHsiangkai Wang /// tileH x tileW x H x W data as the 1-dimensional data array. That is to 6087d246e84SHsiangkai Wang /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this 6097d246e84SHsiangkai Wang /// way, we can convert 6-dimensional inputs to 3-dimensional representation 6107d246e84SHsiangkai Wang /// that is suitable for linalg.batch_matmul. 6117d246e84SHsiangkai Wang /// 6127d246e84SHsiangkai Wang /// Batched matmul will do the matrix multiply with the reduction on channel. 6137d246e84SHsiangkai Wang /// 6147d246e84SHsiangkai Wang /// We get 6157d246e84SHsiangkai Wang /// 6167d246e84SHsiangkai Wang /// %collapsed_input = tensor.collapse_shape %input 6177d246e84SHsiangkai Wang /// %collapsed_filter = tensor.collapse_shape %filter 6187d246e84SHsiangkai Wang /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter 6197d246e84SHsiangkai Wang /// %expanded_ret = tensor.expand_shape %ret 6207d246e84SHsiangkai Wang /// 6217d246e84SHsiangkai Wang /// After this function, we get return value with data layout 6227d246e84SHsiangkai Wang /// (tileH, tileW, H, W, N, F). 6237d246e84SHsiangkai Wang static Value matrixMultiply(RewriterBase &rewriter, Location loc, 6247d246e84SHsiangkai Wang Value transformedFilter, Value transformedInput, 6257d246e84SHsiangkai Wang Type outputElementType) { 6267d246e84SHsiangkai Wang // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter. 6277d246e84SHsiangkai Wang auto filterType = cast<ShapedType>(transformedFilter.getType()); 6287d246e84SHsiangkai Wang assert(filterType.hasStaticShape() && "only support static shapes."); 6297d246e84SHsiangkai Wang ArrayRef<int64_t> filterShape = filterType.getShape(); 6307d246e84SHsiangkai Wang Type filterElementType = filterType.getElementType(); 6317d246e84SHsiangkai Wang auto filterReassocType = RankedTensorType::get( 6327d246e84SHsiangkai Wang {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, 6337d246e84SHsiangkai Wang filterElementType); 6347d246e84SHsiangkai Wang SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}}; 6357d246e84SHsiangkai Wang Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>( 6367d246e84SHsiangkai Wang loc, filterReassocType, transformedFilter, filterReassoc); 6377d246e84SHsiangkai Wang 6387d246e84SHsiangkai Wang // Convert (alphaH, alphaW, tileH, tileW, N, C) to 6397d246e84SHsiangkai Wang // (alphaH x alphaW, tileH x tileW x N, C) for input. 6407d246e84SHsiangkai Wang auto inputType = cast<ShapedType>(transformedInput.getType()); 6417d246e84SHsiangkai Wang assert(inputType.hasStaticShape() && "only support static shapes."); 6427d246e84SHsiangkai Wang ArrayRef<int64_t> inputShape = inputType.getShape(); 6437d246e84SHsiangkai Wang Type inputElementType = inputType.getElementType(); 6447d246e84SHsiangkai Wang auto inputReassocType = RankedTensorType::get( 6457d246e84SHsiangkai Wang {inputShape[0] * inputShape[1], 6467d246e84SHsiangkai Wang inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, 6477d246e84SHsiangkai Wang inputElementType); 6487d246e84SHsiangkai Wang SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; 6497d246e84SHsiangkai Wang Value collapseInput = rewriter.create<tensor::CollapseShapeOp>( 6507d246e84SHsiangkai Wang loc, inputReassocType, transformedInput, inputReassoc); 6517d246e84SHsiangkai Wang 6527d246e84SHsiangkai Wang // Batched matrix multiply. 6537d246e84SHsiangkai Wang auto matmulType = RankedTensorType::get( 6547d246e84SHsiangkai Wang {inputShape[0] * inputShape[1], 6557d246e84SHsiangkai Wang inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, 6567d246e84SHsiangkai Wang outputElementType); 657326287fdSThomas Preud'homme Value empty = rewriter 658326287fdSThomas Preud'homme .create<tensor::EmptyOp>(loc, matmulType.getShape(), 659326287fdSThomas Preud'homme outputElementType) 660326287fdSThomas Preud'homme .getResult(); 661326287fdSThomas Preud'homme Value zero = rewriter.create<arith::ConstantOp>( 662326287fdSThomas Preud'homme loc, rewriter.getZeroAttr(outputElementType)); 663326287fdSThomas Preud'homme Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0); 6647d246e84SHsiangkai Wang 6657d246e84SHsiangkai Wang auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( 6667d246e84SHsiangkai Wang loc, matmulType, ValueRange({collapseInput, collapseFilter}), 6677d246e84SHsiangkai Wang ValueRange{init}); 6687d246e84SHsiangkai Wang 6697d246e84SHsiangkai Wang // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F) 6707d246e84SHsiangkai Wang // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F). 6717d246e84SHsiangkai Wang SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}}; 6727d246e84SHsiangkai Wang auto outputReassocType = 6737d246e84SHsiangkai Wang RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], 6747d246e84SHsiangkai Wang inputShape[3], inputShape[4], filterShape[3]}, 6757d246e84SHsiangkai Wang outputElementType); 6767d246e84SHsiangkai Wang auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( 6777d246e84SHsiangkai Wang loc, outputReassocType, matmulOp.getResult(0), outputReassoc); 6787d246e84SHsiangkai Wang return expandOutput; 6797d246e84SHsiangkai Wang } 6807d246e84SHsiangkai Wang 68127ee33d1SHsiangkai Wang /// This function transforms the output. The data layout of the output is HWNF. 68227ee33d1SHsiangkai Wang /// The transformation matrix is 2-dimension. We need to extract H x W from 68327ee33d1SHsiangkai Wang /// HWNF first. We need to generate 2 levels of loops to iterate on N and F. 68427ee33d1SHsiangkai Wang /// After the transformation, we get 68527ee33d1SHsiangkai Wang /// 68627ee33d1SHsiangkai Wang /// scf.for %h = 0 to tileH step 1 68727ee33d1SHsiangkai Wang /// scf.for %w = 0 to tileW step 1 68827ee33d1SHsiangkai Wang /// scf.for %n = 0 to N step 1 68927ee33d1SHsiangkai Wang /// scf.for %f = 0 to F step 1 69027ee33d1SHsiangkai Wang /// %extracted = extract %extracted<alphaH x alphaW> from 69127ee33d1SHsiangkai Wang /// %input<alphaH x alphaW x tileH x tileW x N x F> 69227ee33d1SHsiangkai Wang /// at [0, 0, %h, %w, %n, %f] 69327ee33d1SHsiangkai Wang /// %ret = linalg.matmul AT, %extracted 69427ee33d1SHsiangkai Wang /// %ret = linalg.matmul %ret, A 69527ee33d1SHsiangkai Wang /// %inserted = insert %ret<alphaH x alphaW> into 69627ee33d1SHsiangkai Wang /// output<N x H x W x F> 69727ee33d1SHsiangkai Wang /// at [%n, (%h x m), (%w x m), %f] 69827ee33d1SHsiangkai Wang Value outputTransform(RewriterBase &rewriter, Location loc, Value value, 69927ee33d1SHsiangkai Wang Value output, int64_t m, int64_t r, 70027ee33d1SHsiangkai Wang bool leftTransform = true, bool rightTransform = true) { 70127ee33d1SHsiangkai Wang // Map from (m, r) to AT transform matrix. 70227ee33d1SHsiangkai Wang static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 70327ee33d1SHsiangkai Wang ATMatrices = { 70427ee33d1SHsiangkai Wang {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, 70527ee33d1SHsiangkai Wang {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, 70627ee33d1SHsiangkai Wang {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, 70727ee33d1SHsiangkai Wang }; 70827ee33d1SHsiangkai Wang 70927ee33d1SHsiangkai Wang // Map from (m, r) to A transform matrix. 71027ee33d1SHsiangkai Wang static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 71127ee33d1SHsiangkai Wang AMatrices = { 71227ee33d1SHsiangkai Wang {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, 71327ee33d1SHsiangkai Wang {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, 71427ee33d1SHsiangkai Wang {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, 71527ee33d1SHsiangkai Wang }; 71627ee33d1SHsiangkai Wang 71727ee33d1SHsiangkai Wang auto valueType = cast<ShapedType>(value.getType()); 71827ee33d1SHsiangkai Wang Type elementType = valueType.getElementType(); 71927ee33d1SHsiangkai Wang auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F 72027ee33d1SHsiangkai Wang int64_t valueH = valueShape[0]; 72127ee33d1SHsiangkai Wang int64_t valueW = valueShape[1]; 72227ee33d1SHsiangkai Wang int64_t valueN = valueShape[4]; 72327ee33d1SHsiangkai Wang int64_t valueF = valueShape[5]; 72427ee33d1SHsiangkai Wang int64_t alphaH = leftTransform ? m + r - 1 : 1; 72527ee33d1SHsiangkai Wang int64_t alphaW = rightTransform ? m + r - 1 : 1; 72627ee33d1SHsiangkai Wang 72727ee33d1SHsiangkai Wang if (valueH != alphaH && valueH != 1) 72827ee33d1SHsiangkai Wang return Value(); 72927ee33d1SHsiangkai Wang if (valueW != alphaW && valueW != 1) 73027ee33d1SHsiangkai Wang return Value(); 73127ee33d1SHsiangkai Wang 73227ee33d1SHsiangkai Wang auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, 73327ee33d1SHsiangkai Wang ValueRange args) -> scf::ValueVector { 734bb4696ceSDmitriy Smirnov auto context = builder.getContext(); 73527ee33d1SHsiangkai Wang Value tileHIter = ivs[0]; 73627ee33d1SHsiangkai Wang Value tileWIter = ivs[1]; 73727ee33d1SHsiangkai Wang Value NIter = ivs[2]; 73827ee33d1SHsiangkai Wang Value FIter = ivs[3]; 73927ee33d1SHsiangkai Wang 74027ee33d1SHsiangkai Wang // Extract (H, W) from (H, W, tileH, tileW, N, F). 74127ee33d1SHsiangkai Wang auto extractValue = 74227ee33d1SHsiangkai Wang extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter, 74327ee33d1SHsiangkai Wang FIter, 2, 3, /*loopNorFIdx=*/4, 74427ee33d1SHsiangkai Wang /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1); 74527ee33d1SHsiangkai Wang 746bb4696ceSDmitriy Smirnov const TransformMapKeyTy key = {m, r}; 747bb4696ceSDmitriy Smirnov const TransformMatrix &AMatrix = AMatrices.at(key); 748bb4696ceSDmitriy Smirnov const TransformMatrix &ATMatrix = ATMatrices.at(key); 749bb4696ceSDmitriy Smirnov int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) * 750bb4696ceSDmitriy Smirnov (leftTransform ? ATMatrix.scalarFactor : 1); 751bb4696ceSDmitriy Smirnov int64_t retCols = rightTransform ? AMatrix.cols : 1; 752bb4696ceSDmitriy Smirnov int64_t retRows = leftTransform ? ATMatrix.rows : 1; 753bb4696ceSDmitriy Smirnov 75427ee33d1SHsiangkai Wang Value matmulRetValue = extractValue; 755326287fdSThomas Preud'homme Value zero = builder.create<arith::ConstantOp>( 756326287fdSThomas Preud'homme loc, rewriter.getZeroAttr(elementType)); 75727ee33d1SHsiangkai Wang 758*f20b8e35SDmitriy Smirnov auto identityAffineMap = rewriter.getMultiDimIdentityMap(1); 759bb4696ceSDmitriy Smirnov auto affineMap = 760bb4696ceSDmitriy Smirnov AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); 761*f20b8e35SDmitriy Smirnov Value heightOffset = builder.create<affine::AffineApplyOp>( 762*f20b8e35SDmitriy Smirnov loc, leftTransform ? affineMap : identityAffineMap, tileHIter); 763*f20b8e35SDmitriy Smirnov Value widthOffset = builder.create<affine::AffineApplyOp>( 764*f20b8e35SDmitriy Smirnov loc, rightTransform ? affineMap : identityAffineMap, tileWIter); 765bb4696ceSDmitriy Smirnov 766bb4696ceSDmitriy Smirnov Value outInitVal = 767bb4696ceSDmitriy Smirnov extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset, 768bb4696ceSDmitriy Smirnov widthOffset, retRows, retCols, 769bb4696ceSDmitriy Smirnov /*loopNorFIdx=*/0, 770bb4696ceSDmitriy Smirnov /*loopCorFIdx=*/3, /*heightIdx=*/1, 771bb4696ceSDmitriy Smirnov /*widthIdx=*/2); 772bb4696ceSDmitriy Smirnov if (leftTransform) { 77327ee33d1SHsiangkai Wang auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); 774bb4696ceSDmitriy Smirnov Value init = outInitVal; 775bb4696ceSDmitriy Smirnov if (rightTransform || scalarFactor != 1) { 776bb4696ceSDmitriy Smirnov auto empty = builder 777bb4696ceSDmitriy Smirnov .create<tensor::EmptyOp>(loc, matmulType.getShape(), 778bb4696ceSDmitriy Smirnov elementType) 779326287fdSThomas Preud'homme .getResult(); 780bb4696ceSDmitriy Smirnov init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 781bb4696ceSDmitriy Smirnov } 78227ee33d1SHsiangkai Wang 78327ee33d1SHsiangkai Wang Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType); 78427ee33d1SHsiangkai Wang // Multiply AT x m. 78527ee33d1SHsiangkai Wang auto matmulOp = builder.create<linalg::MatmulOp>( 78627ee33d1SHsiangkai Wang loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); 78727ee33d1SHsiangkai Wang matmulRetValue = matmulOp.getResult(0); 78827ee33d1SHsiangkai Wang } 78927ee33d1SHsiangkai Wang 79027ee33d1SHsiangkai Wang if (rightTransform) { 79127ee33d1SHsiangkai Wang auto matmulType = 79227ee33d1SHsiangkai Wang RankedTensorType::get({retRows, AMatrix.cols}, elementType); 793bb4696ceSDmitriy Smirnov Value init = outInitVal; 794bb4696ceSDmitriy Smirnov if (scalarFactor != 1) { 795bb4696ceSDmitriy Smirnov auto empty = builder 796bb4696ceSDmitriy Smirnov .create<tensor::EmptyOp>(loc, matmulType.getShape(), 797bb4696ceSDmitriy Smirnov elementType) 798326287fdSThomas Preud'homme .getResult(); 799bb4696ceSDmitriy Smirnov init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 800bb4696ceSDmitriy Smirnov } 80127ee33d1SHsiangkai Wang 80227ee33d1SHsiangkai Wang Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType); 80327ee33d1SHsiangkai Wang // Multiply y = (AT x m) x A. 80427ee33d1SHsiangkai Wang auto matmulOp = builder.create<linalg::MatmulOp>( 80527ee33d1SHsiangkai Wang loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); 80627ee33d1SHsiangkai Wang matmulRetValue = matmulOp.getResult(0); 80727ee33d1SHsiangkai Wang } 80827ee33d1SHsiangkai Wang 809bb4696ceSDmitriy Smirnov if (scalarFactor != 1) { 810bb4696ceSDmitriy Smirnov // Multiply by scalar factor and add outInitVal. 811bb4696ceSDmitriy Smirnov Value scalarFactorValue = builder.create<arith::ConstantOp>( 812bb4696ceSDmitriy Smirnov loc, FloatAttr::get(elementType, scalarFactor)); 81327ee33d1SHsiangkai Wang auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); 81427ee33d1SHsiangkai Wang auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); 81527ee33d1SHsiangkai Wang SmallVector<AffineMap> affineMaps = { 816bb4696ceSDmitriy Smirnov AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap}; 817bb4696ceSDmitriy Smirnov 818bb4696ceSDmitriy Smirnov matmulRetValue = 81927ee33d1SHsiangkai Wang rewriter 82027ee33d1SHsiangkai Wang .create<linalg::GenericOp>( 821bb4696ceSDmitriy Smirnov loc, matmulType, 822bb4696ceSDmitriy Smirnov ValueRange{scalarFactorValue, matmulRetValue}, 823bb4696ceSDmitriy Smirnov ValueRange{outInitVal}, affineMaps, 82427ee33d1SHsiangkai Wang llvm::ArrayRef<utils::IteratorType>{ 82527ee33d1SHsiangkai Wang utils::IteratorType::parallel, 82627ee33d1SHsiangkai Wang utils::IteratorType::parallel}, 82727ee33d1SHsiangkai Wang [&](OpBuilder &nestedBuilder, Location nestedLoc, 82827ee33d1SHsiangkai Wang ValueRange args) { 829bb4696ceSDmitriy Smirnov auto mulf = nestedBuilder.create<arith::MulFOp>( 830bb4696ceSDmitriy Smirnov nestedLoc, args[0], args[1]); 831bb4696ceSDmitriy Smirnov auto addf = nestedBuilder.create<arith::AddFOp>( 832bb4696ceSDmitriy Smirnov nestedLoc, mulf.getResult(), args[2]); 833bb4696ceSDmitriy Smirnov nestedBuilder.create<linalg::YieldOp>(nestedLoc, 834bb4696ceSDmitriy Smirnov addf.getResult()); 83527ee33d1SHsiangkai Wang }) 83627ee33d1SHsiangkai Wang .getResult(0); 83727ee33d1SHsiangkai Wang } 83827ee33d1SHsiangkai Wang 83927ee33d1SHsiangkai Wang // Insert (H, W) to (N, H, W, F). 84027ee33d1SHsiangkai Wang Value combinedVal = 84127ee33d1SHsiangkai Wang insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter, 84227ee33d1SHsiangkai Wang heightOffset, widthOffset, retRows, retCols, 84327ee33d1SHsiangkai Wang /*loopNorFIdx=*/0, 84427ee33d1SHsiangkai Wang /*loopCorFIdx=*/3, /*heightIdx=*/1, 84527ee33d1SHsiangkai Wang /*widthIdx=*/2); 84627ee33d1SHsiangkai Wang 84727ee33d1SHsiangkai Wang return {combinedVal}; 84827ee33d1SHsiangkai Wang }; 84927ee33d1SHsiangkai Wang 85027ee33d1SHsiangkai Wang int64_t tilwH = valueShape[2]; 85127ee33d1SHsiangkai Wang int64_t tileW = valueShape[3]; 85227ee33d1SHsiangkai Wang auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); 85327ee33d1SHsiangkai Wang auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH); 85427ee33d1SHsiangkai Wang auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW); 85527ee33d1SHsiangkai Wang auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN); 85627ee33d1SHsiangkai Wang auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF); 85727ee33d1SHsiangkai Wang auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); 85827ee33d1SHsiangkai Wang scf::LoopNest loops = scf::buildLoopNest( 85927ee33d1SHsiangkai Wang rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, 86027ee33d1SHsiangkai Wang {tileHBound, tileWBound, nUpperBound, fUpperBound}, 86127ee33d1SHsiangkai Wang {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody); 86227ee33d1SHsiangkai Wang return loops.results[0]; 86327ee33d1SHsiangkai Wang } 86427ee33d1SHsiangkai Wang 8657d246e84SHsiangkai Wang /// Create an empty tensor with alignedType and insert the value into the 8667d246e84SHsiangkai Wang /// created empty tensor with aligned size. 8677d246e84SHsiangkai Wang static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, 8687d246e84SHsiangkai Wang Value value, ArrayRef<int64_t> alignedShape) { 8697d246e84SHsiangkai Wang auto valueType = cast<ShapedType>(value.getType()); 8707d246e84SHsiangkai Wang Type elementType = valueType.getElementType(); 8717d246e84SHsiangkai Wang auto alignedType = RankedTensorType::get(alignedShape, elementType); 8727d246e84SHsiangkai Wang Value padValue = rewriter.create<arith::ConstantOp>( 8737d246e84SHsiangkai Wang loc, elementType, rewriter.getZeroAttr(elementType)); 8747d246e84SHsiangkai Wang 8757d246e84SHsiangkai Wang return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value, 8767d246e84SHsiangkai Wang padValue, false); 8777d246e84SHsiangkai Wang } 8787d246e84SHsiangkai Wang 8797d246e84SHsiangkai Wang /// Extract sub-tensor with extractedType from value. 8807d246e84SHsiangkai Wang static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, 8817d246e84SHsiangkai Wang Value value, 8827d246e84SHsiangkai Wang RankedTensorType extractedType) { 8837d246e84SHsiangkai Wang OpFoldResult zeroIndex = rewriter.getIndexAttr(0); 8847d246e84SHsiangkai Wang OpFoldResult oneIndex = rewriter.getIndexAttr(1); 8857d246e84SHsiangkai Wang SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); 8867d246e84SHsiangkai Wang SmallVector<OpFoldResult, 4> strides(4, oneIndex); 8877d246e84SHsiangkai Wang 8887d246e84SHsiangkai Wang ArrayRef<int64_t> extractedShape = extractedType.getShape(); 8897d246e84SHsiangkai Wang SmallVector<OpFoldResult> sizes = 8907d246e84SHsiangkai Wang getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape)); 8917d246e84SHsiangkai Wang 8927d246e84SHsiangkai Wang return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value, 8937d246e84SHsiangkai Wang offsets, sizes, strides); 8947d246e84SHsiangkai Wang } 8957d246e84SHsiangkai Wang 8967d246e84SHsiangkai Wang /// Utility function to check all values in the attribute are 1. 8977d246e84SHsiangkai Wang static bool hasAllOneValues(DenseIntElementsAttr attr) { 8987d246e84SHsiangkai Wang return llvm::all_of( 8997d246e84SHsiangkai Wang attr, [](const APInt &element) { return element.getSExtValue() == 1; }); 9007d246e84SHsiangkai Wang } 9017d246e84SHsiangkai Wang 9027d246e84SHsiangkai Wang /// A helper function to convert linalg.conv_2d_nhwc_fhwc to 9037d246e84SHsiangkai Wang /// linalg.winograd_*_transform ops. 9047d246e84SHsiangkai Wang static FailureOr<Operation *> 9057d246e84SHsiangkai Wang winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, 9067d246e84SHsiangkai Wang int64_t m, int64_t r) { 9077d246e84SHsiangkai Wang Value input = convOp.getInputs()[0]; 9087d246e84SHsiangkai Wang Value filter = convOp.getInputs()[1]; 9097d246e84SHsiangkai Wang Value output = convOp.getOutputs()[0]; 9107d246e84SHsiangkai Wang auto inputType = cast<ShapedType>(input.getType()); 9117d246e84SHsiangkai Wang auto filterType = cast<ShapedType>(filter.getType()); 9127d246e84SHsiangkai Wang auto outputType = cast<ShapedType>(output.getType()); 9137d246e84SHsiangkai Wang 9147d246e84SHsiangkai Wang if (!inputType.hasStaticShape()) 9157d246e84SHsiangkai Wang return rewriter.notifyMatchFailure(convOp, 9167d246e84SHsiangkai Wang "expected a static shape for the input"); 9177d246e84SHsiangkai Wang 9187d246e84SHsiangkai Wang if (!filterType.hasStaticShape()) 9197d246e84SHsiangkai Wang return rewriter.notifyMatchFailure( 9207d246e84SHsiangkai Wang convOp, "expected a static shape for the filter"); 9217d246e84SHsiangkai Wang 9227d246e84SHsiangkai Wang if (!hasAllOneValues(convOp.getDilations())) 9237d246e84SHsiangkai Wang return rewriter.notifyMatchFailure(convOp, 9247d246e84SHsiangkai Wang "expected all ones for dilations"); 9257d246e84SHsiangkai Wang 9267d246e84SHsiangkai Wang if (!hasAllOneValues(convOp.getStrides())) 9277d246e84SHsiangkai Wang return rewriter.notifyMatchFailure(convOp, "expected all ones for strides"); 9287d246e84SHsiangkai Wang 9297d246e84SHsiangkai Wang ArrayRef<int64_t> filterShape = filterType.getShape(); 9307d246e84SHsiangkai Wang int64_t filterF = filterShape[0]; 9317d246e84SHsiangkai Wang int64_t filterH = filterShape[1]; 9327d246e84SHsiangkai Wang int64_t filterW = filterShape[2]; 9337d246e84SHsiangkai Wang int64_t filterC = filterShape[3]; 9347d246e84SHsiangkai Wang ArrayRef<int64_t> inputShape = inputType.getShape(); 9357d246e84SHsiangkai Wang int64_t inputN = inputShape[0]; 9367d246e84SHsiangkai Wang int64_t inputH = inputShape[1]; 9377d246e84SHsiangkai Wang int64_t inputW = inputShape[2]; 9387d246e84SHsiangkai Wang int64_t inputC = inputShape[3]; 9397d246e84SHsiangkai Wang ArrayRef<int64_t> outputShape = outputType.getShape(); 9407d246e84SHsiangkai Wang int64_t outputN = outputShape[0]; 9417d246e84SHsiangkai Wang int64_t outputH = outputShape[1]; 9427d246e84SHsiangkai Wang int64_t outputW = outputShape[2]; 9437d246e84SHsiangkai Wang int64_t outputF = outputShape[3]; 9447d246e84SHsiangkai Wang 9457d246e84SHsiangkai Wang // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r). 9467d246e84SHsiangkai Wang bool isSupportedFilter = false; 9477d246e84SHsiangkai Wang if (filterH == filterW && filterH == r) 9487d246e84SHsiangkai Wang isSupportedFilter = true; 9497d246e84SHsiangkai Wang if (filterH == r && filterW == 1) 9507d246e84SHsiangkai Wang isSupportedFilter = true; 9517d246e84SHsiangkai Wang if (filterH == 1 && filterW == r) 9527d246e84SHsiangkai Wang isSupportedFilter = true; 9537d246e84SHsiangkai Wang 9547d246e84SHsiangkai Wang if (!isSupportedFilter) 9557d246e84SHsiangkai Wang return rewriter.notifyMatchFailure( 9567d246e84SHsiangkai Wang convOp, "only support filter (r x r), (r x 1) or (1 x r)"); 9577d246e84SHsiangkai Wang 9587d246e84SHsiangkai Wang // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5). 9597d246e84SHsiangkai Wang static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = { 9607d246e84SHsiangkai Wang F_2_3, F_4_3, F_2_5}; 9617d246e84SHsiangkai Wang 9627d246e84SHsiangkai Wang TransformMapKeyTy key = {m, r}; 9637d246e84SHsiangkai Wang auto it = std::find(validConfigs.begin(), validConfigs.end(), key); 9647d246e84SHsiangkai Wang // If we cannot find the constant transformation matrix, it means we do 9657d246e84SHsiangkai Wang // not support this configuration yet. 9667d246e84SHsiangkai Wang if (it == validConfigs.end()) 9677d246e84SHsiangkai Wang return failure(); 9687d246e84SHsiangkai Wang 9697d246e84SHsiangkai Wang // All the criterias are satisfied. We can do Winograd Conv2D. 9707d246e84SHsiangkai Wang Location loc = convOp.getLoc(); 9717d246e84SHsiangkai Wang 9727d246e84SHsiangkai Wang // For F(m x 1, r x 1), we only need to do left side transform. 9737d246e84SHsiangkai Wang bool leftTransform = filterH != 1; 9747d246e84SHsiangkai Wang // For F(1 x m, 1 x r), we only need to do right side transform. 9757d246e84SHsiangkai Wang bool rightTransform = filterW != 1; 9767d246e84SHsiangkai Wang int64_t heightM = leftTransform ? m : 1; 9777d246e84SHsiangkai Wang int64_t widthM = rightTransform ? m : 1; 9787d246e84SHsiangkai Wang int64_t heightR = leftTransform ? r : 1; 9797d246e84SHsiangkai Wang int64_t widthR = rightTransform ? r : 1; 9807d246e84SHsiangkai Wang 9817d246e84SHsiangkai Wang // --- Create operation for filter transform --- 9827d246e84SHsiangkai Wang Type filterElementType = filterType.getElementType(); 9837d246e84SHsiangkai Wang int64_t alphaH = heightM + heightR - 1; 9847d246e84SHsiangkai Wang int64_t alphaW = widthM + widthR - 1; 9857d246e84SHsiangkai Wang int64_t tileH = llvm::divideCeilSigned(outputH, heightM); 9867d246e84SHsiangkai Wang int64_t tileW = llvm::divideCeilSigned(outputW, widthM); 9877d246e84SHsiangkai Wang auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF}, 9887d246e84SHsiangkai Wang filterElementType); 9897d246e84SHsiangkai Wang Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), 9907d246e84SHsiangkai Wang filterElementType); 9917d246e84SHsiangkai Wang auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>( 9927d246e84SHsiangkai Wang loc, retType, filter, retValue, m, r); 9937d246e84SHsiangkai Wang 9947d246e84SHsiangkai Wang // --- Create operation for input transform --- 9957d246e84SHsiangkai Wang 9967d246e84SHsiangkai Wang // When input size - (r - 1) is not aligned with output tile size, we need to 9977d246e84SHsiangkai Wang // pad the input data to create the full tiles as tiling. 9987d246e84SHsiangkai Wang Type inputElementType = inputType.getElementType(); 9997d246e84SHsiangkai Wang int64_t alignedInputH = tileH * heightM + (heightR - 1); 10007d246e84SHsiangkai Wang int64_t alignedInputW = tileW * widthM + (widthR - 1); 10017d246e84SHsiangkai Wang if (alignedInputH != inputH || alignedInputW != inputW) { 10027d246e84SHsiangkai Wang input = padToAlignedTensor(rewriter, loc, input, 10037d246e84SHsiangkai Wang {inputN, alignedInputH, alignedInputW, inputC}); 10047d246e84SHsiangkai Wang } 10057d246e84SHsiangkai Wang 10067d246e84SHsiangkai Wang retType = RankedTensorType::get( 10077d246e84SHsiangkai Wang {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType); 10087d246e84SHsiangkai Wang retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), 10097d246e84SHsiangkai Wang inputElementType); 10107d246e84SHsiangkai Wang auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>( 10117d246e84SHsiangkai Wang loc, retType, input, retValue, m, r); 10127d246e84SHsiangkai Wang 10137d246e84SHsiangkai Wang Type outputElementType = outputType.getElementType(); 10147d246e84SHsiangkai Wang Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter, 10157d246e84SHsiangkai Wang transformedInput, outputElementType); 10167d246e84SHsiangkai Wang 10177d246e84SHsiangkai Wang // --- Create operation for output transform --- 10187d246e84SHsiangkai Wang 10197d246e84SHsiangkai Wang // When output size is not aligned with output tile size, we need to pad the 10207d246e84SHsiangkai Wang // output buffer to insert the full tiles after tiling. 10217d246e84SHsiangkai Wang int64_t alignedOutputH = tileH * heightM; 10227d246e84SHsiangkai Wang int64_t alignedOutputW = tileW * widthM; 10237d246e84SHsiangkai Wang bool isOutputUnaligned = 10247d246e84SHsiangkai Wang ((alignedOutputH != outputH) || (alignedOutputW != outputW)); 10257d246e84SHsiangkai Wang if (isOutputUnaligned) { 10267d246e84SHsiangkai Wang auto alignedOutputType = RankedTensorType::get( 10277d246e84SHsiangkai Wang {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType); 10287d246e84SHsiangkai Wang output = 10297d246e84SHsiangkai Wang padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape()); 10307d246e84SHsiangkai Wang outputType = alignedOutputType; 10317d246e84SHsiangkai Wang } 10327d246e84SHsiangkai Wang 10337d246e84SHsiangkai Wang Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>( 10347d246e84SHsiangkai Wang loc, outputType, matmulRet, output, m, r); 10357d246e84SHsiangkai Wang 10367d246e84SHsiangkai Wang // When output size is not aligned with output tile size, extract the 10377d246e84SHsiangkai Wang // value from the padded buffer. 10387d246e84SHsiangkai Wang if (isOutputUnaligned) { 10397d246e84SHsiangkai Wang transformedOutput = extractFromAlignedTensor( 10407d246e84SHsiangkai Wang rewriter, loc, transformedOutput, 10417d246e84SHsiangkai Wang RankedTensorType::get({outputN, outputH, outputW, outputF}, 10427d246e84SHsiangkai Wang outputElementType)); 10437d246e84SHsiangkai Wang } 10447d246e84SHsiangkai Wang 10457d246e84SHsiangkai Wang rewriter.replaceOp(convOp, transformedOutput); 10467d246e84SHsiangkai Wang 10477d246e84SHsiangkai Wang return transformedOutput.getDefiningOp(); 10487d246e84SHsiangkai Wang } 10497d246e84SHsiangkai Wang 105027ee33d1SHsiangkai Wang /// A helper function to decompose linalg.winograd_filter_transform. 105127ee33d1SHsiangkai Wang FailureOr<Operation *> 105227ee33d1SHsiangkai Wang decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, 105327ee33d1SHsiangkai Wang linalg::WinogradFilterTransformOp op) { 105427ee33d1SHsiangkai Wang Location loc = op.getLoc(); 105527ee33d1SHsiangkai Wang Value filter = op.getFilter(); 105627ee33d1SHsiangkai Wang auto filterType = cast<ShapedType>(filter.getType()); 105727ee33d1SHsiangkai Wang auto filterShape = filterType.getShape(); 105827ee33d1SHsiangkai Wang int64_t filterH = filterShape[1]; 105927ee33d1SHsiangkai Wang int64_t filterW = filterShape[2]; 106027ee33d1SHsiangkai Wang 106127ee33d1SHsiangkai Wang // For F(m x 1, r x 1), we only need to do left side transform. 106227ee33d1SHsiangkai Wang bool leftTransform = filterH != 1; 106327ee33d1SHsiangkai Wang // For F(1 x m, 1 x r), we only need to do right side transform. 106427ee33d1SHsiangkai Wang bool rightTransform = filterW != 1; 106527ee33d1SHsiangkai Wang Value transformedFilter = 106627ee33d1SHsiangkai Wang filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(), 106727ee33d1SHsiangkai Wang op.getR(), leftTransform, rightTransform); 106827ee33d1SHsiangkai Wang if (!transformedFilter) 106927ee33d1SHsiangkai Wang return failure(); 107027ee33d1SHsiangkai Wang 107127ee33d1SHsiangkai Wang rewriter.replaceOp(op, transformedFilter); 107227ee33d1SHsiangkai Wang 107327ee33d1SHsiangkai Wang return transformedFilter.getDefiningOp(); 107427ee33d1SHsiangkai Wang } 107527ee33d1SHsiangkai Wang 107627ee33d1SHsiangkai Wang /// A helper function to decompose linalg.winograd_input_transform. 107727ee33d1SHsiangkai Wang FailureOr<Operation *> 107827ee33d1SHsiangkai Wang decomposeWinogradInputTransformHelper(RewriterBase &rewriter, 107927ee33d1SHsiangkai Wang linalg::WinogradInputTransformOp op) { 108027ee33d1SHsiangkai Wang Location loc = op.getLoc(); 1081*f20b8e35SDmitriy Smirnov Value output = op.getOutput(); 1082*f20b8e35SDmitriy Smirnov auto outputType = cast<ShapedType>(output.getType()); 1083*f20b8e35SDmitriy Smirnov auto outputShape = outputType.getShape(); 1084*f20b8e35SDmitriy Smirnov 1085*f20b8e35SDmitriy Smirnov int64_t outputH = outputShape[0]; 1086*f20b8e35SDmitriy Smirnov int64_t outputW = outputShape[1]; 108727ee33d1SHsiangkai Wang 108827ee33d1SHsiangkai Wang // For F(m x 1, r x 1), we only need to do left side transform. 1089*f20b8e35SDmitriy Smirnov bool leftTransform = outputH != 1; 109027ee33d1SHsiangkai Wang // For F(1 x m, 1 x r), we only need to do right side transform. 1091*f20b8e35SDmitriy Smirnov bool rightTransform = outputW != 1; 109227ee33d1SHsiangkai Wang Value transformedInput = 109327ee33d1SHsiangkai Wang inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(), 109427ee33d1SHsiangkai Wang op.getR(), leftTransform, rightTransform); 109527ee33d1SHsiangkai Wang if (!transformedInput) 109627ee33d1SHsiangkai Wang return failure(); 109727ee33d1SHsiangkai Wang 109827ee33d1SHsiangkai Wang rewriter.replaceOp(op, transformedInput); 109927ee33d1SHsiangkai Wang 110027ee33d1SHsiangkai Wang return transformedInput.getDefiningOp(); 110127ee33d1SHsiangkai Wang } 110227ee33d1SHsiangkai Wang 110327ee33d1SHsiangkai Wang /// A helper function to decompose linalg.winograd_output_transform. 110427ee33d1SHsiangkai Wang FailureOr<Operation *> 110527ee33d1SHsiangkai Wang decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, 110627ee33d1SHsiangkai Wang linalg::WinogradOutputTransformOp op) { 110727ee33d1SHsiangkai Wang Location loc = op.getLoc(); 110827ee33d1SHsiangkai Wang Value value = op.getValue(); 110927ee33d1SHsiangkai Wang auto valueType = cast<ShapedType>(value.getType()); 111027ee33d1SHsiangkai Wang auto valueShape = valueType.getShape(); 111127ee33d1SHsiangkai Wang int64_t valueH = valueShape[0]; 111227ee33d1SHsiangkai Wang int64_t valueW = valueShape[1]; 111327ee33d1SHsiangkai Wang 111427ee33d1SHsiangkai Wang // For F(m x 1, r x 1), we only need to do left side transform. 111527ee33d1SHsiangkai Wang bool leftTransform = valueH != 1; 111627ee33d1SHsiangkai Wang // For F(1 x m, 1 x r), we only need to do right side transform. 111727ee33d1SHsiangkai Wang bool rightTransform = valueW != 1; 111827ee33d1SHsiangkai Wang Value transformedOutput = 111927ee33d1SHsiangkai Wang outputTransform(rewriter, loc, value, op.getOutput(), op.getM(), 112027ee33d1SHsiangkai Wang op.getR(), leftTransform, rightTransform); 112127ee33d1SHsiangkai Wang if (!transformedOutput) 112227ee33d1SHsiangkai Wang return failure(); 112327ee33d1SHsiangkai Wang 112427ee33d1SHsiangkai Wang rewriter.replaceOp(op, transformedOutput); 112527ee33d1SHsiangkai Wang 112627ee33d1SHsiangkai Wang return transformedOutput.getDefiningOp(); 112727ee33d1SHsiangkai Wang } 112827ee33d1SHsiangkai Wang 112927ee33d1SHsiangkai Wang /// A rewrite pattern to decompose linalg.winograd_filter_transform operations. 113027ee33d1SHsiangkai Wang class DecomposeWinogradFilterTransform final 113127ee33d1SHsiangkai Wang : public OpRewritePattern<linalg::WinogradFilterTransformOp> { 113227ee33d1SHsiangkai Wang public: 113327ee33d1SHsiangkai Wang using OpRewritePattern::OpRewritePattern; 113427ee33d1SHsiangkai Wang 113527ee33d1SHsiangkai Wang LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, 113627ee33d1SHsiangkai Wang PatternRewriter &rewriter) const override { 113727ee33d1SHsiangkai Wang return decomposeWinogradFilterTransformHelper(rewriter, op); 113827ee33d1SHsiangkai Wang } 113927ee33d1SHsiangkai Wang }; 114027ee33d1SHsiangkai Wang 114127ee33d1SHsiangkai Wang /// A rewrite pattern to decompose linalg.winograd_input_transform operations. 114227ee33d1SHsiangkai Wang class DecomposeWinogradInputTransform final 114327ee33d1SHsiangkai Wang : public OpRewritePattern<linalg::WinogradInputTransformOp> { 114427ee33d1SHsiangkai Wang public: 114527ee33d1SHsiangkai Wang using OpRewritePattern::OpRewritePattern; 114627ee33d1SHsiangkai Wang 114727ee33d1SHsiangkai Wang LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, 114827ee33d1SHsiangkai Wang PatternRewriter &rewriter) const override { 114927ee33d1SHsiangkai Wang return decomposeWinogradInputTransformHelper(rewriter, op); 115027ee33d1SHsiangkai Wang } 115127ee33d1SHsiangkai Wang }; 115227ee33d1SHsiangkai Wang 115327ee33d1SHsiangkai Wang /// A rewrite pattern to decompose linalg.winograd_output_transform operations. 115427ee33d1SHsiangkai Wang class DecomposeWinogradOutputTransform final 115527ee33d1SHsiangkai Wang : public OpRewritePattern<linalg::WinogradOutputTransformOp> { 115627ee33d1SHsiangkai Wang public: 115727ee33d1SHsiangkai Wang using OpRewritePattern::OpRewritePattern; 115827ee33d1SHsiangkai Wang 115927ee33d1SHsiangkai Wang LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op, 116027ee33d1SHsiangkai Wang PatternRewriter &rewriter) const override { 116127ee33d1SHsiangkai Wang return decomposeWinogradOutputTransformHelper(rewriter, op); 116227ee33d1SHsiangkai Wang } 116327ee33d1SHsiangkai Wang }; 116427ee33d1SHsiangkai Wang 11657d246e84SHsiangkai Wang /// A rewrite pattern for Winograd Conv2D algorithm. 11667d246e84SHsiangkai Wang class WinogradConv2DNhwcFhwc final 11677d246e84SHsiangkai Wang : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { 11687d246e84SHsiangkai Wang public: 11697d246e84SHsiangkai Wang using OpRewritePattern::OpRewritePattern; 11707d246e84SHsiangkai Wang WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) 11717d246e84SHsiangkai Wang : OpRewritePattern(context), m(m), r(r) {} 11727d246e84SHsiangkai Wang 11737d246e84SHsiangkai Wang LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, 11747d246e84SHsiangkai Wang PatternRewriter &rewriter) const override { 11757d246e84SHsiangkai Wang if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) 11767d246e84SHsiangkai Wang return failure(); 11777d246e84SHsiangkai Wang 11787d246e84SHsiangkai Wang return success(); 11797d246e84SHsiangkai Wang } 11807d246e84SHsiangkai Wang 11817d246e84SHsiangkai Wang private: 11827d246e84SHsiangkai Wang int64_t m; 11837d246e84SHsiangkai Wang int64_t r; 11847d246e84SHsiangkai Wang }; 11857d246e84SHsiangkai Wang } // end anonymous namespace 11867d246e84SHsiangkai Wang 11877d246e84SHsiangkai Wang //===----------------------------------------------------------------------===// 1188d9c26b9dSHsiangkai Wang FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, 1189d9c26b9dSHsiangkai Wang linalg::Conv2DNhwcFhwcOp op, int64_t m, 1190d9c26b9dSHsiangkai Wang int64_t r) { 1191d9c26b9dSHsiangkai Wang return winogradConv2DHelper(rewriter, op, m, r); 1192d9c26b9dSHsiangkai Wang } 1193d9c26b9dSHsiangkai Wang 1194c4bf9491SHsiangkai Wang FailureOr<Operation *> 1195c4bf9491SHsiangkai Wang decomposeWinogradFilterTransformOp(RewriterBase &rewriter, 1196c4bf9491SHsiangkai Wang linalg::WinogradFilterTransformOp op) { 1197c4bf9491SHsiangkai Wang return decomposeWinogradFilterTransformHelper(rewriter, op); 1198c4bf9491SHsiangkai Wang } 1199c4bf9491SHsiangkai Wang 1200c4bf9491SHsiangkai Wang FailureOr<Operation *> 1201c4bf9491SHsiangkai Wang decomposeWinogradInputTransformOp(RewriterBase &rewriter, 1202c4bf9491SHsiangkai Wang linalg::WinogradInputTransformOp op) { 1203c4bf9491SHsiangkai Wang return decomposeWinogradInputTransformHelper(rewriter, op); 1204c4bf9491SHsiangkai Wang } 1205c4bf9491SHsiangkai Wang 1206c4bf9491SHsiangkai Wang FailureOr<Operation *> 1207c4bf9491SHsiangkai Wang decomposeWinogradOutputTransformOp(RewriterBase &rewriter, 1208c4bf9491SHsiangkai Wang linalg::WinogradOutputTransformOp op) { 1209c4bf9491SHsiangkai Wang return decomposeWinogradOutputTransformHelper(rewriter, op); 1210c4bf9491SHsiangkai Wang } 1211c4bf9491SHsiangkai Wang 12127d246e84SHsiangkai Wang void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, 12137d246e84SHsiangkai Wang int64_t r) { 12147d246e84SHsiangkai Wang MLIRContext *context = patterns.getContext(); 12157d246e84SHsiangkai Wang // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw 12167d246e84SHsiangkai Wang patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r); 12177d246e84SHsiangkai Wang } 12187d246e84SHsiangkai Wang 121927ee33d1SHsiangkai Wang void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { 122027ee33d1SHsiangkai Wang MLIRContext *context = patterns.getContext(); 122127ee33d1SHsiangkai Wang patterns 122227ee33d1SHsiangkai Wang .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform, 122327ee33d1SHsiangkai Wang DecomposeWinogradOutputTransform>(context); 122427ee33d1SHsiangkai Wang } 122527ee33d1SHsiangkai Wang 12267d246e84SHsiangkai Wang } // end namespace linalg 12277d246e84SHsiangkai Wang } // end namespace mlir 1228