1 //===- TransposeConv2D.cpp - Convolution transposition -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 #include "mlir/Dialect/Linalg/IR/Linalg.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/ValueRange.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/Support/ErrorHandling.h" 20 #include "llvm/Support/RWMutex.h" 21 #include <memory> 22 #include <numeric> 23 24 namespace mlir { 25 namespace linalg { 26 namespace { 27 // clang-format off 28 /// Convolution converter that applies the following rewrite: 29 /// 30 /// Before: 31 /// 32 /// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 33 /// strides = dense<2> : tensor<2xi64>} 34 /// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) 35 /// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> 36 /// 37 /// After: 38 /// 39 /// %cst = arith.constant 0.000000e+00 : f32 40 /// %0 = tensor.empty() : tensor<2x2x6x8xf32> 41 /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32> 42 /// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>) 43 /// permutation = [1, 2, 3, 0] 44 /// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} 45 /// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>) 46 /// -> tensor<1x2x2x8xf32> 47 /// 48 /// with an analogous example for the quantized case. 49 // clang-format on 50 template <typename FHWCConvOp, typename HWCFConvOp> 51 FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, 52 FHWCConvOp op) { 53 // Construct a permutation of the filter tensor dimensions. For a 2D 54 // convolution this will be known statically as [1, 2, 3, 0]. 55 SmallVector<int64_t> filterPerm = {1, 2, 3, 0}; 56 57 // Create the type for the transposed filter tensor. 58 auto filter = op->getOperand(1); 59 auto filterTy = cast<ShapedType>(filter.getType()); 60 SmallVector<int64_t> newFilterShape(filterPerm.size()); 61 std::generate(std::begin(newFilterShape), std::end(newFilterShape), 62 [dim = 0, &filterTy, &filterPerm]() mutable { 63 return filterTy.getShape()[filterPerm[dim++]]; 64 }); 65 66 // Because linalg.transpose expects an "out" parameter we need to pass it a 67 // tensor of zeros of the result type so here we construct that tensor. 68 auto inputType = op->getOperand(0).getType(); 69 auto elementTy = cast<ShapedType>(inputType).getElementType(); 70 auto loc = op->getLoc(); 71 72 const auto isTensorOp = isa<TensorType>(inputType); 73 Value input; 74 if (isTensorOp) { 75 76 input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy) 77 .getResult(); 78 } else { 79 input = rewriter 80 .create<memref::AllocOp>( 81 loc, MemRefType::get(newFilterShape, elementTy)) 82 .getResult(); 83 } 84 85 // We can then construct the transposition on our filter. 86 auto transpose = 87 rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm); 88 89 Value newFilter; 90 if (isTensorOp) { 91 newFilter = transpose.getResult()[0]; 92 } else { 93 newFilter = input; 94 } 95 96 SmallVector<Value> newInputs{op.getInputs()}; 97 // The filter is always the second input argument, the other inputs can be 98 // left as they are. 99 newInputs[1] = newFilter; 100 // It is possible the convolution doesn't define any results and its 101 // out argument is just used instead. 102 SmallVector<Type> resultTy; 103 if (op.getNumResults()) { 104 resultTy.push_back(op->getResult(0).getType()); 105 } 106 auto newConv = 107 rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(), 108 op.getStrides(), op.getDilations()); 109 rewriter.replaceOp(op, newConv); 110 return newConv.getOperation(); 111 } 112 113 template <typename FHWCConvOp, typename HWCFConvOp> 114 class ConvConverter : public OpRewritePattern<FHWCConvOp> { 115 public: 116 using OpRewritePattern<FHWCConvOp>::OpRewritePattern; 117 LogicalResult matchAndRewrite(FHWCConvOp op, 118 PatternRewriter &rewriter) const final { 119 if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) { 120 return failure(); 121 } 122 return success(); 123 } 124 }; 125 } // namespace 126 127 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, 128 linalg::Conv2DNhwcFhwcOp op) { 129 130 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp, 131 linalg::Conv2DNhwcHwcfOp>(rewriter, op); 132 } 133 134 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, 135 linalg::Conv2DNhwcFhwcQOp op) { 136 137 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp, 138 linalg::Conv2DNhwcHwcfQOp>(rewriter, op); 139 } 140 141 void populateTransposeConv2DPatterns(RewritePatternSet &patterns) { 142 MLIRContext *context = patterns.getContext(); 143 patterns.insert< 144 ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>, 145 ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>( 146 context); 147 } 148 } // namespace linalg 149 } // namespace mlir 150