xref: /llvm-project/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (revision c42512436b23ab50e7637f239abe8371407104a1)
1 //===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
16 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTARMSMETOSCF
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30 ///   rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31 ///   rank 2: (indices[0] + tileSliceIndex, indices[1])
32 SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33                                        Value tileSliceIndex,
34                                        Value tileSliceNumElts, Location loc,
35                                        PatternRewriter &rewriter) {
36   assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
37   SmallVector<Value, 2> outIndices;
38 
39   auto tileSliceOffset = tileSliceIndex;
40   if (rank == 1)
41     tileSliceOffset =
42         rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
43 
44   auto baseIndexPlusTileSliceOffset =
45       rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
46   outIndices.push_back(baseIndexPlusTileSliceOffset);
47 
48   if (rank == 2)
49     outIndices.push_back(indices[1]);
50 
51   return outIndices;
52 }
53 
54 /// Creates an scf.for for the load/store of an ArmSME tile.
55 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
56     PatternRewriter &rewriter, Location loc, VectorType tileType,
57     ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
58     function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
59                        /*currentTile=*/Value)>
60         makeLoopBody) {
61   PatternRewriter::InsertionGuard guard(rewriter);
62 
63   auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
64       loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
65   auto vscale =
66       rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
67   auto predicateType =
68       VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
69 
70   // This describes both the number of ZA tile slices and the number of
71   // elements in a vector of SVL bits for a given element type (SVL_B,
72   // SVL_H, ..., SVL_Q).
73   auto numTileSlices =
74       rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
75 
76   Value predicate;
77   Value upperBound;
78   if (mask) {
79     auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
80     if (!createMaskOp)
81       return rewriter.notifyMatchFailure(
82           loc, "unsupported mask op, only 'vector.create_mask' is "
83                "currently supported");
84 
85     auto maskDim0 = createMaskOp.getOperands()[0];
86     auto maskDim1 = createMaskOp.getOperands()[1];
87 
88     // The upper bound of the loop must be clamped at `numTileSlices` as
89     // `vector.create_mask` allows operands to be greater than the size of a
90     // dimension.
91     auto numRowI64 = rewriter.create<arith::IndexCastOp>(
92         loc, rewriter.getI64Type(), maskDim0);
93     auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
94         loc, rewriter.getI64Type(), numTileSlices);
95     auto upperBoundI64 =
96         rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
97     upperBound = rewriter.create<arith::IndexCastOp>(
98         loc, rewriter.getIndexType(), upperBoundI64);
99 
100     predicate =
101         rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
102   } else {
103     upperBound = numTileSlices;
104     // No mask. Create an 'all true' predicate for the tile slice.
105     predicate = rewriter.create<arith::ConstantOp>(
106         loc, DenseElementsAttr::get(predicateType, true));
107   }
108 
109   bool hasCarriedArgs = bool(initTile);
110   auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
111   auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
112   auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
113                                            hasCarriedArgs ? ValueRange{initTile}
114                                                           : ValueRange{});
115 
116   rewriter.setInsertionPointToStart(forOp.getBody());
117   Value tileSliceIndex = forOp.getInductionVar();
118 
119   auto adjustedIndices = getMemrefIndices(
120       memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
121   auto nextTile = makeLoopBody(
122       tileSliceIndex, adjustedIndices, predicate,
123       /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
124 
125   assert(bool(nextTile) == hasCarriedArgs);
126   if (nextTile)
127     rewriter.create<scf::YieldOp>(loc, nextTile);
128 
129   return forOp;
130 }
131 
132 FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
133     PatternRewriter &rewriter, Location loc, VectorType tileType,
134     ValueRange memrefIndices, int memrefRank, Value mask,
135     function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
136         makeLoopBody) {
137   return createLoadStoreForOverTileSlices(
138       rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
139       [&](Value index, ValueRange adjustedIndices, Value predicate,
140           Value) -> Value {
141         makeLoopBody(index, adjustedIndices, predicate);
142         return {};
143       });
144 }
145 
146 /// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
147 ///
148 ///  With a mask:
149 ///
150 ///  BEFORE:
151 ///  ```mlir
152 ///  %pad = arith.constant 0 : i32
153 ///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
154 ///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
155 ///    memref<?x?xi32>, vector<[4]x[4]xi32>
156 ///  ```
157 ///
158 ///  AFTER:
159 ///  ```mlir
160 ///  %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
161 ///  %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
162 ///  %loop_rows = arith.minsi %num_rows, %svl_s : index
163 ///  %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
164 ///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
165 ///    %tile_update = arm_sme.load_tile_slice
166 ///      %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
167 ///      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
168 ///    scf.yield %tile_update : vector<[4]x[4]xi32>
169 ///  }
170 ///  ```
171 ///
172 /// Without a mask the lowering is pretty much identical. The only difference is
173 /// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
174 ///
175 /// NOTE: Only mask of 'vector.create_mask' op is currently supported.
176 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
177   using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
178 
179   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
180                                 PatternRewriter &rewriter) const override {
181     auto loc = tileLoadOp.getLoc();
182     auto tileType = tileLoadOp.getVectorType();
183     auto mask = tileLoadOp.getMask();
184 
185     Value initTile;
186     if (mask) {
187       auto padOp = tileLoadOp.getPadding();
188       assert(padOp && "expected padding when masking!");
189 
190       auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
191       if (!constPadOp || constPadOp.getValue() !=
192                              rewriter.getZeroAttr(tileType.getElementType()))
193         return rewriter.notifyMatchFailure(
194             tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
195 
196       // Initialize tile with zero to satisfy padding. Inactive cols will be
197       // zeroed anyway since the loads use zeroing predication. For inactive
198       // rows however, no load will occur so these need to be zeroed.
199       initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
200     } else {
201       initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
202     }
203 
204     // Create a loop to load the active tile slices from memory.
205     auto forOp = createLoadStoreForOverTileSlices(
206         rewriter, loc, tileType, tileLoadOp.getIndices(),
207         tileLoadOp.getMemRefType().getRank(), mask, initTile,
208         [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
209             Value currentTile) -> Value {
210           // Create 'arm_sme.load_tile_slice' to load tile slice from memory
211           // into tile.
212           return rewriter.create<arm_sme::LoadTileSliceOp>(
213               loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
214               memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
215         });
216 
217     if (failed(forOp))
218       return forOp;
219 
220     // Replace 'arm_sme.tile_load' with the result.
221     rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
222 
223     return success();
224   }
225 };
226 
227 /// Lower `arm_sme.tile_load` with mask and non-zero pad.
228 ///
229 ///  BEFORE:
230 ///  ```mlir
231 ///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
232 ///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
233 ///    memref<?x?xi32>, vector<[4]x[4]xi32>
234 ///  ```
235 ///
236 ///  AFTER:
237 ///  ```mlir
238 ///  ...
239 ///  %pad_1d = vector.splat %pad : vector<[4]xi32>
240 ///  %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
241 ///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
242 ///    ...
243 ///    %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
244 ///    %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
245 ///      : memref<?x?xi32>, vector<[4]xi1>,
246 ///        vector<[4]xi32> into vector<[4]xi32>
247 ///    // Insert slice into tile
248 ///    %tile_update = arm_sme.insert_tile_slice
249 ///      %slice, %iter_tile[%tile_slice_idx] :
250 ///      vector<[4]xi32> into vector<[4]x[4]xi32>
251 ///    scf.yield %tile_update : vector<[4]x[4]xi32>
252 ///  }
253 ///  ```
254 struct TileLoadOpWithMaskAndPadNonZeroConversion
255     : public OpRewritePattern<arm_sme::TileLoadOp> {
256   using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
257 
258   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
259                                 PatternRewriter &rewriter) const override {
260     OpBuilder::InsertionGuard g(rewriter);
261     auto loc = tileLoadOp.getLoc();
262     auto tileType = tileLoadOp.getVectorType();
263     auto tileElementType = tileType.getElementType();
264 
265     auto maskOp = tileLoadOp.getMask();
266     if (!maskOp)
267       return rewriter.notifyMatchFailure(
268           tileLoadOp, "op has no mask, needs unmasked pattern");
269 
270     auto padOp = tileLoadOp.getPadding();
271     assert(padOp && "expected padding when masking!");
272 
273     auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
274     if (!createMaskOp)
275       return rewriter.notifyMatchFailure(
276           tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
277                       "currently supported");
278 
279     auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
280     if (constPadOp &&
281         constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
282       return rewriter.notifyMatchFailure(
283           tileLoadOp, "op has constant zero pad, needs zero pad pattern");
284 
285     auto numRows = createMaskOp.getOperands()[0];
286     auto numCols = createMaskOp.getOperands()[1];
287 
288     auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
289         loc, rewriter.getI32Type(), numCols);
290 
291     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
292 
293     // Create a loop that loads each ZA tile slice from memory.
294     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
295     auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
296         loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
297     auto vscale =
298         rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
299     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
300     auto numTileSlices =
301         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
302     auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
303                                              step, ValueRange{initTile});
304 
305     rewriter.setInsertionPointToStart(forOp.getBody());
306 
307     auto tileSliceIndex = forOp.getInductionVar();
308     auto currentTile = forOp.getRegionIterArg(0);
309 
310     // Combine masks.
311     auto rowIsActive = rewriter.create<arith::CmpIOp>(
312         loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
313     auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
314         loc, rewriter.getI32Type(), rowIsActive);
315     auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
316     auto maskIndex =
317         rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
318     auto predicateType =
319         VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
320     auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
321         loc, predicateType, maskIndex.getResult());
322 
323     auto memrefIndices = getMemrefIndices(
324         tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
325         tileSliceIndex, numTileSlices, loc, rewriter);
326 
327     // Splat pad into 1-D vector matching type of tile slice.
328     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
329     auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
330 
331     auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
332         loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
333         /*passthru=*/pad1DOp);
334 
335     // Create 'arm_sme.insert_tile_slice' to insert slice into tile.
336     auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
337         loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
338         tileLoadOp.getLayout());
339     rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
340 
341     rewriter.setInsertionPointAfter(forOp);
342 
343     // Replace 'arm_sme.tile_load' with the result.
344     rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
345 
346     return success();
347   }
348 };
349 
350 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
351 /// slice using `arm_sme.store_tile_slice`.
352 ///
353 ///  BEFORE:
354 ///  ```mlir
355 ///  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
356 ///    : memref<?x?xi32>, vector<[4]x[4]xi32
357 ///  ```
358 ///
359 ///  AFTER:
360 ///  ```mlir
361 ///  %vscale = vector.vscale
362 ///  %c0 = arith.constant 0 : index
363 ///  %c1 = arith.constant 1 : index
364 ///  %min_svl_s = arith.constant 4 : index
365 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
366 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
367 ///    arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
368 ///      layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
369 ///  }
370 ///  ```
371 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
372   using OpRewritePattern<arm_sme::TileStoreOp>::OpRewritePattern;
373 
374   LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
375                                 PatternRewriter &rewriter) const override {
376     // Create a loop that stores each active ZA tile slice from memory.
377     return createLoadStoreForOverTileSlices(
378         rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
379         tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
380         tileStoreOp.getMask(),
381         [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
382           rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
383               tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
384               predicate, tileStoreOp.getBase(), memrefIndices,
385               tileStoreOp.getLayout());
386         });
387   }
388 };
389 
390 } // namespace
391 
392 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
393   patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
394                TileStoreOpConversion>(patterns.getContext());
395 }
396 
397 namespace {
398 
399 struct ConvertArmSMEToSCFPass
400     : public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
401   void runOnOperation() override {
402     RewritePatternSet patterns(&getContext());
403     ConversionTarget target(getContext());
404     populateArmSMEToSCFConversionPatterns(patterns);
405     target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
406                            arith::ArithDialect, scf::SCFDialect>();
407     target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
408     if (failed(applyPartialConversion(getOperation(), target,
409                                       std::move(patterns))))
410       signalPassFailure();
411   }
412 };
413 
414 } // namespace
415 
416 std::unique_ptr<Pass> mlir::createConvertArmSMEToSCFPass() {
417   return std::make_unique<ConvertArmSMEToSCFPass>();
418 }
419