1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Utils/Utils.h" 14 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/Arith/Utils/Utils.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/Linalg/IR/Linalg.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/Dialect/Tensor/Utils/Utils.h" 28 #include "mlir/Dialect/Utils/IndexingUtils.h" 29 #include "mlir/Dialect/Utils/StaticValueUtils.h" 30 #include "mlir/IR/AffineExpr.h" 31 #include "mlir/IR/AffineExprVisitor.h" 32 #include "mlir/IR/AffineMap.h" 33 #include "mlir/IR/Matchers.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/Pass/Pass.h" 36 #include "llvm/ADT/SetOperations.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 #include "llvm/Support/Debug.h" 39 #include <optional> 40 41 #define DEBUG_TYPE "linalg-utils" 42 43 using namespace mlir; 44 using namespace presburger; 45 using namespace mlir::affine; 46 using namespace mlir::linalg; 47 using namespace mlir::scf; 48 49 namespace { 50 51 // Helper visitor to determine whether an AffineExpr is tiled. 52 // This is achieved by traversing every AffineDimExpr with position `pos` and 53 // checking whether the corresponding `tileSizes[pos]` is non-zero. 54 // This also enforces only positive coefficients occur in multiplications. 55 // 56 // Example: 57 // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] 58 // 59 struct TileCheck : public AffineExprVisitor<TileCheck> { 60 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} 61 62 void visitDimExpr(AffineDimExpr expr) { 63 isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]); 64 } 65 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { 66 visit(expr.getLHS()); 67 visit(expr.getRHS()); 68 if (expr.getKind() == mlir::AffineExprKind::Mul) 69 assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 && 70 "nonpositive multiplying coefficient"); 71 } 72 bool isTiled = false; 73 ArrayRef<OpFoldResult> tileSizes; 74 }; 75 76 } // namespace 77 78 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { 79 if (!expr) 80 return false; 81 TileCheck t(tileSizes); 82 t.visit(expr); 83 return t.isTiled; 84 } 85 86 // Checks whether the `map varies with respect to a non-zero `tileSize`. 87 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { 88 if (!map) 89 return false; 90 for (unsigned r = 0; r < map.getNumResults(); ++r) 91 if (isTiled(map.getResult(r), tileSizes)) 92 return true; 93 return false; 94 } 95 96 std::optional<RegionMatcher::BinaryOpKind> 97 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { 98 auto ®ion = op.getRegion(); 99 if (!llvm::hasSingleElement(region)) 100 return std::nullopt; 101 102 Block &block = region.front(); 103 if (block.getNumArguments() != 2 || 104 !block.getArgument(0).getType().isSignlessIntOrFloat() || 105 !block.getArgument(1).getType().isSignlessIntOrFloat()) 106 return std::nullopt; 107 108 auto &ops = block.getOperations(); 109 if (!llvm::hasSingleElement(block.without_terminator())) 110 return std::nullopt; 111 112 using mlir::matchers::m_Val; 113 auto a = m_Val(block.getArgument(0)); 114 auto b = m_Val(block.getArgument(1)); 115 116 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); 117 if (addPattern.match(&ops.back())) 118 return BinaryOpKind::IAdd; 119 120 return std::nullopt; 121 } 122 123 /// Explicit instantiation of loop nest generator for different loop types. 124 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; 125 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; 126 template struct mlir::linalg::GenerateLoopNest<AffineForOp>; 127 128 /// Given a list of subview ranges, extract individual values for lower, upper 129 /// bounds and steps and put them into the corresponding vectors. 130 static void unpackRanges(OpBuilder &builder, Location loc, 131 ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, 132 SmallVectorImpl<Value> &ubs, 133 SmallVectorImpl<Value> &steps) { 134 for (Range range : ranges) { 135 lbs.emplace_back( 136 getValueOrCreateConstantIndexOp(builder, loc, range.offset)); 137 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size)); 138 steps.emplace_back( 139 getValueOrCreateConstantIndexOp(builder, loc, range.stride)); 140 } 141 } 142 143 //===----------------------------------------------------------------------===// 144 // Utilities for inferring various semantics properties of Linalg ops. 145 //===----------------------------------------------------------------------===// 146 147 DenseSet<int64_t> mlir::linalg::findPermutationsIndexingOperand( 148 LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) { 149 DenseSet<int64_t> res; 150 assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); 151 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); 152 for (AffineExpr e : indexingMap.getResults()) { 153 if (auto d = e.dyn_cast<AffineDimExpr>()) { 154 if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && 155 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { 156 return e.isFunctionOfDim(d.getPosition()); 157 }) == 1) 158 res.insert(d.getPosition()); 159 } 160 } 161 return res; 162 } 163 164 namespace { 165 auto par = utils::IteratorType::parallel; 166 auto red = utils::IteratorType::reduction; 167 } // namespace 168 169 bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) { 170 FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp); 171 if (failed(res)) 172 return false; 173 int64_t numLoops = linalgOp.getNumLoops(); 174 for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) { 175 if (s.contains(numLoops - 3) || s.contains(numLoops - 2) || 176 s.contains(numLoops - 1)) 177 continue; 178 return false; 179 } 180 return true; 181 } 182 183 FailureOr<EmbeddedMatmulDimsCandidates> 184 mlir::linalg::inferMatmulDims(LinalgOp linalgOp) { 185 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) 186 return failure(); 187 188 DenseSet<int64_t> a = findPermutationsIndexingOperand( 189 linalgOp, linalgOp.getDpsInputOperand(0), par); 190 DenseSet<int64_t> b = findPermutationsIndexingOperand( 191 linalgOp, linalgOp.getDpsInputOperand(1), par); 192 DenseSet<int64_t> c = findPermutationsIndexingOperand( 193 linalgOp, linalgOp.getDpsInitOperand(0), par); 194 195 // A & C - B are the iterators involved in an outer-product along A (the LHS). 196 DenseSet<int64_t> ac = a; 197 llvm::set_intersect(ac, c); 198 llvm::set_subtract(ac, b); 199 // B & C - A are the iterators involved in an outer-product along B (the RHS). 200 DenseSet<int64_t> bc = b; 201 llvm::set_intersect(bc, c); 202 llvm::set_subtract(bc, a); 203 204 // Note: if we ever need them, A & B & C would be "batch" dimensions. 205 206 // A & B red are the reduction dimensions. 207 DenseSet<int64_t> ra = findPermutationsIndexingOperand( 208 linalgOp, linalgOp.getDpsInputOperand(0), red); 209 DenseSet<int64_t> rb = findPermutationsIndexingOperand( 210 linalgOp, linalgOp.getDpsInputOperand(1), red); 211 llvm::set_intersect(ra, rb); 212 213 if (ac.empty() || bc.empty() || ra.empty()) 214 return failure(); 215 216 // Pick the first one in each set. 217 // TODO: Better heuristic (e.g pick dims based on packing-based metric). 218 return EmbeddedMatmulDimsCandidates{ac, bc, ra}; 219 } 220 221 //===----------------------------------------------------------------------===// 222 // General utilities 223 //===----------------------------------------------------------------------===// 224 225 namespace mlir { 226 namespace linalg { 227 228 bool allIndexingsAreProjectedPermutation(LinalgOp op) { 229 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) { 230 return m.isProjectedPermutation(/*allowZeroInResults=*/true); 231 }); 232 } 233 234 bool hasOnlyScalarElementwiseOp(Region &r) { 235 if (!llvm::hasSingleElement(r)) 236 return false; 237 for (Operation &op : r.front()) { 238 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, 239 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || 240 OpTrait::hasElementwiseMappableTraits(&op)) || 241 llvm::any_of(op.getResultTypes(), 242 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 243 return false; 244 } 245 return true; 246 } 247 248 bool isElementwise(LinalgOp op) { 249 if (op.getNumLoops() != op.getNumParallelLoops()) 250 return false; 251 252 if (!allIndexingsAreProjectedPermutation(op)) 253 return false; 254 255 // TODO: relax the restrictions on indexing map. 256 for (OpOperand *opOperand : op.getDpsInitOperands()) { 257 if (!op.getMatchingIndexingMap(opOperand).isPermutation()) 258 return false; 259 } 260 return hasOnlyScalarElementwiseOp(op->getRegion(0)); 261 } 262 263 bool isParallelIterator(utils::IteratorType iteratorType) { 264 return iteratorType == utils::IteratorType::parallel; 265 } 266 267 bool isReductionIterator(utils::IteratorType iteratorType) { 268 return iteratorType == utils::IteratorType::reduction; 269 } 270 271 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, 272 Value source, Value pad, bool nofold) { 273 // Exit if `source` is not defined by an ExtractSliceOp. 274 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); 275 if (!sliceOp) 276 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 277 278 // Search the `source` use-def chain for padded LinalgOps. 279 Value current = sliceOp.getSource(); 280 while (current) { 281 auto linalgOp = current.getDefiningOp<LinalgOp>(); 282 if (!linalgOp) 283 break; 284 OpResult opResult = cast<OpResult>(current); 285 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); 286 } 287 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; 288 289 // Exit if the search fails to match a tensor::PadOp at the end of the matched 290 // LinalgOp sequence. 291 if (!padOp) 292 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 293 294 // Exit if the padded result type does not match. 295 if (sliceOp.getSource().getType() != type) 296 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 297 298 // Exit if the LinalgOps are not high padded. 299 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { 300 return getConstantIntValue(ofr) != static_cast<int64_t>(0); 301 })) 302 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 303 304 // Exit if `padOpSliceOp`, which defines the slice used by 305 // `padOp`, is rank-reducing. 306 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); 307 if (!padOpSliceOp || 308 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) 309 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 310 311 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size 312 // of the slice padded by `padOp`. 313 if (llvm::any_of( 314 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), 315 [](std::tuple<OpFoldResult, OpFoldResult> it) { 316 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)); 317 })) 318 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 319 320 // Exit if the padding values do not match. 321 Attribute padOpPadAttr, padAttr; 322 Value padOpPad = padOp.getConstantPaddingValue(); 323 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) || 324 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr) 325 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 326 327 // Return the padded result if the padding values and sizes match. 328 return sliceOp.getSource(); 329 } 330 331 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, 332 Value outputTensor, 333 ArrayRef<int64_t> transposeVector) { 334 auto resultTensorType = cast<RankedTensorType>(outputTensor.getType()); 335 Type elementType = resultTensorType.getElementType(); 336 337 assert(isPermutationVector(transposeVector) && 338 "expect transpose vector to be a permutation"); 339 assert(transposeVector.size() == 340 static_cast<size_t>(resultTensorType.getRank()) && 341 "expect transpose vector size to match result tensor rank"); 342 343 // Compute the transpose and the indentity indexing maps. 344 SmallVector<AffineMap> indexingMaps = { 345 inversePermutation(AffineMap::getPermutationMap( 346 SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()), 347 b.getContext())), 348 AffineMap::getMultiDimIdentityMap(transposeVector.size(), 349 b.getContext())}; 350 SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(), 351 utils::IteratorType::parallel); 352 353 // Create a GenericOp to transpose `inputTensor` into `outputTensor`. 354 auto transposeOp = 355 b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor, 356 indexingMaps, iteratorTypes); 357 Region &body = transposeOp.getRegion(); 358 body.push_back(new Block()); 359 body.front().addArguments({elementType, elementType}, {loc, loc}); 360 361 // Create the body of the transpose operation. 362 OpBuilder::InsertionGuard g(b); 363 b.setInsertionPointToEnd(&body.front()); 364 b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0)); 365 return transposeOp; 366 } 367 368 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { 369 auto memrefTypeTo = cast<MemRefType>(to.getType()); 370 #ifndef NDEBUG 371 auto memrefTypeFrom = cast<MemRefType>(from.getType()); 372 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && 373 "`from` and `to` memref must have the same rank"); 374 #endif // NDEBUG 375 376 AffineMap id = 377 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); 378 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), 379 utils::IteratorType::parallel); 380 return b.create<linalg::GenericOp>( 381 loc, 382 /*inputs=*/from, 383 /*outputs=*/to, 384 /*indexingMaps=*/llvm::ArrayRef({id, id}), 385 /*iteratorTypes=*/iteratorTypes, 386 [](OpBuilder &b, Location loc, ValueRange args) { 387 b.create<linalg::YieldOp>(loc, args.front()); 388 }); 389 } 390 391 /// Specialization to build an scf "for" nest. 392 template <> 393 void GenerateLoopNest<scf::ForOp>::doit( 394 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 395 ArrayRef<utils::IteratorType> iteratorTypes, 396 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 397 ValueRange)> 398 bodyBuilderFn, 399 ArrayRef<linalg::ProcInfo> procInfo) { 400 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 401 "expected as many entries for proc info as number of loops, even if " 402 "they are null entries"); 403 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 404 ? SmallVector<Value>{} 405 : linalgOp.getDpsInitOperands(); 406 407 SmallVector<Value, 4> lbs, ubs, steps; 408 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 409 LoopNest loopNest = mlir::scf::buildLoopNest( 410 b, loc, lbs, ubs, steps, iterArgInitValues, 411 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { 412 assert(iterArgs.size() == iterArgInitValues.size() && 413 "expect the number of output tensors and iter args to match"); 414 SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); 415 if (!iterArgs.empty()) { 416 operandValuesToUse = linalgOp.getDpsInputOperands(); 417 operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); 418 } 419 return bodyBuilderFn(b, loc, ivs, operandValuesToUse); 420 }); 421 422 if (loopNest.loops.empty() || procInfo.empty()) 423 return; 424 425 // Filter out scf.for loops that were created out of parallel dimensions. 426 for (const auto &loop : llvm::enumerate(loopNest.loops)) { 427 if (procInfo[loop.index()].distributionMethod == 428 DistributionMethod::Cyclic) { 429 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, 430 procInfo[loop.index()].nprocs); 431 } 432 } 433 } 434 435 /// Specialization to build affine "for" nest. 436 template <> 437 void GenerateLoopNest<AffineForOp>::doit( 438 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 439 ArrayRef<utils::IteratorType> iteratorTypes, 440 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 441 ValueRange)> 442 bodyBuilderFn, 443 ArrayRef<linalg::ProcInfo> /*procInfo*/) { 444 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 445 ? SmallVector<Value>{} 446 : linalgOp.getDpsInitOperands(); 447 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); 448 SmallVector<Value, 4> lbs, ubs, steps; 449 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 450 451 // Affine loops require constant steps. 452 SmallVector<int64_t, 4> constantSteps; 453 constantSteps.reserve(steps.size()); 454 for (Value v : steps) { 455 auto op = v.getDefiningOp<arith::ConstantIndexOp>(); 456 assert(op && "Affine loops require constant steps"); 457 constantSteps.push_back(op.value()); 458 } 459 460 affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, 461 [&](OpBuilder &b, Location loc, ValueRange ivs) { 462 bodyBuilderFn(b, loc, ivs, 463 linalgOp->getOperands()); 464 }); 465 } 466 467 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. 468 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, 469 Value nprocs, Value &lb, Value &ub, 470 Value &step) { 471 AffineExpr d0, d1; 472 bindDims(b.getContext(), d0, d1); 473 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); 474 lb = 475 affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); 476 step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); 477 } 478 479 /// Generates a loop nest consisting of scf.parallel and scf.for, depending 480 /// on the `iteratorTypes.` Consecutive parallel loops create a single 481 /// scf.parallel operation; each sequential loop creates a new scf.for 482 /// operation. The body of the innermost loop is populated by 483 /// `bodyBuilderFn` that accepts a range of induction variables for all 484 /// loops. `ivStorage` is used to store the partial list of induction 485 /// variables. 486 // TODO: this function can be made iterative instead. However, it 487 // will have at most as many recursive calls as nested loops, which rarely 488 // exceeds 10. 489 static void generateParallelLoopNest( 490 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, 491 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, 492 ArrayRef<linalg::ProcInfo> procInfo, 493 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, 494 SmallVectorImpl<Value> &ivStorage) { 495 assert(lbs.size() == ubs.size()); 496 assert(lbs.size() == steps.size()); 497 assert(lbs.size() == iteratorTypes.size()); 498 assert(procInfo.empty() || (lbs.size() == procInfo.size())); 499 500 // If there are no (more) loops to be generated, generate the body and be 501 // done with it. 502 if (iteratorTypes.empty()) { 503 bodyBuilderFn(b, loc, ivStorage); 504 return; 505 } 506 507 // If there are no outer parallel loops, generate one sequential loop and 508 // recurse. 509 if (!isParallelIterator(iteratorTypes.front())) { 510 LoopNest singleLoop = buildLoopNest( 511 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(), 512 [&](OpBuilder &b, Location loc, ValueRange ivs) { 513 ivStorage.append(ivs.begin(), ivs.end()); 514 generateParallelLoopNest( 515 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), 516 iteratorTypes.drop_front(), 517 procInfo.empty() ? procInfo : procInfo.drop_front(), 518 bodyBuilderFn, ivStorage); 519 }); 520 return; 521 } 522 523 unsigned nLoops = iteratorTypes.size(); 524 unsigned numProcessed = 0; 525 DistributionMethod distributionMethod = DistributionMethod::None; 526 if (procInfo.empty()) { 527 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); 528 } else { 529 distributionMethod = procInfo.front().distributionMethod; 530 numProcessed = 531 nLoops - procInfo 532 .drop_while([&](linalg::ProcInfo p) { 533 return p.distributionMethod == distributionMethod; 534 }) 535 .size(); 536 } 537 538 auto remainderProcInfo = 539 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed); 540 switch (distributionMethod) { 541 case DistributionMethod::None: { 542 // Generate a single parallel loop-nest operation for all outermost 543 // parallel loops and recurse. 544 b.create<scf::ParallelOp>( 545 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 546 steps.take_front(numProcessed), 547 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 548 ivStorage.append(localIvs.begin(), localIvs.end()); 549 generateParallelLoopNest( 550 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 551 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 552 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 553 bodyBuilderFn, ivStorage); 554 }); 555 return; 556 } 557 case DistributionMethod::Cyclic: { 558 // Generate a single parallel loop-nest operation for all outermost 559 // parallel loops and recurse. 560 b.create<scf::ParallelOp>( 561 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 562 steps.take_front(numProcessed), 563 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 564 ivStorage.append(localIvs.begin(), localIvs.end()); 565 generateParallelLoopNest( 566 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 567 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 568 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 569 bodyBuilderFn, ivStorage); 570 }); 571 return; 572 } 573 case DistributionMethod::CyclicNumProcsGeNumIters: { 574 // Check (for the processed loops) that the iteration is in-bounds. 575 ArithBuilder ab(b, loc); 576 Value cond = ab.slt(lbs[0], ubs[0]); 577 for (unsigned i = 1; i < numProcessed; ++i) 578 cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); 579 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 580 b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { 581 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), 582 ubs.drop_front(numProcessed), 583 steps.drop_front(numProcessed), 584 iteratorTypes.drop_front(numProcessed), 585 remainderProcInfo, bodyBuilderFn, ivStorage); 586 b.create<scf::YieldOp>(loc, ValueRange{}); 587 }); 588 return; 589 } 590 case DistributionMethod::CyclicNumProcsEqNumIters: 591 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed 592 // with inner loop generation. 593 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 594 generateParallelLoopNest( 595 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), 596 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), 597 remainderProcInfo, bodyBuilderFn, ivStorage); 598 return; 599 } 600 } 601 602 /// Specialization for generating a mix of parallel and sequential scf loops. 603 template <> 604 void GenerateLoopNest<scf::ParallelOp>::doit( 605 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 606 ArrayRef<utils::IteratorType> iteratorTypes, 607 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 608 ValueRange)> 609 bodyBuilderFn, 610 ArrayRef<linalg::ProcInfo> procInfo) { 611 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 612 ? SmallVector<Value>{} 613 : linalgOp.getDpsInitOperands(); 614 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); 615 // This function may be passed more iterator types than ranges. 616 assert(iteratorTypes.size() >= loopRanges.size() && 617 "expected iterator type for all ranges"); 618 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 619 "expected proc information for all loops when present"); 620 iteratorTypes = iteratorTypes.take_front(loopRanges.size()); 621 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; 622 unsigned numLoops = iteratorTypes.size(); 623 ivs.reserve(numLoops); 624 lbsStorage.reserve(numLoops); 625 ubsStorage.reserve(numLoops); 626 stepsStorage.reserve(numLoops); 627 628 // Get the loop lb, ub, and step. 629 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage); 630 631 // Modify the lb, ub, and step based on the distribution options. 632 for (const auto &it : llvm::enumerate(procInfo)) { 633 if (it.value().distributionMethod != linalg::DistributionMethod::None) { 634 updateBoundsForCyclicDistribution( 635 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()], 636 ubsStorage[it.index()], stepsStorage[it.index()]); 637 } 638 } 639 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); 640 generateParallelLoopNest( 641 b, loc, lbs, ubs, steps, iteratorTypes, procInfo, 642 [&](OpBuilder &b, Location loc, ValueRange ivs) { 643 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); 644 }, 645 ivs); 646 647 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); 648 } 649 650 static Value materializeTiledShape(OpBuilder &builder, Location loc, 651 Value valueToTile, 652 const SliceParameters &sliceParams) { 653 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); 654 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) 655 .Case([&](MemRefType) { 656 return builder.create<memref::SubViewOp>( 657 loc, valueToTile, sliceParams.offsets, 658 sliceParams.sizes, sliceParams.strides); 659 }) 660 .Case([&](RankedTensorType) { 661 return builder.create<tensor::ExtractSliceOp>( 662 loc, valueToTile, sliceParams.offsets, 663 sliceParams.sizes, sliceParams.strides); 664 }) 665 .Default([](ShapedType) -> Operation * { 666 llvm_unreachable("Unexpected shaped type"); 667 }); 668 return sliceOp->getResult(0); 669 } 670 671 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, 672 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 673 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 674 ArrayRef<OpFoldResult> subShapeSizes, 675 bool omitPartialTileCheck) { 676 SliceParameters sliceParams = 677 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, 678 ubs, subShapeSizes, omitPartialTileCheck); 679 return materializeTiledShape(builder, loc, valueToTile, sliceParams); 680 } 681 682 SliceParameters 683 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, 684 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 685 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 686 ArrayRef<OpFoldResult> subShapeSizes, 687 bool omitPartialTileCheck) { 688 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); 689 assert(shapedType && "only shaped types can be tiled"); 690 ArrayRef<int64_t> shape = shapedType.getShape(); 691 int64_t rank = shapedType.getRank(); 692 693 // Compute offsets/sizes/strides for the tile. 694 SliceParameters sliceParams; 695 sliceParams.offsets.reserve(rank); 696 sliceParams.sizes.reserve(rank); 697 sliceParams.strides.reserve(rank); 698 for (unsigned r = 0; r < rank; ++r) { 699 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); 700 if (!isTiled(map.getSubMap({r}), tileSizes)) { 701 sliceParams.offsets.push_back(builder.getIndexAttr(0)); 702 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r); 703 sliceParams.sizes.push_back(dim); 704 sliceParams.strides.push_back(builder.getIndexAttr(1)); 705 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); 706 continue; 707 } 708 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n"); 709 710 // Tiling creates a new slice at the proper index, the slice step is 1 711 // (i.e. the op does not subsample, stepping occurs in the loop). 712 auto m = map.getSubMap({r}); 713 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n"); 714 IRRewriter rewriter(builder); 715 OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs); 716 sliceParams.offsets.push_back(offset); 717 OpFoldResult closedIntSize = 718 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes); 719 // Resulting size needs to be made half open interval again. 720 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext()); 721 OpFoldResult size = 722 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize); 723 LLVM_DEBUG(llvm::dbgs() 724 << "computeSliceParameters: raw size: " << size << "\n"); 725 LLVM_DEBUG(llvm::dbgs() 726 << "computeSliceParameters: new offset: " << offset << "\n"); 727 sliceParams.strides.push_back(builder.getIndexAttr(1)); 728 729 if (omitPartialTileCheck) { 730 // We statically know that the partial/boundary tile condition is 731 // unnecessary. 732 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 733 sliceParams.sizes.push_back(size); 734 continue; 735 } 736 737 // The size of the subview / extract_slice should be trimmed to avoid 738 // out-of-bounds accesses, unless: 739 // a. We statically know the subshape size divides the shape size evenly. 740 // b. The subshape size is 1. According to the way the loops are set up, 741 // tensors with "0" dimensions would never be constructed. 742 int64_t shapeSize = shape[r]; 743 std::optional<int64_t> sizeCst = getConstantIntValue(size); 744 auto hasTileSizeOne = sizeCst && *sizeCst == 1; 745 auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && 746 ((shapeSize % *sizeCst) == 0); 747 if (!hasTileSizeOne && !dividesEvenly) { 748 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize 749 << ", size: " << size 750 << ": make sure in bound with affine.min\n"); 751 752 AffineExpr dim0, dim1, dim2; 753 bindDims(builder.getContext(), dim0, dim1, dim2); 754 755 // Get the dimension size for this dimension. We need to first calculate 756 // the max index and then plus one. This is important because for 757 // convolution ops, we have its input window dimension's affine map of the 758 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window 759 // dimension and `s0` is stride. Directly use the dimension size of 760 // output/filer window dimensions will cause incorrect calculation. 761 AffineMap minusOneMap = 762 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}}) 763 .front(); 764 AffineMap plusOneMap = 765 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}}) 766 .front(); 767 SmallVector<OpFoldResult> maxIndices = 768 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) { 769 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap, 770 {ub}); 771 })); 772 OpFoldResult maxIndex = 773 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices); 774 OpFoldResult d = 775 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex}); 776 777 // Compute min(dim - offset, size) to avoid out-of-bounds accesses. 778 AffineMap minMap = AffineMap::inferFromExprList( 779 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}) 780 .front(); 781 size = 782 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset}); 783 } 784 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 785 sliceParams.sizes.push_back(size); 786 } 787 return sliceParams; 788 } 789 790 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, 791 ArrayRef<OpFoldResult> ivs, 792 ArrayRef<OpFoldResult> tileSizes) { 793 SmallVector<OpFoldResult> offsets; 794 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { 795 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); 796 bool isTiled = !isZeroIndex(tileSizes[idx]); 797 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); 798 LLVM_DEBUG(llvm::dbgs() 799 << "computeTileOffsets: " << offsets.back() << "\n"); 800 } 801 return offsets; 802 } 803 804 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, 805 ArrayRef<OpFoldResult> tileSizes, 806 ArrayRef<OpFoldResult> sizeBounds) { 807 SmallVector<OpFoldResult> sizes; 808 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { 809 bool isTiled = !isZeroIndex(tileSizes[idx]); 810 // Before composing, we need to make range a closed interval. 811 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; 812 AffineExpr d0 = getAffineDimExpr(0, b.getContext()); 813 IRRewriter rewriter(b); 814 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size)); 815 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n"); 816 } 817 return sizes; 818 } 819 820 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { 821 if (op.hasBufferSemantics()) 822 return {}; 823 return llvm::to_vector( 824 llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) { 825 return operands[opOperand->getOperandNumber()].getType(); 826 })); 827 } 828 829 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, 830 LinalgOp op, ValueRange operands, 831 ValueRange results) { 832 if (op.hasBufferSemantics()) 833 return {}; 834 SmallVector<Value> tensorResults; 835 tensorResults.reserve(results.size()); 836 // Insert a insert_slice for each output tensor. 837 unsigned resultIdx = 0; 838 for (OpOperand *opOperand : op.getDpsInitOperands()) { 839 // TODO: use an interface/adaptor to avoid leaking position in 840 // `tiledOperands`. 841 Value outputTensor = operands[opOperand->getOperandNumber()]; 842 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { 843 Value inserted = builder.create<tensor::InsertSliceOp>( 844 loc, sliceOp.getSource().getType(), results[resultIdx], 845 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), 846 sliceOp.getStrides(), sliceOp.getStaticOffsets(), 847 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 848 tensorResults.push_back(inserted); 849 } else { 850 tensorResults.push_back(results[resultIdx]); 851 } 852 ++resultIdx; 853 } 854 return tensorResults; 855 } 856 857 SmallVector<std::optional<SliceParameters>> 858 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, 859 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, 860 ArrayRef<OpFoldResult> tileSizes, 861 ArrayRef<OpFoldResult> sizeBounds, 862 bool omitPartialTileCheck) { 863 assert(ivs.size() == static_cast<size_t>(llvm::count_if( 864 llvm::make_range(tileSizes.begin(), tileSizes.end()), 865 [](OpFoldResult v) { return !isZeroIndex(v); })) && 866 "expected as many ivs as non-zero sizes"); 867 868 // Construct (potentially temporary) mins and maxes on which to apply maps 869 // that define tile subshapes. 870 SmallVector<OpFoldResult> lbs = 871 computeTileOffsets(builder, loc, ivs, tileSizes); 872 SmallVector<OpFoldResult> subShapeSizes = 873 computeTileSizes(builder, loc, tileSizes, sizeBounds); 874 875 assert(static_cast<int64_t>(valuesToTile.size()) <= 876 linalgOp->getNumOperands() && 877 "more value to tile than operands."); 878 SmallVector<std::optional<SliceParameters>> allSliceParams; 879 allSliceParams.reserve(valuesToTile.size()); 880 for (auto [opOperand, val] : 881 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { 882 Value shapedOp = val; 883 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); 884 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); 885 // Use `opOperand` as is if it is not tiled and not an output tensor. Having 886 // an extract/insert slice pair for all output tensors simplifies follow up 887 // transformations such as padding and bufferization since the 888 // extract/insert slice pairs make the accessed iteration argument 889 // subdomains explicit. 890 891 Type operandType = opOperand.get().getType(); 892 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) && 893 linalgOp.isDpsInit(&opOperand))) { 894 allSliceParams.push_back(std::nullopt); 895 LLVM_DEBUG(llvm::dbgs() 896 << ": not tiled: use shape: " << operandType << "\n"); 897 continue; 898 } 899 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); 900 901 allSliceParams.push_back(computeSliceParameters( 902 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, 903 omitPartialTileCheck)); 904 } 905 906 return allSliceParams; 907 } 908 909 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, 910 LinalgOp linalgOp, ValueRange valuesToTile, 911 ArrayRef<OpFoldResult> ivs, 912 ArrayRef<OpFoldResult> tileSizes, 913 ArrayRef<OpFoldResult> sizeBounds, 914 bool omitPartialTileCheck) { 915 SmallVector<std::optional<SliceParameters>> allSliceParameter = 916 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, 917 tileSizes, sizeBounds, omitPartialTileCheck); 918 SmallVector<Value> tiledShapes; 919 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { 920 Value valueToTile = std::get<0>(item); 921 std::optional<SliceParameters> sliceParams = std::get<1>(item); 922 tiledShapes.push_back( 923 sliceParams.has_value() 924 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) 925 : valueToTile); 926 } 927 return tiledShapes; 928 } 929 930 void offsetIndices(OpBuilder &b, LinalgOp linalgOp, 931 ArrayRef<OpFoldResult> offsets) { 932 IRRewriter rewriter(b); 933 offsetIndices(rewriter, linalgOp, offsets); 934 } 935 936 void offsetIndices(RewriterBase &b, LinalgOp linalgOp, 937 ArrayRef<OpFoldResult> offsets) { 938 if (!linalgOp.hasIndexSemantics()) 939 return; 940 941 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { 942 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) 943 continue; 944 OpBuilder::InsertionGuard guard(b); 945 b.setInsertionPointAfter(indexOp); 946 AffineExpr index, offset; 947 bindDims(b.getContext(), index, offset); 948 OpFoldResult applied = makeComposedFoldedAffineApply( 949 b, indexOp.getLoc(), index + offset, 950 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); 951 Value materialized = 952 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); 953 b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) { 954 return use.getOwner() != materialized.getDefiningOp(); 955 }); 956 } 957 } 958 959 /// Get the reassociation maps to fold the result of a extract_slice (or source 960 /// of a insert_slice) operation with given offsets, and sizes to its 961 /// rank-reduced version. This is only done for the cases where the size is 1 962 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 963 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 964 /// potentially cannot be handled). 965 std::optional<SmallVector<ReassociationIndices>> 966 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 967 SmallVector<ReassociationIndices> reassociation; 968 ReassociationIndices curr; 969 for (const auto &it : llvm::enumerate(mixedSizes)) { 970 auto dim = it.index(); 971 auto size = it.value(); 972 curr.push_back(dim); 973 auto attr = size.dyn_cast<Attribute>(); 974 if (attr && cast<IntegerAttr>(attr).getInt() == 1) 975 continue; 976 reassociation.emplace_back(ReassociationIndices{}); 977 std::swap(reassociation.back(), curr); 978 } 979 // When the reassociations are not empty, then fold the remaining 980 // unit-dimensions into the last dimension. If the reassociations so far is 981 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 982 if (!curr.empty() && !reassociation.empty()) 983 reassociation.back().append(curr.begin(), curr.end()); 984 return reassociation; 985 } 986 987 /// Return the identity numeric value associated to the give op. 988 std::optional<TypedAttr> getNeutralElement(Operation *op) { 989 // Builder only used as helper for attribute creation. 990 OpBuilder b(op->getContext()); 991 Type resultType = op->getResult(0).getType(); 992 if (auto floatType = dyn_cast<FloatType>(resultType)) { 993 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 994 if (isa<arith::AddFOp>(op)) 995 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 996 if (isa<arith::MulFOp>(op)) 997 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 998 if (isa<arith::MaxFOp>(op)) 999 return b.getFloatAttr(resultType, 1000 llvm::APFloat::getInf(semantic, /*Negative=*/true)); 1001 if (isa<arith::MinFOp>(op)) 1002 return b.getFloatAttr( 1003 resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false)); 1004 return std::nullopt; 1005 } 1006 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 1007 return b.getIntegerAttr(resultType, 0); 1008 if (isa<arith::AndIOp>(op)) 1009 return b.getIntegerAttr(resultType, -1); 1010 if (isa<arith::MaxSIOp>(op)) 1011 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 1012 if (isa<arith::MinSIOp>(op)) 1013 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 1014 if (isa<arith::MulIOp>(op)) 1015 return b.getIntegerAttr(resultType, 1); 1016 return std::nullopt; 1017 } 1018 1019 } // namespace linalg 1020 } // namespace mlir 1021