xref: /llvm-project/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites that fuse 'arm_sme.outerproduct' operations
10 // into the 2-way or 4-way widening outerproduct operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
15 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
16 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 #define DEBUG_TYPE "arm-sme-outerproduct-fusion"
23 
24 namespace mlir::arm_sme {
25 #define GEN_PASS_DEF_OUTERPRODUCTFUSION
26 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
27 } // namespace mlir::arm_sme
28 
29 using namespace mlir;
30 using namespace mlir::arm_sme;
31 
32 namespace {
33 
34 // Common match failure reasons.
35 static constexpr StringLiteral
36     kMatchFailureNoAccumulator("no accumulator operand");
37 static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
38     "defining op of accumulator must be 'arm_sme.outerproduct'");
39 static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
40     "combining kind (add or sub) of outer products must match");
41 static constexpr StringLiteral kMatchFailureInconsistentMasking(
42     "unsupported masking, either both outerproducts are masked "
43     "or neither");
44 static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
45     "outer product(s) not single use and cannot be removed, no benefit to "
46     "fusing");
47 
48 // An outer product is compatible if all of the following are true:
49 // - the result type matches `resultType`.
50 // - the defining operation of LHS is of the type `LhsExtOp`.
51 // - the defining operation of RHS is of the type `RhsExtOp`.
52 // - the input types of the defining operations are identical and match
53 //   `inputType`.
54 template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
55 static LogicalResult isCompatible(PatternRewriter &rewriter,
56                                   arm_sme::OuterProductOp op,
57                                   VectorType resultType, VectorType inputType) {
58   if (op.getResultType() != resultType)
59     return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
60       diag << "unsupported result type, expected " << resultType;
61     });
62 
63   auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
64   auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
65 
66   if (!lhsDefOp || !rhsDefOp)
67     return rewriter.notifyMatchFailure(
68         op, "defining op of outerproduct operands must be one of: "
69             "'arith.extf' or 'arith.extsi' or 'arith.extui'");
70 
71   auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
72   auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
73 
74   if (lhsInType != inputType || rhsInType != inputType)
75     return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
76       diag << "unsupported input type, expected " << inputType;
77     });
78 
79   return success();
80 }
81 
82 // Fuse two 'arm_sme.outerproduct' operations that are chained via the
83 // accumulator into 2-way outer product operation.
84 //
85 // For example:
86 //
87 //  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
88 //  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
89 //  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
90 //                                               vector<[4]xf32>
91 //
92 //  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
93 //  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
94 //  %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
95 //                                                   vector<[4]xf32>
96 //
97 // Becomes:
98 //
99 //  %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
100 //  %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
101 //  %0 = arm_sme.fmopa_2way %a_packed, %b_packed
102 //    : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
103 class OuterProductFusion2Way
104     : public OpRewritePattern<arm_sme::OuterProductOp> {
105 public:
106   using OpRewritePattern::OpRewritePattern;
107 
108   LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
109                                 PatternRewriter &rewriter) const override {
110     Value acc = op.getAcc();
111     if (!acc)
112       return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
113 
114     arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
115     arm_sme::OuterProductOp op2 = op;
116     if (!op1)
117       return rewriter.notifyMatchFailure(
118           op, kMatchFailureExpectedOuterProductDefOp);
119 
120     if (op1.getKind() != op2.getKind())
121       return rewriter.notifyMatchFailure(
122           op, kMatchFailureInconsistentCombiningKind);
123 
124     if (!op1->hasOneUse()) {
125       // If the first outer product has uses other than as the input to another
126       // outer product, it can't be erased after fusion.
127       return rewriter.notifyMatchFailure(op,
128                                          kMatchFailureOuterProductNotSingleUse);
129     }
130 
131     if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
132       return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking);
133 
134     if (failed(canFuseOuterProducts(rewriter, op1, op2)))
135       return failure();
136 
137     auto loc = op.getLoc();
138     auto packInputs = [&](Value lhs, Value rhs) {
139       return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
140     };
141 
142     auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
143                           op2.getLhs().getDefiningOp()->getOperand(0));
144     auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
145                           op2.getRhs().getDefiningOp()->getOperand(0));
146 
147     Value lhsMask, rhsMask;
148     if (op1.getLhsMask() || op2.getLhsMask()) {
149       lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask());
150       rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
151     }
152 
153     auto extOp = op.getLhs().getDefiningOp();
154 
155     arm_sme::CombiningKind kind = op.getKind();
156     if (kind == arm_sme::CombiningKind::Add) {
157       TypeSwitch<Operation *>(extOp)
158           .Case<arith::ExtFOp>([&](auto) {
159             rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>(
160                 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
161                 op1.getAcc());
162           })
163           .Case<arith::ExtSIOp>([&](auto) {
164             rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>(
165                 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
166                 op1.getAcc());
167           })
168           .Case<arith::ExtUIOp>([&](auto) {
169             rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>(
170                 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
171                 op1.getAcc());
172           })
173           .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
174     } else if (kind == arm_sme::CombiningKind::Sub) {
175       TypeSwitch<Operation *>(extOp)
176           .Case<arith::ExtFOp>([&](auto) {
177             rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>(
178                 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
179                 op1.getAcc());
180           })
181           .Case<arith::ExtSIOp>([&](auto) {
182             rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>(
183                 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
184                 op1.getAcc());
185           })
186           .Case<arith::ExtUIOp>([&](auto) {
187             rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>(
188                 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
189                 op1.getAcc());
190           })
191           .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
192     } else {
193       llvm_unreachable("unexpected arm_sme::CombiningKind!");
194     }
195 
196     return success();
197   }
198 
199 private:
200   // A pair of outer product can be fused if all of the following are true:
201   // - input and result types match.
202   // - the defining operations of the inputs are identical extensions,
203   //   specifically either:
204   //     - a signed or unsigned extension for integer types.
205   //     - a floating-point extension for floating-point types.
206   // - the types and extension are supported, i.e. there's a 2-way operation
207   //   they can be fused into.
208   LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
209                                      arm_sme::OuterProductOp op1,
210                                      arm_sme::OuterProductOp op2) const {
211     // Supported result types.
212     auto nxnxv4i32 =
213         VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
214     auto nxnxv4f32 =
215         VectorType::get({4, 4}, rewriter.getF32Type(), {true, true});
216     // Supported input types.
217     // Note: this is before packing so these have half the number of elements
218     // of the input vector types of the 2-way operations.
219     auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true);
220     auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true);
221     auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true);
222     if ((failed(
223              isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
224          failed(
225              isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
226         (failed(
227              isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
228          failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32,
229                                             nxv4bf16))) &&
230         (failed(
231              isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
232          failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32,
233                                              nxv4i16))) &&
234         (failed(
235              isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
236          failed(
237              isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
238       return failure();
239 
240     return success();
241   }
242 };
243 
244 // Fuse four 'arm_sme.outerproduct' operations that are chained via the
245 // accumulator into 4-way outer product operation.
246 class OuterProductFusion4Way
247     : public OpRewritePattern<arm_sme::OuterProductOp> {
248 public:
249   using OpRewritePattern::OpRewritePattern;
250 
251   LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
252                                 PatternRewriter &rewriter) const override {
253     SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
254     outerProductChain.push_back(op);
255 
256     for (int i = 0; i < 3; ++i) {
257       auto currentOp = outerProductChain.back();
258       auto acc = currentOp.getAcc();
259       if (!acc)
260         return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
261       auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
262       if (!previousOp)
263         return rewriter.notifyMatchFailure(
264             op, kMatchFailureExpectedOuterProductDefOp);
265       if (!previousOp->hasOneUse())
266         return rewriter.notifyMatchFailure(
267             op, kMatchFailureOuterProductNotSingleUse);
268       if (previousOp.getKind() != currentOp.getKind())
269         return rewriter.notifyMatchFailure(
270             op, kMatchFailureInconsistentCombiningKind);
271       if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
272         return rewriter.notifyMatchFailure(
273             op, kMatchFailureInconsistentCombiningKind);
274       outerProductChain.push_back(previousOp);
275     }
276 
277     if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
278       return failure();
279 
280     arm_sme::OuterProductOp op1 = outerProductChain[3];
281     arm_sme::OuterProductOp op2 = outerProductChain[2];
282     arm_sme::OuterProductOp op3 = outerProductChain[1];
283     arm_sme::OuterProductOp op4 = outerProductChain[0];
284 
285     auto loc = op.getLoc();
286     auto packInputs = [&](Value lhs, Value rhs) {
287       return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
288     };
289 
290     auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
291                            op3.getLhs().getDefiningOp()->getOperand(0));
292     auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
293                            op4.getLhs().getDefiningOp()->getOperand(0));
294     auto lhs = packInputs(lhs0, lhs1);
295 
296     auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
297                            op3.getRhs().getDefiningOp()->getOperand(0));
298     auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
299                            op4.getRhs().getDefiningOp()->getOperand(0));
300     auto rhs = packInputs(rhs0, rhs1);
301 
302     Value lhsMask, rhsMask;
303     if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
304         op4.getLhsMask()) {
305       auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
306       auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
307       lhsMask = packInputs(lhs0Mask, lhs1Mask);
308 
309       auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
310       auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
311       rhsMask = packInputs(rhs0Mask, rhs1Mask);
312     }
313 
314     auto lhsExtOp = op.getLhs().getDefiningOp();
315     auto rhsExtOp = op.getRhs().getDefiningOp();
316 
317     arm_sme::CombiningKind kind = op.getKind();
318     if (kind == arm_sme::CombiningKind::Add) {
319       if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
320         // signed
321         rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
322             op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
323       } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
324                  isa<arith::ExtUIOp>(rhsExtOp)) {
325         // unsigned
326         rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
327             op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
328       } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
329                  isa<arith::ExtUIOp>(rhsExtOp)) {
330         // signed by unsigned
331         rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
332             op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
333       } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
334                  isa<arith::ExtSIOp>(rhsExtOp)) {
335         // unsigned by signed
336         rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
337             op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
338       } else {
339         llvm_unreachable("unexpected extend op!");
340       }
341     } else if (kind == arm_sme::CombiningKind::Sub) {
342       if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
343         // signed
344         rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
345             op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
346       } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
347                  isa<arith::ExtUIOp>(rhsExtOp)) {
348         // unsigned
349         rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
350             op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
351       } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
352                  isa<arith::ExtUIOp>(rhsExtOp)) {
353         // signed by unsigned
354         rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
355             op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
356       } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
357                  isa<arith::ExtSIOp>(rhsExtOp)) {
358         // unsigned by signed
359         rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
360             op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
361       } else {
362         llvm_unreachable("unexpected extend op!");
363       }
364     } else {
365       llvm_unreachable("unexpected arm_sme::CombiningKind!");
366     }
367 
368     return success();
369   }
370 
371 private:
372   // Four outer products can be fused if all of the following are true:
373   // - input and result types match.
374   // - the defining operations of the inputs are identical extensions,
375   //   specifically either:
376   //     - a signed or unsigned extension for integer types.
377   //     - a floating-point extension for floating-point types.
378   // - the types and extension are supported, i.e. there's a 4-way operation
379   //   they can be fused into.
380   LogicalResult
381   canFuseOuterProducts(PatternRewriter &rewriter,
382                        ArrayRef<arm_sme::OuterProductOp> ops) const {
383     // Supported result types.
384     auto nxnxv4i32 =
385         VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
386     auto nxnxv2i64 =
387         VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
388 
389     // Supported input types.
390     // Note: this is before packing so these have 1/4 the number of elements
391     // of the input vector types of the 4-way operations.
392     auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
393     auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
394 
395     auto failedToMatch = [&](VectorType resultType, VectorType inputType,
396                              auto lhsExtendOp, auto rhsExtendOp) {
397       using LhsExtendOpTy = decltype(lhsExtendOp);
398       using RhsExtendOpTy = decltype(rhsExtendOp);
399       for (auto op : ops) {
400         if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
401                 rewriter, op, resultType, inputType)))
402           return true;
403       }
404       return false;
405     };
406 
407     if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
408         failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
409         failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
410         failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
411         failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
412         failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
413         failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
414         failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
415       return failure();
416 
417     return success();
418   }
419 };
420 
421 // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
422 //
423 // This transforms IR like:
424 //   %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
425 //   %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
426 // Into:
427 //   %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
428 //   %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
429 //
430 // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
431 // pass when the result is the input to an outer product.
432 struct SwapVectorExtractOfArithExtend
433     : public OpRewritePattern<vector::ExtractOp> {
434   using OpRewritePattern::OpRewritePattern;
435 
436   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
437                                 PatternRewriter &rewriter) const override {
438     VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
439     if (!resultType)
440       return rewriter.notifyMatchFailure(extractOp,
441                                          "extracted type is not a vector type");
442 
443     auto numScalableDims = resultType.getNumScalableDims();
444     if (numScalableDims != 1)
445       return rewriter.notifyMatchFailure(
446           extractOp, "extracted type is not a 1-D scalable vector type");
447 
448     auto *extendOp = extractOp.getVector().getDefiningOp();
449     if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
450             extendOp))
451       return rewriter.notifyMatchFailure(extractOp,
452                                          "extract not from extend op");
453 
454     auto loc = extractOp.getLoc();
455     StringAttr extendOpName = extendOp->getName().getIdentifier();
456     Value extendSource = extendOp->getOperand(0);
457 
458     // Create new extract from source of extend.
459     Value newExtract = rewriter.create<vector::ExtractOp>(
460         loc, extendSource, extractOp.getMixedPosition());
461 
462     // Extend new extract to original result type.
463     Operation *newExtend =
464         rewriter.create(loc, extendOpName, Value(newExtract), resultType);
465 
466     rewriter.replaceOp(extractOp, newExtend);
467 
468     return success();
469   }
470 };
471 
472 // Same as above, but for vector.scalable.extract.
473 //
474 // This transforms IR like:
475 //   %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
476 //   %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
477 // Into:
478 //   %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
479 //   %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
480 //
481 // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
482 // pass when the result is the input to an outer product.
483 struct SwapVectorScalableExtractOfArithExtend
484     : public OpRewritePattern<vector::ScalableExtractOp> {
485   using OpRewritePattern::OpRewritePattern;
486 
487   LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
488                                 PatternRewriter &rewriter) const override {
489     auto *extendOp = extractOp.getSource().getDefiningOp();
490     if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
491             extendOp))
492       return rewriter.notifyMatchFailure(extractOp,
493                                          "extract not from extend op");
494 
495     auto loc = extractOp.getLoc();
496     VectorType resultType = extractOp.getResultVectorType();
497 
498     Value extendSource = extendOp->getOperand(0);
499     StringAttr extendOpName = extendOp->getName().getIdentifier();
500     VectorType extendSourceVectorType =
501         cast<VectorType>(extendSource.getType());
502 
503     // Create new extract from source of extend.
504     VectorType extractResultVectorType =
505         resultType.clone(extendSourceVectorType.getElementType());
506     Value newExtract = rewriter.create<vector::ScalableExtractOp>(
507         loc, extractResultVectorType, extendSource, extractOp.getPos());
508 
509     // Extend new extract to original result type.
510     Operation *newExtend =
511         rewriter.create(loc, extendOpName, Value(newExtract), resultType);
512 
513     rewriter.replaceOp(extractOp, newExtend);
514 
515     return success();
516   }
517 };
518 
519 struct OuterProductFusionPass
520     : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
521 
522   void runOnOperation() override {
523     RewritePatternSet patterns(&getContext());
524     populateOuterProductFusionPatterns(patterns);
525 
526     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
527       signalPassFailure();
528   }
529 };
530 
531 } // namespace
532 
533 void mlir::arm_sme::populateOuterProductFusionPatterns(
534     RewritePatternSet &patterns) {
535   MLIRContext *context = patterns.getContext();
536   // Note: High benefit to ensure extract(extend) are swapped first.
537   patterns.add<SwapVectorExtractOfArithExtend,
538                SwapVectorScalableExtractOfArithExtend>(context, 1024);
539   patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(context);
540 }
541 
542 std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {
543   return std::make_unique<OuterProductFusionPass>();
544 }
545