1 //===- NamedOpConversions.cpp - Implements conversions between named ops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements conversions between named ops that can be seens as 10 // canonicalizations of named ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Passes.h" 15 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS 25 #include "mlir/Dialect/Linalg/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) { 32 return llvm::to_vector<2>(llvm::seq<int64_t>(start, end)); 33 } 34 35 static LogicalResult 36 matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, 37 Value iZp, Value kZp, Value init, Attribute stride, 38 Attribute dilation, PatternRewriter &rewriter) { 39 Location loc = operation->getLoc(); 40 auto linalgOp = dyn_cast<LinalgOp>(operation); 41 // Exit out on the memref version of this operation. 42 if (!linalgOp || !linalgOp.hasPureTensorSemantics()) 43 return failure(); 44 45 auto result = operation->getResult(0); 46 47 auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType()); 48 auto initTy = dyn_cast<RankedTensorType>(init.getType()); 49 auto resultTy = dyn_cast<RankedTensorType>(result.getType()); 50 if (!kernelTy || !initTy || !resultTy) 51 return failure(); 52 53 if (kernelTy.getDimSize(3) != 1) 54 return failure(); 55 56 // Collapse kernel dims. 57 SmallVector<ReassociationIndices, 4> collapsedKernelDims = { 58 getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)}; 59 auto newKernelTy = RankedTensorType::get( 60 {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, 61 kernelTy.getElementType()); 62 auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>( 63 loc, newKernelTy, kernel, collapsedKernelDims); 64 65 // Collapse init dims. 66 SmallVector<ReassociationIndices, 4> collapsedInitDims = { 67 getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3), 68 getIndicesVector(3, 5)}; 69 auto newInitTy = 70 RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), 71 initTy.getDimSize(2), initTy.getDimSize(3)}, 72 initTy.getElementType()); 73 auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>( 74 loc, newInitTy, init, collapsedInitDims); 75 76 SmallVector<NamedAttribute> preservedAttrs; 77 Operation *newConv = 78 TypeSwitch<Operation *, Operation *>(operation) 79 .Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) { 80 preservedAttrs = getPrunedAttributeList(op); 81 return rewriter.create<DepthwiseConv2DNhwcHwcOp>( 82 loc, newInitTy, ValueRange{input, collapsedKernel}, 83 ValueRange{collapsedInit}, stride, dilation); 84 }) 85 .Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) { 86 preservedAttrs = getPrunedAttributeList(op); 87 return rewriter.create<DepthwiseConv2DNhwcHwcQOp>( 88 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, 89 ValueRange{collapsedInit}, stride, dilation); 90 }) 91 .Default([](Operation *op) { return nullptr; }); 92 if (!newConv) 93 return failure(); 94 for (auto attr : preservedAttrs) 95 newConv->setAttr(attr.getName(), attr.getValue()); 96 97 // Expand dimensions back out to 98 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 99 operation, resultTy, newConv->getResult(0), collapsedInitDims); 100 return success(); 101 } 102 103 namespace { 104 struct SimplifyDepthwiseConvOp 105 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> { 106 using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern; 107 108 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op, 109 PatternRewriter &rewriter) const override { 110 Operation *operation = op.getOperation(); 111 Value input = op.getDpsInputOperand(0)->get(); 112 Value kernel = op.getDpsInputOperand(1)->get(); 113 Value init = op.getDpsInitOperand(0)->get(); 114 115 auto stride = op.getStrides(); 116 auto dilation = op.getDilations(); 117 118 return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr, 119 nullptr, init, stride, dilation, 120 rewriter); 121 } 122 }; 123 124 struct SimplifyDepthwiseConvQOp 125 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> { 126 using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern; 127 128 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op, 129 PatternRewriter &rewriter) const override { 130 Operation *operation = op.getOperation(); 131 Value input = op.getDpsInputOperand(0)->get(); 132 Value kernel = op.getDpsInputOperand(1)->get(); 133 Value iZp = op.getDpsInputOperand(2)->get(); 134 Value kZp = op.getDpsInputOperand(3)->get(); 135 Value init = op.getDpsInitOperand(0)->get(); 136 137 auto stride = op.getStrides(); 138 auto dilation = op.getDilations(); 139 140 return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp, 141 init, stride, dilation, rewriter); 142 } 143 }; 144 145 struct LinalgNamedOpConversionPass 146 : public impl::LinalgNamedOpConversionPassBase< 147 LinalgNamedOpConversionPass> { 148 using impl::LinalgNamedOpConversionPassBase< 149 LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase; 150 151 void runOnOperation() override { 152 Operation *op = getOperation(); 153 RewritePatternSet patterns(op->getContext()); 154 populateLinalgNamedOpConversionPatterns(patterns); 155 if (failed(applyPatternsGreedily(op, std::move(patterns)))) 156 return signalPassFailure(); 157 } 158 }; 159 } // namespace 160 161 void mlir::linalg::populateLinalgNamedOpConversionPatterns( 162 RewritePatternSet &patterns) { 163 patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>( 164 patterns.getContext()); 165 } 166