xref: /llvm-project/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
120daedacSBenoit Jacob //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
220daedacSBenoit Jacob //
320daedacSBenoit Jacob // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
420daedacSBenoit Jacob // See https://llvm.org/LICENSE.txt for license information.
520daedacSBenoit Jacob // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
620daedacSBenoit Jacob //
720daedacSBenoit Jacob //===----------------------------------------------------------------------===//
820daedacSBenoit Jacob 
920daedacSBenoit Jacob #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
1067d0d7acSMichele Scuttari 
1120daedacSBenoit Jacob #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
1299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1320daedacSBenoit Jacob #include "mlir/IR/PatternMatch.h"
1420daedacSBenoit Jacob #include "mlir/Pass/Pass.h"
1520daedacSBenoit Jacob #include "mlir/Pass/PassRegistry.h"
1620daedacSBenoit Jacob #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1720daedacSBenoit Jacob 
1867d0d7acSMichele Scuttari namespace mlir {
1967d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTARMNEON2DTOINTR
2067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2167d0d7acSMichele Scuttari } // namespace mlir
2267d0d7acSMichele Scuttari 
2320daedacSBenoit Jacob using namespace mlir;
2420daedacSBenoit Jacob using namespace mlir::arm_neon;
2520daedacSBenoit Jacob 
2620daedacSBenoit Jacob namespace {
2720daedacSBenoit Jacob 
2820daedacSBenoit Jacob class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
2920daedacSBenoit Jacob public:
3020daedacSBenoit Jacob   using OpRewritePattern::OpRewritePattern;
3120daedacSBenoit Jacob 
3220daedacSBenoit Jacob   /// Convert to 1-dimensional vector type to match the requirements of
3320daedacSBenoit Jacob   /// arm.neon.intr.sdot
3420daedacSBenoit Jacob   LogicalResult matchAndRewrite(Sdot2dOp op,
3520daedacSBenoit Jacob                                 PatternRewriter &rewriter) const override {
365550c821STres Popp     Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
375550c821STres Popp     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
3820daedacSBenoit Jacob                  Sdot2dOp::kReductionSize;
3920daedacSBenoit Jacob     VectorType flattenedVectorType = VectorType::get({length}, elemType);
408df54a6aSJacques Pienaar     Value b2d = op.getB();
418df54a6aSJacques Pienaar     Value c2d = op.getC();
4220daedacSBenoit Jacob     Location loc = op.getLoc();
4320daedacSBenoit Jacob     Value b1d =
4420daedacSBenoit Jacob         rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
4520daedacSBenoit Jacob     Value c1d =
4620daedacSBenoit Jacob         rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
478df54a6aSJacques Pienaar     Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(),
488df54a6aSJacques Pienaar                                           b1d, c1d);
4920daedacSBenoit Jacob     rewriter.replaceOp(op, {newOp});
5020daedacSBenoit Jacob     return success();
5120daedacSBenoit Jacob   }
5220daedacSBenoit Jacob };
5320daedacSBenoit Jacob 
54039b969bSMichele Scuttari class ConvertArmNeon2dToIntr
5567d0d7acSMichele Scuttari     : public impl::ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> {
5620daedacSBenoit Jacob   void runOnOperation() override {
5720daedacSBenoit Jacob     auto *context = &getContext();
5820daedacSBenoit Jacob 
5920daedacSBenoit Jacob     RewritePatternSet patterns(context);
6020daedacSBenoit Jacob     populateConvertArmNeon2dToIntrPatterns(patterns);
6120daedacSBenoit Jacob 
62*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
6320daedacSBenoit Jacob       return signalPassFailure();
6420daedacSBenoit Jacob   }
6520daedacSBenoit Jacob };
6620daedacSBenoit Jacob 
6720daedacSBenoit Jacob } // namespace
6820daedacSBenoit Jacob 
6947f175b0SRiver Riddle void mlir::populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) {
7020daedacSBenoit Jacob   patterns.add<Sdot2dLoweringPattern>(patterns.getContext());
7120daedacSBenoit Jacob }
72039b969bSMichele Scuttari 
73039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createConvertArmNeon2dToIntrPass() {
74039b969bSMichele Scuttari   return std::make_unique<ConvertArmNeon2dToIntr>();
75039b969bSMichele Scuttari }
76