1 //===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the ArmSME dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
14 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
15
16 namespace mlir::arm_sme {
17
getSMETileSliceMinNumElts(Type type)18 unsigned getSMETileSliceMinNumElts(Type type) {
19 assert(isValidSMETileElementType(type) && "invalid tile type!");
20 return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
21 }
22
isValidSMETileElementType(Type type)23 bool isValidSMETileElementType(Type type) {
24 return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
25 type.isInteger(64) || type.isInteger(128) || type.isF16() ||
26 type.isBF16() || type.isF32() || type.isF64() || type.isF128();
27 }
28
isValidSMETileVectorType(VectorType vType)29 bool isValidSMETileVectorType(VectorType vType) {
30 if ((vType.getRank() != 2) || !vType.allDimsScalable())
31 return false;
32
33 auto elemType = vType.getElementType();
34 if (!isValidSMETileElementType(elemType))
35 return false;
36
37 unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
38 if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
39 return false;
40
41 return true;
42 }
43
getSMETileType(VectorType type)44 std::optional<ArmSMETileType> getSMETileType(VectorType type) {
45 if (!isValidSMETileVectorType(type))
46 return {};
47 switch (type.getElementTypeBitWidth()) {
48 case 8:
49 return ArmSMETileType::ZAB;
50 case 16:
51 return ArmSMETileType::ZAH;
52 case 32:
53 return ArmSMETileType::ZAS;
54 case 64:
55 return ArmSMETileType::ZAD;
56 case 128:
57 return ArmSMETileType::ZAQ;
58 default:
59 llvm_unreachable("unknown SME tile type");
60 }
61 }
62
verifyOperationHasValidTileId(Operation * op)63 LogicalResult verifyOperationHasValidTileId(Operation *op) {
64 auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
65 if (!tileOp)
66 return success(); // Not a tile op (no need to check).
67 auto tileId = tileOp.getTileId();
68 if (!tileId)
69 return success(); // Not having a tile ID (yet) is okay.
70 if (!tileId.getType().isSignlessInteger(32))
71 return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
72 return success();
73 }
74
createLoopOverTileSlices(PatternRewriter & rewriter,Location loc,Value initTile,std::function<Value (OpBuilder &,Location,Value,Value)> makeLoopBody)75 scf::ForOp createLoopOverTileSlices(
76 PatternRewriter &rewriter, Location loc, Value initTile,
77 std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
78 OpBuilder::InsertionGuard g(rewriter);
79 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
80 auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
81 loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
82 auto vscale =
83 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
84 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
85 auto numTileSlices =
86 rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
87 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
88 ValueRange{initTile});
89 rewriter.setInsertionPointToStart(forOp.getBody());
90 Value nextTile =
91 makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
92 /*currentTile=*/forOp.getRegionIterArg(0));
93 rewriter.create<scf::YieldOp>(loc, nextTile);
94 return forOp;
95 }
96
isMultipleOfSMETileVectorType(VectorType vType)97 bool isMultipleOfSMETileVectorType(VectorType vType) {
98 if (vType.getRank() != 2 || !vType.allDimsScalable())
99 return false;
100
101 auto elementType = vType.getElementType();
102 if (!isValidSMETileElementType(elementType))
103 return false;
104
105 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
106
107 int64_t vectorRows = vType.getDimSize(0);
108 int64_t vectorCols = vType.getDimSize(1);
109
110 return (vectorRows > minNumElts || vectorCols > minNumElts) &&
111 vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
112 }
113
getSMETileTypeForElement(Type elementType)114 VectorType getSMETileTypeForElement(Type elementType) {
115 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
116 return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
117 }
118
eraseTriviallyDeadTileOps(IRRewriter & rewriter,FunctionOpInterface function)119 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
120 FunctionOpInterface function) {
121 SmallVector<Operation *> worklist;
122 function->walk([&](Operation *op) {
123 auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
124 if (armSMEOp && isOpTriviallyDead(armSMEOp))
125 worklist.push_back(armSMEOp);
126 });
127 while (!worklist.empty()) {
128 Operation *op = worklist.pop_back_val();
129 if (!isOpTriviallyDead(op))
130 continue;
131 for (Value value : op->getOperands()) {
132 if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
133 worklist.push_back(armSMEOp);
134 }
135 rewriter.eraseOp(op);
136 }
137 }
138
isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)139 bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
140 return tileOp && tileOp->getNumResults() == 1 &&
141 tileOp->getNumOperands() == 0 && isPure(tileOp);
142 }
143
hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)144 bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
145 for (Value result : tileOp->getResults()) {
146 if (arm_sme::isValidSMETileVectorType(result.getType()))
147 return true;
148 }
149 return false;
150 }
151
getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)152 OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
153 if (!tileOp)
154 return nullptr;
155 auto isTileOperandType = [](OpOperand &operand) {
156 return arm_sme::isValidSMETileVectorType(operand.get().getType());
157 };
158 assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
159 "expected at most one tile operand");
160 OpOperand *tileOperand =
161 llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
162 if (tileOperand == tileOp->getOpOperands().end())
163 return nullptr;
164 return tileOperand;
165 }
166
isTileTypeGreaterOrEqual(ArmSMETileType typeA,ArmSMETileType typeB)167 bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
168 // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
169 return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
170 }
171
172 } // namespace mlir::arm_sme
173