1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Utils/Utils.h" 14 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/Arith/Utils/Utils.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/Linalg/IR/Linalg.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/Dialect/Tensor/Utils/Utils.h" 28 #include "mlir/Dialect/Utils/IndexingUtils.h" 29 #include "mlir/Dialect/Utils/StaticValueUtils.h" 30 #include "mlir/IR/AffineExpr.h" 31 #include "mlir/IR/AffineExprVisitor.h" 32 #include "mlir/IR/AffineMap.h" 33 #include "mlir/IR/Matchers.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/Pass/Pass.h" 36 #include "llvm/ADT/SetOperations.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 #include "llvm/Support/Debug.h" 39 #include <optional> 40 41 #define DEBUG_TYPE "linalg-utils" 42 43 using namespace mlir; 44 using namespace presburger; 45 using namespace mlir::linalg; 46 using namespace mlir::scf; 47 48 namespace { 49 50 // Helper visitor to determine whether an AffineExpr is tiled. 51 // This is achieved by traversing every AffineDimExpr with position `pos` and 52 // checking whether the corresponding `tileSizes[pos]` is non-zero. 53 // This also enforces only positive coefficients occur in multiplications. 54 // 55 // Example: 56 // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] 57 // 58 struct TileCheck : public AffineExprVisitor<TileCheck> { 59 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} 60 61 void visitDimExpr(AffineDimExpr expr) { 62 isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]); 63 } 64 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { 65 visit(expr.getLHS()); 66 visit(expr.getRHS()); 67 if (expr.getKind() == mlir::AffineExprKind::Mul) 68 assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 && 69 "nonpositive multiplying coefficient"); 70 } 71 bool isTiled = false; 72 ArrayRef<OpFoldResult> tileSizes; 73 }; 74 75 } // namespace 76 77 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { 78 if (!expr) 79 return false; 80 TileCheck t(tileSizes); 81 t.visit(expr); 82 return t.isTiled; 83 } 84 85 // Checks whether the `map varies with respect to a non-zero `tileSize`. 86 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { 87 if (!map) 88 return false; 89 for (unsigned r = 0; r < map.getNumResults(); ++r) 90 if (isTiled(map.getResult(r), tileSizes)) 91 return true; 92 return false; 93 } 94 95 std::optional<RegionMatcher::BinaryOpKind> 96 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { 97 auto ®ion = op.getRegion(); 98 if (!llvm::hasSingleElement(region)) 99 return std::nullopt; 100 101 Block &block = region.front(); 102 if (block.getNumArguments() != 2 || 103 !block.getArgument(0).getType().isSignlessIntOrFloat() || 104 !block.getArgument(1).getType().isSignlessIntOrFloat()) 105 return std::nullopt; 106 107 auto &ops = block.getOperations(); 108 if (!llvm::hasSingleElement(block.without_terminator())) 109 return std::nullopt; 110 111 using mlir::matchers::m_Val; 112 auto a = m_Val(block.getArgument(0)); 113 auto b = m_Val(block.getArgument(1)); 114 115 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); 116 if (addPattern.match(&ops.back())) 117 return BinaryOpKind::IAdd; 118 119 return std::nullopt; 120 } 121 122 /// Explicit instantiation of loop nest generator for different loop types. 123 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; 124 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; 125 template struct mlir::linalg::GenerateLoopNest<AffineForOp>; 126 127 /// Given a list of subview ranges, extract individual values for lower, upper 128 /// bounds and steps and put them into the corresponding vectors. 129 static void unpackRanges(OpBuilder &builder, Location loc, 130 ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, 131 SmallVectorImpl<Value> &ubs, 132 SmallVectorImpl<Value> &steps) { 133 for (Range range : ranges) { 134 lbs.emplace_back( 135 getValueOrCreateConstantIndexOp(builder, loc, range.offset)); 136 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size)); 137 steps.emplace_back( 138 getValueOrCreateConstantIndexOp(builder, loc, range.stride)); 139 } 140 } 141 142 //===----------------------------------------------------------------------===// 143 // Utilities for inferring various semantics properties of Linalg ops. 144 //===----------------------------------------------------------------------===// 145 146 DenseSet<int64_t> mlir::linalg::findPermutationsIndexingOperand( 147 LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) { 148 DenseSet<int64_t> res; 149 assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); 150 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); 151 for (AffineExpr e : indexingMap.getResults()) { 152 if (auto d = e.dyn_cast<AffineDimExpr>()) { 153 if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && 154 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { 155 return e.isFunctionOfDim(d.getPosition()); 156 }) == 1) 157 res.insert(d.getPosition()); 158 } 159 } 160 return res; 161 } 162 163 namespace { 164 auto par = utils::IteratorType::parallel; 165 auto red = utils::IteratorType::reduction; 166 } // namespace 167 168 bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) { 169 FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp); 170 if (failed(res)) 171 return false; 172 int64_t numLoops = linalgOp.getNumLoops(); 173 for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) { 174 if (s.contains(numLoops - 3) || s.contains(numLoops - 2) || 175 s.contains(numLoops - 1)) 176 continue; 177 return false; 178 } 179 return true; 180 } 181 182 FailureOr<EmbeddedMatmulDimsCandidates> 183 mlir::linalg::inferMatmulDims(LinalgOp linalgOp) { 184 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) 185 return failure(); 186 187 DenseSet<int64_t> a = findPermutationsIndexingOperand( 188 linalgOp, linalgOp.getDpsInputOperand(0), par); 189 DenseSet<int64_t> b = findPermutationsIndexingOperand( 190 linalgOp, linalgOp.getDpsInputOperand(1), par); 191 DenseSet<int64_t> c = findPermutationsIndexingOperand( 192 linalgOp, linalgOp.getDpsInitOperand(0), par); 193 194 // A & C - B are the iterators involved in an outer-product along A (the LHS). 195 DenseSet<int64_t> ac = a; 196 llvm::set_intersect(ac, c); 197 llvm::set_subtract(ac, b); 198 // B & C - A are the iterators involved in an outer-product along B (the RHS). 199 DenseSet<int64_t> bc = b; 200 llvm::set_intersect(bc, c); 201 llvm::set_subtract(bc, a); 202 203 // Note: if we ever need them, A & B & C would be "batch" dimensions. 204 205 // A & B red are the reduction dimensions. 206 DenseSet<int64_t> ra = findPermutationsIndexingOperand( 207 linalgOp, linalgOp.getDpsInputOperand(0), red); 208 DenseSet<int64_t> rb = findPermutationsIndexingOperand( 209 linalgOp, linalgOp.getDpsInputOperand(1), red); 210 llvm::set_intersect(ra, rb); 211 212 if (ac.empty() || bc.empty() || ra.empty()) 213 return failure(); 214 215 // Pick the first one in each set. 216 // TODO: Better heuristic (e.g pick dims based on packing-based metric). 217 return EmbeddedMatmulDimsCandidates{ac, bc, ra}; 218 } 219 220 //===----------------------------------------------------------------------===// 221 // General utilities 222 //===----------------------------------------------------------------------===// 223 224 namespace mlir { 225 namespace linalg { 226 227 bool allIndexingsAreProjectedPermutation(LinalgOp op) { 228 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) { 229 return m.isProjectedPermutation(/*allowZeroInResults=*/true); 230 }); 231 } 232 233 bool hasOnlyScalarElementwiseOp(Region &r) { 234 if (!llvm::hasSingleElement(r)) 235 return false; 236 for (Operation &op : r.front()) { 237 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, 238 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || 239 OpTrait::hasElementwiseMappableTraits(&op)) || 240 llvm::any_of(op.getResultTypes(), 241 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 242 return false; 243 } 244 return true; 245 } 246 247 bool isElementwise(LinalgOp op) { 248 if (op.getNumLoops() != op.getNumParallelLoops()) 249 return false; 250 251 if (!allIndexingsAreProjectedPermutation(op)) 252 return false; 253 254 // TODO: relax the restrictions on indexing map. 255 for (OpOperand *opOperand : op.getDpsInitOperands()) { 256 if (!op.getMatchingIndexingMap(opOperand).isPermutation()) 257 return false; 258 } 259 return hasOnlyScalarElementwiseOp(op->getRegion(0)); 260 } 261 262 bool isParallelIterator(utils::IteratorType iteratorType) { 263 return iteratorType == utils::IteratorType::parallel; 264 } 265 266 bool isReductionIterator(utils::IteratorType iteratorType) { 267 return iteratorType == utils::IteratorType::reduction; 268 } 269 270 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on 271 /// the type of `source`. 272 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { 273 if (source.getType().isa<UnrankedMemRefType, MemRefType>()) 274 return b.createOrFold<memref::DimOp>(loc, source, dim); 275 if (source.getType().isa<UnrankedTensorType, RankedTensorType>()) 276 return b.createOrFold<tensor::DimOp>(loc, source, dim); 277 llvm_unreachable("Expected MemRefType or TensorType"); 278 } 279 280 OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, 281 int64_t dim) { 282 auto shapedType = source.getType().cast<ShapedType>(); 283 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) 284 return createOrFoldDimOp(b, loc, source, dim); 285 return b.getIndexAttr(shapedType.getDimSize(dim)); 286 } 287 288 /// Given an operation, retrieves the value of each dynamic dimension through 289 /// constructing the necessary DimOp operators. 290 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) { 291 SmallVector<Value, 4> dynOperands; 292 auto shapedType = val.getType().cast<ShapedType>(); 293 for (const auto &dim : llvm::enumerate(shapedType.getShape())) { 294 if (dim.value() == ShapedType::kDynamic) 295 dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index())); 296 } 297 return dynOperands; 298 } 299 300 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, 301 Value source, Value pad, bool nofold) { 302 // Exit if `source` is not defined by an ExtractSliceOp. 303 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); 304 if (!sliceOp) 305 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 306 307 // Search the `source` use-def chain for padded LinalgOps. 308 Value current = sliceOp.getSource(); 309 while (current) { 310 auto linalgOp = current.getDefiningOp<LinalgOp>(); 311 if (!linalgOp) 312 break; 313 OpResult opResult = current.cast<OpResult>(); 314 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); 315 } 316 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; 317 318 // Exit if the search fails to match a tensor::PadOp at the end of the matched 319 // LinalgOp sequence. 320 if (!padOp) 321 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 322 323 // Exit if the padded result type does not match. 324 if (sliceOp.getSource().getType() != type) 325 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 326 327 // Exit if the LinalgOps are not high padded. 328 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { 329 return getConstantIntValue(ofr) != static_cast<int64_t>(0); 330 })) 331 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 332 333 // Exit if `padOpSliceOp`, which defines the slice used by 334 // `padOp`, is rank-reducing. 335 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); 336 if (!padOpSliceOp || 337 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) 338 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 339 340 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size 341 // of the slice padded by `padOp`. 342 if (llvm::any_of( 343 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), 344 [](std::tuple<OpFoldResult, OpFoldResult> it) { 345 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)); 346 })) 347 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 348 349 // Exit if the padding values do not match. 350 Attribute padOpPadAttr, padAttr; 351 Value padOpPad = padOp.getConstantPaddingValue(); 352 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) || 353 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr) 354 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 355 356 // Return the padded result if the padding values and sizes match. 357 return sliceOp.getSource(); 358 } 359 360 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, 361 Value outputTensor, 362 ArrayRef<int64_t> transposeVector) { 363 auto resultTensorType = outputTensor.getType().cast<RankedTensorType>(); 364 Type elementType = resultTensorType.getElementType(); 365 366 assert(isPermutationVector(transposeVector) && 367 "expect transpose vector to be a permutation"); 368 assert(transposeVector.size() == 369 static_cast<size_t>(resultTensorType.getRank()) && 370 "expect transpose vector size to match result tensor rank"); 371 372 // Compute the transpose and the indentity indexing maps. 373 SmallVector<AffineMap> indexingMaps = { 374 inversePermutation(AffineMap::getPermutationMap( 375 SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()), 376 b.getContext())), 377 AffineMap::getMultiDimIdentityMap(transposeVector.size(), 378 b.getContext())}; 379 SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(), 380 utils::IteratorType::parallel); 381 382 // Create a GenericOp to transpose `inputTensor` into `outputTensor`. 383 auto transposeOp = 384 b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor, 385 indexingMaps, iteratorTypes); 386 Region &body = transposeOp.getRegion(); 387 body.push_back(new Block()); 388 body.front().addArguments({elementType, elementType}, {loc, loc}); 389 390 // Create the body of the transpose operation. 391 OpBuilder::InsertionGuard g(b); 392 b.setInsertionPointToEnd(&body.front()); 393 b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0)); 394 return transposeOp; 395 } 396 397 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { 398 auto memrefTypeTo = to.getType().cast<MemRefType>(); 399 #ifndef NDEBUG 400 auto memrefTypeFrom = from.getType().cast<MemRefType>(); 401 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && 402 "`from` and `to` memref must have the same rank"); 403 #endif // NDEBUG 404 405 AffineMap id = 406 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); 407 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), 408 utils::IteratorType::parallel); 409 return b.create<linalg::GenericOp>( 410 loc, 411 /*inputs=*/from, 412 /*outputs=*/to, 413 /*indexingMaps=*/llvm::ArrayRef({id, id}), 414 /*iteratorTypes=*/iteratorTypes, 415 [](OpBuilder &b, Location loc, ValueRange args) { 416 b.create<linalg::YieldOp>(loc, args.front()); 417 }); 418 } 419 420 /// Specialization to build an scf "for" nest. 421 template <> 422 void GenerateLoopNest<scf::ForOp>::doit( 423 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 424 ArrayRef<utils::IteratorType> iteratorTypes, 425 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 426 ValueRange)> 427 bodyBuilderFn, 428 ArrayRef<linalg::ProcInfo> procInfo) { 429 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 430 "expected as many entries for proc info as number of loops, even if " 431 "they are null entries"); 432 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 433 ? SmallVector<Value>{} 434 : linalgOp.getDpsInitOperands(); 435 436 SmallVector<Value, 4> lbs, ubs, steps; 437 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 438 LoopNest loopNest = mlir::scf::buildLoopNest( 439 b, loc, lbs, ubs, steps, iterArgInitValues, 440 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { 441 assert(iterArgs.size() == iterArgInitValues.size() && 442 "expect the number of output tensors and iter args to match"); 443 SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); 444 if (!iterArgs.empty()) { 445 operandValuesToUse = linalgOp.getDpsInputOperands(); 446 operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); 447 } 448 return bodyBuilderFn(b, loc, ivs, operandValuesToUse); 449 }); 450 451 if (loopNest.loops.empty() || procInfo.empty()) 452 return; 453 454 // Filter out scf.for loops that were created out of parallel dimensions. 455 for (const auto &loop : llvm::enumerate(loopNest.loops)) { 456 if (procInfo[loop.index()].distributionMethod == 457 DistributionMethod::Cyclic) { 458 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, 459 procInfo[loop.index()].nprocs); 460 } 461 } 462 } 463 464 /// Specialization to build affine "for" nest. 465 template <> 466 void GenerateLoopNest<AffineForOp>::doit( 467 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 468 ArrayRef<utils::IteratorType> iteratorTypes, 469 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 470 ValueRange)> 471 bodyBuilderFn, 472 ArrayRef<linalg::ProcInfo> /*procInfo*/) { 473 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 474 ? SmallVector<Value>{} 475 : linalgOp.getDpsInitOperands(); 476 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); 477 SmallVector<Value, 4> lbs, ubs, steps; 478 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 479 480 // Affine loops require constant steps. 481 SmallVector<int64_t, 4> constantSteps; 482 constantSteps.reserve(steps.size()); 483 for (Value v : steps) { 484 auto op = v.getDefiningOp<arith::ConstantIndexOp>(); 485 assert(op && "Affine loops require constant steps"); 486 constantSteps.push_back(op.value()); 487 } 488 489 mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, 490 [&](OpBuilder &b, Location loc, ValueRange ivs) { 491 bodyBuilderFn(b, loc, ivs, 492 linalgOp->getOperands()); 493 }); 494 } 495 496 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. 497 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, 498 Value nprocs, Value &lb, Value &ub, 499 Value &step) { 500 AffineExpr d0, d1; 501 bindDims(b.getContext(), d0, d1); 502 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); 503 lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); 504 step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); 505 } 506 507 /// Generates a loop nest consisting of scf.parallel and scf.for, depending 508 /// on the `iteratorTypes.` Consecutive parallel loops create a single 509 /// scf.parallel operation; each sequential loop creates a new scf.for 510 /// operation. The body of the innermost loop is populated by 511 /// `bodyBuilderFn` that accepts a range of induction variables for all 512 /// loops. `ivStorage` is used to store the partial list of induction 513 /// variables. 514 // TODO: this function can be made iterative instead. However, it 515 // will have at most as many recursive calls as nested loops, which rarely 516 // exceeds 10. 517 static void generateParallelLoopNest( 518 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, 519 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, 520 ArrayRef<linalg::ProcInfo> procInfo, 521 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, 522 SmallVectorImpl<Value> &ivStorage) { 523 assert(lbs.size() == ubs.size()); 524 assert(lbs.size() == steps.size()); 525 assert(lbs.size() == iteratorTypes.size()); 526 assert(procInfo.empty() || (lbs.size() == procInfo.size())); 527 528 // If there are no (more) loops to be generated, generate the body and be 529 // done with it. 530 if (iteratorTypes.empty()) { 531 bodyBuilderFn(b, loc, ivStorage); 532 return; 533 } 534 535 // If there are no outer parallel loops, generate one sequential loop and 536 // recurse. 537 if (!isParallelIterator(iteratorTypes.front())) { 538 LoopNest singleLoop = buildLoopNest( 539 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(), 540 [&](OpBuilder &b, Location loc, ValueRange ivs) { 541 ivStorage.append(ivs.begin(), ivs.end()); 542 generateParallelLoopNest( 543 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), 544 iteratorTypes.drop_front(), 545 procInfo.empty() ? procInfo : procInfo.drop_front(), 546 bodyBuilderFn, ivStorage); 547 }); 548 return; 549 } 550 551 unsigned nLoops = iteratorTypes.size(); 552 unsigned numProcessed = 0; 553 DistributionMethod distributionMethod = DistributionMethod::None; 554 if (procInfo.empty()) { 555 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); 556 } else { 557 distributionMethod = procInfo.front().distributionMethod; 558 numProcessed = 559 nLoops - procInfo 560 .drop_while([&](linalg::ProcInfo p) { 561 return p.distributionMethod == distributionMethod; 562 }) 563 .size(); 564 } 565 566 auto remainderProcInfo = 567 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed); 568 switch (distributionMethod) { 569 case DistributionMethod::None: { 570 // Generate a single parallel loop-nest operation for all outermost 571 // parallel loops and recurse. 572 b.create<scf::ParallelOp>( 573 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 574 steps.take_front(numProcessed), 575 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 576 ivStorage.append(localIvs.begin(), localIvs.end()); 577 generateParallelLoopNest( 578 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 579 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 580 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 581 bodyBuilderFn, ivStorage); 582 }); 583 return; 584 } 585 case DistributionMethod::Cyclic: { 586 // Generate a single parallel loop-nest operation for all outermost 587 // parallel loops and recurse. 588 b.create<scf::ParallelOp>( 589 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 590 steps.take_front(numProcessed), 591 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 592 ivStorage.append(localIvs.begin(), localIvs.end()); 593 generateParallelLoopNest( 594 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 595 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 596 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 597 bodyBuilderFn, ivStorage); 598 }); 599 return; 600 } 601 case DistributionMethod::CyclicNumProcsGeNumIters: { 602 // Check (for the processed loops) that the iteration is in-bounds. 603 ArithBuilder ab(b, loc); 604 Value cond = ab.slt(lbs[0], ubs[0]); 605 for (unsigned i = 1; i < numProcessed; ++i) 606 cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); 607 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 608 b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { 609 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), 610 ubs.drop_front(numProcessed), 611 steps.drop_front(numProcessed), 612 iteratorTypes.drop_front(numProcessed), 613 remainderProcInfo, bodyBuilderFn, ivStorage); 614 b.create<scf::YieldOp>(loc, ValueRange{}); 615 }); 616 return; 617 } 618 case DistributionMethod::CyclicNumProcsEqNumIters: 619 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed 620 // with inner loop generation. 621 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 622 generateParallelLoopNest( 623 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), 624 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), 625 remainderProcInfo, bodyBuilderFn, ivStorage); 626 return; 627 } 628 } 629 630 /// Specialization for generating a mix of parallel and sequential scf loops. 631 template <> 632 void GenerateLoopNest<scf::ParallelOp>::doit( 633 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 634 ArrayRef<utils::IteratorType> iteratorTypes, 635 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 636 ValueRange)> 637 bodyBuilderFn, 638 ArrayRef<linalg::ProcInfo> procInfo) { 639 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 640 ? SmallVector<Value>{} 641 : linalgOp.getDpsInitOperands(); 642 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); 643 // This function may be passed more iterator types than ranges. 644 assert(iteratorTypes.size() >= loopRanges.size() && 645 "expected iterator type for all ranges"); 646 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 647 "expected proc information for all loops when present"); 648 iteratorTypes = iteratorTypes.take_front(loopRanges.size()); 649 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; 650 unsigned numLoops = iteratorTypes.size(); 651 ivs.reserve(numLoops); 652 lbsStorage.reserve(numLoops); 653 ubsStorage.reserve(numLoops); 654 stepsStorage.reserve(numLoops); 655 656 // Get the loop lb, ub, and step. 657 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage); 658 659 // Modify the lb, ub, and step based on the distribution options. 660 for (const auto &it : llvm::enumerate(procInfo)) { 661 if (it.value().distributionMethod != linalg::DistributionMethod::None) { 662 updateBoundsForCyclicDistribution( 663 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()], 664 ubsStorage[it.index()], stepsStorage[it.index()]); 665 } 666 } 667 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); 668 generateParallelLoopNest( 669 b, loc, lbs, ubs, steps, iteratorTypes, procInfo, 670 [&](OpBuilder &b, Location loc, ValueRange ivs) { 671 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); 672 }, 673 ivs); 674 675 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); 676 } 677 678 static Value materializeTiledShape(OpBuilder &builder, Location loc, 679 Value valueToTile, 680 const SliceParameters &sliceParams) { 681 auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); 682 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) 683 .Case([&](MemRefType) { 684 return builder.create<memref::SubViewOp>( 685 loc, valueToTile, sliceParams.offsets, 686 sliceParams.sizes, sliceParams.strides); 687 }) 688 .Case([&](RankedTensorType) { 689 return builder.create<tensor::ExtractSliceOp>( 690 loc, valueToTile, sliceParams.offsets, 691 sliceParams.sizes, sliceParams.strides); 692 }) 693 .Default([](ShapedType) -> Operation * { 694 llvm_unreachable("Unexpected shaped type"); 695 }); 696 return sliceOp->getResult(0); 697 } 698 699 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, 700 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 701 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 702 ArrayRef<OpFoldResult> subShapeSizes, 703 bool omitPartialTileCheck) { 704 SliceParameters sliceParams = 705 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, 706 ubs, subShapeSizes, omitPartialTileCheck); 707 return materializeTiledShape(builder, loc, valueToTile, sliceParams); 708 } 709 710 SliceParameters 711 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, 712 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 713 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 714 ArrayRef<OpFoldResult> subShapeSizes, 715 bool omitPartialTileCheck) { 716 auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); 717 assert(shapedType && "only shaped types can be tiled"); 718 ArrayRef<int64_t> shape = shapedType.getShape(); 719 int64_t rank = shapedType.getRank(); 720 721 // Compute offsets/sizes/strides for the tile. 722 SliceParameters sliceParams; 723 sliceParams.offsets.reserve(rank); 724 sliceParams.sizes.reserve(rank); 725 sliceParams.strides.reserve(rank); 726 for (unsigned r = 0; r < rank; ++r) { 727 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); 728 if (!isTiled(map.getSubMap({r}), tileSizes)) { 729 sliceParams.offsets.push_back(builder.getIndexAttr(0)); 730 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r); 731 sliceParams.sizes.push_back(dim); 732 sliceParams.strides.push_back(builder.getIndexAttr(1)); 733 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); 734 continue; 735 } 736 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n"); 737 738 // Tiling creates a new slice at the proper index, the slice step is 1 739 // (i.e. the op does not subsample, stepping occurs in the loop). 740 auto m = map.getSubMap({r}); 741 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n"); 742 IRRewriter rewriter(builder); 743 OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs); 744 sliceParams.offsets.push_back(offset); 745 OpFoldResult closedIntSize = 746 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes); 747 // Resulting size needs to be made half open interval again. 748 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext()); 749 OpFoldResult size = 750 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize); 751 LLVM_DEBUG(llvm::dbgs() 752 << "computeSliceParameters: raw size: " << size << "\n"); 753 LLVM_DEBUG(llvm::dbgs() 754 << "computeSliceParameters: new offset: " << offset << "\n"); 755 sliceParams.strides.push_back(builder.getIndexAttr(1)); 756 757 if (omitPartialTileCheck) { 758 // We statically know that the partial/boundary tile condition is 759 // unnecessary. 760 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 761 sliceParams.sizes.push_back(size); 762 continue; 763 } 764 765 // The size of the subview / extract_slice should be trimmed to avoid 766 // out-of-bounds accesses, unless: 767 // a. We statically know the subshape size divides the shape size evenly. 768 // b. The subshape size is 1. According to the way the loops are set up, 769 // tensors with "0" dimensions would never be constructed. 770 int64_t shapeSize = shape[r]; 771 std::optional<int64_t> sizeCst = getConstantIntValue(size); 772 auto hasTileSizeOne = sizeCst && *sizeCst == 1; 773 auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && 774 ((shapeSize % *sizeCst) == 0); 775 if (!hasTileSizeOne && !dividesEvenly) { 776 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize 777 << ", size: " << size 778 << ": make sure in bound with affine.min\n"); 779 780 AffineExpr dim0, dim1, dim2; 781 bindDims(builder.getContext(), dim0, dim1, dim2); 782 783 // Get the dimension size for this dimension. We need to first calculate 784 // the max index and then plus one. This is important because for 785 // convolution ops, we have its input window dimension's affine map of the 786 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window 787 // dimension and `s0` is stride. Directly use the dimension size of 788 // output/filer window dimensions will cause incorrect calculation. 789 AffineMap minusOneMap = 790 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}}) 791 .front(); 792 AffineMap plusOneMap = 793 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}}) 794 .front(); 795 SmallVector<OpFoldResult> maxIndices = 796 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) { 797 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap, 798 {ub}); 799 })); 800 OpFoldResult maxIndex = 801 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices); 802 OpFoldResult d = 803 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex}); 804 805 // Compute min(dim - offset, size) to avoid out-of-bounds accesses. 806 AffineMap minMap = AffineMap::inferFromExprList( 807 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}) 808 .front(); 809 size = 810 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset}); 811 } 812 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 813 sliceParams.sizes.push_back(size); 814 } 815 return sliceParams; 816 } 817 818 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, 819 ArrayRef<OpFoldResult> ivs, 820 ArrayRef<OpFoldResult> tileSizes) { 821 SmallVector<OpFoldResult> offsets; 822 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { 823 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); 824 bool isTiled = !isZeroIndex(tileSizes[idx]); 825 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); 826 LLVM_DEBUG(llvm::dbgs() 827 << "computeTileOffsets: " << offsets.back() << "\n"); 828 } 829 return offsets; 830 } 831 832 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, 833 ArrayRef<OpFoldResult> tileSizes, 834 ArrayRef<OpFoldResult> sizeBounds) { 835 SmallVector<OpFoldResult> sizes; 836 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { 837 bool isTiled = !isZeroIndex(tileSizes[idx]); 838 // Before composing, we need to make range a closed interval. 839 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; 840 AffineExpr d0 = getAffineDimExpr(0, b.getContext()); 841 IRRewriter rewriter(b); 842 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size)); 843 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n"); 844 } 845 return sizes; 846 } 847 848 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { 849 if (op.hasBufferSemantics()) 850 return {}; 851 return llvm::to_vector( 852 llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) { 853 return operands[opOperand->getOperandNumber()].getType(); 854 })); 855 } 856 857 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, 858 LinalgOp op, ValueRange operands, 859 ValueRange results) { 860 if (op.hasBufferSemantics()) 861 return {}; 862 SmallVector<Value> tensorResults; 863 tensorResults.reserve(results.size()); 864 // Insert a insert_slice for each output tensor. 865 unsigned resultIdx = 0; 866 for (OpOperand *opOperand : op.getDpsInitOperands()) { 867 // TODO: use an interface/adaptor to avoid leaking position in 868 // `tiledOperands`. 869 Value outputTensor = operands[opOperand->getOperandNumber()]; 870 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { 871 Value inserted = builder.create<tensor::InsertSliceOp>( 872 loc, sliceOp.getSource().getType(), results[resultIdx], 873 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), 874 sliceOp.getStrides(), sliceOp.getStaticOffsets(), 875 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 876 tensorResults.push_back(inserted); 877 } else { 878 tensorResults.push_back(results[resultIdx]); 879 } 880 ++resultIdx; 881 } 882 return tensorResults; 883 } 884 885 SmallVector<std::optional<SliceParameters>> 886 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, 887 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, 888 ArrayRef<OpFoldResult> tileSizes, 889 ArrayRef<OpFoldResult> sizeBounds, 890 bool omitPartialTileCheck) { 891 assert(ivs.size() == static_cast<size_t>(llvm::count_if( 892 llvm::make_range(tileSizes.begin(), tileSizes.end()), 893 [](OpFoldResult v) { return !isZeroIndex(v); })) && 894 "expected as many ivs as non-zero sizes"); 895 896 // Construct (potentially temporary) mins and maxes on which to apply maps 897 // that define tile subshapes. 898 SmallVector<OpFoldResult> lbs = 899 computeTileOffsets(builder, loc, ivs, tileSizes); 900 SmallVector<OpFoldResult> subShapeSizes = 901 computeTileSizes(builder, loc, tileSizes, sizeBounds); 902 903 assert(static_cast<int64_t>(valuesToTile.size()) <= 904 linalgOp->getNumOperands() && 905 "more value to tile than operands."); 906 SmallVector<std::optional<SliceParameters>> allSliceParams; 907 allSliceParams.reserve(valuesToTile.size()); 908 for (auto [opOperand, val] : 909 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { 910 Value shapedOp = val; 911 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); 912 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); 913 // Use `opOperand` as is if it is not tiled and not an output tensor. Having 914 // an extract/insert slice pair for all output tensors simplifies follow up 915 // transformations such as padding and bufferization since the 916 // extract/insert slice pairs make the accessed iteration argument 917 // subdomains explicit. 918 919 Type operandType = opOperand.get().getType(); 920 if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() && 921 linalgOp.isDpsInit(&opOperand))) { 922 allSliceParams.push_back(std::nullopt); 923 LLVM_DEBUG(llvm::dbgs() 924 << ": not tiled: use shape: " << operandType << "\n"); 925 continue; 926 } 927 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); 928 929 allSliceParams.push_back(computeSliceParameters( 930 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, 931 omitPartialTileCheck)); 932 } 933 934 return allSliceParams; 935 } 936 937 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, 938 LinalgOp linalgOp, ValueRange valuesToTile, 939 ArrayRef<OpFoldResult> ivs, 940 ArrayRef<OpFoldResult> tileSizes, 941 ArrayRef<OpFoldResult> sizeBounds, 942 bool omitPartialTileCheck) { 943 SmallVector<std::optional<SliceParameters>> allSliceParameter = 944 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, 945 tileSizes, sizeBounds, omitPartialTileCheck); 946 SmallVector<Value> tiledShapes; 947 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { 948 Value valueToTile = std::get<0>(item); 949 std::optional<SliceParameters> sliceParams = std::get<1>(item); 950 tiledShapes.push_back( 951 sliceParams.has_value() 952 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) 953 : valueToTile); 954 } 955 return tiledShapes; 956 } 957 958 void offsetIndices(OpBuilder &b, LinalgOp linalgOp, 959 ArrayRef<OpFoldResult> offsets) { 960 IRRewriter rewriter(b); 961 offsetIndices(rewriter, linalgOp, offsets); 962 } 963 964 void offsetIndices(RewriterBase &b, LinalgOp linalgOp, 965 ArrayRef<OpFoldResult> offsets) { 966 if (!linalgOp.hasIndexSemantics()) 967 return; 968 969 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { 970 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) 971 continue; 972 OpBuilder::InsertionGuard guard(b); 973 b.setInsertionPointAfter(indexOp); 974 AffineExpr index, offset; 975 bindDims(b.getContext(), index, offset); 976 OpFoldResult applied = makeComposedFoldedAffineApply( 977 b, indexOp.getLoc(), index + offset, 978 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); 979 Value materialized = 980 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); 981 b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) { 982 return use.getOwner() != materialized.getDefiningOp(); 983 }); 984 } 985 } 986 987 /// Get the reassociation maps to fold the result of a extract_slice (or source 988 /// of a insert_slice) operation with given offsets, and sizes to its 989 /// rank-reduced version. This is only done for the cases where the size is 1 990 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 991 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 992 /// potentially cannot be handled). 993 std::optional<SmallVector<ReassociationIndices>> 994 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 995 SmallVector<ReassociationIndices> reassociation; 996 ReassociationIndices curr; 997 for (const auto &it : llvm::enumerate(mixedSizes)) { 998 auto dim = it.index(); 999 auto size = it.value(); 1000 curr.push_back(dim); 1001 auto attr = size.dyn_cast<Attribute>(); 1002 if (attr && attr.cast<IntegerAttr>().getInt() == 1) 1003 continue; 1004 reassociation.emplace_back(ReassociationIndices{}); 1005 std::swap(reassociation.back(), curr); 1006 } 1007 // When the reassociations are not empty, then fold the remaining 1008 // unit-dimensions into the last dimension. If the reassociations so far is 1009 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 1010 if (!curr.empty() && !reassociation.empty()) 1011 reassociation.back().append(curr.begin(), curr.end()); 1012 return reassociation; 1013 } 1014 1015 /// Return the identity numeric value associated to the give op. 1016 std::optional<Attribute> getNeutralElement(Operation *op) { 1017 // Builder only used as helper for attribute creation. 1018 OpBuilder b(op->getContext()); 1019 Type resultType = op->getResult(0).getType(); 1020 if (auto floatType = resultType.dyn_cast<FloatType>()) { 1021 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 1022 if (isa<arith::AddFOp>(op)) 1023 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 1024 if (isa<arith::MulFOp>(op)) 1025 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 1026 if (isa<arith::MaxFOp>(op)) 1027 return b.getFloatAttr(resultType, 1028 llvm::APFloat::getInf(semantic, /*Negative=*/true)); 1029 if (isa<arith::MinFOp>(op)) 1030 return b.getFloatAttr( 1031 resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false)); 1032 return std::nullopt; 1033 } 1034 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 1035 return b.getIntegerAttr(resultType, 0); 1036 if (isa<arith::AndIOp>(op)) 1037 return b.getIntegerAttr(resultType, -1); 1038 if (isa<arith::MaxSIOp>(op)) 1039 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 1040 if (isa<arith::MinSIOp>(op)) 1041 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 1042 if (isa<arith::MulIOp>(op)) 1043 return b.getIntegerAttr(resultType, 1); 1044 return std::nullopt; 1045 } 1046 1047 } // namespace linalg 1048 } // namespace mlir 1049