1 //===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrites that fuse 'arm_sme.outerproduct' operations 10 // into the 2-way or 4-way widening outerproduct operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 15 #include "mlir/Dialect/ArmSME/Transforms/Passes.h" 16 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 #define DEBUG_TYPE "arm-sme-outerproduct-fusion" 23 24 namespace mlir::arm_sme { 25 #define GEN_PASS_DEF_OUTERPRODUCTFUSION 26 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" 27 } // namespace mlir::arm_sme 28 29 using namespace mlir; 30 using namespace mlir::arm_sme; 31 32 namespace { 33 34 // Common match failure reasons. 35 static constexpr StringLiteral 36 kMatchFailureNoAccumulator("no accumulator operand"); 37 static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp( 38 "defining op of accumulator must be 'arm_sme.outerproduct'"); 39 static constexpr StringLiteral kMatchFailureInconsistentCombiningKind( 40 "combining kind (add or sub) of outer products must match"); 41 static constexpr StringLiteral kMatchFailureInconsistentMasking( 42 "unsupported masking, either both outerproducts are masked " 43 "or neither"); 44 static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse( 45 "outer product(s) not single use and cannot be removed, no benefit to " 46 "fusing"); 47 48 // An outer product is compatible if all of the following are true: 49 // - the result type matches `resultType`. 50 // - the defining operation of LHS is of the type `LhsExtOp`. 51 // - the defining operation of RHS is of the type `RhsExtOp`. 52 // - the input types of the defining operations are identical and match 53 // `inputType`. 54 template <typename LhsExtOp, typename RhsExtOp = LhsExtOp> 55 static LogicalResult isCompatible(PatternRewriter &rewriter, 56 arm_sme::OuterProductOp op, 57 VectorType resultType, VectorType inputType) { 58 if (op.getResultType() != resultType) 59 return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) { 60 diag << "unsupported result type, expected " << resultType; 61 }); 62 63 auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>(); 64 auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>(); 65 66 if (!lhsDefOp || !rhsDefOp) 67 return rewriter.notifyMatchFailure( 68 op, "defining op of outerproduct operands must be one of: " 69 "'arith.extf' or 'arith.extsi' or 'arith.extui'"); 70 71 auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType()); 72 auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType()); 73 74 if (lhsInType != inputType || rhsInType != inputType) 75 return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) { 76 diag << "unsupported input type, expected " << inputType; 77 }); 78 79 return success(); 80 } 81 82 // Fuse two 'arm_sme.outerproduct' operations that are chained via the 83 // accumulator into 2-way outer product operation. 84 // 85 // For example: 86 // 87 // %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> 88 // %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> 89 // %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, 90 // vector<[4]xf32> 91 // 92 // %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> 93 // %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> 94 // %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, 95 // vector<[4]xf32> 96 // 97 // Becomes: 98 // 99 // %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16> 100 // %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16> 101 // %0 = arm_sme.fmopa_2way %a_packed, %b_packed 102 // : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> 103 class OuterProductFusion2Way 104 : public OpRewritePattern<arm_sme::OuterProductOp> { 105 public: 106 using OpRewritePattern::OpRewritePattern; 107 108 LogicalResult matchAndRewrite(arm_sme::OuterProductOp op, 109 PatternRewriter &rewriter) const override { 110 Value acc = op.getAcc(); 111 if (!acc) 112 return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator); 113 114 arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>(); 115 arm_sme::OuterProductOp op2 = op; 116 if (!op1) 117 return rewriter.notifyMatchFailure( 118 op, kMatchFailureExpectedOuterProductDefOp); 119 120 if (op1.getKind() != op2.getKind()) 121 return rewriter.notifyMatchFailure( 122 op, kMatchFailureInconsistentCombiningKind); 123 124 if (!op1->hasOneUse()) { 125 // If the first outer product has uses other than as the input to another 126 // outer product, it can't be erased after fusion. 127 return rewriter.notifyMatchFailure(op, 128 kMatchFailureOuterProductNotSingleUse); 129 } 130 131 if (bool(op1.getLhsMask()) != bool(op2.getLhsMask())) 132 return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking); 133 134 if (failed(canFuseOuterProducts(rewriter, op1, op2))) 135 return failure(); 136 137 auto loc = op.getLoc(); 138 auto packInputs = [&](Value lhs, Value rhs) { 139 return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs); 140 }; 141 142 auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), 143 op2.getLhs().getDefiningOp()->getOperand(0)); 144 auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0), 145 op2.getRhs().getDefiningOp()->getOperand(0)); 146 147 Value lhsMask, rhsMask; 148 if (op1.getLhsMask() || op2.getLhsMask()) { 149 lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask()); 150 rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask()); 151 } 152 153 auto extOp = op.getLhs().getDefiningOp(); 154 155 arm_sme::CombiningKind kind = op.getKind(); 156 if (kind == arm_sme::CombiningKind::Add) { 157 TypeSwitch<Operation *>(extOp) 158 .Case<arith::ExtFOp>([&](auto) { 159 rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>( 160 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, 161 op1.getAcc()); 162 }) 163 .Case<arith::ExtSIOp>([&](auto) { 164 rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>( 165 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, 166 op1.getAcc()); 167 }) 168 .Case<arith::ExtUIOp>([&](auto) { 169 rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>( 170 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, 171 op1.getAcc()); 172 }) 173 .Default([&](auto) { llvm_unreachable("unexpected extend op!"); }); 174 } else if (kind == arm_sme::CombiningKind::Sub) { 175 TypeSwitch<Operation *>(extOp) 176 .Case<arith::ExtFOp>([&](auto) { 177 rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>( 178 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, 179 op1.getAcc()); 180 }) 181 .Case<arith::ExtSIOp>([&](auto) { 182 rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>( 183 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, 184 op1.getAcc()); 185 }) 186 .Case<arith::ExtUIOp>([&](auto) { 187 rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>( 188 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, 189 op1.getAcc()); 190 }) 191 .Default([&](auto) { llvm_unreachable("unexpected extend op!"); }); 192 } else { 193 llvm_unreachable("unexpected arm_sme::CombiningKind!"); 194 } 195 196 return success(); 197 } 198 199 private: 200 // A pair of outer product can be fused if all of the following are true: 201 // - input and result types match. 202 // - the defining operations of the inputs are identical extensions, 203 // specifically either: 204 // - a signed or unsigned extension for integer types. 205 // - a floating-point extension for floating-point types. 206 // - the types and extension are supported, i.e. there's a 2-way operation 207 // they can be fused into. 208 LogicalResult canFuseOuterProducts(PatternRewriter &rewriter, 209 arm_sme::OuterProductOp op1, 210 arm_sme::OuterProductOp op2) const { 211 // Supported result types. 212 auto nxnxv4i32 = 213 VectorType::get({4, 4}, rewriter.getI32Type(), {true, true}); 214 auto nxnxv4f32 = 215 VectorType::get({4, 4}, rewriter.getF32Type(), {true, true}); 216 // Supported input types. 217 // Note: this is before packing so these have half the number of elements 218 // of the input vector types of the 2-way operations. 219 auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true); 220 auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true); 221 auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true); 222 if ((failed( 223 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) || 224 failed( 225 isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) && 226 (failed( 227 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) || 228 failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, 229 nxv4bf16))) && 230 (failed( 231 isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) || 232 failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, 233 nxv4i16))) && 234 (failed( 235 isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) || 236 failed( 237 isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16)))) 238 return failure(); 239 240 return success(); 241 } 242 }; 243 244 // Fuse four 'arm_sme.outerproduct' operations that are chained via the 245 // accumulator into 4-way outer product operation. 246 class OuterProductFusion4Way 247 : public OpRewritePattern<arm_sme::OuterProductOp> { 248 public: 249 using OpRewritePattern::OpRewritePattern; 250 251 LogicalResult matchAndRewrite(arm_sme::OuterProductOp op, 252 PatternRewriter &rewriter) const override { 253 SmallVector<arm_sme::OuterProductOp, 4> outerProductChain; 254 outerProductChain.push_back(op); 255 256 for (int i = 0; i < 3; ++i) { 257 auto currentOp = outerProductChain.back(); 258 auto acc = currentOp.getAcc(); 259 if (!acc) 260 return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator); 261 auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>(); 262 if (!previousOp) 263 return rewriter.notifyMatchFailure( 264 op, kMatchFailureExpectedOuterProductDefOp); 265 if (!previousOp->hasOneUse()) 266 return rewriter.notifyMatchFailure( 267 op, kMatchFailureOuterProductNotSingleUse); 268 if (previousOp.getKind() != currentOp.getKind()) 269 return rewriter.notifyMatchFailure( 270 op, kMatchFailureInconsistentCombiningKind); 271 if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask())) 272 return rewriter.notifyMatchFailure( 273 op, kMatchFailureInconsistentCombiningKind); 274 outerProductChain.push_back(previousOp); 275 } 276 277 if (failed(canFuseOuterProducts(rewriter, outerProductChain))) 278 return failure(); 279 280 arm_sme::OuterProductOp op1 = outerProductChain[3]; 281 arm_sme::OuterProductOp op2 = outerProductChain[2]; 282 arm_sme::OuterProductOp op3 = outerProductChain[1]; 283 arm_sme::OuterProductOp op4 = outerProductChain[0]; 284 285 auto loc = op.getLoc(); 286 auto packInputs = [&](Value lhs, Value rhs) { 287 return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs); 288 }; 289 290 auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), 291 op3.getLhs().getDefiningOp()->getOperand(0)); 292 auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0), 293 op4.getLhs().getDefiningOp()->getOperand(0)); 294 auto lhs = packInputs(lhs0, lhs1); 295 296 auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0), 297 op3.getRhs().getDefiningOp()->getOperand(0)); 298 auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0), 299 op4.getRhs().getDefiningOp()->getOperand(0)); 300 auto rhs = packInputs(rhs0, rhs1); 301 302 Value lhsMask, rhsMask; 303 if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() || 304 op4.getLhsMask()) { 305 auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask()); 306 auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask()); 307 lhsMask = packInputs(lhs0Mask, lhs1Mask); 308 309 auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask()); 310 auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask()); 311 rhsMask = packInputs(rhs0Mask, rhs1Mask); 312 } 313 314 auto lhsExtOp = op.getLhs().getDefiningOp(); 315 auto rhsExtOp = op.getRhs().getDefiningOp(); 316 317 arm_sme::CombiningKind kind = op.getKind(); 318 if (kind == arm_sme::CombiningKind::Add) { 319 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) { 320 // signed 321 rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>( 322 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); 323 } else if (isa<arith::ExtUIOp>(lhsExtOp) && 324 isa<arith::ExtUIOp>(rhsExtOp)) { 325 // unsigned 326 rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>( 327 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); 328 } else if (isa<arith::ExtSIOp>(lhsExtOp) && 329 isa<arith::ExtUIOp>(rhsExtOp)) { 330 // signed by unsigned 331 rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>( 332 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); 333 } else if (isa<arith::ExtUIOp>(lhsExtOp) && 334 isa<arith::ExtSIOp>(rhsExtOp)) { 335 // unsigned by signed 336 rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>( 337 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); 338 } else { 339 llvm_unreachable("unexpected extend op!"); 340 } 341 } else if (kind == arm_sme::CombiningKind::Sub) { 342 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) { 343 // signed 344 rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>( 345 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); 346 } else if (isa<arith::ExtUIOp>(lhsExtOp) && 347 isa<arith::ExtUIOp>(rhsExtOp)) { 348 // unsigned 349 rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>( 350 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); 351 } else if (isa<arith::ExtSIOp>(lhsExtOp) && 352 isa<arith::ExtUIOp>(rhsExtOp)) { 353 // signed by unsigned 354 rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>( 355 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); 356 } else if (isa<arith::ExtUIOp>(lhsExtOp) && 357 isa<arith::ExtSIOp>(rhsExtOp)) { 358 // unsigned by signed 359 rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>( 360 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); 361 } else { 362 llvm_unreachable("unexpected extend op!"); 363 } 364 } else { 365 llvm_unreachable("unexpected arm_sme::CombiningKind!"); 366 } 367 368 return success(); 369 } 370 371 private: 372 // Four outer products can be fused if all of the following are true: 373 // - input and result types match. 374 // - the defining operations of the inputs are identical extensions, 375 // specifically either: 376 // - a signed or unsigned extension for integer types. 377 // - a floating-point extension for floating-point types. 378 // - the types and extension are supported, i.e. there's a 4-way operation 379 // they can be fused into. 380 LogicalResult 381 canFuseOuterProducts(PatternRewriter &rewriter, 382 ArrayRef<arm_sme::OuterProductOp> ops) const { 383 // Supported result types. 384 auto nxnxv4i32 = 385 VectorType::get({4, 4}, rewriter.getI32Type(), {true, true}); 386 auto nxnxv2i64 = 387 VectorType::get({2, 2}, rewriter.getI64Type(), {true, true}); 388 389 // Supported input types. 390 // Note: this is before packing so these have 1/4 the number of elements 391 // of the input vector types of the 4-way operations. 392 auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true); 393 auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true); 394 395 auto failedToMatch = [&](VectorType resultType, VectorType inputType, 396 auto lhsExtendOp, auto rhsExtendOp) { 397 using LhsExtendOpTy = decltype(lhsExtendOp); 398 using RhsExtendOpTy = decltype(rhsExtendOp); 399 for (auto op : ops) { 400 if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>( 401 rewriter, op, resultType, inputType))) 402 return true; 403 } 404 return false; 405 }; 406 407 if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) && 408 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) && 409 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) && 410 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) && 411 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) && 412 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) && 413 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) && 414 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{})) 415 return failure(); 416 417 return success(); 418 } 419 }; 420 421 // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract). 422 // 423 // This transforms IR like: 424 // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> 425 // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> 426 // Into: 427 // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8> 428 // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32> 429 // 430 // This enables outer product fusion in the `-arm-sme-outer-product-fusion` 431 // pass when the result is the input to an outer product. 432 struct SwapVectorExtractOfArithExtend 433 : public OpRewritePattern<vector::ExtractOp> { 434 using OpRewritePattern::OpRewritePattern; 435 436 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 437 PatternRewriter &rewriter) const override { 438 VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType()); 439 if (!resultType) 440 return rewriter.notifyMatchFailure(extractOp, 441 "extracted type is not a vector type"); 442 443 auto numScalableDims = resultType.getNumScalableDims(); 444 if (numScalableDims != 1) 445 return rewriter.notifyMatchFailure( 446 extractOp, "extracted type is not a 1-D scalable vector type"); 447 448 auto *extendOp = extractOp.getVector().getDefiningOp(); 449 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>( 450 extendOp)) 451 return rewriter.notifyMatchFailure(extractOp, 452 "extract not from extend op"); 453 454 auto loc = extractOp.getLoc(); 455 StringAttr extendOpName = extendOp->getName().getIdentifier(); 456 Value extendSource = extendOp->getOperand(0); 457 458 // Create new extract from source of extend. 459 Value newExtract = rewriter.create<vector::ExtractOp>( 460 loc, extendSource, extractOp.getMixedPosition()); 461 462 // Extend new extract to original result type. 463 Operation *newExtend = 464 rewriter.create(loc, extendOpName, Value(newExtract), resultType); 465 466 rewriter.replaceOp(extractOp, newExtend); 467 468 return success(); 469 } 470 }; 471 472 // Same as above, but for vector.scalable.extract. 473 // 474 // This transforms IR like: 475 // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32> 476 // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32> 477 // Into: 478 // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8> 479 // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32> 480 // 481 // This enables outer product fusion in the `-arm-sme-outer-product-fusion` 482 // pass when the result is the input to an outer product. 483 struct SwapVectorScalableExtractOfArithExtend 484 : public OpRewritePattern<vector::ScalableExtractOp> { 485 using OpRewritePattern::OpRewritePattern; 486 487 LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp, 488 PatternRewriter &rewriter) const override { 489 auto *extendOp = extractOp.getSource().getDefiningOp(); 490 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>( 491 extendOp)) 492 return rewriter.notifyMatchFailure(extractOp, 493 "extract not from extend op"); 494 495 auto loc = extractOp.getLoc(); 496 VectorType resultType = extractOp.getResultVectorType(); 497 498 Value extendSource = extendOp->getOperand(0); 499 StringAttr extendOpName = extendOp->getName().getIdentifier(); 500 VectorType extendSourceVectorType = 501 cast<VectorType>(extendSource.getType()); 502 503 // Create new extract from source of extend. 504 VectorType extractResultVectorType = 505 resultType.clone(extendSourceVectorType.getElementType()); 506 Value newExtract = rewriter.create<vector::ScalableExtractOp>( 507 loc, extractResultVectorType, extendSource, extractOp.getPos()); 508 509 // Extend new extract to original result type. 510 Operation *newExtend = 511 rewriter.create(loc, extendOpName, Value(newExtract), resultType); 512 513 rewriter.replaceOp(extractOp, newExtend); 514 515 return success(); 516 } 517 }; 518 519 struct OuterProductFusionPass 520 : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> { 521 522 void runOnOperation() override { 523 RewritePatternSet patterns(&getContext()); 524 populateOuterProductFusionPatterns(patterns); 525 526 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 527 signalPassFailure(); 528 } 529 }; 530 531 } // namespace 532 533 void mlir::arm_sme::populateOuterProductFusionPatterns( 534 RewritePatternSet &patterns) { 535 MLIRContext *context = patterns.getContext(); 536 // Note: High benefit to ensure extract(extend) are swapped first. 537 patterns.add<SwapVectorExtractOfArithExtend, 538 SwapVectorScalableExtractOfArithExtend>(context, 1024); 539 patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(context); 540 } 541 542 std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() { 543 return std::make_unique<OuterProductFusionPass>(); 544 } 545