xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1 //===- TransposeConv2D.cpp - Convolution transposition  -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/Linalg/IR/Linalg.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/ValueRange.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Support/ErrorHandling.h"
20 #include "llvm/Support/RWMutex.h"
21 #include <memory>
22 #include <numeric>
23 
24 namespace mlir {
25 namespace linalg {
26 namespace {
27 // clang-format off
28 /// Convolution converter that applies the following rewrite:
29 ///
30 /// Before:
31 ///
32 ///   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
33 ///                                               strides = dense<2> : tensor<2xi64>}
34 ///      ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
35 ///     outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
36 ///
37 /// After:
38 ///
39 ///    %cst = arith.constant 0.000000e+00 : f32
40 ///    %0 = tensor.empty() : tensor<2x2x6x8xf32>
41 ///    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
42 ///    %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
43 ///                  permutation = [1, 2, 3, 0]
44 ///    %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
45 ///         ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
46 ///         -> tensor<1x2x2x8xf32>
47 ///
48 /// with an analogous example for the quantized case.
49 // clang-format on
50 template <typename FHWCConvOp, typename HWCFConvOp>
51 FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
52                                              FHWCConvOp op) {
53   // Construct a permutation of the filter tensor dimensions. For a 2D
54   // convolution this will be known statically as [1, 2, 3, 0].
55   SmallVector<int64_t> filterPerm = {1, 2, 3, 0};
56 
57   // Create the type for the transposed filter tensor.
58   auto filter = op->getOperand(1);
59   auto filterTy = cast<ShapedType>(filter.getType());
60   SmallVector<int64_t> newFilterShape(filterPerm.size());
61   std::generate(std::begin(newFilterShape), std::end(newFilterShape),
62                 [dim = 0, &filterTy, &filterPerm]() mutable {
63                   return filterTy.getShape()[filterPerm[dim++]];
64                 });
65 
66   // Because linalg.transpose expects an "out" parameter we need to pass it a
67   // tensor of zeros of the result type so here we construct that tensor.
68   auto inputType = op->getOperand(0).getType();
69   auto elementTy = cast<ShapedType>(inputType).getElementType();
70   auto loc = op->getLoc();
71 
72   const auto isTensorOp = isa<TensorType>(inputType);
73   Value input;
74   if (isTensorOp) {
75 
76     input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
77                 .getResult();
78   } else {
79     input = rewriter
80                 .create<memref::AllocOp>(
81                     loc, MemRefType::get(newFilterShape, elementTy))
82                 .getResult();
83   }
84 
85   // We can then construct the transposition on our filter.
86   auto transpose =
87       rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
88 
89   Value newFilter;
90   if (isTensorOp) {
91     newFilter = transpose.getResult()[0];
92   } else {
93     newFilter = input;
94   }
95 
96   SmallVector<Value> newInputs{op.getInputs()};
97   // The filter is always the second input argument, the other inputs can be
98   // left as they are.
99   newInputs[1] = newFilter;
100   // It is possible the convolution doesn't define any results and its
101   // out argument is just used instead.
102   SmallVector<Type> resultTy;
103   if (op.getNumResults()) {
104     resultTy.push_back(op->getResult(0).getType());
105   }
106   auto newConv =
107       rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
108                                   op.getStrides(), op.getDilations());
109   rewriter.replaceOp(op, newConv);
110   return newConv.getOperation();
111 }
112 
113 template <typename FHWCConvOp, typename HWCFConvOp>
114 class ConvConverter : public OpRewritePattern<FHWCConvOp> {
115 public:
116   using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
117   LogicalResult matchAndRewrite(FHWCConvOp op,
118                                 PatternRewriter &rewriter) const final {
119     if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
120       return failure();
121     }
122     return success();
123   }
124 };
125 } // namespace
126 
127 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
128                                        linalg::Conv2DNhwcFhwcOp op) {
129 
130   return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
131                                linalg::Conv2DNhwcHwcfOp>(rewriter, op);
132 }
133 
134 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
135                                        linalg::Conv2DNhwcFhwcQOp op) {
136 
137   return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
138                                linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
139 }
140 
141 void populateTransposeConv2DPatterns(RewritePatternSet &patterns) {
142   MLIRContext *context = patterns.getContext();
143   patterns.insert<
144       ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
145       ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
146       context);
147 }
148 } // namespace linalg
149 } // namespace mlir
150