xref: /llvm-project/mlir/lib/Dialect/ArmSME/IR/Utils.cpp (revision 041baf2f60ac3e107399641aea04c77019e7eab8)
1 //===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the ArmSME dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
14 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
15 
16 namespace mlir::arm_sme {
17 
getSMETileSliceMinNumElts(Type type)18 unsigned getSMETileSliceMinNumElts(Type type) {
19   assert(isValidSMETileElementType(type) && "invalid tile type!");
20   return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
21 }
22 
isValidSMETileElementType(Type type)23 bool isValidSMETileElementType(Type type) {
24   return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
25          type.isInteger(64) || type.isInteger(128) || type.isF16() ||
26          type.isBF16() || type.isF32() || type.isF64() || type.isF128();
27 }
28 
isValidSMETileVectorType(VectorType vType)29 bool isValidSMETileVectorType(VectorType vType) {
30   if ((vType.getRank() != 2) || !vType.allDimsScalable())
31     return false;
32 
33   auto elemType = vType.getElementType();
34   if (!isValidSMETileElementType(elemType))
35     return false;
36 
37   unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
38   if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
39     return false;
40 
41   return true;
42 }
43 
getSMETileType(VectorType type)44 std::optional<ArmSMETileType> getSMETileType(VectorType type) {
45   if (!isValidSMETileVectorType(type))
46     return {};
47   switch (type.getElementTypeBitWidth()) {
48   case 8:
49     return ArmSMETileType::ZAB;
50   case 16:
51     return ArmSMETileType::ZAH;
52   case 32:
53     return ArmSMETileType::ZAS;
54   case 64:
55     return ArmSMETileType::ZAD;
56   case 128:
57     return ArmSMETileType::ZAQ;
58   default:
59     llvm_unreachable("unknown SME tile type");
60   }
61 }
62 
verifyOperationHasValidTileId(Operation * op)63 LogicalResult verifyOperationHasValidTileId(Operation *op) {
64   auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
65   if (!tileOp)
66     return success(); // Not a tile op (no need to check).
67   auto tileId = tileOp.getTileId();
68   if (!tileId)
69     return success(); // Not having a tile ID (yet) is okay.
70   if (!tileId.getType().isSignlessInteger(32))
71     return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
72   return success();
73 }
74 
createLoopOverTileSlices(PatternRewriter & rewriter,Location loc,Value initTile,std::function<Value (OpBuilder &,Location,Value,Value)> makeLoopBody)75 scf::ForOp createLoopOverTileSlices(
76     PatternRewriter &rewriter, Location loc, Value initTile,
77     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
78   OpBuilder::InsertionGuard g(rewriter);
79   auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
80   auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
81       loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
82   auto vscale =
83       rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
84   auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
85   auto numTileSlices =
86       rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
87   auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
88                                            ValueRange{initTile});
89   rewriter.setInsertionPointToStart(forOp.getBody());
90   Value nextTile =
91       makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
92                    /*currentTile=*/forOp.getRegionIterArg(0));
93   rewriter.create<scf::YieldOp>(loc, nextTile);
94   return forOp;
95 }
96 
isMultipleOfSMETileVectorType(VectorType vType)97 bool isMultipleOfSMETileVectorType(VectorType vType) {
98   if (vType.getRank() != 2 || !vType.allDimsScalable())
99     return false;
100 
101   auto elementType = vType.getElementType();
102   if (!isValidSMETileElementType(elementType))
103     return false;
104 
105   unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
106 
107   int64_t vectorRows = vType.getDimSize(0);
108   int64_t vectorCols = vType.getDimSize(1);
109 
110   return (vectorRows > minNumElts || vectorCols > minNumElts) &&
111          vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
112 }
113 
getSMETileTypeForElement(Type elementType)114 VectorType getSMETileTypeForElement(Type elementType) {
115   unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
116   return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
117 }
118 
eraseTriviallyDeadTileOps(IRRewriter & rewriter,FunctionOpInterface function)119 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
120                                FunctionOpInterface function) {
121   SmallVector<Operation *> worklist;
122   function->walk([&](Operation *op) {
123     auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
124     if (armSMEOp && isOpTriviallyDead(armSMEOp))
125       worklist.push_back(armSMEOp);
126   });
127   while (!worklist.empty()) {
128     Operation *op = worklist.pop_back_val();
129     if (!isOpTriviallyDead(op))
130       continue;
131     for (Value value : op->getOperands()) {
132       if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
133         worklist.push_back(armSMEOp);
134     }
135     rewriter.eraseOp(op);
136   }
137 }
138 
isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)139 bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
140   return tileOp && tileOp->getNumResults() == 1 &&
141          tileOp->getNumOperands() == 0 && isPure(tileOp);
142 }
143 
hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)144 bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
145   for (Value result : tileOp->getResults()) {
146     if (arm_sme::isValidSMETileVectorType(result.getType()))
147       return true;
148   }
149   return false;
150 }
151 
getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)152 OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
153   if (!tileOp)
154     return nullptr;
155   auto isTileOperandType = [](OpOperand &operand) {
156     return arm_sme::isValidSMETileVectorType(operand.get().getType());
157   };
158   assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
159          "expected at most one tile operand");
160   OpOperand *tileOperand =
161       llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
162   if (tileOperand == tileOp->getOpOperands().end())
163     return nullptr;
164   return tileOperand;
165 }
166 
isTileTypeGreaterOrEqual(ArmSMETileType typeA,ArmSMETileType typeB)167 bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
168   // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
169   return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
170 }
171 
172 } // namespace mlir::arm_sme
173