1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.contract' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinAttributeInterfaces.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/ImplicitLocOpBuilder.h" 29 #include "mlir/IR/Location.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Interfaces/VectorInterfaces.h" 34 #include "mlir/Support/LogicalResult.h" 35 36 #define DEBUG_TYPE "vector-contract-lowering" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 //===----------------------------------------------------------------------===// 42 // Helper functions 43 //===----------------------------------------------------------------------===// 44 45 // Helper to find an index in an affine map. 46 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 47 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 48 int64_t idx = map.getDimPosition(i); 49 if (idx == index) 50 return i; 51 } 52 return std::nullopt; 53 } 54 55 // Helper to construct iterator types with one index removed. 56 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes, 57 int64_t index) { 58 SmallVector<Attribute> results; 59 for (const auto &it : llvm::enumerate(iteratorTypes)) { 60 int64_t idx = it.index(); 61 if (idx == index) 62 continue; 63 results.push_back(it.value()); 64 } 65 return results; 66 } 67 68 // Helper to construct an affine map with one index removed. 69 static AffineMap adjustMap(AffineMap map, int64_t index, 70 PatternRewriter &rewriter) { 71 auto *ctx = rewriter.getContext(); 72 SmallVector<AffineExpr> results; 73 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 74 int64_t idx = map.getDimPosition(i); 75 if (idx == index) 76 continue; 77 // Re-insert remaining indices, but renamed when occurring 78 // after the removed index. 79 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 80 results.push_back(targetExpr); 81 } 82 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 83 } 84 85 // Helper method to possibly drop a dimension in a load. 86 // TODO 87 static Value reshapeLoad(Location loc, Value val, VectorType type, 88 int64_t index, int64_t pos, 89 PatternRewriter &rewriter) { 90 if (index == -1) 91 return val; 92 Type lowType = VectorType::Builder(type).dropDim(0); 93 // At extraction dimension? 94 if (index == 0) 95 return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos); 96 // Unroll leading dimensions. 97 VectorType vType = cast<VectorType>(lowType); 98 Type resType = VectorType::Builder(type).dropDim(index); 99 auto resVectorType = cast<VectorType>(resType); 100 Value result = rewriter.create<arith::ConstantOp>( 101 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 102 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 103 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d); 104 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 105 result = 106 rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d); 107 } 108 return result; 109 } 110 111 // Helper method to possibly drop a dimension in a store. 112 // TODO 113 static Value reshapeStore(Location loc, Value val, Value result, 114 VectorType type, int64_t index, int64_t pos, 115 PatternRewriter &rewriter) { 116 // Unmodified? 117 if (index == -1) 118 return val; 119 // At insertion dimension? 120 if (index == 0) 121 return rewriter.create<vector::InsertOp>(loc, type, val, result, pos); 122 // Unroll leading dimensions. 123 Type lowType = VectorType::Builder(type).dropDim(0); 124 VectorType vType = cast<VectorType>(lowType); 125 Type insType = VectorType::Builder(vType).dropDim(0); 126 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 127 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d); 128 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d); 129 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 130 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d); 131 } 132 return result; 133 } 134 135 /// Helper to create arithmetic operation associated with a kind of contraction. 136 static std::optional<Value> 137 createContractArithOp(Location loc, Value x, Value y, Value acc, 138 vector::CombiningKind kind, PatternRewriter &rewriter, 139 bool isInt, Value mask = Value()) { 140 using vector::CombiningKind; 141 Value mul; 142 143 if (isInt) { 144 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 145 // Only valid for floating point types. 146 return std::nullopt; 147 mul = rewriter.create<arith::MulIOp>(loc, x, y); 148 } else { 149 // Float case. 150 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || 151 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || 152 kind == CombiningKind::MAXSI || kind == CombiningKind::OR || 153 kind == CombiningKind::XOR) 154 // Only valid for integer types. 155 return std::nullopt; 156 // Special case for fused multiply-add. 157 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) { 158 Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc); 159 if (mask) 160 // The fma op doesn't need explicit masking. However, fma ops used in 161 // reductions must preserve previous 'acc' values for masked-out lanes. 162 fma = selectPassthru(rewriter, mask, fma, acc); 163 return fma; 164 } 165 mul = rewriter.create<arith::MulFOp>(loc, x, y); 166 } 167 168 if (!acc) 169 return std::optional<Value>(mul); 170 171 return makeArithReduction(rewriter, loc, kind, mul, acc, mask); 172 } 173 174 /// Return the positions of the reductions in the given map. 175 static SmallVector<int64_t> getReductionIndex(AffineMap map, 176 ArrayAttr iteratorTypes) { 177 SmallVector<int64_t> dimsIdx; 178 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 179 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) 180 dimsIdx.push_back(i); 181 } 182 return dimsIdx; 183 } 184 185 /// Look for a given dimension in an affine map and return its position. Return 186 /// std::nullopt if the dimension is not in the map results. 187 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) { 188 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 189 if (map.getDimPosition(i) == dim) 190 return i; 191 } 192 return std::nullopt; 193 } 194 195 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 196 /// operands `x` and `y`. 197 static Value createAdd(Location loc, Value x, Value y, bool isInt, 198 PatternRewriter &rewriter) { 199 if (isInt) 200 return rewriter.create<arith::AddIOp>(loc, x, y); 201 return rewriter.create<arith::AddFOp>(loc, x, y); 202 } 203 204 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 205 /// operands `x and `y`. 206 static Value createMul(Location loc, Value x, Value y, bool isInt, 207 PatternRewriter &rewriter) { 208 if (isInt) 209 return rewriter.create<arith::MulIOp>(loc, x, y); 210 return rewriter.create<arith::MulFOp>(loc, x, y); 211 } 212 213 namespace { 214 215 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 216 /// semantics to: 217 /// ``` 218 /// %flattened_a = vector.shape_cast %a 219 /// %flattened_b = vector.shape_cast %b 220 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 221 /// %d = vector.shape_cast %%flattened_d 222 /// %e = add %c, %d 223 /// ``` 224 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 225 // 226 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 227 /// the vector.contract op is a row-major matrix multiply. 228 class ContractionOpToMatmulOpLowering 229 : public OpRewritePattern<vector::ContractionOp> { 230 public: 231 using OpRewritePattern::OpRewritePattern; 232 233 using FilterConstraintType = 234 std::function<LogicalResult(vector::ContractionOp op)>; 235 236 static LogicalResult defaultFilter(vector::ContractionOp op) { 237 return success(); 238 } 239 240 ContractionOpToMatmulOpLowering( 241 vector::VectorTransformsOptions vectorTransformOptions, 242 MLIRContext *context, PatternBenefit benefit = 1, 243 FilterConstraintType constraint = defaultFilter) 244 : OpRewritePattern<vector::ContractionOp>(context, benefit), 245 vectorTransformOptions(vectorTransformOptions), 246 filter(std::move(constraint)) {} 247 248 LogicalResult matchAndRewrite(vector::ContractionOp op, 249 PatternRewriter &rewriter) const override; 250 251 private: 252 /// Options to control the vector patterns. 253 vector::VectorTransformsOptions vectorTransformOptions; 254 FilterConstraintType filter; 255 }; 256 257 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 258 /// semantics to a reduction_size-unrolled sequence: 259 /// ``` 260 /// %at = vector.transpose %a, [1, 0] 261 /// %bRow0 = vector.extract %b[0] 262 /// %atRow0 = vector.extract %at[0] 263 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 264 /// ... 265 /// %bRowK = vector.extract %b[K] 266 /// %atRowK = vector.extract %at[K] 267 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 268 /// ``` 269 /// 270 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 271 /// the vector.contract op is a row-major matrix multiply. 272 class ContractionOpToOuterProductOpLowering 273 : public OpRewritePattern<vector::ContractionOp> { 274 public: 275 using OpRewritePattern::OpRewritePattern; 276 277 using FilterConstraintType = 278 std::function<LogicalResult(vector::ContractionOp op)>; 279 280 static LogicalResult defaultFilter(vector::ContractionOp op) { 281 return success(); 282 } 283 284 ContractionOpToOuterProductOpLowering( 285 vector::VectorTransformsOptions vectorTransformOptions, 286 MLIRContext *context, PatternBenefit benefit = 1, 287 FilterConstraintType constraint = defaultFilter) 288 : OpRewritePattern<vector::ContractionOp>(context, benefit), 289 vectorTransformOptions(vectorTransformOptions), 290 filter(std::move(constraint)) {} 291 292 LogicalResult matchAndRewrite(vector::ContractionOp op, 293 PatternRewriter &rewriter) const override; 294 295 private: 296 /// Options to control the vector patterns. 297 vector::VectorTransformsOptions vectorTransformOptions; 298 FilterConstraintType filter; 299 }; 300 301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 302 /// semantics to an output-size-unrolled sequence: 303 /// ``` 304 /// %out = arith.constant ... : vector<MxNxelt_type> 305 /// %bt = vector.transpose %b, [1, 0] 306 /// %aRow0 = vector.extract %a[0] 307 /// %btRow0 = vector.extract %bt[0] 308 /// %c00 = vector.reduce %atRow0, %bRow0 309 /// %out00 = vector.insert %c00, %out[0, 0] 310 /// ... 311 /// %aRowLast = vector.extract %at[M-1] 312 /// %btRowLast = vector.extract %b[N-1] 313 /// %cLastLast = vector.reduce %atRowLast, %bRowLast 314 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] 315 /// ``` 316 /// 317 /// This only kicks in when VectorTransformsOptions is set to Dot and 318 /// the vector.contract op is a row-major matmul or matvec. 319 class ContractionOpToDotLowering 320 : public OpRewritePattern<vector::ContractionOp> { 321 public: 322 using OpRewritePattern::OpRewritePattern; 323 324 using FilterConstraintType = 325 std::function<LogicalResult(vector::ContractionOp op)>; 326 327 static LogicalResult defaultFilter(vector::ContractionOp op) { 328 return success(); 329 } 330 331 ContractionOpToDotLowering( 332 vector::VectorTransformsOptions vectorTransformOptions, 333 MLIRContext *context, PatternBenefit benefit = 1, 334 const FilterConstraintType &constraint = defaultFilter) 335 : OpRewritePattern<vector::ContractionOp>(context, benefit), 336 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 337 338 LogicalResult matchAndRewrite(vector::ContractionOp op, 339 PatternRewriter &rewriter) const override; 340 341 private: 342 /// Options to control the vector patterns. 343 vector::VectorTransformsOptions vectorTransformOptions; 344 FilterConstraintType filter; 345 }; 346 347 /// Progressive lowering of ContractionOp. 348 /// 349 /// One: 350 /// %x = vector.contract with at least one free/batch dimension 351 /// is replaced by: 352 /// %a = vector.contract with one less free/batch dimension 353 /// %b = vector.contract with one less free/batch dimension 354 /// .. 355 /// %x = combine %a %b .. 356 /// until a pure contraction is reached (no free/batch dimensions), 357 /// which is replaced by a dot-product. 358 /// 359 /// This only kicks in when either VectorTransformsOptions is set 360 /// to Dot or when other contraction patterns fail. 361 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> { 362 public: 363 using OpRewritePattern::OpRewritePattern; 364 using FilterConstraintType = 365 std::function<LogicalResult(vector::ContractionOp op)>; 366 367 static LogicalResult defaultFilter(vector::ContractionOp op) { 368 return success(); 369 } 370 371 ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 372 MLIRContext *context, PatternBenefit benefit = 1, 373 FilterConstraintType constraint = defaultFilter) 374 : OpRewritePattern<vector::ContractionOp>(context, benefit), 375 vectorTransformOptions(vectorTransformOptions), 376 filter(std::move(constraint)) {} 377 378 LogicalResult matchAndRewrite(vector::ContractionOp op, 379 PatternRewriter &rewriter) const override; 380 381 private: 382 /// Options to control the vector patterns. 383 vector::VectorTransformsOptions vectorTransformOptions; 384 FilterConstraintType filter; 385 // Lower one parallel dimension. 386 FailureOr<Value> lowerParallel(PatternRewriter &rewriter, 387 vector::ContractionOp op, int64_t lhsIndex, 388 int64_t rhsIndex, Value mask) const; 389 // Lower one reduction dimension. 390 FailureOr<Value> lowerReduction(PatternRewriter &rewriter, 391 vector::ContractionOp op, Value mask) const; 392 }; 393 394 /// Generate a vector implementation for matmat, matvec and tmatvec. 395 /// This unrolls outer-products along the reduction dimension. 396 struct UnrolledOuterProductGenerator 397 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> { 398 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) 399 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op), 400 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 401 res(op.getAcc()), lhsType(op.getLhsType()) { 402 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 403 if (maskableOp.isMasked()) 404 mask = maskableOp.getMaskingOp().getMask(); 405 } 406 407 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) { 408 if (!v) 409 return v; 410 return rewriter.create<vector::TransposeOp>(loc, v, perm); 411 } 412 413 Value promote(Value v, Type dstElementType) { 414 Type elementType = v.getType(); 415 auto vecType = dyn_cast<VectorType>(elementType); 416 if (vecType) 417 elementType = vecType.getElementType(); 418 if (elementType == dstElementType) 419 return v; 420 Type promotedType = dstElementType; 421 if (vecType) 422 promotedType = VectorType::get(vecType.getShape(), promotedType); 423 if (isa<FloatType>(dstElementType)) 424 return rewriter.create<arith::ExtFOp>(loc, promotedType, v); 425 return rewriter.create<arith::ExtSIOp>(loc, promotedType, v); 426 } 427 428 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize, 429 std::optional<Value> maybeMask = std::nullopt) { 430 assert(reductionSize > 0); 431 // Incremental support for masking. 432 if (mask && !maybeMask.has_value()) 433 return failure(); 434 435 Type resElementType = cast<VectorType>(res.getType()).getElementType(); 436 for (int64_t k = 0; k < reductionSize; ++k) { 437 Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k); 438 Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k); 439 extractA = promote(extractA, resElementType); 440 extractB = promote(extractB, resElementType); 441 Value extractMask; 442 if (maybeMask.has_value() && maybeMask.value()) 443 extractMask = 444 rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k); 445 446 Operation *outerProdOp = rewriter.create<vector::OuterProductOp>( 447 loc, res.getType(), extractA, extractB, res, kind); 448 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); 449 } 450 return res; 451 } 452 453 /// Two outer parallel, one inner reduction (matmat flavor). 454 FailureOr<Value> matmat() { 455 if (!iters({Par(), Par(), Red()})) 456 return failure(); 457 // Set up the parallel/reduction structure in the right form. 458 AffineExpr m, n, k; 459 bindDims(rewriter.getContext(), m, n, k); 460 // Classical row-major matmul: Just permute the lhs. 461 if (layout({{m, k}, {k, n}, {m, n}})) 462 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), 463 t(mask, {2, 0, 1})); 464 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 465 if (layout({{m, k}, {n, k}, {m, n}})) { 466 Value tlhs = t(lhs); 467 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 468 } 469 // No need to permute anything. 470 if (layout({{k, m}, {k, n}, {m, n}})) 471 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 472 // Just permute the rhs. 473 if (layout({{k, m}, {n, k}, {m, n}})) 474 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 475 // Transposed output: swap RHS and LHS. 476 // Classical row-major matmul: permute the lhs. 477 if (layout({{m, k}, {k, n}, {n, m}})) 478 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 479 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 480 if (layout({{m, k}, {n, k}, {n, m}})) { 481 Value trhs = t(rhs); 482 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 483 } 484 if (layout({{k, m}, {k, n}, {n, m}})) 485 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 486 if (layout({{k, m}, {n, k}, {n, m}})) 487 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 488 return failure(); 489 } 490 491 /// One outer parallel, one inner reduction (matvec flavor) 492 FailureOr<Value> matvec() { 493 if (!iters({Par(), Red()})) 494 return failure(); 495 AffineExpr m, k; 496 bindDims(rewriter.getContext(), m, k); 497 498 // Case mat-vec: transpose. 499 if (layout({{m, k}, {k}, {m}})) 500 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask)); 501 // Case mat-trans-vec: ready to go. 502 if (layout({{k, m}, {k}, {m}})) 503 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 504 // Case vec-mat: swap and transpose. 505 if (layout({{k}, {m, k}, {m}})) 506 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 507 // Case vec-mat-trans: swap and ready to go. 508 if (layout({{k}, {k, m}, {m}})) 509 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 510 return failure(); 511 } 512 513 // 514 // One outer reduction, one inner parallel (tmatvec flavor) 515 // 516 FailureOr<Value> tmatvec() { 517 if (!iters({Red(), Par()})) 518 return failure(); 519 AffineExpr k, m; 520 bindDims(rewriter.getContext(), k, m); 521 522 // Case mat-vec: transpose. 523 if (layout({{m, k}, {k}, {m}})) 524 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 525 // Case mat-trans-vec: ready to go. 526 if (layout({{k, m}, {k}, {m}})) 527 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 528 // Case vec-mat: swap and transpose. 529 if (layout({{k}, {m, k}, {m}})) 530 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 531 // Case vec-mat-trans: swap and ready to go. 532 if (layout({{k}, {k, m}, {m}})) 533 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 534 return failure(); 535 } 536 537 private: 538 vector::CombiningKind kind; 539 Value lhs, rhs, res, mask; 540 VectorType lhsType; 541 }; 542 543 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 544 /// semantics to a reduction_size-unrolled sequence: 545 /// ``` 546 /// %at = vector.transpose %a, [1, 0] 547 /// %bRow0 = vector.extract %b[0] 548 /// %atRow0 = vector.extract %at[0] 549 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 550 /// ... 551 /// %bRowK = vector.extract %b[K] 552 /// %atRowK = vector.extract %at[K] 553 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 554 /// ``` 555 /// 556 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 557 /// otherwise supports any layout permutation of the matrix-multiply. 558 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 559 vector::ContractionOp op, PatternRewriter &rewriter) const { 560 if (vectorTransformOptions.vectorContractLowering != 561 vector::VectorContractLowering::OuterProduct) 562 return failure(); 563 564 if (failed(filter(op))) 565 return failure(); 566 567 // Vector mask setup. 568 OpBuilder::InsertionGuard guard(rewriter); 569 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); 570 Operation *rootOp; 571 if (maskableOp.isMasked()) { 572 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 573 rootOp = maskableOp.getMaskingOp(); 574 } else { 575 rootOp = op; 576 } 577 578 UnrolledOuterProductGenerator e(rewriter, op); 579 FailureOr<Value> matmatRes = e.matmat(); 580 if (succeeded(matmatRes)) { 581 rewriter.replaceOp(rootOp, *matmatRes); 582 return success(); 583 } 584 FailureOr<Value> matvecRes = e.matvec(); 585 if (succeeded(matvecRes)) { 586 rewriter.replaceOp(rootOp, *matvecRes); 587 return success(); 588 } 589 FailureOr<Value> tmatvecRes = e.tmatvec(); 590 if (succeeded(tmatvecRes)) { 591 rewriter.replaceOp(rootOp, *tmatvecRes); 592 return success(); 593 } 594 595 return failure(); 596 } 597 598 LogicalResult 599 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 600 PatternRewriter &rewriter) const { 601 // TODO: Support vector.mask. 602 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 603 if (maskableOp.isMasked()) 604 return failure(); 605 606 if (failed(filter(op))) 607 return failure(); 608 609 if (vectorTransformOptions.vectorContractLowering != 610 vector::VectorContractLowering::Dot) 611 return failure(); 612 613 auto iteratorTypes = op.getIteratorTypes().getValue(); 614 static constexpr std::array<int64_t, 2> perm = {1, 0}; 615 Location loc = op.getLoc(); 616 Value lhs = op.getLhs(), rhs = op.getRhs(); 617 618 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 619 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 620 AffineExpr m, n, k; 621 bindDims(rewriter.getContext(), m, n, k); 622 SmallVector<AffineMap> maps = op.getIndexingMapsArray(); 623 // 624 // In the following we wish to make the reduction dimension innermost so we 625 // can load vectors and just fmul + reduce into a scalar. 626 // 627 if (isParallelIterator(iteratorTypes[0]) && 628 isParallelIterator(iteratorTypes[1]) && 629 isReductionIterator(iteratorTypes[2])) { 630 // 631 // Two outer parallel, one inner reduction (matmat flavor). 632 // 633 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 634 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 635 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 636 // No need to permute anything. 637 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 638 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 639 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 640 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 641 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 642 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 643 // This is the classical row-major matmul. Just permute the lhs. 644 Value tmp = lhs; 645 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 646 rhs = tmp; 647 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 648 std::swap(lhs, rhs); 649 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 650 Value tmp = lhs; 651 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 652 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 653 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 654 Value tmp = rhs; 655 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 656 lhs = tmp; 657 } else { 658 return failure(); 659 } 660 } else if (isParallelIterator(iteratorTypes[0]) && 661 isReductionIterator(iteratorTypes[1])) { 662 // 663 // One outer parallel, one inner reduction (matvec flavor) 664 // 665 if (maps == infer({{m, n}, {n}, {m}})) { 666 // No need to permute anything. 667 } else if (maps == infer({{n, m}, {n}, {m}})) { 668 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 669 } else if (maps == infer({{n}, {m, n}, {m}})) { 670 std::swap(lhs, rhs); 671 } else if (maps == infer({{n}, {n, m}, {m}})) { 672 std::swap(lhs, rhs); 673 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 674 } else { 675 return failure(); 676 } 677 } else { 678 return failure(); 679 } 680 681 VectorType dstType = cast<VectorType>(op.getResultType()); 682 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 683 "Expected dst type of rank 1 or 2"); 684 685 unsigned rank = dstType.getRank(); 686 unsigned dstRows = dstType.getShape()[0]; 687 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 688 689 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 690 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 691 rewriter.getZeroAttr(dstType)); 692 bool isInt = isa<IntegerType>(dstType.getElementType()); 693 for (unsigned r = 0; r < dstRows; ++r) { 694 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 695 for (unsigned c = 0; c < dstColumns; ++c) { 696 Value b = rank == 1 697 ? rhs 698 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 699 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 700 Value reduced = rewriter.create<vector::ReductionOp>( 701 op.getLoc(), vector::CombiningKind::ADD, m); 702 703 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 704 : SmallVector<int64_t, 2>{r, c}; 705 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 706 } 707 } 708 if (auto acc = op.getAcc()) 709 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 710 rewriter.replaceOp(op, res); 711 return success(); 712 } 713 714 /// Lower vector.contract with all size one reduction dimensions to 715 /// elementwise ops when possible. 716 struct ContractOpToElementwise 717 : public OpRewritePattern<vector::ContractionOp> { 718 using OpRewritePattern::OpRewritePattern; 719 using FilterConstraintType = 720 std::function<LogicalResult(vector::ContractionOp op)>; 721 static LogicalResult defaultFilter(vector::ContractionOp op) { 722 return success(); 723 } 724 ContractOpToElementwise( 725 vector::VectorTransformsOptions vectorTransformOptions, 726 MLIRContext *context, PatternBenefit benefit = 1, 727 const FilterConstraintType &constraint = defaultFilter) 728 : OpRewritePattern<vector::ContractionOp>(context, benefit), 729 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 730 731 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 732 PatternRewriter &rewriter) const override { 733 // TODO: Support vector.mask. 734 auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation()); 735 if (maskableOp.isMasked()) 736 return failure(); 737 738 if (failed(filter(contractOp))) 739 return failure(); 740 741 if (vectorTransformOptions.vectorContractLowering != 742 vector::VectorContractLowering::ParallelArith) 743 return failure(); 744 745 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape(); 746 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape(); 747 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; 748 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1]; 749 SmallVector<int64_t> lhsReductionDims = 750 getReductionIndex(lhsMap, contractOp.getIteratorTypes()); 751 SmallVector<int64_t> rhsReductionDims = 752 getReductionIndex(rhsMap, contractOp.getIteratorTypes()); 753 // All the reduction dimensions must be a size 1. 754 for (int64_t dim : lhsReductionDims) { 755 if (lhsShape[dim] != 1) 756 return failure(); 757 } 758 for (int64_t dim : rhsReductionDims) { 759 if (rhsShape[dim] != 1) 760 return failure(); 761 } 762 AffineMap accMap = contractOp.getIndexingMapsArray()[2]; 763 unsigned numParallelDims = accMap.getNumResults(); 764 unsigned numLhsDimToBroadcast = 765 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); 766 unsigned numRhsDimToBroadcast = 767 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); 768 SmallVector<int64_t> lhsDims; 769 SmallVector<int64_t> lhsTranspose; 770 SmallVector<int64_t> rhsDims; 771 SmallVector<int64_t> rhsTranspose; 772 for (int64_t dim : lhsReductionDims) 773 lhsTranspose.push_back(numLhsDimToBroadcast + dim); 774 for (int64_t dim : rhsReductionDims) 775 rhsTranspose.push_back(numRhsDimToBroadcast + dim); 776 // Loop through the parallel dimensions to calculate the dimensions to 777 // broadcast and to permute in order to extract only parallel dimensions. 778 for (unsigned i = 0; i < numParallelDims; i++) { 779 std::optional<unsigned> lhsDim = 780 getDimPosition(lhsMap, accMap.getDimPosition(i)); 781 if (lhsDim) { 782 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); 783 } else { 784 // If the parallel dimension doesn't exist we will have to broadcast it. 785 lhsDims.push_back( 786 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 787 lhsTranspose.push_back(lhsDims.size() - 1); 788 } 789 std::optional<unsigned> rhsDim = 790 getDimPosition(rhsMap, accMap.getDimPosition(i)); 791 if (rhsDim) { 792 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); 793 } else { 794 // If the parallel dimension doesn't exist we will have to broadcast it. 795 rhsDims.push_back( 796 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 797 rhsTranspose.push_back(rhsDims.size() - 1); 798 } 799 } 800 Value newLhs = contractOp.getLhs(); 801 Value newRhs = contractOp.getRhs(); 802 Location loc = contractOp.getLoc(); 803 if (!lhsDims.empty()) { 804 lhsDims.append(lhsShape.begin(), lhsShape.end()); 805 auto expandedType = 806 VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); 807 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs); 808 } 809 if (!rhsDims.empty()) { 810 rhsDims.append(rhsShape.begin(), rhsShape.end()); 811 auto expandedType = 812 VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); 813 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs); 814 } 815 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); 816 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose); 817 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose); 818 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0); 819 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0); 820 newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets); 821 newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets); 822 std::optional<Value> result = 823 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), 824 contractOp.getKind(), rewriter, isInt); 825 rewriter.replaceOp(contractOp, {*result}); 826 return success(); 827 } 828 829 private: 830 /// Options to control the vector patterns. 831 vector::VectorTransformsOptions vectorTransformOptions; 832 FilterConstraintType filter; 833 }; 834 835 /// Progressive lowering of ContractionOp. 836 /// One: 837 /// %x = vector.contract with at least one free/batch dimension 838 /// is replaced by: 839 /// %a = vector.contract with one less free/batch dimension 840 /// %b = vector.contract with one less free/batch dimension 841 /// .. 842 /// %x = combine %a %b .. 843 /// until a pure contraction is reached (no free/batch dimensions), 844 /// which is replaced by a dot-product. 845 /// 846 /// This only kicks in when either VectorTransformsOptions is set 847 /// to DOT or when other contraction patterns fail. 848 // 849 // TODO: break down into transpose/reshape/cast ops 850 // when they become available to avoid code dup 851 // TODO: investigate lowering order impact on performance 852 LogicalResult 853 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 854 PatternRewriter &rewriter) const { 855 if (failed(filter(op))) 856 return failure(); 857 858 // TODO: support mixed mode contract lowering. 859 if (op.getLhsType().getElementType() != 860 getElementTypeOrSelf(op.getAccType()) || 861 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 862 return failure(); 863 864 // TODO: the code below assumes the default contraction, make sure it supports 865 // other kinds before enabling this lowering. 866 if (op.getKind() != vector::CombiningKind::ADD) { 867 return rewriter.notifyMatchFailure( 868 op, "contractions other than 'add' not supported"); 869 } 870 871 // TODO: implement benefits, cost models. 872 MLIRContext *ctx = op.getContext(); 873 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 874 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 875 return success(); 876 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 877 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 878 return success(); 879 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 880 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 881 return success(); 882 ContractOpToElementwise pat4(vectorTransformOptions, ctx); 883 if (succeeded(pat4.matchAndRewrite(op, rewriter))) 884 return success(); 885 886 // Vector mask setup. 887 OpBuilder::InsertionGuard guard(rewriter); 888 Operation *rootOp = op; 889 Value mask; 890 if (op.isMasked()) { 891 rewriter.setInsertionPoint(op.getMaskingOp()); 892 rootOp = op.getMaskingOp(); 893 mask = op.getMaskingOp().getMask(); 894 } 895 896 // Find first batch dimension in LHS/RHS, and lower when found. 897 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 898 if (!batchDimMap.empty()) { 899 int64_t lhsIndex = batchDimMap[0].first; 900 int64_t rhsIndex = batchDimMap[0].second; 901 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); 902 if (failed(newOp)) 903 return failure(); 904 rewriter.replaceOp(rootOp, *newOp); 905 return success(); 906 } 907 908 // Collect contracting dimensions. 909 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 910 op.getContractingDimMap(); 911 DenseSet<int64_t> lhsContractingDimSet; 912 DenseSet<int64_t> rhsContractingDimSet; 913 for (auto &dimPair : contractingDimMap) { 914 lhsContractingDimSet.insert(dimPair.first); 915 rhsContractingDimSet.insert(dimPair.second); 916 } 917 918 // Find first free dimension in LHS, and lower when found. 919 VectorType lhsType = op.getLhsType(); 920 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 921 if (lhsContractingDimSet.count(lhsIndex) == 0) { 922 auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); 923 if (failed(newOp)) 924 return failure(); 925 rewriter.replaceOp(rootOp, *newOp); 926 return success(); 927 } 928 } 929 930 // Find first free dimension in RHS, and lower when found. 931 VectorType rhsType = op.getRhsType(); 932 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 933 if (rhsContractingDimSet.count(rhsIndex) == 0) { 934 auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); 935 if (failed(newOp)) 936 return failure(); 937 rewriter.replaceOp(rootOp, *newOp); 938 return success(); 939 } 940 } 941 942 // Lower the first remaining reduction dimension. 943 if (!contractingDimMap.empty()) { 944 auto newOp = lowerReduction(rewriter, op, mask); 945 if (failed(newOp)) 946 return failure(); 947 rewriter.replaceOp(rootOp, *newOp); 948 return success(); 949 } 950 951 return failure(); 952 } 953 954 // Lower one parallel dimension. 955 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions. 956 // TODO: consider reusing existing contract unrolling 957 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, 958 vector::ContractionOp op, 959 int64_t lhsIndex, 960 int64_t rhsIndex, 961 Value mask) const { 962 VectorType lhsType = op.getLhsType(); 963 VectorType rhsType = op.getRhsType(); 964 VectorType resType = cast<VectorType>(op.getResultType()); 965 // Find the iterator type index and result index. 966 SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 967 int64_t iterIndex = -1; 968 int64_t dimSize = -1; 969 if (lhsIndex >= 0) { 970 iterIndex = iMap[0].getDimPosition(lhsIndex); 971 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) 972 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 973 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex 974 << " to map to the same dimension"; 975 }); 976 dimSize = lhsType.getDimSize(lhsIndex); 977 } else if (rhsIndex >= 0) { 978 iterIndex = iMap[1].getDimPosition(rhsIndex); 979 dimSize = rhsType.getDimSize(rhsIndex); 980 } 981 if (iterIndex < 0) 982 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 983 diag << "expected either lhsIndex=" << lhsIndex 984 << " or rhsIndex=" << rhsIndex << " to be nonnegative"; 985 }); 986 // value_or(-1) means that we tolerate a dimension not appearing 987 // in the result map. That can't happen for actual parallel iterators, but 988 // the caller ContractionOpLowering::matchAndRewrite is currently calling 989 // lowerParallel also for the case of unit-size reduction dims appearing only 990 // on one of LHS or RHS, not both. At the moment, such cases are created by 991 // CastAwayContractionLeadingOneDim, so we need to either support that or 992 // modify that pattern. 993 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1); 994 if (resIndex == -1 && dimSize != 1) 995 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 996 diag << "expected the dimension for iterIndex=" << iterIndex 997 << " to either appear in the result map, or to be a unit dimension"; 998 }); 999 1000 // Construct new iterator types and affine map array attribute. 1001 std::array<AffineMap, 3> lowIndexingMaps = { 1002 adjustMap(iMap[0], iterIndex, rewriter), 1003 adjustMap(iMap[1], iterIndex, rewriter), 1004 adjustMap(iMap[2], iterIndex, rewriter)}; 1005 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1006 auto lowIter = 1007 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1008 // Unroll into a series of lower dimensional vector.contract ops. 1009 Location loc = op.getLoc(); 1010 Value result = rewriter.create<arith::ConstantOp>( 1011 loc, resType, rewriter.getZeroAttr(resType)); 1012 1013 for (int64_t d = 0; d < dimSize; ++d) { 1014 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1015 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1016 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 1017 1018 Value lowMask; 1019 if (mask) 1020 lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 1021 iterIndex, d, rewriter); 1022 1023 Operation *lowContract = rewriter.create<vector::ContractionOp>( 1024 loc, lhs, rhs, acc, lowAffine, lowIter); 1025 lowContract = maskOperation(rewriter, lowContract, lowMask); 1026 result = reshapeStore(loc, lowContract->getResult(0), result, resType, 1027 resIndex, d, rewriter); 1028 } 1029 return result; 1030 } 1031 1032 // Lower one reduction dimension. 1033 FailureOr<Value> ContractionOpLowering::lowerReduction( 1034 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { 1035 auto loc = op.getLoc(); 1036 VectorType lhsType = op.getLhsType(); 1037 VectorType rhsType = op.getRhsType(); 1038 Type resType = op.getResultType(); 1039 if (isa<VectorType>(resType)) 1040 return rewriter.notifyMatchFailure(op, 1041 "did not expect a VectorType result"); 1042 bool isInt = isa<IntegerType>(resType); 1043 // Use iterator index 0. 1044 int64_t iterIndex = 0; 1045 SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 1046 std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1047 std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1048 if (!lookupLhs.has_value()) 1049 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1050 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; 1051 }); 1052 if (!lookupRhs.has_value()) 1053 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1054 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; 1055 }); 1056 int64_t lhsIndex = *lookupLhs; 1057 int64_t rhsIndex = *lookupRhs; 1058 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1059 if (dimSize != rhsType.getDimSize(rhsIndex)) 1060 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1061 diag << "expect LHS dimension " << lhsIndex 1062 << " to have the same size as RHS dimension " << rhsIndex; 1063 }); 1064 // Base case. 1065 if (lhsType.getRank() == 1) { 1066 if (rhsType.getRank() != 1) 1067 return rewriter.notifyMatchFailure( 1068 op, "When LHS has rank 1, expected also RHS to have rank 1"); 1069 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 1070 auto kind = vector::CombiningKind::ADD; 1071 1072 Value acc = op.getAcc(); 1073 Operation *reductionOp = 1074 acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc) 1075 : rewriter.create<vector::ReductionOp>(loc, kind, m); 1076 return maskOperation(rewriter, reductionOp, mask)->getResult(0); 1077 } 1078 // Construct new iterator types and affine map array attribute. 1079 std::array<AffineMap, 3> lowIndexingMaps = { 1080 adjustMap(iMap[0], iterIndex, rewriter), 1081 adjustMap(iMap[1], iterIndex, rewriter), 1082 adjustMap(iMap[2], iterIndex, rewriter)}; 1083 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1084 auto lowIter = 1085 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1086 // Unroll into a series of lower dimensional vector.contract ops. 1087 // By feeding the initial accumulator into the first contraction, 1088 // and the result of each contraction into the next, eventually 1089 // the sum of all reductions is computed. 1090 Value result = op.getAcc(); 1091 for (int64_t d = 0; d < dimSize; ++d) { 1092 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1093 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1094 Value newMask; 1095 if (mask) 1096 newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 1097 iterIndex, d, rewriter); 1098 1099 Operation *newContract = rewriter.create<vector::ContractionOp>( 1100 loc, lhs, rhs, result, lowAffine, lowIter); 1101 result = maskOperation(rewriter, newContract, newMask)->getResult(0); 1102 } 1103 return result; 1104 } 1105 1106 /// Progressive lowering of OuterProductOp. 1107 /// One: 1108 /// %x = vector.outerproduct %lhs, %rhs, %acc 1109 /// is replaced by: 1110 /// %z = zero-result 1111 /// %0 = vector.extract %lhs[0] 1112 /// %1 = vector.broadcast %0 1113 /// %2 = vector.extract %acc[0] 1114 /// %3 = vector.fma %1, %rhs, %2 1115 /// %4 = vector.insert %3, %z[0] 1116 /// .. 1117 /// %x = vector.insert %.., %..[N-1] 1118 /// 1119 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 1120 public: 1121 using OpRewritePattern::OpRewritePattern; 1122 1123 LogicalResult matchAndRewrite(vector::OuterProductOp op, 1124 PatternRewriter &rewriter) const override { 1125 auto loc = op.getLoc(); 1126 1127 VectorType lhsType = op.getOperandVectorTypeLHS(); 1128 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS()); 1129 VectorType resType = op.getResultVectorType(); 1130 Type eltType = resType.getElementType(); 1131 bool isInt = isa<IntegerType, IndexType>(eltType); 1132 Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; 1133 vector::CombiningKind kind = op.getKind(); 1134 1135 // Vector mask setup. 1136 OpBuilder::InsertionGuard guard(rewriter); 1137 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); 1138 Operation *rootOp; 1139 Value mask; 1140 if (maskableOp.isMasked()) { 1141 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 1142 rootOp = maskableOp.getMaskingOp(); 1143 mask = maskableOp.getMaskingOp().getMask(); 1144 } else { 1145 rootOp = op; 1146 } 1147 1148 if (!rhsType) { 1149 // Special case: AXPY operation. 1150 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 1151 std::optional<Value> mult = createContractArithOp( 1152 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); 1153 if (!mult.has_value()) 1154 return failure(); 1155 rewriter.replaceOp(rootOp, *mult); 1156 return success(); 1157 } 1158 1159 Value result = rewriter.create<arith::ConstantOp>( 1160 loc, resType, rewriter.getZeroAttr(resType)); 1161 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 1162 Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d); 1163 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 1164 Value r = nullptr; 1165 if (acc) 1166 r = rewriter.create<vector::ExtractOp>(loc, acc, d); 1167 Value extrMask; 1168 if (mask) 1169 extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d); 1170 1171 std::optional<Value> m = createContractArithOp( 1172 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); 1173 if (!m.has_value()) 1174 return failure(); 1175 result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, d); 1176 } 1177 1178 rewriter.replaceOp(rootOp, result); 1179 return success(); 1180 } 1181 }; 1182 1183 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1184 /// semantics to: 1185 /// ``` 1186 /// %mta = maybe_transpose 1187 /// %mtb = maybe_transpose 1188 /// %flattened_a = vector.shape_cast %mta 1189 /// %flattened_b = vector.shape_cast %mtb 1190 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1191 /// %mtd = vector.shape_cast %flattened_d 1192 /// %d = maybe_untranspose %mtd 1193 /// %e = add %c, %d 1194 /// ``` 1195 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1196 // 1197 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1198 /// vector.transpose operations are inserted if the vector.contract op is not a 1199 /// row-major matrix multiply. 1200 LogicalResult 1201 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1202 PatternRewriter &rew) const { 1203 // TODO: Support vector.mask. 1204 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 1205 if (maskableOp.isMasked()) 1206 return failure(); 1207 1208 if (vectorTransformOptions.vectorContractLowering != 1209 vector::VectorContractLowering::Matmul) 1210 return failure(); 1211 if (failed(filter(op))) 1212 return failure(); 1213 1214 auto iteratorTypes = op.getIteratorTypes().getValue(); 1215 if (!isParallelIterator(iteratorTypes[0]) || 1216 !isParallelIterator(iteratorTypes[1]) || 1217 !isReductionIterator(iteratorTypes[2])) 1218 return failure(); 1219 1220 Type elementType = op.getLhsType().getElementType(); 1221 if (!elementType.isIntOrFloat()) 1222 return failure(); 1223 1224 Type dstElementType = op.getType(); 1225 if (auto vecType = dyn_cast<VectorType>(dstElementType)) 1226 dstElementType = vecType.getElementType(); 1227 if (elementType != dstElementType) 1228 return failure(); 1229 1230 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1231 // Bail out if the contraction cannot be put in this form. 1232 MLIRContext *ctx = op.getContext(); 1233 Location loc = op.getLoc(); 1234 AffineExpr m, n, k; 1235 bindDims(rew.getContext(), m, n, k); 1236 // LHS must be A(m, k) or A(k, m). 1237 Value lhs = op.getLhs(); 1238 auto lhsMap = op.getIndexingMapsArray()[0]; 1239 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1240 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1241 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1242 return failure(); 1243 1244 // RHS must be B(k, n) or B(n, k). 1245 Value rhs = op.getRhs(); 1246 auto rhsMap = op.getIndexingMapsArray()[1]; 1247 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1248 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1249 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1250 return failure(); 1251 1252 // At this point lhs and rhs are in row-major. 1253 VectorType lhsType = cast<VectorType>(lhs.getType()); 1254 VectorType rhsType = cast<VectorType>(rhs.getType()); 1255 int64_t lhsRows = lhsType.getDimSize(0); 1256 int64_t lhsColumns = lhsType.getDimSize(1); 1257 int64_t rhsColumns = rhsType.getDimSize(1); 1258 1259 Type flattenedLHSType = 1260 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1261 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1262 1263 Type flattenedRHSType = 1264 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1265 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1266 1267 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1268 rhsColumns); 1269 mul = rew.create<vector::ShapeCastOp>( 1270 loc, 1271 VectorType::get({lhsRows, rhsColumns}, 1272 getElementTypeOrSelf(op.getAcc().getType())), 1273 mul); 1274 1275 // ACC must be C(m, n) or C(n, m). 1276 auto accMap = op.getIndexingMapsArray()[2]; 1277 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1278 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1279 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1280 llvm_unreachable("invalid contraction semantics"); 1281 1282 Value res = 1283 isa<IntegerType>(elementType) 1284 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 1285 : static_cast<Value>( 1286 rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 1287 1288 rew.replaceOp(op, res); 1289 return success(); 1290 } 1291 } // namespace 1292 1293 void mlir::vector::populateVectorContractLoweringPatterns( 1294 RewritePatternSet &patterns, VectorTransformsOptions options, 1295 PatternBenefit benefit, bool disableOuterProductLowering) { 1296 if (!disableOuterProductLowering) 1297 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 1298 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 1299 ContractionOpToOuterProductOpLowering>( 1300 options, patterns.getContext(), benefit); 1301 } 1302 1303 void mlir::vector::populateVectorOuterProductLoweringPatterns( 1304 RewritePatternSet &patterns, PatternBenefit benefit) { 1305 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 1306 } 1307