xref: /llvm-project/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
19f7fff7fSCullen Rhodes //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
29f7fff7fSCullen Rhodes //
39f7fff7fSCullen Rhodes // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49f7fff7fSCullen Rhodes // See https://llvm.org/LICENSE.txt for license information.
59f7fff7fSCullen Rhodes // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69f7fff7fSCullen Rhodes //
79f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
89f7fff7fSCullen Rhodes 
99f7fff7fSCullen Rhodes #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
109f7fff7fSCullen Rhodes 
119f7fff7fSCullen Rhodes #include "mlir/Dialect/Arith/IR/Arith.h"
129f7fff7fSCullen Rhodes #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
139f7fff7fSCullen Rhodes #include "mlir/Dialect/ArmSME/Utils/Utils.h"
149f7fff7fSCullen Rhodes #include "mlir/Pass/Pass.h"
159f7fff7fSCullen Rhodes #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
169f7fff7fSCullen Rhodes 
179f7fff7fSCullen Rhodes namespace mlir {
189f7fff7fSCullen Rhodes #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
199f7fff7fSCullen Rhodes #include "mlir/Conversion/Passes.h.inc"
209f7fff7fSCullen Rhodes } // namespace mlir
219f7fff7fSCullen Rhodes 
229f7fff7fSCullen Rhodes #define DEBUG_TYPE "arith-to-arm-sme"
239f7fff7fSCullen Rhodes 
249f7fff7fSCullen Rhodes using namespace mlir;
259f7fff7fSCullen Rhodes 
269f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
279f7fff7fSCullen Rhodes // Conversion helpers
289f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
299f7fff7fSCullen Rhodes 
309f7fff7fSCullen Rhodes /// Returns true if 'val' is a splat of zero, false otherwise.
319f7fff7fSCullen Rhodes static bool isSplatZero(Type elemType, DenseElementsAttr val) {
329f7fff7fSCullen Rhodes   if (llvm::isa<FloatType>(elemType))
339f7fff7fSCullen Rhodes     return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
349f7fff7fSCullen Rhodes   if (llvm::isa<IntegerType>(elemType))
359f7fff7fSCullen Rhodes     return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
369f7fff7fSCullen Rhodes   return false;
379f7fff7fSCullen Rhodes }
389f7fff7fSCullen Rhodes 
399f7fff7fSCullen Rhodes namespace {
409f7fff7fSCullen Rhodes 
419f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
429f7fff7fSCullen Rhodes // ConstantOp
439f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
449f7fff7fSCullen Rhodes 
459f7fff7fSCullen Rhodes /// Conversion pattern for dense arith.constant.
469f7fff7fSCullen Rhodes struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
479f7fff7fSCullen Rhodes   using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
489f7fff7fSCullen Rhodes 
499f7fff7fSCullen Rhodes   LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
509f7fff7fSCullen Rhodes                                 PatternRewriter &rewriter) const final {
519f7fff7fSCullen Rhodes     auto tileType = dyn_cast<VectorType>(constantOp.getType());
529f7fff7fSCullen Rhodes     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
539f7fff7fSCullen Rhodes       return failure();
549f7fff7fSCullen Rhodes 
559f7fff7fSCullen Rhodes     auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
569f7fff7fSCullen Rhodes     if (!denseAttr || !denseAttr.isSplat())
579f7fff7fSCullen Rhodes       return failure();
589f7fff7fSCullen Rhodes 
599f7fff7fSCullen Rhodes     auto tileElementType = tileType.getElementType();
609f7fff7fSCullen Rhodes 
619f7fff7fSCullen Rhodes     // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
629f7fff7fSCullen Rhodes     if (isSplatZero(tileElementType, denseAttr)) {
639f7fff7fSCullen Rhodes       rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
649f7fff7fSCullen Rhodes       return success();
659f7fff7fSCullen Rhodes     }
669f7fff7fSCullen Rhodes 
67c4251243SBenjamin Maxwell     // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
689f7fff7fSCullen Rhodes     // ops that broadcast the constant to each tile slice.
699f7fff7fSCullen Rhodes     auto loc = constantOp.getLoc();
709f7fff7fSCullen Rhodes 
719f7fff7fSCullen Rhodes     // To fill a tile with a constant, we create a 1-D splat of the constant,
729f7fff7fSCullen Rhodes     // then move that into each tile slice (the largest unit we can set at once,
739f7fff7fSCullen Rhodes     // outside of operations like the outerproduct).
749f7fff7fSCullen Rhodes     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
759f7fff7fSCullen Rhodes     auto denseAttr1D = DenseElementsAttr::get(
769f7fff7fSCullen Rhodes         tileSliceType, denseAttr.getSplatValue<Attribute>());
779f7fff7fSCullen Rhodes     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
789f7fff7fSCullen Rhodes 
799f7fff7fSCullen Rhodes     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
809f7fff7fSCullen Rhodes     auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
819f7fff7fSCullen Rhodes                             Value currentTile) {
82c4251243SBenjamin Maxwell       // Create 'arm_sme.insert_tile_slice' to write vector to tile
839f7fff7fSCullen Rhodes       // slice.
84c4251243SBenjamin Maxwell       auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
859f7fff7fSCullen Rhodes           loc, tileType, constantOp1D, currentTile, tileSliceIndex);
869f7fff7fSCullen Rhodes       return nextTile.getResult();
879f7fff7fSCullen Rhodes     };
889f7fff7fSCullen Rhodes     auto forOp = mlir::arm_sme::createLoopOverTileSlices(
899f7fff7fSCullen Rhodes         rewriter, loc, initTile, makeLoopBody);
909f7fff7fSCullen Rhodes     rewriter.replaceOp(constantOp, forOp.getResult(0));
919f7fff7fSCullen Rhodes 
929f7fff7fSCullen Rhodes     return success();
939f7fff7fSCullen Rhodes   }
949f7fff7fSCullen Rhodes };
959f7fff7fSCullen Rhodes 
969f7fff7fSCullen Rhodes } // namespace
979f7fff7fSCullen Rhodes 
989f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
999f7fff7fSCullen Rhodes // Pattern population
1009f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
1019f7fff7fSCullen Rhodes 
1029f7fff7fSCullen Rhodes void mlir::arith::populateArithToArmSMEConversionPatterns(
1039f7fff7fSCullen Rhodes     RewritePatternSet &patterns) {
1049f7fff7fSCullen Rhodes   patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
1059f7fff7fSCullen Rhodes }
1069f7fff7fSCullen Rhodes 
1079f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
1089f7fff7fSCullen Rhodes // Pass definition
1099f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===//
1109f7fff7fSCullen Rhodes 
1119f7fff7fSCullen Rhodes namespace {
1129f7fff7fSCullen Rhodes struct ArithToArmSMEConversionPass final
1139f7fff7fSCullen Rhodes     : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
1149f7fff7fSCullen Rhodes   using impl::ArithToArmSMEConversionPassBase<
1159f7fff7fSCullen Rhodes       ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
1169f7fff7fSCullen Rhodes 
1179f7fff7fSCullen Rhodes   void runOnOperation() override {
1189f7fff7fSCullen Rhodes     RewritePatternSet patterns(&getContext());
1199f7fff7fSCullen Rhodes     arith::populateArithToArmSMEConversionPatterns(patterns);
120*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1219f7fff7fSCullen Rhodes       return signalPassFailure();
1229f7fff7fSCullen Rhodes   }
1239f7fff7fSCullen Rhodes };
1249f7fff7fSCullen Rhodes } // namespace
125