xref: /llvm-project/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
13 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 namespace mlir {
18 #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19 #include "mlir/Conversion/Passes.h.inc"
20 } // namespace mlir
21 
22 #define DEBUG_TYPE "arith-to-arm-sme"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Conversion helpers
28 //===----------------------------------------------------------------------===//
29 
30 /// Returns true if 'val' is a splat of zero, false otherwise.
31 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
32   if (llvm::isa<FloatType>(elemType))
33     return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
34   if (llvm::isa<IntegerType>(elemType))
35     return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
36   return false;
37 }
38 
39 namespace {
40 
41 //===----------------------------------------------------------------------===//
42 // ConstantOp
43 //===----------------------------------------------------------------------===//
44 
45 /// Conversion pattern for dense arith.constant.
46 struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
47   using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
48 
49   LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50                                 PatternRewriter &rewriter) const final {
51     auto tileType = dyn_cast<VectorType>(constantOp.getType());
52     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
53       return failure();
54 
55     auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56     if (!denseAttr || !denseAttr.isSplat())
57       return failure();
58 
59     auto tileElementType = tileType.getElementType();
60 
61     // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62     if (isSplatZero(tileElementType, denseAttr)) {
63       rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
64       return success();
65     }
66 
67     // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
68     // ops that broadcast the constant to each tile slice.
69     auto loc = constantOp.getLoc();
70 
71     // To fill a tile with a constant, we create a 1-D splat of the constant,
72     // then move that into each tile slice (the largest unit we can set at once,
73     // outside of operations like the outerproduct).
74     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
75     auto denseAttr1D = DenseElementsAttr::get(
76         tileSliceType, denseAttr.getSplatValue<Attribute>());
77     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
78 
79     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
80     auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
81                             Value currentTile) {
82       // Create 'arm_sme.insert_tile_slice' to write vector to tile
83       // slice.
84       auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
85           loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86       return nextTile.getResult();
87     };
88     auto forOp = mlir::arm_sme::createLoopOverTileSlices(
89         rewriter, loc, initTile, makeLoopBody);
90     rewriter.replaceOp(constantOp, forOp.getResult(0));
91 
92     return success();
93   }
94 };
95 
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // Pattern population
100 //===----------------------------------------------------------------------===//
101 
102 void mlir::arith::populateArithToArmSMEConversionPatterns(
103     RewritePatternSet &patterns) {
104   patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Pass definition
109 //===----------------------------------------------------------------------===//
110 
111 namespace {
112 struct ArithToArmSMEConversionPass final
113     : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
114   using impl::ArithToArmSMEConversionPassBase<
115       ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
116 
117   void runOnOperation() override {
118     RewritePatternSet patterns(&getContext());
119     arith::populateArithToArmSMEConversionPatterns(patterns);
120     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
121       return signalPassFailure();
122   }
123 };
124 } // namespace
125