1 //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" 10 #include "../PassDetail.h" 11 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Pass/PassRegistry.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::arm_neon; 20 21 namespace { 22 23 class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> { 24 public: 25 using OpRewritePattern::OpRewritePattern; 26 27 /// Convert to 1-dimensional vector type to match the requirements of 28 /// arm.neon.intr.sdot 29 LogicalResult matchAndRewrite(Sdot2dOp op, 30 PatternRewriter &rewriter) const override { 31 Type elemType = op.getB().getType().cast<VectorType>().getElementType(); 32 int length = op.getB().getType().cast<VectorType>().getShape()[0] * 33 Sdot2dOp::kReductionSize; 34 VectorType flattenedVectorType = VectorType::get({length}, elemType); 35 Value b2d = op.getB(); 36 Value c2d = op.getC(); 37 Location loc = op.getLoc(); 38 Value b1d = 39 rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d); 40 Value c1d = 41 rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d); 42 Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(), 43 b1d, c1d); 44 rewriter.replaceOp(op, {newOp}); 45 return success(); 46 } 47 }; 48 49 class ConvertArmNeon2dToIntr 50 : public ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> { 51 void runOnOperation() override { 52 auto *context = &getContext(); 53 54 RewritePatternSet patterns(context); 55 populateConvertArmNeon2dToIntrPatterns(patterns); 56 57 if (failed( 58 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 59 return signalPassFailure(); 60 } 61 }; 62 63 } // namespace 64 65 void mlir::populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) { 66 patterns.add<Sdot2dLoweringPattern>(patterns.getContext()); 67 } 68 69 std::unique_ptr<Pass> mlir::createConvertArmNeon2dToIntrPass() { 70 return std::make_unique<ConvertArmNeon2dToIntr>(); 71 } 72