1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.contract' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinAttributeInterfaces.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/ImplicitLocOpBuilder.h" 29 #include "mlir/IR/Location.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Interfaces/VectorInterfaces.h" 34 #include "mlir/Support/LogicalResult.h" 35 36 #define DEBUG_TYPE "vector-contract-lowering" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 //===----------------------------------------------------------------------===// 42 // Helper functions 43 //===----------------------------------------------------------------------===// 44 45 // Helper to find an index in an affine map. 46 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 47 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 48 int64_t idx = map.getDimPosition(i); 49 if (idx == index) 50 return i; 51 } 52 return std::nullopt; 53 } 54 55 // Helper to construct iterator types with one index removed. 56 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes, 57 int64_t index) { 58 SmallVector<Attribute> results; 59 for (const auto &it : llvm::enumerate(iteratorTypes)) { 60 int64_t idx = it.index(); 61 if (idx == index) 62 continue; 63 results.push_back(it.value()); 64 } 65 return results; 66 } 67 68 // Helper to construct an affine map with one index removed. 69 static AffineMap adjustMap(AffineMap map, int64_t index, 70 PatternRewriter &rewriter) { 71 auto *ctx = rewriter.getContext(); 72 SmallVector<AffineExpr> results; 73 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 74 int64_t idx = map.getDimPosition(i); 75 if (idx == index) 76 continue; 77 // Re-insert remaining indices, but renamed when occurring 78 // after the removed index. 79 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 80 results.push_back(targetExpr); 81 } 82 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 83 } 84 85 // Helper method to possibly drop a dimension in a load. 86 // TODO 87 static Value reshapeLoad(Location loc, Value val, VectorType type, 88 int64_t index, int64_t pos, 89 PatternRewriter &rewriter) { 90 if (index == -1) 91 return val; 92 93 // At extraction dimension? 94 if (index == 0) 95 return rewriter.create<vector::ExtractOp>(loc, val, pos); 96 97 // Unroll leading dimensions. 98 VectorType vType = VectorType::Builder(type).dropDim(0); 99 VectorType resType = VectorType::Builder(type).dropDim(index); 100 Value result = rewriter.create<arith::ConstantOp>( 101 loc, resType, rewriter.getZeroAttr(resType)); 102 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { 103 Value ext = rewriter.create<vector::ExtractOp>(loc, val, d); 104 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 105 result = rewriter.create<vector::InsertOp>(loc, load, result, d); 106 } 107 return result; 108 } 109 110 // Helper method to possibly drop a dimension in a store. 111 // TODO 112 static Value reshapeStore(Location loc, Value val, Value result, 113 VectorType type, int64_t index, int64_t pos, 114 PatternRewriter &rewriter) { 115 // Unmodified? 116 if (index == -1) 117 return val; 118 // At insertion dimension? 119 if (index == 0) 120 return rewriter.create<vector::InsertOp>(loc, val, result, pos); 121 122 // Unroll leading dimensions. 123 VectorType vType = VectorType::Builder(type).dropDim(0); 124 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 125 Value ext = rewriter.create<vector::ExtractOp>(loc, result, d); 126 Value ins = rewriter.create<vector::ExtractOp>(loc, val, d); 127 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 128 result = rewriter.create<vector::InsertOp>(loc, sto, result, d); 129 } 130 return result; 131 } 132 133 /// Helper to create arithmetic operation associated with a kind of contraction. 134 static std::optional<Value> 135 createContractArithOp(Location loc, Value x, Value y, Value acc, 136 vector::CombiningKind kind, PatternRewriter &rewriter, 137 bool isInt, Value mask = Value()) { 138 using vector::CombiningKind; 139 Value mul; 140 141 if (isInt) { 142 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF || 143 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF) 144 // Only valid for floating point types. 145 return std::nullopt; 146 mul = rewriter.create<arith::MulIOp>(loc, x, y); 147 } else { 148 // Float case. 149 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || 150 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || 151 kind == CombiningKind::MAXSI || kind == CombiningKind::OR || 152 kind == CombiningKind::XOR) 153 // Only valid for integer types. 154 return std::nullopt; 155 // Special case for fused multiply-add. 156 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) { 157 Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc); 158 if (mask) 159 // The fma op doesn't need explicit masking. However, fma ops used in 160 // reductions must preserve previous 'acc' values for masked-out lanes. 161 fma = selectPassthru(rewriter, mask, fma, acc); 162 return fma; 163 } 164 mul = rewriter.create<arith::MulFOp>(loc, x, y); 165 } 166 167 if (!acc) 168 return std::optional<Value>(mul); 169 170 return makeArithReduction(rewriter, loc, kind, mul, acc, mask); 171 } 172 173 /// Return the positions of the reductions in the given map. 174 static SmallVector<int64_t> getReductionIndex(AffineMap map, 175 ArrayAttr iteratorTypes) { 176 SmallVector<int64_t> dimsIdx; 177 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 178 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) 179 dimsIdx.push_back(i); 180 } 181 return dimsIdx; 182 } 183 184 /// Look for a given dimension in an affine map and return its position. Return 185 /// std::nullopt if the dimension is not in the map results. 186 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) { 187 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 188 if (map.getDimPosition(i) == dim) 189 return i; 190 } 191 return std::nullopt; 192 } 193 194 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 195 /// operands `x` and `y`. 196 static Value createAdd(Location loc, Value x, Value y, bool isInt, 197 PatternRewriter &rewriter) { 198 if (isInt) 199 return rewriter.create<arith::AddIOp>(loc, x, y); 200 return rewriter.create<arith::AddFOp>(loc, x, y); 201 } 202 203 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 204 /// operands `x and `y`. 205 static Value createMul(Location loc, Value x, Value y, bool isInt, 206 PatternRewriter &rewriter) { 207 if (isInt) 208 return rewriter.create<arith::MulIOp>(loc, x, y); 209 return rewriter.create<arith::MulFOp>(loc, x, y); 210 } 211 212 namespace { 213 214 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 215 /// semantics to: 216 /// ``` 217 /// %flattened_a = vector.shape_cast %a 218 /// %flattened_b = vector.shape_cast %b 219 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 220 /// %d = vector.shape_cast %%flattened_d 221 /// %e = add %c, %d 222 /// ``` 223 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 224 // 225 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 226 /// the vector.contract op is a row-major matrix multiply. 227 class ContractionOpToMatmulOpLowering 228 : public OpRewritePattern<vector::ContractionOp> { 229 public: 230 using OpRewritePattern::OpRewritePattern; 231 232 using FilterConstraintType = 233 std::function<LogicalResult(vector::ContractionOp op)>; 234 235 static LogicalResult defaultFilter(vector::ContractionOp op) { 236 return success(); 237 } 238 239 ContractionOpToMatmulOpLowering( 240 vector::VectorTransformsOptions vectorTransformOptions, 241 MLIRContext *context, PatternBenefit benefit = 1, 242 FilterConstraintType constraint = defaultFilter) 243 : OpRewritePattern<vector::ContractionOp>(context, benefit), 244 vectorTransformOptions(vectorTransformOptions), 245 filter(std::move(constraint)) {} 246 247 LogicalResult matchAndRewrite(vector::ContractionOp op, 248 PatternRewriter &rewriter) const override; 249 250 private: 251 /// Options to control the vector patterns. 252 vector::VectorTransformsOptions vectorTransformOptions; 253 FilterConstraintType filter; 254 }; 255 256 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 257 /// semantics to a reduction_size-unrolled sequence: 258 /// ``` 259 /// %at = vector.transpose %a, [1, 0] 260 /// %bRow0 = vector.extract %b[0] 261 /// %atRow0 = vector.extract %at[0] 262 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 263 /// ... 264 /// %bRowK = vector.extract %b[K] 265 /// %atRowK = vector.extract %at[K] 266 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 267 /// ``` 268 /// 269 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 270 /// the vector.contract op is a row-major matrix multiply. 271 class ContractionOpToOuterProductOpLowering 272 : public OpRewritePattern<vector::ContractionOp> { 273 public: 274 using OpRewritePattern::OpRewritePattern; 275 276 using FilterConstraintType = 277 std::function<LogicalResult(vector::ContractionOp op)>; 278 279 static LogicalResult defaultFilter(vector::ContractionOp op) { 280 return success(); 281 } 282 283 ContractionOpToOuterProductOpLowering( 284 vector::VectorTransformsOptions vectorTransformOptions, 285 MLIRContext *context, PatternBenefit benefit = 1, 286 FilterConstraintType constraint = defaultFilter) 287 : OpRewritePattern<vector::ContractionOp>(context, benefit), 288 vectorTransformOptions(vectorTransformOptions), 289 filter(std::move(constraint)) {} 290 291 LogicalResult matchAndRewrite(vector::ContractionOp op, 292 PatternRewriter &rewriter) const override; 293 294 private: 295 /// Options to control the vector patterns. 296 vector::VectorTransformsOptions vectorTransformOptions; 297 FilterConstraintType filter; 298 }; 299 300 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 301 /// semantics to an output-size-unrolled sequence: 302 /// ``` 303 /// %out = arith.constant ... : vector<MxNxelt_type> 304 /// %bt = vector.transpose %b, [1, 0] 305 /// %aRow0 = vector.extract %a[0] 306 /// %btRow0 = vector.extract %bt[0] 307 /// %c00 = vector.reduce %atRow0, %bRow0 308 /// %out00 = vector.insert %c00, %out[0, 0] 309 /// ... 310 /// %aRowLast = vector.extract %at[M-1] 311 /// %btRowLast = vector.extract %b[N-1] 312 /// %cLastLast = vector.reduce %atRowLast, %bRowLast 313 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] 314 /// ``` 315 /// 316 /// This only kicks in when VectorTransformsOptions is set to Dot and 317 /// the vector.contract op is a row-major matmul or matvec. 318 class ContractionOpToDotLowering 319 : public OpRewritePattern<vector::ContractionOp> { 320 public: 321 using OpRewritePattern::OpRewritePattern; 322 323 using FilterConstraintType = 324 std::function<LogicalResult(vector::ContractionOp op)>; 325 326 static LogicalResult defaultFilter(vector::ContractionOp op) { 327 return success(); 328 } 329 330 ContractionOpToDotLowering( 331 vector::VectorTransformsOptions vectorTransformOptions, 332 MLIRContext *context, PatternBenefit benefit = 1, 333 const FilterConstraintType &constraint = defaultFilter) 334 : OpRewritePattern<vector::ContractionOp>(context, benefit), 335 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 336 337 LogicalResult matchAndRewrite(vector::ContractionOp op, 338 PatternRewriter &rewriter) const override; 339 340 private: 341 /// Options to control the vector patterns. 342 vector::VectorTransformsOptions vectorTransformOptions; 343 FilterConstraintType filter; 344 }; 345 346 /// Progressive lowering of ContractionOp. 347 /// 348 /// One: 349 /// %x = vector.contract with at least one free/batch dimension 350 /// is replaced by: 351 /// %a = vector.contract with one less free/batch dimension 352 /// %b = vector.contract with one less free/batch dimension 353 /// .. 354 /// %x = combine %a %b .. 355 /// until a pure contraction is reached (no free/batch dimensions), 356 /// which is replaced by a dot-product. 357 /// 358 /// This only kicks in when either VectorTransformsOptions is set 359 /// to Dot or when other contraction patterns fail. 360 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> { 361 public: 362 using OpRewritePattern::OpRewritePattern; 363 using FilterConstraintType = 364 std::function<LogicalResult(vector::ContractionOp op)>; 365 366 static LogicalResult defaultFilter(vector::ContractionOp op) { 367 return success(); 368 } 369 370 ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 371 MLIRContext *context, PatternBenefit benefit = 1, 372 FilterConstraintType constraint = defaultFilter) 373 : OpRewritePattern<vector::ContractionOp>(context, benefit), 374 vectorTransformOptions(vectorTransformOptions), 375 filter(std::move(constraint)) {} 376 377 LogicalResult matchAndRewrite(vector::ContractionOp op, 378 PatternRewriter &rewriter) const override; 379 380 private: 381 /// Options to control the vector patterns. 382 vector::VectorTransformsOptions vectorTransformOptions; 383 FilterConstraintType filter; 384 // Lower one parallel dimension. 385 FailureOr<Value> lowerParallel(PatternRewriter &rewriter, 386 vector::ContractionOp op, int64_t lhsIndex, 387 int64_t rhsIndex, Value mask) const; 388 // Lower one reduction dimension. 389 FailureOr<Value> lowerReduction(PatternRewriter &rewriter, 390 vector::ContractionOp op, Value mask) const; 391 }; 392 393 /// Generate a vector implementation for matmat, matvec and tmatvec. 394 /// This unrolls outer-products along the reduction dimension. 395 struct UnrolledOuterProductGenerator 396 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> { 397 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) 398 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op), 399 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 400 res(op.getAcc()), lhsType(op.getLhsType()) { 401 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 402 if (maskableOp.isMasked()) 403 mask = maskableOp.getMaskingOp().getMask(); 404 } 405 406 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) { 407 if (!v) 408 return v; 409 return rewriter.create<vector::TransposeOp>(loc, v, perm); 410 } 411 412 Value promote(Value v, Type dstElementType) { 413 Type elementType = v.getType(); 414 auto vecType = dyn_cast<VectorType>(elementType); 415 if (vecType) 416 elementType = vecType.getElementType(); 417 if (elementType == dstElementType) 418 return v; 419 Type promotedType = dstElementType; 420 if (vecType) 421 promotedType = vecType.clone(promotedType); 422 if (isa<FloatType>(dstElementType)) 423 return rewriter.create<arith::ExtFOp>(loc, promotedType, v); 424 return rewriter.create<arith::ExtSIOp>(loc, promotedType, v); 425 } 426 427 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, 428 VectorType lhsType, int reductionDim, 429 std::optional<Value> maybeMask = std::nullopt) { 430 // Unrolling a scalable dimension would be incorrect - bail out. 431 if (lhsType.getScalableDims()[reductionDim]) 432 return failure(); 433 434 int reductionSize = lhsType.getDimSize(reductionDim); 435 assert(reductionSize > 0 && 436 "Reduction dim must be a known static size to allow unrolling"); 437 438 // Incremental support for masking. 439 if (mask && !maybeMask.has_value()) 440 return failure(); 441 442 Type resElementType = cast<VectorType>(res.getType()).getElementType(); 443 for (int64_t k = 0; k < reductionSize; ++k) { 444 Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k); 445 Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k); 446 extractA = promote(extractA, resElementType); 447 extractB = promote(extractB, resElementType); 448 Value extractMask; 449 if (maybeMask.has_value() && maybeMask.value()) 450 extractMask = 451 rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k); 452 453 Operation *outerProdOp = rewriter.create<vector::OuterProductOp>( 454 loc, res.getType(), extractA, extractB, res, kind); 455 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); 456 } 457 return res; 458 } 459 460 /// Two outer parallel, one inner reduction (matmat flavor). 461 FailureOr<Value> matmat() { 462 if (!iters({Par(), Par(), Red()})) 463 return failure(); 464 // Set up the parallel/reduction structure in the right form. 465 AffineExpr m, n, k; 466 bindDims(rewriter.getContext(), m, n, k); 467 Value transposedMask = t(mask, {2, 0, 1}); 468 // Classical row-major matmul: Just permute the lhs. 469 if (layout({{m, k}, {k, n}, {m, n}})) 470 return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, 471 transposedMask); 472 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 473 if (layout({{m, k}, {n, k}, {m, n}})) { 474 Value tlhs = t(lhs); 475 return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1, 476 transposedMask); 477 } 478 // No need to permute anything. 479 if (layout({{k, m}, {k, n}, {m, n}})) 480 return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, 481 transposedMask); 482 // Just permute the rhs. 483 if (layout({{k, m}, {n, k}, {m, n}})) 484 return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0, 485 transposedMask); 486 // Transposed output: swap RHS and LHS. 487 // Classical row-major matmul: permute the lhs. 488 if (layout({{m, k}, {k, n}, {n, m}})) 489 return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1, 490 transposedMask); 491 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 492 if (layout({{m, k}, {n, k}, {n, m}})) { 493 Value trhs = t(rhs); 494 return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1, 495 transposedMask); 496 } 497 if (layout({{k, m}, {k, n}, {n, m}})) 498 return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, 499 transposedMask); 500 if (layout({{k, m}, {n, k}, {n, m}})) 501 return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, 502 transposedMask); 503 return failure(); 504 } 505 506 // 507 // One outer parallel, one inner reduction (matvec flavor). 508 // Mask needs to be transposed everywhere to turn the reduction dimension 509 // outermost as required by outerproduct. 510 // 511 FailureOr<Value> matvec() { 512 if (!iters({Par(), Red()})) 513 return failure(); 514 AffineExpr m, k; 515 bindDims(rewriter.getContext(), m, k); 516 Value transposedMask = t(mask); 517 518 // Case mat-vec: transpose. 519 if (layout({{m, k}, {k}, {m}})) 520 return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, 521 transposedMask); 522 // Case mat-trans-vec: ready to go. 523 if (layout({{k, m}, {k}, {m}})) 524 return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, 525 transposedMask); 526 // Case vec-mat: swap and transpose. 527 if (layout({{k}, {m, k}, {m}})) 528 return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, 529 transposedMask); 530 // Case vec-mat-trans: swap and ready to go. 531 if (layout({{k}, {k, m}, {m}})) 532 return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, 533 transposedMask); 534 return failure(); 535 } 536 537 // 538 // One outer reduction, one inner parallel (tmatvec flavor). 539 // Mask already has the shape of the outer product. 540 // 541 FailureOr<Value> tmatvec() { 542 if (!iters({Red(), Par()})) 543 return failure(); 544 AffineExpr k, m; 545 bindDims(rewriter.getContext(), k, m); 546 547 // Case mat-vec: transpose. 548 if (layout({{m, k}, {k}, {m}})) 549 return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask); 550 // Case mat-trans-vec: ready to go. 551 if (layout({{k, m}, {k}, {m}})) 552 return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask); 553 // Case vec-mat: swap and transpose. 554 if (layout({{k}, {m, k}, {m}})) 555 return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask); 556 // Case vec-mat-trans: swap and ready to go. 557 if (layout({{k}, {k, m}, {m}})) 558 return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask); 559 return failure(); 560 } 561 562 private: 563 vector::CombiningKind kind; 564 Value lhs, rhs, res, mask; 565 VectorType lhsType; 566 }; 567 568 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 569 /// semantics to a reduction_size-unrolled sequence: 570 /// ``` 571 /// %at = vector.transpose %a, [1, 0] 572 /// %bRow0 = vector.extract %b[0] 573 /// %atRow0 = vector.extract %at[0] 574 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 575 /// ... 576 /// %bRowK = vector.extract %b[K] 577 /// %atRowK = vector.extract %at[K] 578 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 579 /// ``` 580 /// 581 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 582 /// otherwise supports any layout permutation of the matrix-multiply. 583 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 584 vector::ContractionOp op, PatternRewriter &rewriter) const { 585 if (vectorTransformOptions.vectorContractLowering != 586 vector::VectorContractLowering::OuterProduct) 587 return failure(); 588 589 if (failed(filter(op))) 590 return failure(); 591 592 // Vector mask setup. 593 OpBuilder::InsertionGuard guard(rewriter); 594 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); 595 Operation *rootOp; 596 if (maskableOp.isMasked()) { 597 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 598 rootOp = maskableOp.getMaskingOp(); 599 } else { 600 rootOp = op; 601 } 602 603 UnrolledOuterProductGenerator e(rewriter, op); 604 FailureOr<Value> matmatRes = e.matmat(); 605 if (succeeded(matmatRes)) { 606 rewriter.replaceOp(rootOp, *matmatRes); 607 return success(); 608 } 609 FailureOr<Value> matvecRes = e.matvec(); 610 if (succeeded(matvecRes)) { 611 rewriter.replaceOp(rootOp, *matvecRes); 612 return success(); 613 } 614 FailureOr<Value> tmatvecRes = e.tmatvec(); 615 if (succeeded(tmatvecRes)) { 616 rewriter.replaceOp(rootOp, *tmatvecRes); 617 return success(); 618 } 619 620 return failure(); 621 } 622 623 LogicalResult 624 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 625 PatternRewriter &rewriter) const { 626 // TODO: Support vector.mask. 627 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 628 if (maskableOp.isMasked()) 629 return failure(); 630 631 if (failed(filter(op))) 632 return failure(); 633 634 if (vectorTransformOptions.vectorContractLowering != 635 vector::VectorContractLowering::Dot) 636 return failure(); 637 638 auto iteratorTypes = op.getIteratorTypes().getValue(); 639 static constexpr std::array<int64_t, 2> perm = {1, 0}; 640 Location loc = op.getLoc(); 641 Value lhs = op.getLhs(), rhs = op.getRhs(); 642 643 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 644 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 645 AffineExpr m, n, k; 646 bindDims(rewriter.getContext(), m, n, k); 647 SmallVector<AffineMap> maps = op.getIndexingMapsArray(); 648 // 649 // In the following we wish to make the reduction dimension innermost so we 650 // can load vectors and just fmul + reduce into a scalar. 651 // 652 if (isParallelIterator(iteratorTypes[0]) && 653 isParallelIterator(iteratorTypes[1]) && 654 isReductionIterator(iteratorTypes[2])) { 655 // 656 // Two outer parallel, one inner reduction (matmat flavor). 657 // 658 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 659 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 660 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 661 // No need to permute anything. 662 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 663 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 664 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 665 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 666 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 667 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 668 // This is the classical row-major matmul. Just permute the lhs. 669 Value tmp = lhs; 670 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 671 rhs = tmp; 672 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 673 std::swap(lhs, rhs); 674 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 675 Value tmp = lhs; 676 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 677 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 678 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 679 Value tmp = rhs; 680 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 681 lhs = tmp; 682 } else { 683 return failure(); 684 } 685 } else if (isParallelIterator(iteratorTypes[0]) && 686 isReductionIterator(iteratorTypes[1])) { 687 // 688 // One outer parallel, one inner reduction (matvec flavor) 689 // 690 if (maps == infer({{m, n}, {n}, {m}})) { 691 // No need to permute anything. 692 } else if (maps == infer({{n, m}, {n}, {m}})) { 693 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 694 } else if (maps == infer({{n}, {m, n}, {m}})) { 695 std::swap(lhs, rhs); 696 } else if (maps == infer({{n}, {n, m}, {m}})) { 697 std::swap(lhs, rhs); 698 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 699 } else { 700 return failure(); 701 } 702 } else { 703 return failure(); 704 } 705 706 VectorType dstType = cast<VectorType>(op.getResultType()); 707 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 708 "Expected dst type of rank 1 or 2"); 709 710 unsigned rank = dstType.getRank(); 711 unsigned dstRows = dstType.getShape()[0]; 712 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 713 714 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 715 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 716 rewriter.getZeroAttr(dstType)); 717 bool isInt = isa<IntegerType>(dstType.getElementType()); 718 for (unsigned r = 0; r < dstRows; ++r) { 719 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 720 for (unsigned c = 0; c < dstColumns; ++c) { 721 Value b = rank == 1 722 ? rhs 723 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 724 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 725 Value reduced = rewriter.create<vector::ReductionOp>( 726 op.getLoc(), vector::CombiningKind::ADD, m); 727 728 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 729 : SmallVector<int64_t, 2>{r, c}; 730 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 731 } 732 } 733 if (auto acc = op.getAcc()) 734 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 735 rewriter.replaceOp(op, res); 736 return success(); 737 } 738 739 /// Lower vector.contract with all size one reduction dimensions to 740 /// elementwise ops when possible. 741 struct ContractOpToElementwise 742 : public OpRewritePattern<vector::ContractionOp> { 743 using OpRewritePattern::OpRewritePattern; 744 using FilterConstraintType = 745 std::function<LogicalResult(vector::ContractionOp op)>; 746 static LogicalResult defaultFilter(vector::ContractionOp op) { 747 return success(); 748 } 749 ContractOpToElementwise( 750 vector::VectorTransformsOptions vectorTransformOptions, 751 MLIRContext *context, PatternBenefit benefit = 1, 752 const FilterConstraintType &constraint = defaultFilter) 753 : OpRewritePattern<vector::ContractionOp>(context, benefit), 754 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 755 756 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 757 PatternRewriter &rewriter) const override { 758 // TODO: Support vector.mask. 759 auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation()); 760 if (maskableOp.isMasked()) 761 return failure(); 762 763 if (failed(filter(contractOp))) 764 return failure(); 765 766 if (vectorTransformOptions.vectorContractLowering != 767 vector::VectorContractLowering::ParallelArith) 768 return failure(); 769 770 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape(); 771 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape(); 772 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; 773 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1]; 774 SmallVector<int64_t> lhsReductionDims = 775 getReductionIndex(lhsMap, contractOp.getIteratorTypes()); 776 SmallVector<int64_t> rhsReductionDims = 777 getReductionIndex(rhsMap, contractOp.getIteratorTypes()); 778 // All the reduction dimensions must be a size 1. 779 for (int64_t dim : lhsReductionDims) { 780 if (lhsShape[dim] != 1) 781 return failure(); 782 } 783 for (int64_t dim : rhsReductionDims) { 784 if (rhsShape[dim] != 1) 785 return failure(); 786 } 787 AffineMap accMap = contractOp.getIndexingMapsArray()[2]; 788 unsigned numParallelDims = accMap.getNumResults(); 789 unsigned numLhsDimToBroadcast = 790 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); 791 unsigned numRhsDimToBroadcast = 792 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); 793 SmallVector<int64_t> lhsDims; 794 SmallVector<int64_t> lhsTranspose; 795 SmallVector<int64_t> rhsDims; 796 SmallVector<int64_t> rhsTranspose; 797 for (int64_t dim : lhsReductionDims) 798 lhsTranspose.push_back(numLhsDimToBroadcast + dim); 799 for (int64_t dim : rhsReductionDims) 800 rhsTranspose.push_back(numRhsDimToBroadcast + dim); 801 // Loop through the parallel dimensions to calculate the dimensions to 802 // broadcast and to permute in order to extract only parallel dimensions. 803 for (unsigned i = 0; i < numParallelDims; i++) { 804 std::optional<unsigned> lhsDim = 805 getDimPosition(lhsMap, accMap.getDimPosition(i)); 806 if (lhsDim) { 807 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); 808 } else { 809 // If the parallel dimension doesn't exist we will have to broadcast it. 810 lhsDims.push_back( 811 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 812 lhsTranspose.push_back(lhsDims.size() - 1); 813 } 814 std::optional<unsigned> rhsDim = 815 getDimPosition(rhsMap, accMap.getDimPosition(i)); 816 if (rhsDim) { 817 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); 818 } else { 819 // If the parallel dimension doesn't exist we will have to broadcast it. 820 rhsDims.push_back( 821 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 822 rhsTranspose.push_back(rhsDims.size() - 1); 823 } 824 } 825 Value newLhs = contractOp.getLhs(); 826 Value newRhs = contractOp.getRhs(); 827 Location loc = contractOp.getLoc(); 828 if (!lhsDims.empty()) { 829 lhsDims.append(lhsShape.begin(), lhsShape.end()); 830 auto expandedType = 831 VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); 832 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs); 833 } 834 if (!rhsDims.empty()) { 835 rhsDims.append(rhsShape.begin(), rhsShape.end()); 836 auto expandedType = 837 VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); 838 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs); 839 } 840 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); 841 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose); 842 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose); 843 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0); 844 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0); 845 newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets); 846 newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets); 847 std::optional<Value> result = 848 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), 849 contractOp.getKind(), rewriter, isInt); 850 rewriter.replaceOp(contractOp, {*result}); 851 return success(); 852 } 853 854 private: 855 /// Options to control the vector patterns. 856 vector::VectorTransformsOptions vectorTransformOptions; 857 FilterConstraintType filter; 858 }; 859 860 /// Progressive lowering of ContractionOp. 861 /// One: 862 /// %x = vector.contract with at least one free/batch dimension 863 /// is replaced by: 864 /// %a = vector.contract with one less free/batch dimension 865 /// %b = vector.contract with one less free/batch dimension 866 /// .. 867 /// %x = combine %a %b .. 868 /// until a pure contraction is reached (no free/batch dimensions), 869 /// which is replaced by a dot-product. 870 /// 871 /// This only kicks in when either VectorTransformsOptions is set 872 /// to DOT or when other contraction patterns fail. 873 // 874 // TODO: break down into transpose/reshape/cast ops 875 // when they become available to avoid code dup 876 // TODO: investigate lowering order impact on performance 877 LogicalResult 878 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 879 PatternRewriter &rewriter) const { 880 if (failed(filter(op))) 881 return failure(); 882 883 // TODO: support mixed mode contract lowering. 884 if (op.getLhsType().getElementType() != 885 getElementTypeOrSelf(op.getAccType()) || 886 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 887 return failure(); 888 889 // TODO: the code below assumes the default contraction, make sure it supports 890 // other kinds before enabling this lowering. 891 if (op.getKind() != vector::CombiningKind::ADD) { 892 return rewriter.notifyMatchFailure( 893 op, "contractions other than 'add' not supported"); 894 } 895 896 // TODO: implement benefits, cost models. 897 MLIRContext *ctx = op.getContext(); 898 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 899 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 900 return success(); 901 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 902 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 903 return success(); 904 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 905 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 906 return success(); 907 ContractOpToElementwise pat4(vectorTransformOptions, ctx); 908 if (succeeded(pat4.matchAndRewrite(op, rewriter))) 909 return success(); 910 911 // Vector mask setup. 912 OpBuilder::InsertionGuard guard(rewriter); 913 Operation *rootOp = op; 914 Value mask; 915 if (op.isMasked()) { 916 rewriter.setInsertionPoint(op.getMaskingOp()); 917 rootOp = op.getMaskingOp(); 918 mask = op.getMaskingOp().getMask(); 919 } 920 921 // Find first batch dimension in LHS/RHS, and lower when found. 922 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 923 if (!batchDimMap.empty()) { 924 int64_t lhsIndex = batchDimMap[0].first; 925 int64_t rhsIndex = batchDimMap[0].second; 926 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); 927 if (failed(newOp)) 928 return failure(); 929 rewriter.replaceOp(rootOp, *newOp); 930 return success(); 931 } 932 933 // Collect contracting dimensions. 934 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 935 op.getContractingDimMap(); 936 DenseSet<int64_t> lhsContractingDimSet; 937 DenseSet<int64_t> rhsContractingDimSet; 938 for (auto &dimPair : contractingDimMap) { 939 lhsContractingDimSet.insert(dimPair.first); 940 rhsContractingDimSet.insert(dimPair.second); 941 } 942 943 // Find first free dimension in LHS, and lower when found. 944 VectorType lhsType = op.getLhsType(); 945 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 946 if (lhsContractingDimSet.count(lhsIndex) == 0) { 947 auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); 948 if (failed(newOp)) 949 return failure(); 950 rewriter.replaceOp(rootOp, *newOp); 951 return success(); 952 } 953 } 954 955 // Find first free dimension in RHS, and lower when found. 956 VectorType rhsType = op.getRhsType(); 957 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 958 if (rhsContractingDimSet.count(rhsIndex) == 0) { 959 auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); 960 if (failed(newOp)) 961 return failure(); 962 rewriter.replaceOp(rootOp, *newOp); 963 return success(); 964 } 965 } 966 967 // Lower the first remaining reduction dimension. 968 if (!contractingDimMap.empty()) { 969 auto newOp = lowerReduction(rewriter, op, mask); 970 if (failed(newOp)) 971 return failure(); 972 rewriter.replaceOp(rootOp, *newOp); 973 return success(); 974 } 975 976 return failure(); 977 } 978 979 // Lower one parallel dimension. 980 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions. 981 // TODO: consider reusing existing contract unrolling 982 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, 983 vector::ContractionOp op, 984 int64_t lhsIndex, 985 int64_t rhsIndex, 986 Value mask) const { 987 VectorType lhsType = op.getLhsType(); 988 VectorType rhsType = op.getRhsType(); 989 VectorType resType = cast<VectorType>(op.getResultType()); 990 // Find the iterator type index and result index. 991 SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 992 int64_t iterIndex = -1; 993 int64_t dimSize = -1; 994 if (lhsIndex >= 0) { 995 iterIndex = iMap[0].getDimPosition(lhsIndex); 996 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) 997 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 998 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex 999 << " to map to the same dimension"; 1000 }); 1001 if (lhsType.getScalableDims()[lhsIndex]) 1002 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1003 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex 1004 << ") is not supported yet"; 1005 }); 1006 dimSize = lhsType.getDimSize(lhsIndex); 1007 } else if (rhsIndex >= 0) { 1008 iterIndex = iMap[1].getDimPosition(rhsIndex); 1009 if (rhsType.getScalableDims()[rhsIndex]) 1010 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1011 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex 1012 << ") is not supported yet"; 1013 }); 1014 dimSize = rhsType.getDimSize(rhsIndex); 1015 } 1016 if (iterIndex < 0) 1017 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1018 diag << "expected either lhsIndex=" << lhsIndex 1019 << " or rhsIndex=" << rhsIndex << " to be nonnegative"; 1020 }); 1021 // value_or(-1) means that we tolerate a dimension not appearing 1022 // in the result map. That can't happen for actual parallel iterators, but 1023 // the caller ContractionOpLowering::matchAndRewrite is currently calling 1024 // lowerParallel also for the case of unit-size reduction dims appearing only 1025 // on one of LHS or RHS, not both. At the moment, such cases are created by 1026 // CastAwayContractionLeadingOneDim, so we need to either support that or 1027 // modify that pattern. 1028 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1); 1029 if (resIndex == -1 && dimSize != 1) 1030 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1031 diag << "expected the dimension for iterIndex=" << iterIndex 1032 << " to either appear in the result map, or to be a unit dimension"; 1033 }); 1034 1035 // Construct new iterator types and affine map array attribute. 1036 std::array<AffineMap, 3> lowIndexingMaps = { 1037 adjustMap(iMap[0], iterIndex, rewriter), 1038 adjustMap(iMap[1], iterIndex, rewriter), 1039 adjustMap(iMap[2], iterIndex, rewriter)}; 1040 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1041 auto lowIter = 1042 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1043 // Unroll into a series of lower dimensional vector.contract ops. 1044 Location loc = op.getLoc(); 1045 Value result = rewriter.create<arith::ConstantOp>( 1046 loc, resType, rewriter.getZeroAttr(resType)); 1047 1048 for (int64_t d = 0; d < dimSize; ++d) { 1049 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1050 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1051 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 1052 1053 Value lowMask; 1054 if (mask) 1055 lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 1056 iterIndex, d, rewriter); 1057 1058 Operation *lowContract = rewriter.create<vector::ContractionOp>( 1059 loc, lhs, rhs, acc, lowAffine, lowIter); 1060 lowContract = maskOperation(rewriter, lowContract, lowMask); 1061 result = reshapeStore(loc, lowContract->getResult(0), result, resType, 1062 resIndex, d, rewriter); 1063 } 1064 return result; 1065 } 1066 1067 // Lower one reduction dimension. 1068 FailureOr<Value> ContractionOpLowering::lowerReduction( 1069 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { 1070 auto loc = op.getLoc(); 1071 VectorType lhsType = op.getLhsType(); 1072 VectorType rhsType = op.getRhsType(); 1073 Type resType = op.getResultType(); 1074 if (isa<VectorType>(resType)) 1075 return rewriter.notifyMatchFailure(op, 1076 "did not expect a VectorType result"); 1077 bool isInt = isa<IntegerType>(resType); 1078 // Use iterator index 0. 1079 int64_t iterIndex = 0; 1080 SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 1081 std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1082 std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1083 if (!lookupLhs.has_value()) 1084 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1085 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; 1086 }); 1087 if (!lookupRhs.has_value()) 1088 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1089 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; 1090 }); 1091 int64_t lhsIndex = *lookupLhs; 1092 int64_t rhsIndex = *lookupRhs; 1093 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1094 if (dimSize != rhsType.getDimSize(rhsIndex)) 1095 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1096 diag << "expect LHS dimension " << lhsIndex 1097 << " to have the same size as RHS dimension " << rhsIndex; 1098 }); 1099 // Base case. 1100 if (lhsType.getRank() == 1) { 1101 if (rhsType.getRank() != 1) 1102 return rewriter.notifyMatchFailure( 1103 op, "When LHS has rank 1, expected also RHS to have rank 1"); 1104 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 1105 auto kind = vector::CombiningKind::ADD; 1106 1107 Value acc = op.getAcc(); 1108 Operation *reductionOp = 1109 acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc) 1110 : rewriter.create<vector::ReductionOp>(loc, kind, m); 1111 return maskOperation(rewriter, reductionOp, mask)->getResult(0); 1112 } 1113 // Construct new iterator types and affine map array attribute. 1114 std::array<AffineMap, 3> lowIndexingMaps = { 1115 adjustMap(iMap[0], iterIndex, rewriter), 1116 adjustMap(iMap[1], iterIndex, rewriter), 1117 adjustMap(iMap[2], iterIndex, rewriter)}; 1118 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1119 auto lowIter = 1120 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1121 // Unroll into a series of lower dimensional vector.contract ops. 1122 // By feeding the initial accumulator into the first contraction, 1123 // and the result of each contraction into the next, eventually 1124 // the sum of all reductions is computed. 1125 Value result = op.getAcc(); 1126 for (int64_t d = 0; d < dimSize; ++d) { 1127 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1128 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1129 Value newMask; 1130 if (mask) 1131 newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 1132 iterIndex, d, rewriter); 1133 1134 Operation *newContract = rewriter.create<vector::ContractionOp>( 1135 loc, lhs, rhs, result, lowAffine, lowIter); 1136 result = maskOperation(rewriter, newContract, newMask)->getResult(0); 1137 } 1138 return result; 1139 } 1140 1141 /// Progressive lowering of OuterProductOp. 1142 /// One: 1143 /// %x = vector.outerproduct %lhs, %rhs, %acc 1144 /// is replaced by: 1145 /// %z = zero-result 1146 /// %0 = vector.extract %lhs[0] 1147 /// %1 = vector.broadcast %0 1148 /// %2 = vector.extract %acc[0] 1149 /// %3 = vector.fma %1, %rhs, %2 1150 /// %4 = vector.insert %3, %z[0] 1151 /// .. 1152 /// %x = vector.insert %.., %..[N-1] 1153 /// 1154 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 1155 public: 1156 using OpRewritePattern::OpRewritePattern; 1157 1158 LogicalResult matchAndRewrite(vector::OuterProductOp op, 1159 PatternRewriter &rewriter) const override { 1160 VectorType resType = op.getResultVectorType(); 1161 if ((resType.getShape().size() >= 2) && resType.allDimsScalable()) 1162 return failure(); 1163 1164 auto loc = op.getLoc(); 1165 1166 VectorType lhsType = op.getOperandVectorTypeLHS(); 1167 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS()); 1168 Type eltType = resType.getElementType(); 1169 bool isInt = isa<IntegerType, IndexType>(eltType); 1170 Value acc = op.getAcc(); 1171 vector::CombiningKind kind = op.getKind(); 1172 1173 // Vector mask setup. 1174 OpBuilder::InsertionGuard guard(rewriter); 1175 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); 1176 Operation *rootOp; 1177 Value mask; 1178 if (maskableOp.isMasked()) { 1179 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 1180 rootOp = maskableOp.getMaskingOp(); 1181 mask = maskableOp.getMaskingOp().getMask(); 1182 } else { 1183 rootOp = op; 1184 } 1185 1186 if (!rhsType) { 1187 // Special case: AXPY operation. 1188 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 1189 std::optional<Value> mult = createContractArithOp( 1190 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); 1191 if (!mult.has_value()) 1192 return failure(); 1193 rewriter.replaceOp(rootOp, *mult); 1194 return success(); 1195 } 1196 1197 Value result = rewriter.create<arith::ConstantOp>( 1198 loc, resType, rewriter.getZeroAttr(resType)); 1199 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 1200 Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d); 1201 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 1202 Value r = nullptr; 1203 if (acc) 1204 r = rewriter.create<vector::ExtractOp>(loc, acc, d); 1205 Value extrMask; 1206 if (mask) 1207 extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d); 1208 1209 std::optional<Value> m = createContractArithOp( 1210 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); 1211 if (!m.has_value()) 1212 return failure(); 1213 result = rewriter.create<vector::InsertOp>(loc, *m, result, d); 1214 } 1215 1216 rewriter.replaceOp(rootOp, result); 1217 return success(); 1218 } 1219 }; 1220 1221 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1222 /// semantics to: 1223 /// ``` 1224 /// %mta = maybe_transpose 1225 /// %mtb = maybe_transpose 1226 /// %flattened_a = vector.shape_cast %mta 1227 /// %flattened_b = vector.shape_cast %mtb 1228 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1229 /// %mtd = vector.shape_cast %flattened_d 1230 /// %d = maybe_untranspose %mtd 1231 /// %e = add %c, %d 1232 /// ``` 1233 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1234 // 1235 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1236 /// vector.transpose operations are inserted if the vector.contract op is not a 1237 /// row-major matrix multiply. 1238 LogicalResult 1239 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1240 PatternRewriter &rew) const { 1241 // TODO: Support vector.mask. 1242 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 1243 if (maskableOp.isMasked()) 1244 return failure(); 1245 1246 if (vectorTransformOptions.vectorContractLowering != 1247 vector::VectorContractLowering::Matmul) 1248 return failure(); 1249 if (failed(filter(op))) 1250 return failure(); 1251 1252 auto iteratorTypes = op.getIteratorTypes().getValue(); 1253 if (!isParallelIterator(iteratorTypes[0]) || 1254 !isParallelIterator(iteratorTypes[1]) || 1255 !isReductionIterator(iteratorTypes[2])) 1256 return failure(); 1257 1258 Type elementType = op.getLhsType().getElementType(); 1259 if (!elementType.isIntOrFloat()) 1260 return failure(); 1261 1262 Type dstElementType = op.getType(); 1263 if (auto vecType = dyn_cast<VectorType>(dstElementType)) 1264 dstElementType = vecType.getElementType(); 1265 if (elementType != dstElementType) 1266 return failure(); 1267 1268 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1269 // Bail out if the contraction cannot be put in this form. 1270 MLIRContext *ctx = op.getContext(); 1271 Location loc = op.getLoc(); 1272 AffineExpr m, n, k; 1273 bindDims(rew.getContext(), m, n, k); 1274 // LHS must be A(m, k) or A(k, m). 1275 Value lhs = op.getLhs(); 1276 auto lhsMap = op.getIndexingMapsArray()[0]; 1277 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1278 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1279 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1280 return failure(); 1281 1282 // RHS must be B(k, n) or B(n, k). 1283 Value rhs = op.getRhs(); 1284 auto rhsMap = op.getIndexingMapsArray()[1]; 1285 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1286 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1287 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1288 return failure(); 1289 1290 // At this point lhs and rhs are in row-major. 1291 VectorType lhsType = cast<VectorType>(lhs.getType()); 1292 VectorType rhsType = cast<VectorType>(rhs.getType()); 1293 int64_t lhsRows = lhsType.getDimSize(0); 1294 int64_t lhsColumns = lhsType.getDimSize(1); 1295 int64_t rhsColumns = rhsType.getDimSize(1); 1296 1297 Type flattenedLHSType = 1298 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1299 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1300 1301 Type flattenedRHSType = 1302 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1303 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1304 1305 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1306 rhsColumns); 1307 mul = rew.create<vector::ShapeCastOp>( 1308 loc, 1309 VectorType::get({lhsRows, rhsColumns}, 1310 getElementTypeOrSelf(op.getAcc().getType())), 1311 mul); 1312 1313 // ACC must be C(m, n) or C(n, m). 1314 auto accMap = op.getIndexingMapsArray()[2]; 1315 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1316 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1317 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1318 llvm_unreachable("invalid contraction semantics"); 1319 1320 Value res = 1321 isa<IntegerType>(elementType) 1322 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 1323 : static_cast<Value>( 1324 rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 1325 1326 rew.replaceOp(op, res); 1327 return success(); 1328 } 1329 } // namespace 1330 1331 void mlir::vector::populateVectorContractLoweringPatterns( 1332 RewritePatternSet &patterns, VectorTransformsOptions options, 1333 PatternBenefit benefit, bool disableOuterProductLowering) { 1334 if (!disableOuterProductLowering) 1335 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 1336 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 1337 ContractionOpToOuterProductOpLowering>( 1338 options, patterns.getContext(), benefit); 1339 } 1340 1341 void mlir::vector::populateVectorOuterProductLoweringPatterns( 1342 RewritePatternSet &patterns, PatternBenefit benefit) { 1343 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 1344 } 1345