1 //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implement Winograd Conv2D algorithm. The implementation is based on the 10 // paper: Fast Algorithms for Convolutional Neural Networks 11 // (https://arxiv.org/abs/1509.09308) 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "llvm/Support/MathExtras.h" 23 24 namespace mlir { 25 namespace linalg { 26 27 namespace { 28 29 // clang-format off 30 /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its 31 /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), 32 /// m is the output dimension and r is the filter dimension, is 33 /// 34 /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A 35 /// 36 /// g is filter and d is input data. We need to prepare 6 constant 37 /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. 38 /// 39 /// The following tables define these constant transformation matrices for 40 /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) 41 constexpr float G_2x2_3x3[] = { 42 -1, 0, 0, 43 1./2, -1./2, 1./2, 44 1./2, 1./2, 1./2, 45 0, 0, 1 46 }; 47 48 constexpr float GT_2x2_3x3[] = { 49 -1, 1./2, 1./2, 0, 50 0, -1./2, 1./2, 0, 51 0, 1./2, 1./2, 1 52 }; 53 54 constexpr float BT_2x2_3x3[] = { 55 -1, 0, 1, 0, 56 0, -1, 1, 0, 57 0, 1, 1, 0, 58 0, -1, 0, 1 59 }; 60 61 constexpr float B_2x2_3x3[] = { 62 -1, 0, 0, 0, 63 0, -1, 1, -1, 64 1, 1, 1, 0, 65 0, 0, 0, 1 66 }; 67 68 constexpr float AT_2x2_3x3[] = { 69 1, 1, 1, 0, 70 0, -1, 1, 1 71 }; 72 73 constexpr float A_2x2_3x3[] = { 74 1, 0, 75 1, -1, 76 1, 1, 77 0, 1 78 }; 79 80 constexpr float G_4x4_3x3[] = { 81 1, 0, 0, 82 -1./3, 1./3, -1./3, 83 -1./3, -1./3, -1./3, 84 1./12, -1./6, 1./3, 85 1./12, 1./6, 1./3, 86 0, 0, 1 87 }; 88 89 constexpr float GT_4x4_3x3[] = { 90 1, -1./3, -1./3, 1./12, 1./12, 0, 91 0, 1./3, -1./3, -1./6, 1./6, 0, 92 0, -1./3, -1./3, 1./3, 1./3, 1 93 }; 94 95 constexpr float BT_4x4_3x3[] = { 96 1./4, 0, -5./16, 0, 1./16, 0, 97 0, 1./4, -1./4, -1./16, 1./16, 0, 98 0, -1./4, -1./4, 1./16, 1./16, 0, 99 0, 1./4, -1./8, -1./4, 1./8, 0, 100 0, -1./4, -1./8, 1./4, 1./8, 0, 101 0, 1./4, 0, -5./16, 0, 1./16 102 }; 103 104 constexpr float B_4x4_3x3[] = { 105 1./4, 0, 0, 0, 0, 0, 106 0, 1./4, -1./4, 1./4, -1./4, 1./4, 107 -5./16, -1./4, -1./4, -1./8, -1./8, 0, 108 0, -1./16, 1./16, -1./4, 1./4, -5./16, 109 1./16, 1./16, 1./16, 1./8, 1./8, 0, 110 0, 0, 0, 0, 0, 1./16 111 }; 112 113 constexpr float AT_4x4_3x3[] = { 114 1./8, 1./4, 1./4, 1./8, 1./8, 0, 115 0, -1./4, 1./4, -1./4, 1./4, 0, 116 0, 1./4, 1./4, 1./2, 1./2, 0, 117 0, -1./4, 1./4, -1, 1, 1./2 118 }; 119 120 constexpr float A_4x4_3x3[] = { 121 1./8, 0, 0, 0, 122 1./4, -1./4, 1./4, -1./4, 123 1./4, 1./4, 1./4, 1./4, 124 1./8, -1./4, 1./2, -1, 125 1./8, 1./4, 1./2, 1, 126 0, 0, 0, 1./2 127 }; 128 129 constexpr float G_2x2_5x5[] = { 130 1, 0, 0, 0, 0, 131 1./6, -1./6, 1./6, -1./6, 1./6, 132 -1./6, -1./6, -1./6, -1./6, -1./6, 133 -4./15, 2./15, -1./15, 1./30, -1./60, 134 1./60, 1./30, 1./15, 2./15, 4./15, 135 0, 0, 0, 0, 1 136 }; 137 138 constexpr float GT_2x2_5x5[] = { 139 1, 1./6, -1./6, -4./15, 1./60, 0, 140 0, -1./6, -1./6, 2./15, 1./30, 0, 141 0, 1./6, -1./6, -1./15, 1./15, 0, 142 0, -1./6, -1./6, 1./30, 2./15, 0, 143 0, 1./6, -1./6, -1./60, 4./15, 1 144 }; 145 146 constexpr float BT_2x2_5x5[] = { 147 1./8, 3./16, -1./4, -3./16, 1./8, 0, 148 0, 1./8, 1./16, -5./16, 1./8, 0, 149 0, -1./8, -5./16, -1./16, 1./8, 0, 150 0, 1./4, -1./8, -1./4, 1./8, 0, 151 0, -1./8, -1./4, 1./8, 1./4, 0, 152 0, 1./8, 3./16, -1./4, -3./16, 1./8 153 }; 154 155 constexpr float B_2x2_5x5[] = { 156 1./8, 0, 0, 0, 0, 0, 157 3./16, 1./8, -1./8, 1./4, -1./8, 1./8, 158 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16, 159 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4, 160 1./8, 1./8, 1./8, 1./8, 1./4, -3./16, 161 0, 0, 0, 0, 0, 1./8 162 }; 163 164 constexpr float AT_2x2_5x5[] = { 165 1./2, 1, 1, 2, 1, 0, 166 0, -1, 1, -1, 2, 1./2 167 }; 168 169 constexpr float A_2x2_5x5[] = { 170 1./2, 0, 171 1, -1, 172 1, 1, 173 2, -1, 174 1, 2, 175 0, 1./2 176 }; 177 // clang-format on 178 179 using TransformMapKeyTy = std::pair<int, int>; 180 181 /// We use F(m, r) to define the size of minimal filtering algorithms. 182 /// m is the output dimension and r is the filter dimension. We can get 183 /// the input dimension, alpha, from the formula, alpha = m + r - 1. 184 /// 185 /// For example, when m = 2 and r = 3, we know its input size is 4. 186 /// The Conv2D will operate on 4x4 input data with 3x3 filter and get 187 /// 2x2 output result. 188 constexpr TransformMapKeyTy F_2_3{2, 3}; 189 constexpr TransformMapKeyTy F_4_3{4, 3}; 190 constexpr TransformMapKeyTy F_2_5{2, 5}; 191 192 /// Structure to keep information of constant transform matrices. 193 struct TransformMatrix { 194 TransformMatrix(const float *table, int64_t rows, int64_t cols, 195 int64_t scalarFactor = 1) 196 : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} 197 198 const float *table; 199 int64_t rows; 200 int64_t cols; 201 int64_t scalarFactor; 202 }; 203 204 /// Utility function to convert constant array to arith.constant Value. 205 Value create2DTransformMatrix(OpBuilder &builder, Location loc, 206 TransformMatrix transform, Type type) { 207 ArrayRef<float> constVec(transform.table, transform.rows * transform.cols); 208 209 return builder.create<arith::ConstantOp>( 210 loc, DenseFPElementsAttr::get( 211 RankedTensorType::get( 212 SmallVector<int64_t>{transform.rows, transform.cols}, type), 213 constVec)); 214 } 215 216 /// Extract height x width data from 4D tensors. 217 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source, 218 Value loopNorFIndex, Value loopCorFIndex, 219 Value heightOffset, Value widthOffset, 220 int64_t extractHeight, int64_t extractWidth, 221 int64_t loopNorFIdx, int64_t loopCorFIdx, 222 int64_t heightIdx, int64_t widthIdx) { 223 auto sourceType = cast<ShapedType>(source.getType()); 224 Type elementType = sourceType.getElementType(); 225 int64_t srcSize = sourceType.getRank(); 226 227 auto oneIndex = builder.getIndexAttr(1); 228 SmallVector<OpFoldResult> offsets; 229 offsets.resize(srcSize); 230 offsets[loopNorFIdx] = loopNorFIndex; 231 offsets[loopCorFIdx] = loopCorFIndex; 232 offsets[heightIdx] = heightOffset; 233 offsets[widthIdx] = widthOffset; 234 SmallVector<OpFoldResult> sizes(srcSize, oneIndex); 235 sizes[heightIdx] = builder.getIndexAttr(extractHeight); 236 sizes[widthIdx] = builder.getIndexAttr(extractWidth); 237 SmallVector<OpFoldResult> strides(srcSize, oneIndex); 238 239 auto extractFilterType = 240 RankedTensorType::get({extractHeight, extractWidth}, elementType); 241 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>( 242 loc, extractFilterType, source, offsets, sizes, strides); 243 244 return extractFilterOp; 245 } 246 247 /// Extract height x width data from 6D tensors. 248 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source, 249 Value tileHIndex, Value tileWIndex, 250 Value loopNorFIndex, Value loopCorFIndex, 251 int64_t tileHIdx, int64_t tileWIdx, 252 int64_t loopNorFIdx, int64_t loopCorFIdx, 253 int64_t heightIdx, int64_t widthIdx) { 254 auto sourceType = cast<ShapedType>(source.getType()); 255 Type elementType = sourceType.getElementType(); 256 auto sourceShape = sourceType.getShape(); 257 int64_t srcSize = sourceType.getRank(); 258 int64_t height = sourceShape[heightIdx]; 259 int64_t width = sourceShape[widthIdx]; 260 261 auto zeroIndex = builder.getIndexAttr(0); 262 auto oneIndex = builder.getIndexAttr(1); 263 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex); 264 offsets.resize(srcSize); 265 offsets[tileHIdx] = tileHIndex; 266 offsets[tileWIdx] = tileWIndex; 267 offsets[loopNorFIdx] = loopNorFIndex; 268 offsets[loopCorFIdx] = loopCorFIndex; 269 SmallVector<OpFoldResult> sizes(srcSize, oneIndex); 270 sizes[heightIdx] = builder.getIndexAttr(height); 271 sizes[widthIdx] = builder.getIndexAttr(width); 272 SmallVector<OpFoldResult> strides(srcSize, oneIndex); 273 274 auto extractFilterType = RankedTensorType::get({height, width}, elementType); 275 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>( 276 loc, extractFilterType, source, offsets, sizes, strides); 277 278 return extractFilterOp; 279 } 280 281 /// Insert transformed height x width data to 4D tensors which it is 282 /// extracted from. 283 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, 284 Value dest, Value loopNorFIndex, Value loopCorFIndex, 285 Value heightOffset, Value widthOffset, int64_t height, 286 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx, 287 int64_t heightIdx, int64_t widthIdx) { 288 int64_t destSize = cast<ShapedType>(dest.getType()).getRank(); 289 auto oneIndex = builder.getIndexAttr(1); 290 SmallVector<OpFoldResult> retOffsets; 291 retOffsets.resize(destSize); 292 retOffsets[loopNorFIdx] = loopNorFIndex; 293 retOffsets[loopCorFIdx] = loopCorFIndex; 294 retOffsets[heightIdx] = heightOffset; 295 retOffsets[widthIdx] = widthOffset; 296 SmallVector<OpFoldResult> retSizes(destSize, oneIndex); 297 retSizes[heightIdx] = builder.getIndexAttr(height); 298 retSizes[widthIdx] = builder.getIndexAttr(width); 299 SmallVector<OpFoldResult> strides(destSize, oneIndex); 300 301 auto insertSliceOp = builder.create<tensor::InsertSliceOp>( 302 loc, source, dest, retOffsets, retSizes, strides); 303 304 return insertSliceOp; 305 } 306 307 /// Insert transformed height x width data to 6D tensors which it is 308 /// extracted from. 309 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, 310 Value dest, Value tileHIndex, Value tileWIndex, 311 Value loopNorFIndex, Value loopCorFIndex, int64_t height, 312 int64_t width, int64_t tileHIdx, int64_t tileWIdx, 313 int64_t loopNorFIdx, int64_t loopCorFIdx, 314 int64_t heightIdx, int64_t widthIdx) { 315 int64_t destSize = cast<ShapedType>(dest.getType()).getRank(); 316 auto zeroIndex = builder.getIndexAttr(0); 317 auto oneIndex = builder.getIndexAttr(1); 318 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex); 319 retOffsets.resize(destSize); 320 retOffsets[tileHIdx] = tileHIndex; 321 retOffsets[tileWIdx] = tileWIndex; 322 retOffsets[loopNorFIdx] = loopNorFIndex; 323 retOffsets[loopCorFIdx] = loopCorFIndex; 324 SmallVector<OpFoldResult> retSizes(destSize, oneIndex); 325 retSizes[heightIdx] = builder.getIndexAttr(height); 326 retSizes[widthIdx] = builder.getIndexAttr(width); 327 SmallVector<OpFoldResult> strides(destSize, oneIndex); 328 329 auto insertSliceOp = builder.create<tensor::InsertSliceOp>( 330 loc, source, dest, retOffsets, retSizes, strides); 331 332 return insertSliceOp; 333 } 334 335 /// This function transforms the filter. The data layout of the filter is FHWC. 336 /// The transformation matrix is 2-dimension. We need to extract H x W from 337 /// FHWC first. We need to generate 2 levels of loops to iterate on F and C. 338 /// After the transformation, we get 339 /// 340 /// scf.for %f = lo_f to hi_f step 1 341 /// scf.for %c = lo_c to hi_c step 1 342 /// %extracted = extract filter<h x w> from filter<f x h x w x c> 343 /// %ret = linalg.matmul G, %extracted 344 /// %ret = linalg.matmul %ret, GT 345 /// %inserted = insert %ret into filter<h x w x c x f> 346 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, 347 Value retValue, int64_t m, int64_t r, 348 bool leftTransform = true, bool rightTransform = true) { 349 // Map from (m, r) to G transform matrix. 350 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 351 GMatrices = { 352 {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, 353 {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, 354 {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, 355 }; 356 357 // Map from (m, r) to GT transform matrix. 358 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 359 GTMatrices = { 360 {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, 361 {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, 362 {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, 363 }; 364 365 auto filterType = cast<ShapedType>(filter.getType()); 366 Type elementType = filterType.getElementType(); 367 auto filterShape = filterType.getShape(); // F, H, W, C 368 int64_t filterF = filterShape[0]; 369 int64_t filterH = filterShape[1]; 370 int64_t filterW = filterShape[2]; 371 int64_t filterC = filterShape[3]; 372 373 if (filterH != r && filterH != 1) 374 return Value(); 375 if (filterW != r && filterW != 1) 376 return Value(); 377 378 Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); 379 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, 380 ValueRange args) -> scf::ValueVector { 381 Value FIter = ivs[0]; 382 Value CIter = ivs[1]; 383 384 // Extract (H, W) from (F, H, W, C). 385 auto extractFilter = 386 extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx, 387 zeroIdx, filterH, filterW, /*loopNorFIdx=*/0, 388 /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); 389 390 TransformMapKeyTy key = {m, r}; 391 int64_t retRows = 1; 392 Value matmulRetValue = extractFilter; 393 Value zero = builder.create<arith::ConstantOp>( 394 loc, rewriter.getZeroAttr(elementType)); 395 if (leftTransform) { 396 // Get constant transform matrix G. 397 auto it = GMatrices.find(key); 398 if (it == GMatrices.end()) 399 return {}; 400 const TransformMatrix &GMatrix = it->second; 401 402 retRows = GMatrix.rows; 403 auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); 404 auto empty = 405 builder 406 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) 407 .getResult(); 408 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 409 410 Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType); 411 // Multiply G x g. 412 auto matmulOp = builder.create<linalg::MatmulOp>( 413 loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init}); 414 matmulRetValue = matmulOp.getResult(0); 415 } 416 417 if (rightTransform) { 418 // Get constant transform matrix GT. 419 auto it = GTMatrices.find(key); 420 if (it == GTMatrices.end()) 421 return {}; 422 const TransformMatrix >Matrix = it->second; 423 424 auto matmulType = 425 RankedTensorType::get({retRows, GTMatrix.cols}, elementType); 426 auto empty = 427 builder 428 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) 429 .getResult(); 430 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 431 432 Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType); 433 // Multiply u = (G x g) x GT. 434 auto matmulOp = builder.create<linalg::MatmulOp>( 435 loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init}); 436 matmulRetValue = matmulOp.getResult(0); 437 } 438 439 // Insert (H, W) to (H, W, C, F). 440 int64_t retHeight = leftTransform ? m + r - 1 : 1; 441 int64_t retWidth = rightTransform ? m + r - 1 : 1; 442 443 auto insertSliceOp = 444 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter, 445 zeroIdx, zeroIdx, retHeight, retWidth, 446 /*loopNorFIdx=*/3, /*loopCorFIdx=*/2, 447 /*heightIdx=*/0, /*widthIdx=*/1); 448 449 return {insertSliceOp}; 450 }; 451 452 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF); 453 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC); 454 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); 455 scf::LoopNest loops = scf::buildLoopNest( 456 rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound}, 457 {oneStep, oneStep}, {retValue}, buildBody); 458 return loops.results[0]; 459 } 460 461 /// This function transforms the input. The data layout of the input is NHWC. 462 /// The transformation matrix is 2-dimension. We need to extract H x W from 463 /// NHWC first. We need to generate 2 levels of loops to iterate on N and C. 464 /// After the transformation, we get 465 /// 466 /// scf.for %h = 0 to tileH step 1 467 /// scf.for %w = 0 to tileW step 1 468 /// scf.for %n = 0 to N step 1 469 /// scf.for %c = 0 to C step 1 470 /// %extracted = extract %extracted<alphaH x alphaW> from 471 /// %input<N x H x W x C> 472 /// at [%n, (%h x m), (%w x m), %c] 473 /// %ret = linalg.matmul BT, %extracted 474 /// %ret = linalg.matmul %ret, B 475 /// %inserted = insert %ret<alphaH x alphaW> into 476 /// %output<alphaH x alphaW x tileH x tileW x N x C> 477 /// at [0, 0, %h, %w, %n, %c] 478 Value inputTransform(RewriterBase &rewriter, Location loc, Value input, 479 Value retValue, int64_t m, int64_t r, 480 bool leftTransform = true, bool rightTransform = true) { 481 // Map from (m, r) to BT transform matrix. 482 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 483 BTMatrices = { 484 {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)}, 485 {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)}, 486 {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)}, 487 }; 488 489 // Map from (m, r) to B transform matrix. 490 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 491 BMatrices = { 492 {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)}, 493 {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)}, 494 {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)}, 495 }; 496 497 auto inputType = cast<ShapedType>(input.getType()); 498 Type elementType = inputType.getElementType(); 499 auto inputShape = inputType.getShape(); // N, H, W, C 500 int64_t inputN = inputShape[0]; 501 int64_t inputC = inputShape[3]; 502 auto valueType = cast<ShapedType>(retValue.getType()); 503 auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C 504 int64_t tileH = valueShape[2]; 505 int64_t tileW = valueShape[3]; 506 int64_t alphaH = leftTransform ? m + r - 1 : 1; 507 int64_t alphaW = rightTransform ? m + r - 1 : 1; 508 509 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, 510 ValueRange args) -> scf::ValueVector { 511 Value tileHIter = ivs[0]; 512 Value tileWIter = ivs[1]; 513 Value NIter = ivs[2]; 514 Value CIter = ivs[3]; 515 516 auto context = builder.getContext(); 517 518 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1); 519 auto affineMap = 520 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); 521 Value heightOffset = builder.create<affine::AffineApplyOp>( 522 loc, leftTransform ? affineMap : identityAffineMap, tileHIter); 523 Value widthOffset = builder.create<affine::AffineApplyOp>( 524 loc, rightTransform ? affineMap : identityAffineMap, tileWIter); 525 526 // Extract (H, W) from (N, H, W, C). 527 auto extractInput = 528 extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset, 529 widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0, 530 /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); 531 532 TransformMapKeyTy key = {m, r}; 533 int64_t retRows = 1; 534 int64_t retCols = 1; 535 Value matmulRetValue = extractInput; 536 Value zero = builder.create<arith::ConstantOp>( 537 loc, rewriter.getZeroAttr(elementType)); 538 if (leftTransform) { 539 // Get constant transform matrix BT. 540 auto it = BTMatrices.find(key); 541 if (it == BTMatrices.end()) 542 return {}; 543 const TransformMatrix &BTMatrix = it->second; 544 545 retRows = BTMatrix.rows; 546 auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType); 547 auto empty = 548 builder 549 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) 550 .getResult(); 551 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 552 553 Value BT = 554 create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); 555 // Multiply BT x d. 556 auto matmulOp = builder.create<linalg::MatmulOp>( 557 loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init}); 558 matmulRetValue = matmulOp.getResult(0); 559 } 560 561 if (rightTransform) { 562 // Get constant transform matrix B. 563 auto it = BMatrices.find(key); 564 if (it == BMatrices.end()) 565 return {}; 566 const TransformMatrix &BMatrix = it->second; 567 568 retCols = BMatrix.cols; 569 auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); 570 auto empty = 571 builder 572 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) 573 .getResult(); 574 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 575 Value B = 576 create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); 577 // Multiply v = (BT x d) x B. 578 auto matmulOp = builder.create<linalg::MatmulOp>( 579 loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init}); 580 matmulRetValue = matmulOp.getResult(0); 581 } 582 583 // Insert (H, W) to (H, W, tileH, tileW, N, C). 584 auto combinedVal = insert2DDataTo6D( 585 builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter, 586 CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5, 587 /*heightIdx=*/0, /*widthIdx=*/1); 588 589 return {combinedVal}; 590 }; 591 592 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); 593 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH); 594 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW); 595 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN); 596 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC); 597 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); 598 scf::LoopNest loops = scf::buildLoopNest( 599 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, 600 {tileHBound, tileWBound, nUpperBound, cUpperBound}, 601 {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody); 602 return loops.results[0]; 603 } 604 605 /// This function generates linalg.batch_matmul to multiply input with filter. 606 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat 607 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to 608 /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this 609 /// way, we can convert 6-dimensional inputs to 3-dimensional representation 610 /// that is suitable for linalg.batch_matmul. 611 /// 612 /// Batched matmul will do the matrix multiply with the reduction on channel. 613 /// 614 /// We get 615 /// 616 /// %collapsed_input = tensor.collapse_shape %input 617 /// %collapsed_filter = tensor.collapse_shape %filter 618 /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter 619 /// %expanded_ret = tensor.expand_shape %ret 620 /// 621 /// After this function, we get return value with data layout 622 /// (tileH, tileW, H, W, N, F). 623 static Value matrixMultiply(RewriterBase &rewriter, Location loc, 624 Value transformedFilter, Value transformedInput, 625 Type outputElementType) { 626 // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter. 627 auto filterType = cast<ShapedType>(transformedFilter.getType()); 628 assert(filterType.hasStaticShape() && "only support static shapes."); 629 ArrayRef<int64_t> filterShape = filterType.getShape(); 630 Type filterElementType = filterType.getElementType(); 631 auto filterReassocType = RankedTensorType::get( 632 {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, 633 filterElementType); 634 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}}; 635 Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>( 636 loc, filterReassocType, transformedFilter, filterReassoc); 637 638 // Convert (alphaH, alphaW, tileH, tileW, N, C) to 639 // (alphaH x alphaW, tileH x tileW x N, C) for input. 640 auto inputType = cast<ShapedType>(transformedInput.getType()); 641 assert(inputType.hasStaticShape() && "only support static shapes."); 642 ArrayRef<int64_t> inputShape = inputType.getShape(); 643 Type inputElementType = inputType.getElementType(); 644 auto inputReassocType = RankedTensorType::get( 645 {inputShape[0] * inputShape[1], 646 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, 647 inputElementType); 648 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; 649 Value collapseInput = rewriter.create<tensor::CollapseShapeOp>( 650 loc, inputReassocType, transformedInput, inputReassoc); 651 652 // Batched matrix multiply. 653 auto matmulType = RankedTensorType::get( 654 {inputShape[0] * inputShape[1], 655 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, 656 outputElementType); 657 Value empty = rewriter 658 .create<tensor::EmptyOp>(loc, matmulType.getShape(), 659 outputElementType) 660 .getResult(); 661 Value zero = rewriter.create<arith::ConstantOp>( 662 loc, rewriter.getZeroAttr(outputElementType)); 663 Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0); 664 665 auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( 666 loc, matmulType, ValueRange({collapseInput, collapseFilter}), 667 ValueRange{init}); 668 669 // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F) 670 // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F). 671 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}}; 672 auto outputReassocType = 673 RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], 674 inputShape[3], inputShape[4], filterShape[3]}, 675 outputElementType); 676 auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( 677 loc, outputReassocType, matmulOp.getResult(0), outputReassoc); 678 return expandOutput; 679 } 680 681 /// This function transforms the output. The data layout of the output is HWNF. 682 /// The transformation matrix is 2-dimension. We need to extract H x W from 683 /// HWNF first. We need to generate 2 levels of loops to iterate on N and F. 684 /// After the transformation, we get 685 /// 686 /// scf.for %h = 0 to tileH step 1 687 /// scf.for %w = 0 to tileW step 1 688 /// scf.for %n = 0 to N step 1 689 /// scf.for %f = 0 to F step 1 690 /// %extracted = extract %extracted<alphaH x alphaW> from 691 /// %input<alphaH x alphaW x tileH x tileW x N x F> 692 /// at [0, 0, %h, %w, %n, %f] 693 /// %ret = linalg.matmul AT, %extracted 694 /// %ret = linalg.matmul %ret, A 695 /// %inserted = insert %ret<alphaH x alphaW> into 696 /// output<N x H x W x F> 697 /// at [%n, (%h x m), (%w x m), %f] 698 Value outputTransform(RewriterBase &rewriter, Location loc, Value value, 699 Value output, int64_t m, int64_t r, 700 bool leftTransform = true, bool rightTransform = true) { 701 // Map from (m, r) to AT transform matrix. 702 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 703 ATMatrices = { 704 {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, 705 {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, 706 {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, 707 }; 708 709 // Map from (m, r) to A transform matrix. 710 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> 711 AMatrices = { 712 {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, 713 {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, 714 {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, 715 }; 716 717 auto valueType = cast<ShapedType>(value.getType()); 718 Type elementType = valueType.getElementType(); 719 auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F 720 int64_t valueH = valueShape[0]; 721 int64_t valueW = valueShape[1]; 722 int64_t valueN = valueShape[4]; 723 int64_t valueF = valueShape[5]; 724 int64_t alphaH = leftTransform ? m + r - 1 : 1; 725 int64_t alphaW = rightTransform ? m + r - 1 : 1; 726 727 if (valueH != alphaH && valueH != 1) 728 return Value(); 729 if (valueW != alphaW && valueW != 1) 730 return Value(); 731 732 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, 733 ValueRange args) -> scf::ValueVector { 734 auto context = builder.getContext(); 735 Value tileHIter = ivs[0]; 736 Value tileWIter = ivs[1]; 737 Value NIter = ivs[2]; 738 Value FIter = ivs[3]; 739 740 // Extract (H, W) from (H, W, tileH, tileW, N, F). 741 auto extractValue = 742 extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter, 743 FIter, 2, 3, /*loopNorFIdx=*/4, 744 /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1); 745 746 const TransformMapKeyTy key = {m, r}; 747 const TransformMatrix &AMatrix = AMatrices.at(key); 748 const TransformMatrix &ATMatrix = ATMatrices.at(key); 749 int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) * 750 (leftTransform ? ATMatrix.scalarFactor : 1); 751 int64_t retCols = rightTransform ? AMatrix.cols : 1; 752 int64_t retRows = leftTransform ? ATMatrix.rows : 1; 753 754 Value matmulRetValue = extractValue; 755 Value zero = builder.create<arith::ConstantOp>( 756 loc, rewriter.getZeroAttr(elementType)); 757 758 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1); 759 auto affineMap = 760 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); 761 Value heightOffset = builder.create<affine::AffineApplyOp>( 762 loc, leftTransform ? affineMap : identityAffineMap, tileHIter); 763 Value widthOffset = builder.create<affine::AffineApplyOp>( 764 loc, rightTransform ? affineMap : identityAffineMap, tileWIter); 765 766 Value outInitVal = 767 extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset, 768 widthOffset, retRows, retCols, 769 /*loopNorFIdx=*/0, 770 /*loopCorFIdx=*/3, /*heightIdx=*/1, 771 /*widthIdx=*/2); 772 if (leftTransform) { 773 auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); 774 Value init = outInitVal; 775 if (rightTransform || scalarFactor != 1) { 776 auto empty = builder 777 .create<tensor::EmptyOp>(loc, matmulType.getShape(), 778 elementType) 779 .getResult(); 780 init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 781 } 782 783 Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType); 784 // Multiply AT x m. 785 auto matmulOp = builder.create<linalg::MatmulOp>( 786 loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); 787 matmulRetValue = matmulOp.getResult(0); 788 } 789 790 if (rightTransform) { 791 auto matmulType = 792 RankedTensorType::get({retRows, AMatrix.cols}, elementType); 793 Value init = outInitVal; 794 if (scalarFactor != 1) { 795 auto empty = builder 796 .create<tensor::EmptyOp>(loc, matmulType.getShape(), 797 elementType) 798 .getResult(); 799 init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); 800 } 801 802 Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType); 803 // Multiply y = (AT x m) x A. 804 auto matmulOp = builder.create<linalg::MatmulOp>( 805 loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); 806 matmulRetValue = matmulOp.getResult(0); 807 } 808 809 if (scalarFactor != 1) { 810 // Multiply by scalar factor and add outInitVal. 811 Value scalarFactorValue = builder.create<arith::ConstantOp>( 812 loc, FloatAttr::get(elementType, scalarFactor)); 813 auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); 814 auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); 815 SmallVector<AffineMap> affineMaps = { 816 AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap}; 817 818 matmulRetValue = 819 rewriter 820 .create<linalg::GenericOp>( 821 loc, matmulType, 822 ValueRange{scalarFactorValue, matmulRetValue}, 823 ValueRange{outInitVal}, affineMaps, 824 llvm::ArrayRef<utils::IteratorType>{ 825 utils::IteratorType::parallel, 826 utils::IteratorType::parallel}, 827 [&](OpBuilder &nestedBuilder, Location nestedLoc, 828 ValueRange args) { 829 auto mulf = nestedBuilder.create<arith::MulFOp>( 830 nestedLoc, args[0], args[1]); 831 auto addf = nestedBuilder.create<arith::AddFOp>( 832 nestedLoc, mulf.getResult(), args[2]); 833 nestedBuilder.create<linalg::YieldOp>(nestedLoc, 834 addf.getResult()); 835 }) 836 .getResult(0); 837 } 838 839 // Insert (H, W) to (N, H, W, F). 840 Value combinedVal = 841 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter, 842 heightOffset, widthOffset, retRows, retCols, 843 /*loopNorFIdx=*/0, 844 /*loopCorFIdx=*/3, /*heightIdx=*/1, 845 /*widthIdx=*/2); 846 847 return {combinedVal}; 848 }; 849 850 int64_t tilwH = valueShape[2]; 851 int64_t tileW = valueShape[3]; 852 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); 853 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH); 854 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW); 855 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN); 856 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF); 857 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); 858 scf::LoopNest loops = scf::buildLoopNest( 859 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, 860 {tileHBound, tileWBound, nUpperBound, fUpperBound}, 861 {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody); 862 return loops.results[0]; 863 } 864 865 /// Create an empty tensor with alignedType and insert the value into the 866 /// created empty tensor with aligned size. 867 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, 868 Value value, ArrayRef<int64_t> alignedShape) { 869 auto valueType = cast<ShapedType>(value.getType()); 870 Type elementType = valueType.getElementType(); 871 auto alignedType = RankedTensorType::get(alignedShape, elementType); 872 Value padValue = rewriter.create<arith::ConstantOp>( 873 loc, elementType, rewriter.getZeroAttr(elementType)); 874 875 return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value, 876 padValue, false); 877 } 878 879 /// Extract sub-tensor with extractedType from value. 880 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, 881 Value value, 882 RankedTensorType extractedType) { 883 OpFoldResult zeroIndex = rewriter.getIndexAttr(0); 884 OpFoldResult oneIndex = rewriter.getIndexAttr(1); 885 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); 886 SmallVector<OpFoldResult, 4> strides(4, oneIndex); 887 888 ArrayRef<int64_t> extractedShape = extractedType.getShape(); 889 SmallVector<OpFoldResult> sizes = 890 getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape)); 891 892 return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value, 893 offsets, sizes, strides); 894 } 895 896 /// Utility function to check all values in the attribute are 1. 897 static bool hasAllOneValues(DenseIntElementsAttr attr) { 898 return llvm::all_of( 899 attr, [](const APInt &element) { return element.getSExtValue() == 1; }); 900 } 901 902 /// A helper function to convert linalg.conv_2d_nhwc_fhwc to 903 /// linalg.winograd_*_transform ops. 904 static FailureOr<Operation *> 905 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, 906 int64_t m, int64_t r) { 907 Value input = convOp.getInputs()[0]; 908 Value filter = convOp.getInputs()[1]; 909 Value output = convOp.getOutputs()[0]; 910 auto inputType = cast<ShapedType>(input.getType()); 911 auto filterType = cast<ShapedType>(filter.getType()); 912 auto outputType = cast<ShapedType>(output.getType()); 913 914 if (!inputType.hasStaticShape()) 915 return rewriter.notifyMatchFailure(convOp, 916 "expected a static shape for the input"); 917 918 if (!filterType.hasStaticShape()) 919 return rewriter.notifyMatchFailure( 920 convOp, "expected a static shape for the filter"); 921 922 if (!hasAllOneValues(convOp.getDilations())) 923 return rewriter.notifyMatchFailure(convOp, 924 "expected all ones for dilations"); 925 926 if (!hasAllOneValues(convOp.getStrides())) 927 return rewriter.notifyMatchFailure(convOp, "expected all ones for strides"); 928 929 ArrayRef<int64_t> filterShape = filterType.getShape(); 930 int64_t filterF = filterShape[0]; 931 int64_t filterH = filterShape[1]; 932 int64_t filterW = filterShape[2]; 933 int64_t filterC = filterShape[3]; 934 ArrayRef<int64_t> inputShape = inputType.getShape(); 935 int64_t inputN = inputShape[0]; 936 int64_t inputH = inputShape[1]; 937 int64_t inputW = inputShape[2]; 938 int64_t inputC = inputShape[3]; 939 ArrayRef<int64_t> outputShape = outputType.getShape(); 940 int64_t outputN = outputShape[0]; 941 int64_t outputH = outputShape[1]; 942 int64_t outputW = outputShape[2]; 943 int64_t outputF = outputShape[3]; 944 945 // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r). 946 bool isSupportedFilter = false; 947 if (filterH == filterW && filterH == r) 948 isSupportedFilter = true; 949 if (filterH == r && filterW == 1) 950 isSupportedFilter = true; 951 if (filterH == 1 && filterW == r) 952 isSupportedFilter = true; 953 954 if (!isSupportedFilter) 955 return rewriter.notifyMatchFailure( 956 convOp, "only support filter (r x r), (r x 1) or (1 x r)"); 957 958 // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5). 959 static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = { 960 F_2_3, F_4_3, F_2_5}; 961 962 TransformMapKeyTy key = {m, r}; 963 auto it = std::find(validConfigs.begin(), validConfigs.end(), key); 964 // If we cannot find the constant transformation matrix, it means we do 965 // not support this configuration yet. 966 if (it == validConfigs.end()) 967 return failure(); 968 969 // All the criterias are satisfied. We can do Winograd Conv2D. 970 Location loc = convOp.getLoc(); 971 972 // For F(m x 1, r x 1), we only need to do left side transform. 973 bool leftTransform = filterH != 1; 974 // For F(1 x m, 1 x r), we only need to do right side transform. 975 bool rightTransform = filterW != 1; 976 int64_t heightM = leftTransform ? m : 1; 977 int64_t widthM = rightTransform ? m : 1; 978 int64_t heightR = leftTransform ? r : 1; 979 int64_t widthR = rightTransform ? r : 1; 980 981 // --- Create operation for filter transform --- 982 Type filterElementType = filterType.getElementType(); 983 int64_t alphaH = heightM + heightR - 1; 984 int64_t alphaW = widthM + widthR - 1; 985 int64_t tileH = llvm::divideCeilSigned(outputH, heightM); 986 int64_t tileW = llvm::divideCeilSigned(outputW, widthM); 987 auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF}, 988 filterElementType); 989 Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), 990 filterElementType); 991 auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>( 992 loc, retType, filter, retValue, m, r); 993 994 // --- Create operation for input transform --- 995 996 // When input size - (r - 1) is not aligned with output tile size, we need to 997 // pad the input data to create the full tiles as tiling. 998 Type inputElementType = inputType.getElementType(); 999 int64_t alignedInputH = tileH * heightM + (heightR - 1); 1000 int64_t alignedInputW = tileW * widthM + (widthR - 1); 1001 if (alignedInputH != inputH || alignedInputW != inputW) { 1002 input = padToAlignedTensor(rewriter, loc, input, 1003 {inputN, alignedInputH, alignedInputW, inputC}); 1004 } 1005 1006 retType = RankedTensorType::get( 1007 {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType); 1008 retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), 1009 inputElementType); 1010 auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>( 1011 loc, retType, input, retValue, m, r); 1012 1013 Type outputElementType = outputType.getElementType(); 1014 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter, 1015 transformedInput, outputElementType); 1016 1017 // --- Create operation for output transform --- 1018 1019 // When output size is not aligned with output tile size, we need to pad the 1020 // output buffer to insert the full tiles after tiling. 1021 int64_t alignedOutputH = tileH * heightM; 1022 int64_t alignedOutputW = tileW * widthM; 1023 bool isOutputUnaligned = 1024 ((alignedOutputH != outputH) || (alignedOutputW != outputW)); 1025 if (isOutputUnaligned) { 1026 auto alignedOutputType = RankedTensorType::get( 1027 {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType); 1028 output = 1029 padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape()); 1030 outputType = alignedOutputType; 1031 } 1032 1033 Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>( 1034 loc, outputType, matmulRet, output, m, r); 1035 1036 // When output size is not aligned with output tile size, extract the 1037 // value from the padded buffer. 1038 if (isOutputUnaligned) { 1039 transformedOutput = extractFromAlignedTensor( 1040 rewriter, loc, transformedOutput, 1041 RankedTensorType::get({outputN, outputH, outputW, outputF}, 1042 outputElementType)); 1043 } 1044 1045 rewriter.replaceOp(convOp, transformedOutput); 1046 1047 return transformedOutput.getDefiningOp(); 1048 } 1049 1050 /// A helper function to decompose linalg.winograd_filter_transform. 1051 FailureOr<Operation *> 1052 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, 1053 linalg::WinogradFilterTransformOp op) { 1054 Location loc = op.getLoc(); 1055 Value filter = op.getFilter(); 1056 auto filterType = cast<ShapedType>(filter.getType()); 1057 auto filterShape = filterType.getShape(); 1058 int64_t filterH = filterShape[1]; 1059 int64_t filterW = filterShape[2]; 1060 1061 // For F(m x 1, r x 1), we only need to do left side transform. 1062 bool leftTransform = filterH != 1; 1063 // For F(1 x m, 1 x r), we only need to do right side transform. 1064 bool rightTransform = filterW != 1; 1065 Value transformedFilter = 1066 filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(), 1067 op.getR(), leftTransform, rightTransform); 1068 if (!transformedFilter) 1069 return failure(); 1070 1071 rewriter.replaceOp(op, transformedFilter); 1072 1073 return transformedFilter.getDefiningOp(); 1074 } 1075 1076 /// A helper function to decompose linalg.winograd_input_transform. 1077 FailureOr<Operation *> 1078 decomposeWinogradInputTransformHelper(RewriterBase &rewriter, 1079 linalg::WinogradInputTransformOp op) { 1080 Location loc = op.getLoc(); 1081 Value output = op.getOutput(); 1082 auto outputType = cast<ShapedType>(output.getType()); 1083 auto outputShape = outputType.getShape(); 1084 1085 int64_t outputH = outputShape[0]; 1086 int64_t outputW = outputShape[1]; 1087 1088 // For F(m x 1, r x 1), we only need to do left side transform. 1089 bool leftTransform = outputH != 1; 1090 // For F(1 x m, 1 x r), we only need to do right side transform. 1091 bool rightTransform = outputW != 1; 1092 Value transformedInput = 1093 inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(), 1094 op.getR(), leftTransform, rightTransform); 1095 if (!transformedInput) 1096 return failure(); 1097 1098 rewriter.replaceOp(op, transformedInput); 1099 1100 return transformedInput.getDefiningOp(); 1101 } 1102 1103 /// A helper function to decompose linalg.winograd_output_transform. 1104 FailureOr<Operation *> 1105 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, 1106 linalg::WinogradOutputTransformOp op) { 1107 Location loc = op.getLoc(); 1108 Value value = op.getValue(); 1109 auto valueType = cast<ShapedType>(value.getType()); 1110 auto valueShape = valueType.getShape(); 1111 int64_t valueH = valueShape[0]; 1112 int64_t valueW = valueShape[1]; 1113 1114 // For F(m x 1, r x 1), we only need to do left side transform. 1115 bool leftTransform = valueH != 1; 1116 // For F(1 x m, 1 x r), we only need to do right side transform. 1117 bool rightTransform = valueW != 1; 1118 Value transformedOutput = 1119 outputTransform(rewriter, loc, value, op.getOutput(), op.getM(), 1120 op.getR(), leftTransform, rightTransform); 1121 if (!transformedOutput) 1122 return failure(); 1123 1124 rewriter.replaceOp(op, transformedOutput); 1125 1126 return transformedOutput.getDefiningOp(); 1127 } 1128 1129 /// A rewrite pattern to decompose linalg.winograd_filter_transform operations. 1130 class DecomposeWinogradFilterTransform final 1131 : public OpRewritePattern<linalg::WinogradFilterTransformOp> { 1132 public: 1133 using OpRewritePattern::OpRewritePattern; 1134 1135 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, 1136 PatternRewriter &rewriter) const override { 1137 return decomposeWinogradFilterTransformHelper(rewriter, op); 1138 } 1139 }; 1140 1141 /// A rewrite pattern to decompose linalg.winograd_input_transform operations. 1142 class DecomposeWinogradInputTransform final 1143 : public OpRewritePattern<linalg::WinogradInputTransformOp> { 1144 public: 1145 using OpRewritePattern::OpRewritePattern; 1146 1147 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, 1148 PatternRewriter &rewriter) const override { 1149 return decomposeWinogradInputTransformHelper(rewriter, op); 1150 } 1151 }; 1152 1153 /// A rewrite pattern to decompose linalg.winograd_output_transform operations. 1154 class DecomposeWinogradOutputTransform final 1155 : public OpRewritePattern<linalg::WinogradOutputTransformOp> { 1156 public: 1157 using OpRewritePattern::OpRewritePattern; 1158 1159 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op, 1160 PatternRewriter &rewriter) const override { 1161 return decomposeWinogradOutputTransformHelper(rewriter, op); 1162 } 1163 }; 1164 1165 /// A rewrite pattern for Winograd Conv2D algorithm. 1166 class WinogradConv2DNhwcFhwc final 1167 : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { 1168 public: 1169 using OpRewritePattern::OpRewritePattern; 1170 WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) 1171 : OpRewritePattern(context), m(m), r(r) {} 1172 1173 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, 1174 PatternRewriter &rewriter) const override { 1175 if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) 1176 return failure(); 1177 1178 return success(); 1179 } 1180 1181 private: 1182 int64_t m; 1183 int64_t r; 1184 }; 1185 } // end anonymous namespace 1186 1187 //===----------------------------------------------------------------------===// 1188 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, 1189 linalg::Conv2DNhwcFhwcOp op, int64_t m, 1190 int64_t r) { 1191 return winogradConv2DHelper(rewriter, op, m, r); 1192 } 1193 1194 FailureOr<Operation *> 1195 decomposeWinogradFilterTransformOp(RewriterBase &rewriter, 1196 linalg::WinogradFilterTransformOp op) { 1197 return decomposeWinogradFilterTransformHelper(rewriter, op); 1198 } 1199 1200 FailureOr<Operation *> 1201 decomposeWinogradInputTransformOp(RewriterBase &rewriter, 1202 linalg::WinogradInputTransformOp op) { 1203 return decomposeWinogradInputTransformHelper(rewriter, op); 1204 } 1205 1206 FailureOr<Operation *> 1207 decomposeWinogradOutputTransformOp(RewriterBase &rewriter, 1208 linalg::WinogradOutputTransformOp op) { 1209 return decomposeWinogradOutputTransformHelper(rewriter, op); 1210 } 1211 1212 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, 1213 int64_t r) { 1214 MLIRContext *context = patterns.getContext(); 1215 // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw 1216 patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r); 1217 } 1218 1219 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { 1220 MLIRContext *context = patterns.getContext(); 1221 patterns 1222 .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform, 1223 DecomposeWinogradOutputTransform>(context); 1224 } 1225 1226 } // end namespace linalg 1227 } // end namespace mlir 1228