120daedacSBenoit Jacob //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===// 220daedacSBenoit Jacob // 320daedacSBenoit Jacob // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 420daedacSBenoit Jacob // See https://llvm.org/LICENSE.txt for license information. 520daedacSBenoit Jacob // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 620daedacSBenoit Jacob // 720daedacSBenoit Jacob //===----------------------------------------------------------------------===// 820daedacSBenoit Jacob 920daedacSBenoit Jacob #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" 1067d0d7acSMichele Scuttari 1120daedacSBenoit Jacob #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 1299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 1320daedacSBenoit Jacob #include "mlir/IR/PatternMatch.h" 1420daedacSBenoit Jacob #include "mlir/Pass/Pass.h" 1520daedacSBenoit Jacob #include "mlir/Pass/PassRegistry.h" 1620daedacSBenoit Jacob #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 1720daedacSBenoit Jacob 1867d0d7acSMichele Scuttari namespace mlir { 1967d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTARMNEON2DTOINTR 2067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 2167d0d7acSMichele Scuttari } // namespace mlir 2267d0d7acSMichele Scuttari 2320daedacSBenoit Jacob using namespace mlir; 2420daedacSBenoit Jacob using namespace mlir::arm_neon; 2520daedacSBenoit Jacob 2620daedacSBenoit Jacob namespace { 2720daedacSBenoit Jacob 2820daedacSBenoit Jacob class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> { 2920daedacSBenoit Jacob public: 3020daedacSBenoit Jacob using OpRewritePattern::OpRewritePattern; 3120daedacSBenoit Jacob 3220daedacSBenoit Jacob /// Convert to 1-dimensional vector type to match the requirements of 3320daedacSBenoit Jacob /// arm.neon.intr.sdot 3420daedacSBenoit Jacob LogicalResult matchAndRewrite(Sdot2dOp op, 3520daedacSBenoit Jacob PatternRewriter &rewriter) const override { 365550c821STres Popp Type elemType = cast<VectorType>(op.getB().getType()).getElementType(); 375550c821STres Popp int length = cast<VectorType>(op.getB().getType()).getShape()[0] * 3820daedacSBenoit Jacob Sdot2dOp::kReductionSize; 3920daedacSBenoit Jacob VectorType flattenedVectorType = VectorType::get({length}, elemType); 408df54a6aSJacques Pienaar Value b2d = op.getB(); 418df54a6aSJacques Pienaar Value c2d = op.getC(); 4220daedacSBenoit Jacob Location loc = op.getLoc(); 4320daedacSBenoit Jacob Value b1d = 4420daedacSBenoit Jacob rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d); 4520daedacSBenoit Jacob Value c1d = 4620daedacSBenoit Jacob rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d); 478df54a6aSJacques Pienaar Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(), 488df54a6aSJacques Pienaar b1d, c1d); 4920daedacSBenoit Jacob rewriter.replaceOp(op, {newOp}); 5020daedacSBenoit Jacob return success(); 5120daedacSBenoit Jacob } 5220daedacSBenoit Jacob }; 5320daedacSBenoit Jacob 54039b969bSMichele Scuttari class ConvertArmNeon2dToIntr 5567d0d7acSMichele Scuttari : public impl::ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> { 5620daedacSBenoit Jacob void runOnOperation() override { 5720daedacSBenoit Jacob auto *context = &getContext(); 5820daedacSBenoit Jacob 5920daedacSBenoit Jacob RewritePatternSet patterns(context); 6020daedacSBenoit Jacob populateConvertArmNeon2dToIntrPatterns(patterns); 6120daedacSBenoit Jacob 62*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 6320daedacSBenoit Jacob return signalPassFailure(); 6420daedacSBenoit Jacob } 6520daedacSBenoit Jacob }; 6620daedacSBenoit Jacob 6720daedacSBenoit Jacob } // namespace 6820daedacSBenoit Jacob 6947f175b0SRiver Riddle void mlir::populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) { 7020daedacSBenoit Jacob patterns.add<Sdot2dLoweringPattern>(patterns.getContext()); 7120daedacSBenoit Jacob } 72039b969bSMichele Scuttari 73039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createConvertArmNeon2dToIntrPass() { 74039b969bSMichele Scuttari return std::make_unique<ConvertArmNeon2dToIntr>(); 75039b969bSMichele Scuttari } 76