xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
14a3d2088SJack Frankland //===- TransposeConv2D.cpp - Convolution transposition  -------------------===//
24a3d2088SJack Frankland //
34a3d2088SJack Frankland // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44a3d2088SJack Frankland // See https://llvm.org/LICENSE.txt for license information.
54a3d2088SJack Frankland // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64a3d2088SJack Frankland //
74a3d2088SJack Frankland //===----------------------------------------------------------------------===//
84a3d2088SJack Frankland 
94a3d2088SJack Frankland #include "mlir/Dialect/Func/IR/FuncOps.h"
104a3d2088SJack Frankland #include "mlir/Dialect/Linalg/IR/Linalg.h"
114a3d2088SJack Frankland #include "mlir/Dialect/MemRef/IR/MemRef.h"
124a3d2088SJack Frankland #include "mlir/Dialect/Tensor/IR/Tensor.h"
134a3d2088SJack Frankland #include "mlir/IR/BuiltinTypes.h"
144a3d2088SJack Frankland #include "mlir/IR/PatternMatch.h"
154a3d2088SJack Frankland #include "mlir/IR/ValueRange.h"
164a3d2088SJack Frankland #include "mlir/Transforms/DialectConversion.h"
174a3d2088SJack Frankland #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
184a3d2088SJack Frankland #include "llvm/ADT/SmallVector.h"
194a3d2088SJack Frankland #include "llvm/Support/ErrorHandling.h"
204a3d2088SJack Frankland #include "llvm/Support/RWMutex.h"
214a3d2088SJack Frankland #include <memory>
224a3d2088SJack Frankland #include <numeric>
234a3d2088SJack Frankland 
244a3d2088SJack Frankland namespace mlir {
254a3d2088SJack Frankland namespace linalg {
264a3d2088SJack Frankland namespace {
274a3d2088SJack Frankland // clang-format off
284a3d2088SJack Frankland /// Convolution converter that applies the following rewrite:
294a3d2088SJack Frankland ///
304a3d2088SJack Frankland /// Before:
314a3d2088SJack Frankland ///
324a3d2088SJack Frankland ///   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
334a3d2088SJack Frankland ///                                               strides = dense<2> : tensor<2xi64>}
344a3d2088SJack Frankland ///      ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
354a3d2088SJack Frankland ///     outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
364a3d2088SJack Frankland ///
374a3d2088SJack Frankland /// After:
384a3d2088SJack Frankland ///
394a3d2088SJack Frankland ///    %cst = arith.constant 0.000000e+00 : f32
404a3d2088SJack Frankland ///    %0 = tensor.empty() : tensor<2x2x6x8xf32>
414a3d2088SJack Frankland ///    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
424a3d2088SJack Frankland ///    %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
434a3d2088SJack Frankland ///                  permutation = [1, 2, 3, 0]
444a3d2088SJack Frankland ///    %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
454a3d2088SJack Frankland ///         ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
464a3d2088SJack Frankland ///         -> tensor<1x2x2x8xf32>
474a3d2088SJack Frankland ///
484a3d2088SJack Frankland /// with an analogous example for the quantized case.
494a3d2088SJack Frankland // clang-format on
504a3d2088SJack Frankland template <typename FHWCConvOp, typename HWCFConvOp>
514a3d2088SJack Frankland FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
524a3d2088SJack Frankland                                              FHWCConvOp op) {
534a3d2088SJack Frankland   // Construct a permutation of the filter tensor dimensions. For a 2D
544a3d2088SJack Frankland   // convolution this will be known statically as [1, 2, 3, 0].
559cbc1f29SHan-Chung Wang   SmallVector<int64_t> filterPerm = {1, 2, 3, 0};
564a3d2088SJack Frankland 
574a3d2088SJack Frankland   // Create the type for the transposed filter tensor.
584a3d2088SJack Frankland   auto filter = op->getOperand(1);
594a3d2088SJack Frankland   auto filterTy = cast<ShapedType>(filter.getType());
604a3d2088SJack Frankland   SmallVector<int64_t> newFilterShape(filterPerm.size());
614a3d2088SJack Frankland   std::generate(std::begin(newFilterShape), std::end(newFilterShape),
624a3d2088SJack Frankland                 [dim = 0, &filterTy, &filterPerm]() mutable {
634a3d2088SJack Frankland                   return filterTy.getShape()[filterPerm[dim++]];
644a3d2088SJack Frankland                 });
654a3d2088SJack Frankland 
664a3d2088SJack Frankland   // Because linalg.transpose expects an "out" parameter we need to pass it a
674a3d2088SJack Frankland   // tensor of zeros of the result type so here we construct that tensor.
684a3d2088SJack Frankland   auto inputType = op->getOperand(0).getType();
694a3d2088SJack Frankland   auto elementTy = cast<ShapedType>(inputType).getElementType();
704a3d2088SJack Frankland   auto loc = op->getLoc();
714a3d2088SJack Frankland 
724a3d2088SJack Frankland   const auto isTensorOp = isa<TensorType>(inputType);
734a3d2088SJack Frankland   Value input;
744a3d2088SJack Frankland   if (isTensorOp) {
754a3d2088SJack Frankland 
764a3d2088SJack Frankland     input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
774a3d2088SJack Frankland                 .getResult();
784a3d2088SJack Frankland   } else {
794a3d2088SJack Frankland     input = rewriter
804a3d2088SJack Frankland                 .create<memref::AllocOp>(
814a3d2088SJack Frankland                     loc, MemRefType::get(newFilterShape, elementTy))
824a3d2088SJack Frankland                 .getResult();
834a3d2088SJack Frankland   }
844a3d2088SJack Frankland 
854a3d2088SJack Frankland   // We can then construct the transposition on our filter.
864a3d2088SJack Frankland   auto transpose =
874a3d2088SJack Frankland       rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
884a3d2088SJack Frankland 
894a3d2088SJack Frankland   Value newFilter;
904a3d2088SJack Frankland   if (isTensorOp) {
914a3d2088SJack Frankland     newFilter = transpose.getResult()[0];
924a3d2088SJack Frankland   } else {
934a3d2088SJack Frankland     newFilter = input;
944a3d2088SJack Frankland   }
954a3d2088SJack Frankland 
964a3d2088SJack Frankland   SmallVector<Value> newInputs{op.getInputs()};
974a3d2088SJack Frankland   // The filter is always the second input argument, the other inputs can be
984a3d2088SJack Frankland   // left as they are.
994a3d2088SJack Frankland   newInputs[1] = newFilter;
1004a3d2088SJack Frankland   // It is possible the convolution doesn't define any results and its
1014a3d2088SJack Frankland   // out argument is just used instead.
1024a3d2088SJack Frankland   SmallVector<Type> resultTy;
1034a3d2088SJack Frankland   if (op.getNumResults()) {
1044a3d2088SJack Frankland     resultTy.push_back(op->getResult(0).getType());
1054a3d2088SJack Frankland   }
1064a3d2088SJack Frankland   auto newConv =
1074a3d2088SJack Frankland       rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
1084a3d2088SJack Frankland                                   op.getStrides(), op.getDilations());
1094a3d2088SJack Frankland   rewriter.replaceOp(op, newConv);
1104a3d2088SJack Frankland   return newConv.getOperation();
1114a3d2088SJack Frankland }
1124a3d2088SJack Frankland 
1134a3d2088SJack Frankland template <typename FHWCConvOp, typename HWCFConvOp>
1144a3d2088SJack Frankland class ConvConverter : public OpRewritePattern<FHWCConvOp> {
1154a3d2088SJack Frankland public:
1164a3d2088SJack Frankland   using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
1174a3d2088SJack Frankland   LogicalResult matchAndRewrite(FHWCConvOp op,
1184a3d2088SJack Frankland                                 PatternRewriter &rewriter) const final {
1194a3d2088SJack Frankland     if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
1204a3d2088SJack Frankland       return failure();
1214a3d2088SJack Frankland     }
1224a3d2088SJack Frankland     return success();
1234a3d2088SJack Frankland   }
1244a3d2088SJack Frankland };
1254a3d2088SJack Frankland } // namespace
1264a3d2088SJack Frankland 
1274a3d2088SJack Frankland FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1284a3d2088SJack Frankland                                        linalg::Conv2DNhwcFhwcOp op) {
1294a3d2088SJack Frankland 
1304a3d2088SJack Frankland   return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
1314a3d2088SJack Frankland                                linalg::Conv2DNhwcHwcfOp>(rewriter, op);
1324a3d2088SJack Frankland }
1334a3d2088SJack Frankland 
1344a3d2088SJack Frankland FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1354a3d2088SJack Frankland                                        linalg::Conv2DNhwcFhwcQOp op) {
1364a3d2088SJack Frankland 
1374a3d2088SJack Frankland   return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
1384a3d2088SJack Frankland                                linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
1394a3d2088SJack Frankland }
1404a3d2088SJack Frankland 
141*aa295216SJay Foad void populateTransposeConv2DPatterns(RewritePatternSet &patterns) {
1424a3d2088SJack Frankland   MLIRContext *context = patterns.getContext();
1434a3d2088SJack Frankland   patterns.insert<
1444a3d2088SJack Frankland       ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
1454a3d2088SJack Frankland       ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
1464a3d2088SJack Frankland       context);
1474a3d2088SJack Frankland }
1484a3d2088SJack Frankland } // namespace linalg
1494a3d2088SJack Frankland } // namespace mlir
150