14a3d2088SJack Frankland //===- TransposeConv2D.cpp - Convolution transposition -------------------===// 24a3d2088SJack Frankland // 34a3d2088SJack Frankland // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44a3d2088SJack Frankland // See https://llvm.org/LICENSE.txt for license information. 54a3d2088SJack Frankland // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64a3d2088SJack Frankland // 74a3d2088SJack Frankland //===----------------------------------------------------------------------===// 84a3d2088SJack Frankland 94a3d2088SJack Frankland #include "mlir/Dialect/Func/IR/FuncOps.h" 104a3d2088SJack Frankland #include "mlir/Dialect/Linalg/IR/Linalg.h" 114a3d2088SJack Frankland #include "mlir/Dialect/MemRef/IR/MemRef.h" 124a3d2088SJack Frankland #include "mlir/Dialect/Tensor/IR/Tensor.h" 134a3d2088SJack Frankland #include "mlir/IR/BuiltinTypes.h" 144a3d2088SJack Frankland #include "mlir/IR/PatternMatch.h" 154a3d2088SJack Frankland #include "mlir/IR/ValueRange.h" 164a3d2088SJack Frankland #include "mlir/Transforms/DialectConversion.h" 174a3d2088SJack Frankland #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 184a3d2088SJack Frankland #include "llvm/ADT/SmallVector.h" 194a3d2088SJack Frankland #include "llvm/Support/ErrorHandling.h" 204a3d2088SJack Frankland #include "llvm/Support/RWMutex.h" 214a3d2088SJack Frankland #include <memory> 224a3d2088SJack Frankland #include <numeric> 234a3d2088SJack Frankland 244a3d2088SJack Frankland namespace mlir { 254a3d2088SJack Frankland namespace linalg { 264a3d2088SJack Frankland namespace { 274a3d2088SJack Frankland // clang-format off 284a3d2088SJack Frankland /// Convolution converter that applies the following rewrite: 294a3d2088SJack Frankland /// 304a3d2088SJack Frankland /// Before: 314a3d2088SJack Frankland /// 324a3d2088SJack Frankland /// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 334a3d2088SJack Frankland /// strides = dense<2> : tensor<2xi64>} 344a3d2088SJack Frankland /// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) 354a3d2088SJack Frankland /// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> 364a3d2088SJack Frankland /// 374a3d2088SJack Frankland /// After: 384a3d2088SJack Frankland /// 394a3d2088SJack Frankland /// %cst = arith.constant 0.000000e+00 : f32 404a3d2088SJack Frankland /// %0 = tensor.empty() : tensor<2x2x6x8xf32> 414a3d2088SJack Frankland /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32> 424a3d2088SJack Frankland /// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>) 434a3d2088SJack Frankland /// permutation = [1, 2, 3, 0] 444a3d2088SJack Frankland /// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} 454a3d2088SJack Frankland /// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>) 464a3d2088SJack Frankland /// -> tensor<1x2x2x8xf32> 474a3d2088SJack Frankland /// 484a3d2088SJack Frankland /// with an analogous example for the quantized case. 494a3d2088SJack Frankland // clang-format on 504a3d2088SJack Frankland template <typename FHWCConvOp, typename HWCFConvOp> 514a3d2088SJack Frankland FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, 524a3d2088SJack Frankland FHWCConvOp op) { 534a3d2088SJack Frankland // Construct a permutation of the filter tensor dimensions. For a 2D 544a3d2088SJack Frankland // convolution this will be known statically as [1, 2, 3, 0]. 559cbc1f29SHan-Chung Wang SmallVector<int64_t> filterPerm = {1, 2, 3, 0}; 564a3d2088SJack Frankland 574a3d2088SJack Frankland // Create the type for the transposed filter tensor. 584a3d2088SJack Frankland auto filter = op->getOperand(1); 594a3d2088SJack Frankland auto filterTy = cast<ShapedType>(filter.getType()); 604a3d2088SJack Frankland SmallVector<int64_t> newFilterShape(filterPerm.size()); 614a3d2088SJack Frankland std::generate(std::begin(newFilterShape), std::end(newFilterShape), 624a3d2088SJack Frankland [dim = 0, &filterTy, &filterPerm]() mutable { 634a3d2088SJack Frankland return filterTy.getShape()[filterPerm[dim++]]; 644a3d2088SJack Frankland }); 654a3d2088SJack Frankland 664a3d2088SJack Frankland // Because linalg.transpose expects an "out" parameter we need to pass it a 674a3d2088SJack Frankland // tensor of zeros of the result type so here we construct that tensor. 684a3d2088SJack Frankland auto inputType = op->getOperand(0).getType(); 694a3d2088SJack Frankland auto elementTy = cast<ShapedType>(inputType).getElementType(); 704a3d2088SJack Frankland auto loc = op->getLoc(); 714a3d2088SJack Frankland 724a3d2088SJack Frankland const auto isTensorOp = isa<TensorType>(inputType); 734a3d2088SJack Frankland Value input; 744a3d2088SJack Frankland if (isTensorOp) { 754a3d2088SJack Frankland 764a3d2088SJack Frankland input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy) 774a3d2088SJack Frankland .getResult(); 784a3d2088SJack Frankland } else { 794a3d2088SJack Frankland input = rewriter 804a3d2088SJack Frankland .create<memref::AllocOp>( 814a3d2088SJack Frankland loc, MemRefType::get(newFilterShape, elementTy)) 824a3d2088SJack Frankland .getResult(); 834a3d2088SJack Frankland } 844a3d2088SJack Frankland 854a3d2088SJack Frankland // We can then construct the transposition on our filter. 864a3d2088SJack Frankland auto transpose = 874a3d2088SJack Frankland rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm); 884a3d2088SJack Frankland 894a3d2088SJack Frankland Value newFilter; 904a3d2088SJack Frankland if (isTensorOp) { 914a3d2088SJack Frankland newFilter = transpose.getResult()[0]; 924a3d2088SJack Frankland } else { 934a3d2088SJack Frankland newFilter = input; 944a3d2088SJack Frankland } 954a3d2088SJack Frankland 964a3d2088SJack Frankland SmallVector<Value> newInputs{op.getInputs()}; 974a3d2088SJack Frankland // The filter is always the second input argument, the other inputs can be 984a3d2088SJack Frankland // left as they are. 994a3d2088SJack Frankland newInputs[1] = newFilter; 1004a3d2088SJack Frankland // It is possible the convolution doesn't define any results and its 1014a3d2088SJack Frankland // out argument is just used instead. 1024a3d2088SJack Frankland SmallVector<Type> resultTy; 1034a3d2088SJack Frankland if (op.getNumResults()) { 1044a3d2088SJack Frankland resultTy.push_back(op->getResult(0).getType()); 1054a3d2088SJack Frankland } 1064a3d2088SJack Frankland auto newConv = 1074a3d2088SJack Frankland rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(), 1084a3d2088SJack Frankland op.getStrides(), op.getDilations()); 1094a3d2088SJack Frankland rewriter.replaceOp(op, newConv); 1104a3d2088SJack Frankland return newConv.getOperation(); 1114a3d2088SJack Frankland } 1124a3d2088SJack Frankland 1134a3d2088SJack Frankland template <typename FHWCConvOp, typename HWCFConvOp> 1144a3d2088SJack Frankland class ConvConverter : public OpRewritePattern<FHWCConvOp> { 1154a3d2088SJack Frankland public: 1164a3d2088SJack Frankland using OpRewritePattern<FHWCConvOp>::OpRewritePattern; 1174a3d2088SJack Frankland LogicalResult matchAndRewrite(FHWCConvOp op, 1184a3d2088SJack Frankland PatternRewriter &rewriter) const final { 1194a3d2088SJack Frankland if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) { 1204a3d2088SJack Frankland return failure(); 1214a3d2088SJack Frankland } 1224a3d2088SJack Frankland return success(); 1234a3d2088SJack Frankland } 1244a3d2088SJack Frankland }; 1254a3d2088SJack Frankland } // namespace 1264a3d2088SJack Frankland 1274a3d2088SJack Frankland FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, 1284a3d2088SJack Frankland linalg::Conv2DNhwcFhwcOp op) { 1294a3d2088SJack Frankland 1304a3d2088SJack Frankland return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp, 1314a3d2088SJack Frankland linalg::Conv2DNhwcHwcfOp>(rewriter, op); 1324a3d2088SJack Frankland } 1334a3d2088SJack Frankland 1344a3d2088SJack Frankland FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, 1354a3d2088SJack Frankland linalg::Conv2DNhwcFhwcQOp op) { 1364a3d2088SJack Frankland 1374a3d2088SJack Frankland return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp, 1384a3d2088SJack Frankland linalg::Conv2DNhwcHwcfQOp>(rewriter, op); 1394a3d2088SJack Frankland } 1404a3d2088SJack Frankland 141*aa295216SJay Foad void populateTransposeConv2DPatterns(RewritePatternSet &patterns) { 1424a3d2088SJack Frankland MLIRContext *context = patterns.getContext(); 1434a3d2088SJack Frankland patterns.insert< 1444a3d2088SJack Frankland ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>, 1454a3d2088SJack Frankland ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>( 1464a3d2088SJack Frankland context); 1474a3d2088SJack Frankland } 1484a3d2088SJack Frankland } // namespace linalg 1494a3d2088SJack Frankland } // namespace mlir 150