1 //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" 10 11 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Pass/PassRegistry.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 namespace mlir { 19 #define GEN_PASS_DEF_CONVERTARMNEON2DTOINTR 20 #include "mlir/Conversion/Passes.h.inc" 21 } // namespace mlir 22 23 using namespace mlir; 24 using namespace mlir::arm_neon; 25 26 namespace { 27 28 class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> { 29 public: 30 using OpRewritePattern::OpRewritePattern; 31 32 /// Convert to 1-dimensional vector type to match the requirements of 33 /// arm.neon.intr.sdot 34 LogicalResult matchAndRewrite(Sdot2dOp op, 35 PatternRewriter &rewriter) const override { 36 Type elemType = cast<VectorType>(op.getB().getType()).getElementType(); 37 int length = cast<VectorType>(op.getB().getType()).getShape()[0] * 38 Sdot2dOp::kReductionSize; 39 VectorType flattenedVectorType = VectorType::get({length}, elemType); 40 Value b2d = op.getB(); 41 Value c2d = op.getC(); 42 Location loc = op.getLoc(); 43 Value b1d = 44 rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d); 45 Value c1d = 46 rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d); 47 Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(), 48 b1d, c1d); 49 rewriter.replaceOp(op, {newOp}); 50 return success(); 51 } 52 }; 53 54 class ConvertArmNeon2dToIntr 55 : public impl::ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> { 56 void runOnOperation() override { 57 auto *context = &getContext(); 58 59 RewritePatternSet patterns(context); 60 populateConvertArmNeon2dToIntrPatterns(patterns); 61 62 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 63 return signalPassFailure(); 64 } 65 }; 66 67 } // namespace 68 69 void mlir::populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) { 70 patterns.add<Sdot2dLoweringPattern>(patterns.getContext()); 71 } 72 73 std::unique_ptr<Pass> mlir::createConvertArmNeon2dToIntrPass() { 74 return std::make_unique<ConvertArmNeon2dToIntr>(); 75 } 76