1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Utils/Utils.h" 14 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/Arith/Utils/Utils.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/Linalg/IR/Linalg.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/Dialect/Tensor/Utils/Utils.h" 28 #include "mlir/Dialect/Utils/IndexingUtils.h" 29 #include "mlir/Dialect/Utils/StaticValueUtils.h" 30 #include "mlir/IR/AffineExpr.h" 31 #include "mlir/IR/AffineExprVisitor.h" 32 #include "mlir/IR/AffineMap.h" 33 #include "mlir/IR/Matchers.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/Pass/Pass.h" 36 #include "llvm/ADT/TypeSwitch.h" 37 #include "llvm/Support/Debug.h" 38 #include <optional> 39 40 #define DEBUG_TYPE "linalg-utils" 41 42 using namespace mlir; 43 using namespace presburger; 44 using namespace mlir::linalg; 45 using namespace mlir::scf; 46 47 namespace { 48 49 // Helper visitor to determine whether an AffineExpr is tiled. 50 // This is achieved by traversing every AffineDimExpr with position `pos` and 51 // checking whether the corresponding `tileSizes[pos]` is non-zero. 52 // This also enforces only positive coefficients occur in multiplications. 53 // 54 // Example: 55 // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] 56 // 57 struct TileCheck : public AffineExprVisitor<TileCheck> { 58 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} 59 60 void visitDimExpr(AffineDimExpr expr) { 61 isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]); 62 } 63 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { 64 visit(expr.getLHS()); 65 visit(expr.getRHS()); 66 if (expr.getKind() == mlir::AffineExprKind::Mul) 67 assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 && 68 "nonpositive multiplying coefficient"); 69 } 70 bool isTiled = false; 71 ArrayRef<OpFoldResult> tileSizes; 72 }; 73 74 } // namespace 75 76 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { 77 if (!expr) 78 return false; 79 TileCheck t(tileSizes); 80 t.visit(expr); 81 return t.isTiled; 82 } 83 84 // Checks whether the `map varies with respect to a non-zero `tileSize`. 85 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { 86 if (!map) 87 return false; 88 for (unsigned r = 0; r < map.getNumResults(); ++r) 89 if (isTiled(map.getResult(r), tileSizes)) 90 return true; 91 return false; 92 } 93 94 std::optional<RegionMatcher::BinaryOpKind> 95 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { 96 auto ®ion = op.getRegion(); 97 if (!llvm::hasSingleElement(region)) 98 return std::nullopt; 99 100 Block &block = region.front(); 101 if (block.getNumArguments() != 2 || 102 !block.getArgument(0).getType().isSignlessIntOrFloat() || 103 !block.getArgument(1).getType().isSignlessIntOrFloat()) 104 return std::nullopt; 105 106 auto &ops = block.getOperations(); 107 if (!llvm::hasSingleElement(block.without_terminator())) 108 return std::nullopt; 109 110 using mlir::matchers::m_Val; 111 auto a = m_Val(block.getArgument(0)); 112 auto b = m_Val(block.getArgument(1)); 113 114 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); 115 if (addPattern.match(&ops.back())) 116 return BinaryOpKind::IAdd; 117 118 return std::nullopt; 119 } 120 121 /// Explicit instantiation of loop nest generator for different loop types. 122 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; 123 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; 124 template struct mlir::linalg::GenerateLoopNest<AffineForOp>; 125 126 /// Given a list of subview ranges, extract individual values for lower, upper 127 /// bounds and steps and put them into the corresponding vectors. 128 static void unpackRanges(OpBuilder &builder, Location loc, 129 ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, 130 SmallVectorImpl<Value> &ubs, 131 SmallVectorImpl<Value> &steps) { 132 for (Range range : ranges) { 133 lbs.emplace_back( 134 getValueOrCreateConstantIndexOp(builder, loc, range.offset)); 135 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size)); 136 steps.emplace_back( 137 getValueOrCreateConstantIndexOp(builder, loc, range.stride)); 138 } 139 } 140 141 namespace mlir { 142 namespace linalg { 143 144 bool allIndexingsAreProjectedPermutation(LinalgOp op) { 145 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) { 146 return m.isProjectedPermutation(/*allowZeroInResults=*/true); 147 }); 148 } 149 150 bool hasOnlyScalarElementwiseOp(Region &r) { 151 if (!llvm::hasSingleElement(r)) 152 return false; 153 for (Operation &op : r.front()) { 154 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, 155 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || 156 OpTrait::hasElementwiseMappableTraits(&op)) || 157 llvm::any_of(op.getResultTypes(), 158 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 159 return false; 160 } 161 return true; 162 } 163 164 bool isElementwise(LinalgOp op) { 165 if (op.getNumLoops() != op.getNumParallelLoops()) 166 return false; 167 168 if (!allIndexingsAreProjectedPermutation(op)) 169 return false; 170 171 // TODO: relax the restrictions on indexing map. 172 for (OpOperand *opOperand : op.getDpsInitOperands()) { 173 if (!op.getMatchingIndexingMap(opOperand).isPermutation()) 174 return false; 175 } 176 return hasOnlyScalarElementwiseOp(op->getRegion(0)); 177 } 178 179 bool isParallelIterator(utils::IteratorType iteratorType) { 180 return iteratorType == utils::IteratorType::parallel; 181 } 182 183 bool isReductionIterator(utils::IteratorType iteratorType) { 184 return iteratorType == utils::IteratorType::reduction; 185 } 186 187 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on 188 /// the type of `source`. 189 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { 190 if (source.getType().isa<UnrankedMemRefType, MemRefType>()) 191 return b.createOrFold<memref::DimOp>(loc, source, dim); 192 if (source.getType().isa<UnrankedTensorType, RankedTensorType>()) 193 return b.createOrFold<tensor::DimOp>(loc, source, dim); 194 llvm_unreachable("Expected MemRefType or TensorType"); 195 } 196 197 OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, 198 int64_t dim) { 199 auto shapedType = source.getType().cast<ShapedType>(); 200 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) 201 return createOrFoldDimOp(b, loc, source, dim); 202 return b.getIndexAttr(shapedType.getDimSize(dim)); 203 } 204 205 /// Given an operation, retrieves the value of each dynamic dimension through 206 /// constructing the necessary DimOp operators. 207 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) { 208 SmallVector<Value, 4> dynOperands; 209 auto shapedType = val.getType().cast<ShapedType>(); 210 for (const auto &dim : llvm::enumerate(shapedType.getShape())) { 211 if (dim.value() == ShapedType::kDynamic) 212 dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index())); 213 } 214 return dynOperands; 215 } 216 217 void getUpperBoundForIndex(Value value, AffineMap &boundMap, 218 SmallVectorImpl<Value> &boundOperands, 219 bool constantRequired) { 220 // Initialize `boundMap` and `boundOperands` to the identity returning 221 // `value`. This combination is the default result of the method if no 222 // simplification is possible. 223 assert(value.getType().isIndex() && "expect value to have index type"); 224 boundMap = AffineMap::getMultiDimIdentityMap(1, value.getContext()); 225 boundOperands.assign({value}); 226 canonicalizeMapAndOperands(&boundMap, &boundOperands); 227 228 // Continue only if there is an affine index computation to simplify. 229 Operation *definingOp = value.getDefiningOp(); 230 if (!definingOp || !isa<AffineApplyOp, AffineMinOp>(definingOp)) 231 return; 232 233 // Get the backward slice containing the affine index computation. 234 SetVector<Operation *> backwardSlice; 235 getBackwardSlice(definingOp, &backwardSlice, [](Operation *op) { 236 return isa<AffineApplyOp, AffineMinOp>(op); 237 }); 238 backwardSlice.insert(definingOp); 239 240 // Setup a system of affine constraints that describe the index computation. 241 FlatAffineValueConstraints constraints; 242 243 // Helper to find or create an identifier for the given value. 244 auto findOrCreateId = [&](Value value) { 245 if (!constraints.containsVar(value)) { 246 constraints.appendDimVar(value); 247 return true; 248 } 249 unsigned pos; 250 constraints.findVar(value, &pos); 251 return pos < constraints.getNumDimVars(); 252 }; 253 // Helper to get the position for the given value. 254 auto getPosition = [&](Value value) { 255 unsigned pos; 256 bool exists = constraints.findVar(value, &pos); 257 (void)exists; 258 assert(exists && "expect to find the identifier"); 259 return pos; 260 }; 261 262 // Add the affine operations in `backwardSlice` to the constraints. 263 for (Operation *op : llvm::reverse(backwardSlice)) { 264 // Add an identifier for all op results and operands. 265 if (!(llvm::all_of(op->getResults(), findOrCreateId) && 266 llvm::all_of(op->getOperands(), findOrCreateId))) 267 return; 268 269 // Add AffineApplyOps to the constraints. 270 if (auto applyOp = dyn_cast<AffineApplyOp>(op)) { 271 AffineMap map = constraints.computeAlignedMap(applyOp.getAffineMap(), 272 applyOp.getOperands()); 273 if (failed(constraints.addBound(IntegerPolyhedron::EQ, 274 getPosition(applyOp.getResult()), map))) 275 return; 276 continue; 277 } 278 // Add AffineMinOps to the constraints. 279 auto minOp = cast<AffineMinOp>(op); 280 AffineMap map = constraints.computeAlignedMap(minOp.getAffineMap(), 281 minOp.getOperands()); 282 if (failed(constraints.addBound(IntegerPolyhedron::UB, 283 getPosition(minOp.getResult()), map, 284 /*isClosedBound=*/true))) 285 return; 286 } 287 288 // Obtain an upper bound for the affine index computation by projecting out 289 // all temporary results and expressing the upper bound for `value` in terms 290 // of the terminals of the index computation. 291 unsigned pos = getPosition(value); 292 if (constantRequired) { 293 auto ubConst = constraints.getConstantBound64( 294 FlatAffineValueConstraints::BoundType::UB, pos); 295 if (!ubConst) 296 return; 297 298 boundMap = AffineMap::getConstantMap(*ubConst, value.getContext()); 299 return; 300 } 301 302 SmallVector<AffineMap> lowerBounds(1), upperBounds(1); 303 constraints.getSliceBounds(pos, 1, value.getContext(), &lowerBounds, 304 &upperBounds, 305 /*getClosedUB=*/true); 306 // Verify `upperBounds[0]` is valid and has at least one result. 307 if (!upperBounds[0] || upperBounds[0].getNumResults() == 0) 308 return; 309 310 // Set `boundMap` and `boundOperands` to the computed upper bound. 311 boundMap = upperBounds[0]; 312 constraints.getAllValues(&boundOperands); 313 erase_value(boundOperands, value); 314 canonicalizeMapAndOperands(&boundMap, &boundOperands); 315 } 316 317 FailureOr<int64_t> getConstantUpperBoundForIndex(Value value) { 318 // Compute an upper bound for `value`. 319 AffineMap boundMap; 320 SmallVector<Value> boundOperands; 321 getUpperBoundForIndex(value, boundMap, boundOperands, 322 /*constantRequired=*/true); 323 324 // Search the results of `boundMap` for constant upper bounds. 325 SmallVector<int64_t> constantBounds; 326 for (AffineExpr result : boundMap.getResults()) 327 if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) 328 constantBounds.push_back(constExpr.getValue()); 329 330 // Return the minimal upper bound or failure if none is found. 331 if (constantBounds.empty()) 332 return failure(); 333 return *std::min_element(constantBounds.begin(), constantBounds.end()); 334 } 335 336 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, 337 Value source, Value pad, bool nofold) { 338 // Exit if `source` is not defined by an ExtractSliceOp. 339 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); 340 if (!sliceOp) 341 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 342 343 // Search the `source` use-def chain for padded LinalgOps. 344 Value current = sliceOp.getSource(); 345 while (current) { 346 auto linalgOp = current.getDefiningOp<LinalgOp>(); 347 if (!linalgOp) 348 break; 349 OpResult opResult = current.cast<OpResult>(); 350 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); 351 } 352 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; 353 354 // Exit if the search fails to match a tensor::PadOp at the end of the matched 355 // LinalgOp sequence. 356 if (!padOp) 357 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 358 359 // Exit if the padded result type does not match. 360 if (sliceOp.getSource().getType() != type) 361 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 362 363 // Exit if the LinalgOps are not high padded. 364 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { 365 return getConstantIntValue(ofr) != static_cast<int64_t>(0); 366 })) 367 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 368 369 // Exit if `padOpSliceOp`, which defines the slice used by 370 // `padOp`, is rank-reducing. 371 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); 372 if (!padOpSliceOp || 373 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) 374 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 375 376 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size 377 // of the slice padded by `padOp`. 378 if (llvm::any_of( 379 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), 380 [](std::tuple<OpFoldResult, OpFoldResult> it) { 381 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)); 382 })) 383 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 384 385 // Exit if the padding values do not match. 386 Attribute padOpPadAttr, padAttr; 387 Value padOpPad = padOp.getConstantPaddingValue(); 388 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) || 389 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr) 390 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 391 392 // Return the padded result if the padding values and sizes match. 393 return sliceOp.getSource(); 394 } 395 396 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, 397 Value outputTensor, 398 ArrayRef<int64_t> transposeVector) { 399 auto resultTensorType = outputTensor.getType().cast<RankedTensorType>(); 400 Type elementType = resultTensorType.getElementType(); 401 402 assert(isPermutationVector(transposeVector) && 403 "expect transpose vector to be a permutation"); 404 assert(transposeVector.size() == 405 static_cast<size_t>(resultTensorType.getRank()) && 406 "expect transpose vector size to match result tensor rank"); 407 408 // Compute the transpose and the indentity indexing maps. 409 SmallVector<AffineMap> indexingMaps = { 410 inversePermutation(AffineMap::getPermutationMap( 411 SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()), 412 b.getContext())), 413 AffineMap::getMultiDimIdentityMap(transposeVector.size(), 414 b.getContext())}; 415 SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(), 416 utils::IteratorType::parallel); 417 418 // Create a GenericOp to transpose `inputTensor` into `outputTensor`. 419 auto transposeOp = 420 b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor, 421 indexingMaps, iteratorTypes); 422 Region &body = transposeOp.getRegion(); 423 body.push_back(new Block()); 424 body.front().addArguments({elementType, elementType}, {loc, loc}); 425 426 // Create the body of the transpose operation. 427 OpBuilder::InsertionGuard g(b); 428 b.setInsertionPointToEnd(&body.front()); 429 b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0)); 430 return transposeOp; 431 } 432 433 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { 434 auto memrefTypeTo = to.getType().cast<MemRefType>(); 435 #ifndef NDEBUG 436 auto memrefTypeFrom = from.getType().cast<MemRefType>(); 437 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && 438 "`from` and `to` memref must have the same rank"); 439 #endif // NDEBUG 440 441 AffineMap id = 442 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); 443 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), 444 utils::IteratorType::parallel); 445 return b.create<linalg::GenericOp>( 446 loc, 447 /*inputs=*/from, 448 /*outputs=*/to, 449 /*indexingMaps=*/llvm::ArrayRef({id, id}), 450 /*iteratorTypes=*/iteratorTypes, 451 [](OpBuilder &b, Location loc, ValueRange args) { 452 b.create<linalg::YieldOp>(loc, args.front()); 453 }); 454 } 455 456 /// Specialization to build an scf "for" nest. 457 template <> 458 void GenerateLoopNest<scf::ForOp>::doit( 459 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 460 ArrayRef<utils::IteratorType> iteratorTypes, 461 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 462 ValueRange)> 463 bodyBuilderFn, 464 ArrayRef<linalg::ProcInfo> procInfo) { 465 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 466 "expected as many entries for proc info as number of loops, even if " 467 "they are null entries"); 468 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 469 ? SmallVector<Value>{} 470 : linalgOp.getDpsInitOperands(); 471 472 SmallVector<Value, 4> lbs, ubs, steps; 473 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 474 LoopNest loopNest = mlir::scf::buildLoopNest( 475 b, loc, lbs, ubs, steps, iterArgInitValues, 476 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { 477 assert(iterArgs.size() == iterArgInitValues.size() && 478 "expect the number of output tensors and iter args to match"); 479 SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); 480 if (!iterArgs.empty()) { 481 operandValuesToUse = linalgOp.getDpsInputOperands(); 482 operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); 483 } 484 return bodyBuilderFn(b, loc, ivs, operandValuesToUse); 485 }); 486 487 if (loopNest.loops.empty() || procInfo.empty()) 488 return; 489 490 // Filter out scf.for loops that were created out of parallel dimensions. 491 for (const auto &loop : llvm::enumerate(loopNest.loops)) { 492 if (procInfo[loop.index()].distributionMethod == 493 DistributionMethod::Cyclic) { 494 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, 495 procInfo[loop.index()].nprocs); 496 } 497 } 498 } 499 500 /// Specialization to build affine "for" nest. 501 template <> 502 void GenerateLoopNest<AffineForOp>::doit( 503 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 504 ArrayRef<utils::IteratorType> iteratorTypes, 505 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 506 ValueRange)> 507 bodyBuilderFn, 508 ArrayRef<linalg::ProcInfo> /*procInfo*/) { 509 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 510 ? SmallVector<Value>{} 511 : linalgOp.getDpsInitOperands(); 512 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); 513 SmallVector<Value, 4> lbs, ubs, steps; 514 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 515 516 // Affine loops require constant steps. 517 SmallVector<int64_t, 4> constantSteps; 518 constantSteps.reserve(steps.size()); 519 for (Value v : steps) { 520 auto op = v.getDefiningOp<arith::ConstantIndexOp>(); 521 assert(op && "Affine loops require constant steps"); 522 constantSteps.push_back(op.value()); 523 } 524 525 mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, 526 [&](OpBuilder &b, Location loc, ValueRange ivs) { 527 bodyBuilderFn(b, loc, ivs, 528 linalgOp->getOperands()); 529 }); 530 } 531 532 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. 533 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, 534 Value nprocs, Value &lb, Value &ub, 535 Value &step) { 536 AffineExpr d0, d1; 537 bindDims(b.getContext(), d0, d1); 538 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); 539 lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); 540 step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); 541 } 542 543 /// Generates a loop nest consisting of scf.parallel and scf.for, depending 544 /// on the `iteratorTypes.` Consecutive parallel loops create a single 545 /// scf.parallel operation; each sequential loop creates a new scf.for 546 /// operation. The body of the innermost loop is populated by 547 /// `bodyBuilderFn` that accepts a range of induction variables for all 548 /// loops. `ivStorage` is used to store the partial list of induction 549 /// variables. 550 // TODO: this function can be made iterative instead. However, it 551 // will have at most as many recursive calls as nested loops, which rarely 552 // exceeds 10. 553 static void generateParallelLoopNest( 554 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, 555 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, 556 ArrayRef<linalg::ProcInfo> procInfo, 557 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, 558 SmallVectorImpl<Value> &ivStorage) { 559 assert(lbs.size() == ubs.size()); 560 assert(lbs.size() == steps.size()); 561 assert(lbs.size() == iteratorTypes.size()); 562 assert(procInfo.empty() || (lbs.size() == procInfo.size())); 563 564 // If there are no (more) loops to be generated, generate the body and be 565 // done with it. 566 if (iteratorTypes.empty()) { 567 bodyBuilderFn(b, loc, ivStorage); 568 return; 569 } 570 571 // If there are no outer parallel loops, generate one sequential loop and 572 // recurse. 573 if (!isParallelIterator(iteratorTypes.front())) { 574 LoopNest singleLoop = buildLoopNest( 575 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(), 576 [&](OpBuilder &b, Location loc, ValueRange ivs) { 577 ivStorage.append(ivs.begin(), ivs.end()); 578 generateParallelLoopNest( 579 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), 580 iteratorTypes.drop_front(), 581 procInfo.empty() ? procInfo : procInfo.drop_front(), 582 bodyBuilderFn, ivStorage); 583 }); 584 return; 585 } 586 587 unsigned nLoops = iteratorTypes.size(); 588 unsigned numProcessed = 0; 589 DistributionMethod distributionMethod = DistributionMethod::None; 590 if (procInfo.empty()) { 591 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); 592 } else { 593 distributionMethod = procInfo.front().distributionMethod; 594 numProcessed = 595 nLoops - procInfo 596 .drop_while([&](linalg::ProcInfo p) { 597 return p.distributionMethod == distributionMethod; 598 }) 599 .size(); 600 } 601 602 auto remainderProcInfo = 603 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed); 604 switch (distributionMethod) { 605 case DistributionMethod::None: { 606 // Generate a single parallel loop-nest operation for all outermost 607 // parallel loops and recurse. 608 b.create<scf::ParallelOp>( 609 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 610 steps.take_front(numProcessed), 611 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 612 ivStorage.append(localIvs.begin(), localIvs.end()); 613 generateParallelLoopNest( 614 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 615 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 616 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 617 bodyBuilderFn, ivStorage); 618 }); 619 return; 620 } 621 case DistributionMethod::Cyclic: { 622 // Generate a single parallel loop-nest operation for all outermost 623 // parallel loops and recurse. 624 b.create<scf::ParallelOp>( 625 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 626 steps.take_front(numProcessed), 627 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 628 ivStorage.append(localIvs.begin(), localIvs.end()); 629 generateParallelLoopNest( 630 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 631 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 632 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 633 bodyBuilderFn, ivStorage); 634 }); 635 return; 636 } 637 case DistributionMethod::CyclicNumProcsGeNumIters: { 638 // Check (for the processed loops) that the iteration is in-bounds. 639 ArithBuilder ab(b, loc); 640 Value cond = ab.slt(lbs[0], ubs[0]); 641 for (unsigned i = 1; i < numProcessed; ++i) 642 cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); 643 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 644 b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { 645 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), 646 ubs.drop_front(numProcessed), 647 steps.drop_front(numProcessed), 648 iteratorTypes.drop_front(numProcessed), 649 remainderProcInfo, bodyBuilderFn, ivStorage); 650 b.create<scf::YieldOp>(loc, ValueRange{}); 651 }); 652 return; 653 } 654 case DistributionMethod::CyclicNumProcsEqNumIters: 655 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed 656 // with inner loop generation. 657 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 658 generateParallelLoopNest( 659 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), 660 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), 661 remainderProcInfo, bodyBuilderFn, ivStorage); 662 return; 663 } 664 } 665 666 /// Specialization for generating a mix of parallel and sequential scf loops. 667 template <> 668 void GenerateLoopNest<scf::ParallelOp>::doit( 669 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 670 ArrayRef<utils::IteratorType> iteratorTypes, 671 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 672 ValueRange)> 673 bodyBuilderFn, 674 ArrayRef<linalg::ProcInfo> procInfo) { 675 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 676 ? SmallVector<Value>{} 677 : linalgOp.getDpsInitOperands(); 678 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); 679 // This function may be passed more iterator types than ranges. 680 assert(iteratorTypes.size() >= loopRanges.size() && 681 "expected iterator type for all ranges"); 682 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 683 "expected proc information for all loops when present"); 684 iteratorTypes = iteratorTypes.take_front(loopRanges.size()); 685 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; 686 unsigned numLoops = iteratorTypes.size(); 687 ivs.reserve(numLoops); 688 lbsStorage.reserve(numLoops); 689 ubsStorage.reserve(numLoops); 690 stepsStorage.reserve(numLoops); 691 692 // Get the loop lb, ub, and step. 693 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage); 694 695 // Modify the lb, ub, and step based on the distribution options. 696 for (const auto &it : llvm::enumerate(procInfo)) { 697 if (it.value().distributionMethod != linalg::DistributionMethod::None) { 698 updateBoundsForCyclicDistribution( 699 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()], 700 ubsStorage[it.index()], stepsStorage[it.index()]); 701 } 702 } 703 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); 704 generateParallelLoopNest( 705 b, loc, lbs, ubs, steps, iteratorTypes, procInfo, 706 [&](OpBuilder &b, Location loc, ValueRange ivs) { 707 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); 708 }, 709 ivs); 710 711 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); 712 } 713 714 static Value materializeTiledShape(OpBuilder &builder, Location loc, 715 Value valueToTile, 716 const SliceParameters &sliceParams) { 717 auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); 718 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) 719 .Case([&](MemRefType) { 720 return builder.create<memref::SubViewOp>( 721 loc, valueToTile, sliceParams.offsets, 722 sliceParams.sizes, sliceParams.strides); 723 }) 724 .Case([&](RankedTensorType) { 725 return builder.create<tensor::ExtractSliceOp>( 726 loc, valueToTile, sliceParams.offsets, 727 sliceParams.sizes, sliceParams.strides); 728 }) 729 .Default([](ShapedType) -> Operation * { 730 llvm_unreachable("Unexpected shaped type"); 731 }); 732 return sliceOp->getResult(0); 733 } 734 735 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, 736 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 737 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 738 ArrayRef<OpFoldResult> subShapeSizes, 739 bool omitPartialTileCheck) { 740 SliceParameters sliceParams = 741 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, 742 ubs, subShapeSizes, omitPartialTileCheck); 743 return materializeTiledShape(builder, loc, valueToTile, sliceParams); 744 } 745 746 SliceParameters 747 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, 748 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 749 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 750 ArrayRef<OpFoldResult> subShapeSizes, 751 bool omitPartialTileCheck) { 752 auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); 753 assert(shapedType && "only shaped types can be tiled"); 754 ArrayRef<int64_t> shape = shapedType.getShape(); 755 int64_t rank = shapedType.getRank(); 756 757 // Compute offsets/sizes/strides for the tile. 758 SliceParameters sliceParams; 759 sliceParams.offsets.reserve(rank); 760 sliceParams.sizes.reserve(rank); 761 sliceParams.strides.reserve(rank); 762 for (unsigned r = 0; r < rank; ++r) { 763 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); 764 if (!isTiled(map.getSubMap({r}), tileSizes)) { 765 sliceParams.offsets.push_back(builder.getIndexAttr(0)); 766 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r); 767 sliceParams.sizes.push_back(dim); 768 sliceParams.strides.push_back(builder.getIndexAttr(1)); 769 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); 770 continue; 771 } 772 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n"); 773 774 // Tiling creates a new slice at the proper index, the slice step is 1 775 // (i.e. the op does not subsample, stepping occurs in the loop). 776 auto m = map.getSubMap({r}); 777 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n"); 778 IRRewriter rewriter(builder); 779 OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs); 780 sliceParams.offsets.push_back(offset); 781 OpFoldResult closedIntSize = 782 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes); 783 // Resulting size needs to be made half open interval again. 784 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext()); 785 OpFoldResult size = 786 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize); 787 LLVM_DEBUG(llvm::dbgs() 788 << "computeSliceParameters: raw size: " << size << "\n"); 789 LLVM_DEBUG(llvm::dbgs() 790 << "computeSliceParameters: new offset: " << offset << "\n"); 791 sliceParams.strides.push_back(builder.getIndexAttr(1)); 792 793 if (omitPartialTileCheck) { 794 // We statically know that the partial/boundary tile condition is 795 // unnecessary. 796 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 797 sliceParams.sizes.push_back(size); 798 continue; 799 } 800 801 // The size of the subview / extract_slice should be trimmed to avoid 802 // out-of-bounds accesses, unless: 803 // a. We statically know the subshape size divides the shape size evenly. 804 // b. The subshape size is 1. According to the way the loops are set up, 805 // tensors with "0" dimensions would never be constructed. 806 int64_t shapeSize = shape[r]; 807 std::optional<int64_t> sizeCst = getConstantIntValue(size); 808 auto hasTileSizeOne = sizeCst && *sizeCst == 1; 809 auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && 810 ((shapeSize % *sizeCst) == 0); 811 if (!hasTileSizeOne && !dividesEvenly) { 812 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize 813 << ", size: " << size 814 << ": make sure in bound with affine.min\n"); 815 816 AffineExpr dim0, dim1, dim2; 817 bindDims(builder.getContext(), dim0, dim1, dim2); 818 819 // Get the dimension size for this dimension. We need to first calculate 820 // the max index and then plus one. This is important because for 821 // convolution ops, we have its input window dimension's affine map of the 822 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window 823 // dimension and `s0` is stride. Directly use the dimension size of 824 // output/filer window dimensions will cause incorrect calculation. 825 AffineMap minusOneMap = 826 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}}) 827 .front(); 828 AffineMap plusOneMap = 829 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}}) 830 .front(); 831 SmallVector<OpFoldResult> maxIndices = 832 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) { 833 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap, 834 {ub}); 835 })); 836 OpFoldResult maxIndex = 837 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices); 838 OpFoldResult d = 839 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex}); 840 841 // Compute min(dim - offset, size) to avoid out-of-bounds accesses. 842 AffineMap minMap = AffineMap::inferFromExprList( 843 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}) 844 .front(); 845 size = 846 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset}); 847 } 848 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 849 sliceParams.sizes.push_back(size); 850 } 851 return sliceParams; 852 } 853 854 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, 855 ArrayRef<OpFoldResult> ivs, 856 ArrayRef<OpFoldResult> tileSizes) { 857 SmallVector<OpFoldResult> offsets; 858 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { 859 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); 860 bool isTiled = !isZeroIndex(tileSizes[idx]); 861 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); 862 LLVM_DEBUG(llvm::dbgs() 863 << "computeTileOffsets: " << offsets.back() << "\n"); 864 } 865 return offsets; 866 } 867 868 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, 869 ArrayRef<OpFoldResult> tileSizes, 870 ArrayRef<OpFoldResult> sizeBounds) { 871 SmallVector<OpFoldResult> sizes; 872 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { 873 bool isTiled = !isZeroIndex(tileSizes[idx]); 874 // Before composing, we need to make range a closed interval. 875 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; 876 AffineExpr d0 = getAffineDimExpr(0, b.getContext()); 877 IRRewriter rewriter(b); 878 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size)); 879 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n"); 880 } 881 return sizes; 882 } 883 884 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { 885 if (op.hasBufferSemantics()) 886 return {}; 887 return llvm::to_vector( 888 llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) { 889 return operands[opOperand->getOperandNumber()].getType(); 890 })); 891 } 892 893 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, 894 LinalgOp op, ValueRange operands, 895 ValueRange results) { 896 if (op.hasBufferSemantics()) 897 return {}; 898 SmallVector<Value> tensorResults; 899 tensorResults.reserve(results.size()); 900 // Insert a insert_slice for each output tensor. 901 unsigned resultIdx = 0; 902 for (OpOperand *opOperand : op.getDpsInitOperands()) { 903 // TODO: use an interface/adaptor to avoid leaking position in 904 // `tiledOperands`. 905 Value outputTensor = operands[opOperand->getOperandNumber()]; 906 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { 907 Value inserted = builder.create<tensor::InsertSliceOp>( 908 loc, sliceOp.getSource().getType(), results[resultIdx], 909 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), 910 sliceOp.getStrides(), sliceOp.getStaticOffsets(), 911 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 912 tensorResults.push_back(inserted); 913 } else { 914 tensorResults.push_back(results[resultIdx]); 915 } 916 ++resultIdx; 917 } 918 return tensorResults; 919 } 920 921 SmallVector<std::optional<SliceParameters>> 922 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, 923 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, 924 ArrayRef<OpFoldResult> tileSizes, 925 ArrayRef<OpFoldResult> sizeBounds, 926 bool omitPartialTileCheck) { 927 assert(ivs.size() == static_cast<size_t>(llvm::count_if( 928 llvm::make_range(tileSizes.begin(), tileSizes.end()), 929 [](OpFoldResult v) { return !isZeroIndex(v); })) && 930 "expected as many ivs as non-zero sizes"); 931 932 // Construct (potentially temporary) mins and maxes on which to apply maps 933 // that define tile subshapes. 934 SmallVector<OpFoldResult> lbs = 935 computeTileOffsets(builder, loc, ivs, tileSizes); 936 SmallVector<OpFoldResult> subShapeSizes = 937 computeTileSizes(builder, loc, tileSizes, sizeBounds); 938 939 assert(static_cast<int64_t>(valuesToTile.size()) <= 940 linalgOp->getNumOperands() && 941 "more value to tile than operands."); 942 SmallVector<std::optional<SliceParameters>> allSliceParams; 943 allSliceParams.reserve(valuesToTile.size()); 944 for (auto [opOperand, val] : 945 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { 946 Value shapedOp = val; 947 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); 948 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); 949 // Use `opOperand` as is if it is not tiled and not an output tensor. Having 950 // an extract/insert slice pair for all output tensors simplifies follow up 951 // transformations such as padding and bufferization since the 952 // extract/insert slice pairs make the accessed iteration argument 953 // subdomains explicit. 954 955 Type operandType = opOperand.get().getType(); 956 if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() && 957 linalgOp.isDpsInit(&opOperand))) { 958 allSliceParams.push_back(std::nullopt); 959 LLVM_DEBUG(llvm::dbgs() 960 << ": not tiled: use shape: " << operandType << "\n"); 961 continue; 962 } 963 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); 964 965 allSliceParams.push_back(computeSliceParameters( 966 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, 967 omitPartialTileCheck)); 968 } 969 970 return allSliceParams; 971 } 972 973 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, 974 LinalgOp linalgOp, ValueRange valuesToTile, 975 ArrayRef<OpFoldResult> ivs, 976 ArrayRef<OpFoldResult> tileSizes, 977 ArrayRef<OpFoldResult> sizeBounds, 978 bool omitPartialTileCheck) { 979 SmallVector<std::optional<SliceParameters>> allSliceParameter = 980 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, 981 tileSizes, sizeBounds, omitPartialTileCheck); 982 SmallVector<Value> tiledShapes; 983 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { 984 Value valueToTile = std::get<0>(item); 985 std::optional<SliceParameters> sliceParams = std::get<1>(item); 986 tiledShapes.push_back( 987 sliceParams.has_value() 988 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) 989 : valueToTile); 990 } 991 return tiledShapes; 992 } 993 994 void offsetIndices(OpBuilder &b, LinalgOp linalgOp, 995 ArrayRef<OpFoldResult> offsets) { 996 IRRewriter rewriter(b); 997 offsetIndices(rewriter, linalgOp, offsets); 998 } 999 1000 void offsetIndices(RewriterBase &b, LinalgOp linalgOp, 1001 ArrayRef<OpFoldResult> offsets) { 1002 if (!linalgOp.hasIndexSemantics()) 1003 return; 1004 1005 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { 1006 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) 1007 continue; 1008 OpBuilder::InsertionGuard guard(b); 1009 b.setInsertionPointAfter(indexOp); 1010 AffineExpr index, offset; 1011 bindDims(b.getContext(), index, offset); 1012 OpFoldResult applied = makeComposedFoldedAffineApply( 1013 b, indexOp.getLoc(), index + offset, 1014 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); 1015 Value materialized = 1016 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); 1017 b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) { 1018 return use.getOwner() != materialized.getDefiningOp(); 1019 }); 1020 } 1021 } 1022 1023 /// Get the reassociation maps to fold the result of a extract_slice (or source 1024 /// of a insert_slice) operation with given offsets, and sizes to its 1025 /// rank-reduced version. This is only done for the cases where the size is 1 1026 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 1027 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 1028 /// potentially cannot be handled). 1029 std::optional<SmallVector<ReassociationIndices>> 1030 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 1031 SmallVector<ReassociationIndices> reassociation; 1032 ReassociationIndices curr; 1033 for (const auto &it : llvm::enumerate(mixedSizes)) { 1034 auto dim = it.index(); 1035 auto size = it.value(); 1036 curr.push_back(dim); 1037 auto attr = size.dyn_cast<Attribute>(); 1038 if (attr && attr.cast<IntegerAttr>().getInt() == 1) 1039 continue; 1040 reassociation.emplace_back(ReassociationIndices{}); 1041 std::swap(reassociation.back(), curr); 1042 } 1043 // When the reassociations are not empty, then fold the remaining 1044 // unit-dimensions into the last dimension. If the reassociations so far is 1045 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 1046 if (!curr.empty() && !reassociation.empty()) 1047 reassociation.back().append(curr.begin(), curr.end()); 1048 return reassociation; 1049 } 1050 1051 /// Return the identity numeric value associated to the give op. 1052 std::optional<Attribute> getNeutralElement(Operation *op) { 1053 // Builder only used as helper for attribute creation. 1054 OpBuilder b(op->getContext()); 1055 Type resultType = op->getResult(0).getType(); 1056 if (auto floatType = resultType.dyn_cast<FloatType>()) { 1057 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 1058 if (isa<arith::AddFOp>(op)) 1059 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 1060 if (isa<arith::MulFOp>(op)) 1061 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 1062 if (isa<arith::MaxFOp>(op)) 1063 return b.getFloatAttr(resultType, 1064 llvm::APFloat::getInf(semantic, /*Negative=*/true)); 1065 if (isa<arith::MinFOp>(op)) 1066 return b.getFloatAttr( 1067 resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false)); 1068 return std::nullopt; 1069 } 1070 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 1071 return b.getIntegerAttr(resultType, 0); 1072 if (isa<arith::AndIOp>(op)) 1073 return b.getIntegerAttr(resultType, -1); 1074 if (isa<arith::MaxSIOp>(op)) 1075 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 1076 if (isa<arith::MinSIOp>(op)) 1077 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 1078 if (isa<arith::MulIOp>(op)) 1079 return b.getIntegerAttr(resultType, 1); 1080 return std::nullopt; 1081 } 1082 1083 } // namespace linalg 1084 } // namespace mlir 1085