1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Utils/Utils.h" 14 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/Arith/Utils/Utils.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/Linalg/IR/Linalg.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/Dialect/Tensor/Utils/Utils.h" 28 #include "mlir/Dialect/Utils/IndexingUtils.h" 29 #include "mlir/Dialect/Utils/StaticValueUtils.h" 30 #include "mlir/IR/AffineExpr.h" 31 #include "mlir/IR/AffineExprVisitor.h" 32 #include "mlir/IR/AffineMap.h" 33 #include "mlir/IR/Matchers.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/Pass/Pass.h" 36 #include "llvm/ADT/TypeSwitch.h" 37 #include "llvm/Support/Debug.h" 38 #include <optional> 39 40 #define DEBUG_TYPE "linalg-utils" 41 42 using namespace mlir; 43 using namespace presburger; 44 using namespace mlir::linalg; 45 using namespace mlir::scf; 46 47 static bool isZero(OpFoldResult v) { 48 if (!v) 49 return false; 50 if (auto attr = v.dyn_cast<Attribute>()) { 51 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); 52 return intAttr && intAttr.getValue().isZero(); 53 } 54 if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>()) 55 return cst.value() == 0; 56 return false; 57 } 58 59 namespace { 60 61 // Helper visitor to determine whether an AffineExpr is tiled. 62 // This is achieved by traversing every AffineDimExpr with position `pos` and 63 // checking whether the corresponding `tileSizes[pos]` is non-zero. 64 // This also enforces only positive coefficients occur in multiplications. 65 // 66 // Example: 67 // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] 68 // 69 struct TileCheck : public AffineExprVisitor<TileCheck> { 70 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} 71 72 void visitDimExpr(AffineDimExpr expr) { 73 isTiled |= !isZero(tileSizes[expr.getPosition()]); 74 } 75 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { 76 visit(expr.getLHS()); 77 visit(expr.getRHS()); 78 if (expr.getKind() == mlir::AffineExprKind::Mul) 79 assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 && 80 "nonpositive multiplying coefficient"); 81 } 82 bool isTiled = false; 83 ArrayRef<OpFoldResult> tileSizes; 84 }; 85 86 } // namespace 87 88 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { 89 if (!expr) 90 return false; 91 TileCheck t(tileSizes); 92 t.visit(expr); 93 return t.isTiled; 94 } 95 96 // Checks whether the `map varies with respect to a non-zero `tileSize`. 97 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { 98 if (!map) 99 return false; 100 for (unsigned r = 0; r < map.getNumResults(); ++r) 101 if (isTiled(map.getResult(r), tileSizes)) 102 return true; 103 return false; 104 } 105 106 std::optional<RegionMatcher::BinaryOpKind> 107 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { 108 auto ®ion = op.getRegion(); 109 if (!llvm::hasSingleElement(region)) 110 return std::nullopt; 111 112 Block &block = region.front(); 113 if (block.getNumArguments() != 2 || 114 !block.getArgument(0).getType().isSignlessIntOrFloat() || 115 !block.getArgument(1).getType().isSignlessIntOrFloat()) 116 return std::nullopt; 117 118 auto &ops = block.getOperations(); 119 if (!llvm::hasSingleElement(block.without_terminator())) 120 return std::nullopt; 121 122 using mlir::matchers::m_Val; 123 auto a = m_Val(block.getArgument(0)); 124 auto b = m_Val(block.getArgument(1)); 125 126 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); 127 if (addPattern.match(&ops.back())) 128 return BinaryOpKind::IAdd; 129 130 return std::nullopt; 131 } 132 133 /// Explicit instantiation of loop nest generator for different loop types. 134 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; 135 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; 136 template struct mlir::linalg::GenerateLoopNest<AffineForOp>; 137 138 /// Given a list of subview ranges, extract individual values for lower, upper 139 /// bounds and steps and put them into the corresponding vectors. 140 static void unpackRanges(OpBuilder &builder, Location loc, 141 ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, 142 SmallVectorImpl<Value> &ubs, 143 SmallVectorImpl<Value> &steps) { 144 for (Range range : ranges) { 145 lbs.emplace_back( 146 getValueOrCreateConstantIndexOp(builder, loc, range.offset)); 147 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size)); 148 steps.emplace_back( 149 getValueOrCreateConstantIndexOp(builder, loc, range.stride)); 150 } 151 } 152 153 namespace mlir { 154 namespace linalg { 155 156 bool allIndexingsAreProjectedPermutation(LinalgOp op) { 157 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) { 158 return m.isProjectedPermutation(/*allowZeroInResults=*/true); 159 }); 160 } 161 162 bool hasOnlyScalarElementwiseOp(Region &r) { 163 if (!llvm::hasSingleElement(r)) 164 return false; 165 for (Operation &op : r.front()) { 166 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, 167 linalg::YieldOp, linalg::IndexOp>(op) || 168 OpTrait::hasElementwiseMappableTraits(&op)) || 169 llvm::any_of(op.getResultTypes(), 170 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 171 return false; 172 } 173 return true; 174 } 175 176 bool isElementwise(LinalgOp op) { 177 if (op.getNumLoops() != op.getNumParallelLoops()) 178 return false; 179 180 if (!allIndexingsAreProjectedPermutation(op)) 181 return false; 182 183 // TODO: relax the restrictions on indexing map. 184 for (OpOperand *opOperand : op.getDpsInitOperands()) { 185 if (!op.getMatchingIndexingMap(opOperand).isPermutation()) 186 return false; 187 } 188 return hasOnlyScalarElementwiseOp(op->getRegion(0)); 189 } 190 191 bool isParallelIterator(utils::IteratorType iteratorType) { 192 return iteratorType == utils::IteratorType::parallel; 193 } 194 195 bool isReductionIterator(utils::IteratorType iteratorType) { 196 return iteratorType == utils::IteratorType::reduction; 197 } 198 199 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on 200 /// the type of `source`. 201 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { 202 if (source.getType().isa<UnrankedMemRefType, MemRefType>()) 203 return b.createOrFold<memref::DimOp>(loc, source, dim); 204 if (source.getType().isa<UnrankedTensorType, RankedTensorType>()) 205 return b.createOrFold<tensor::DimOp>(loc, source, dim); 206 llvm_unreachable("Expected MemRefType or TensorType"); 207 } 208 209 OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, 210 int64_t dim) { 211 auto shapedType = source.getType().cast<ShapedType>(); 212 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) 213 return createOrFoldDimOp(b, loc, source, dim); 214 return b.getIndexAttr(shapedType.getDimSize(dim)); 215 } 216 217 /// Given an operation, retrieves the value of each dynamic dimension through 218 /// constructing the necessary DimOp operators. 219 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) { 220 SmallVector<Value, 4> dynOperands; 221 auto shapedType = val.getType().cast<ShapedType>(); 222 for (const auto &dim : llvm::enumerate(shapedType.getShape())) { 223 if (dim.value() == ShapedType::kDynamic) 224 dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index())); 225 } 226 return dynOperands; 227 } 228 229 void getUpperBoundForIndex(Value value, AffineMap &boundMap, 230 SmallVectorImpl<Value> &boundOperands, 231 bool constantRequired) { 232 // Initialize `boundMap` and `boundOperands` to the identity returning 233 // `value`. This combination is the default result of the method if no 234 // simplification is possible. 235 assert(value.getType().isIndex() && "expect value to have index type"); 236 boundMap = AffineMap::getMultiDimIdentityMap(1, value.getContext()); 237 boundOperands.assign({value}); 238 canonicalizeMapAndOperands(&boundMap, &boundOperands); 239 240 // Continue only if there is an affine index computation to simplify. 241 Operation *definingOp = value.getDefiningOp(); 242 if (!definingOp || !isa<AffineApplyOp, AffineMinOp>(definingOp)) 243 return; 244 245 // Get the backward slice containing the affine index computation. 246 SetVector<Operation *> backwardSlice; 247 getBackwardSlice(definingOp, &backwardSlice, [](Operation *op) { 248 return isa<AffineApplyOp, AffineMinOp>(op); 249 }); 250 backwardSlice.insert(definingOp); 251 252 // Setup a system of affine constraints that describe the index computation. 253 FlatAffineValueConstraints constraints; 254 255 // Helper to find or create an identifier for the given value. 256 auto findOrCreateId = [&](Value value) { 257 if (!constraints.containsVar(value)) { 258 constraints.appendDimVar(value); 259 return true; 260 } 261 unsigned pos; 262 constraints.findVar(value, &pos); 263 return pos < constraints.getNumDimVars(); 264 }; 265 // Helper to get the position for the given value. 266 auto getPosition = [&](Value value) { 267 unsigned pos; 268 bool exists = constraints.findVar(value, &pos); 269 (void)exists; 270 assert(exists && "expect to find the identifier"); 271 return pos; 272 }; 273 274 // Add the affine operations in `backwardSlice` to the constraints. 275 for (Operation *op : llvm::reverse(backwardSlice)) { 276 // Add an identifier for all op results and operands. 277 if (!(llvm::all_of(op->getResults(), findOrCreateId) && 278 llvm::all_of(op->getOperands(), findOrCreateId))) 279 return; 280 281 // Add AffineApplyOps to the constraints. 282 if (auto applyOp = dyn_cast<AffineApplyOp>(op)) { 283 AffineMap map = constraints.computeAlignedMap(applyOp.getAffineMap(), 284 applyOp.getOperands()); 285 if (failed(constraints.addBound(IntegerPolyhedron::EQ, 286 getPosition(applyOp.getResult()), map))) 287 return; 288 continue; 289 } 290 // Add AffineMinOps to the constraints. 291 auto minOp = cast<AffineMinOp>(op); 292 AffineMap map = constraints.computeAlignedMap(minOp.getAffineMap(), 293 minOp.getOperands()); 294 if (failed(constraints.addBound(IntegerPolyhedron::UB, 295 getPosition(minOp.getResult()), map, 296 /*isClosedBound=*/true))) 297 return; 298 } 299 300 // Obtain an upper bound for the affine index computation by projecting out 301 // all temporary results and expressing the upper bound for `value` in terms 302 // of the terminals of the index computation. 303 unsigned pos = getPosition(value); 304 if (constantRequired) { 305 auto ubConst = constraints.getConstantBound64( 306 FlatAffineValueConstraints::BoundType::UB, pos); 307 if (!ubConst) 308 return; 309 310 boundMap = AffineMap::getConstantMap(*ubConst, value.getContext()); 311 return; 312 } 313 314 SmallVector<AffineMap> lowerBounds(1), upperBounds(1); 315 constraints.getSliceBounds(pos, 1, value.getContext(), &lowerBounds, 316 &upperBounds, 317 /*getClosedUB=*/true); 318 // Verify `upperBounds[0]` is valid and has at least one result. 319 if (!upperBounds[0] || upperBounds[0].getNumResults() == 0) 320 return; 321 322 // Set `boundMap` and `boundOperands` to the computed upper bound. 323 boundMap = upperBounds[0]; 324 constraints.getAllValues(&boundOperands); 325 erase_value(boundOperands, value); 326 canonicalizeMapAndOperands(&boundMap, &boundOperands); 327 } 328 329 FailureOr<int64_t> getConstantUpperBoundForIndex(Value value) { 330 // Compute an upper bound for `value`. 331 AffineMap boundMap; 332 SmallVector<Value> boundOperands; 333 getUpperBoundForIndex(value, boundMap, boundOperands, 334 /*constantRequired=*/true); 335 336 // Search the results of `boundMap` for constant upper bounds. 337 SmallVector<int64_t> constantBounds; 338 for (AffineExpr result : boundMap.getResults()) 339 if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) 340 constantBounds.push_back(constExpr.getValue()); 341 342 // Return the minimal upper bound or failure if none is found. 343 if (constantBounds.empty()) 344 return failure(); 345 return *std::min_element(constantBounds.begin(), constantBounds.end()); 346 } 347 348 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, 349 Value source, Value pad, bool nofold) { 350 // Exit if `source` is not defined by an ExtractSliceOp. 351 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); 352 if (!sliceOp) 353 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 354 355 // Search the `source` use-def chain for padded LinalgOps. 356 Value current = sliceOp.getSource(); 357 while (current) { 358 auto linalgOp = current.getDefiningOp<LinalgOp>(); 359 if (!linalgOp) 360 break; 361 OpResult opResult = current.cast<OpResult>(); 362 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); 363 } 364 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; 365 366 // Exit if the search fails to match a tensor::PadOp at the end of the matched 367 // LinalgOp sequence. 368 if (!padOp) 369 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 370 371 // Exit if the padded result type does not match. 372 if (sliceOp.getSource().getType() != type) 373 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 374 375 // Exit if the LinalgOps are not high padded. 376 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { 377 return getConstantIntValue(ofr) != static_cast<int64_t>(0); 378 })) 379 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 380 381 // Exit if `padOpSliceOp`, which defines the slice used by 382 // `padOp`, is rank-reducing. 383 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); 384 if (!padOpSliceOp || 385 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) 386 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 387 388 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size 389 // of the slice padded by `padOp`. 390 if (llvm::any_of( 391 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), 392 [](std::tuple<OpFoldResult, OpFoldResult> it) { 393 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)); 394 })) 395 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 396 397 // Exit if the padding values do not match. 398 Attribute padOpPadAttr, padAttr; 399 Value padOpPad = padOp.getConstantPaddingValue(); 400 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) || 401 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr) 402 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 403 404 // Return the padded result if the padding values and sizes match. 405 return sliceOp.getSource(); 406 } 407 408 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, 409 Value outputTensor, 410 ArrayRef<int64_t> transposeVector) { 411 auto resultTensorType = outputTensor.getType().cast<RankedTensorType>(); 412 Type elementType = resultTensorType.getElementType(); 413 414 assert(isPermutationVector(transposeVector) && 415 "expect transpose vector to be a permutation"); 416 assert(transposeVector.size() == 417 static_cast<size_t>(resultTensorType.getRank()) && 418 "expect transpose vector size to match result tensor rank"); 419 420 // Compute the transpose and the indentity indexing maps. 421 SmallVector<AffineMap> indexingMaps = { 422 inversePermutation(AffineMap::getPermutationMap( 423 SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()), 424 b.getContext())), 425 AffineMap::getMultiDimIdentityMap(transposeVector.size(), 426 b.getContext())}; 427 SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(), 428 utils::IteratorType::parallel); 429 430 // Create a GenericOp to transpose `inputTensor` into `outputTensor`. 431 auto transposeOp = 432 b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor, 433 indexingMaps, iteratorTypes); 434 Region &body = transposeOp.getRegion(); 435 body.push_back(new Block()); 436 body.front().addArguments({elementType, elementType}, {loc, loc}); 437 438 // Create the body of the transpose operation. 439 OpBuilder::InsertionGuard g(b); 440 b.setInsertionPointToEnd(&body.front()); 441 b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0)); 442 return transposeOp; 443 } 444 445 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { 446 auto memrefTypeTo = to.getType().cast<MemRefType>(); 447 #ifndef NDEBUG 448 auto memrefTypeFrom = from.getType().cast<MemRefType>(); 449 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && 450 "`from` and `to` memref must have the same rank"); 451 #endif // NDEBUG 452 453 AffineMap id = 454 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); 455 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), 456 utils::IteratorType::parallel); 457 return b.create<linalg::GenericOp>( 458 loc, 459 /*inputs=*/from, 460 /*outputs=*/to, 461 /*indexingMaps=*/llvm::ArrayRef({id, id}), 462 /*iteratorTypes=*/iteratorTypes, 463 [](OpBuilder &b, Location loc, ValueRange args) { 464 b.create<linalg::YieldOp>(loc, args.front()); 465 }); 466 } 467 468 /// Specialization to build an scf "for" nest. 469 template <> 470 void GenerateLoopNest<scf::ForOp>::doit( 471 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 472 ArrayRef<utils::IteratorType> iteratorTypes, 473 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 474 ValueRange)> 475 bodyBuilderFn, 476 ArrayRef<linalg::ProcInfo> procInfo) { 477 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 478 "expected as many entries for proc info as number of loops, even if " 479 "they are null entries"); 480 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 481 ? SmallVector<Value>{} 482 : linalgOp.getDpsInitOperands(); 483 484 SmallVector<Value, 4> lbs, ubs, steps; 485 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 486 LoopNest loopNest = mlir::scf::buildLoopNest( 487 b, loc, lbs, ubs, steps, iterArgInitValues, 488 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { 489 assert(iterArgs.size() == iterArgInitValues.size() && 490 "expect the number of output tensors and iter args to match"); 491 SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); 492 if (!iterArgs.empty()) { 493 operandValuesToUse = linalgOp.getDpsInputOperands(); 494 operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); 495 } 496 return bodyBuilderFn(b, loc, ivs, operandValuesToUse); 497 }); 498 499 if (loopNest.loops.empty() || procInfo.empty()) 500 return; 501 502 // Filter out scf.for loops that were created out of parallel dimensions. 503 for (const auto &loop : llvm::enumerate(loopNest.loops)) { 504 if (procInfo[loop.index()].distributionMethod == 505 DistributionMethod::Cyclic) { 506 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, 507 procInfo[loop.index()].nprocs); 508 } 509 } 510 } 511 512 /// Specialization to build affine "for" nest. 513 template <> 514 void GenerateLoopNest<AffineForOp>::doit( 515 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 516 ArrayRef<utils::IteratorType> iteratorTypes, 517 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 518 ValueRange)> 519 bodyBuilderFn, 520 ArrayRef<linalg::ProcInfo> /*procInfo*/) { 521 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 522 ? SmallVector<Value>{} 523 : linalgOp.getDpsInitOperands(); 524 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); 525 SmallVector<Value, 4> lbs, ubs, steps; 526 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 527 528 // Affine loops require constant steps. 529 SmallVector<int64_t, 4> constantSteps; 530 constantSteps.reserve(steps.size()); 531 for (Value v : steps) { 532 auto op = v.getDefiningOp<arith::ConstantIndexOp>(); 533 assert(op && "Affine loops require constant steps"); 534 constantSteps.push_back(op.value()); 535 } 536 537 mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, 538 [&](OpBuilder &b, Location loc, ValueRange ivs) { 539 bodyBuilderFn(b, loc, ivs, 540 linalgOp->getOperands()); 541 }); 542 } 543 544 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. 545 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, 546 Value nprocs, Value &lb, Value &ub, 547 Value &step) { 548 AffineExpr d0, d1; 549 bindDims(b.getContext(), d0, d1); 550 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); 551 lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); 552 step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); 553 } 554 555 /// Generates a loop nest consisting of scf.parallel and scf.for, depending 556 /// on the `iteratorTypes.` Consecutive parallel loops create a single 557 /// scf.parallel operation; each sequential loop creates a new scf.for 558 /// operation. The body of the innermost loop is populated by 559 /// `bodyBuilderFn` that accepts a range of induction variables for all 560 /// loops. `ivStorage` is used to store the partial list of induction 561 /// variables. 562 // TODO: this function can be made iterative instead. However, it 563 // will have at most as many recursive calls as nested loops, which rarely 564 // exceeds 10. 565 static void generateParallelLoopNest( 566 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, 567 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, 568 ArrayRef<linalg::ProcInfo> procInfo, 569 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, 570 SmallVectorImpl<Value> &ivStorage) { 571 assert(lbs.size() == ubs.size()); 572 assert(lbs.size() == steps.size()); 573 assert(lbs.size() == iteratorTypes.size()); 574 assert(procInfo.empty() || (lbs.size() == procInfo.size())); 575 576 // If there are no (more) loops to be generated, generate the body and be 577 // done with it. 578 if (iteratorTypes.empty()) { 579 bodyBuilderFn(b, loc, ivStorage); 580 return; 581 } 582 583 // If there are no outer parallel loops, generate one sequential loop and 584 // recurse. 585 if (!isParallelIterator(iteratorTypes.front())) { 586 LoopNest singleLoop = buildLoopNest( 587 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(), 588 [&](OpBuilder &b, Location loc, ValueRange ivs) { 589 ivStorage.append(ivs.begin(), ivs.end()); 590 generateParallelLoopNest( 591 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), 592 iteratorTypes.drop_front(), 593 procInfo.empty() ? procInfo : procInfo.drop_front(), 594 bodyBuilderFn, ivStorage); 595 }); 596 return; 597 } 598 599 unsigned nLoops = iteratorTypes.size(); 600 unsigned numProcessed = 0; 601 DistributionMethod distributionMethod = DistributionMethod::None; 602 if (procInfo.empty()) { 603 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); 604 } else { 605 distributionMethod = procInfo.front().distributionMethod; 606 numProcessed = 607 nLoops - procInfo 608 .drop_while([&](linalg::ProcInfo p) { 609 return p.distributionMethod == distributionMethod; 610 }) 611 .size(); 612 } 613 614 auto remainderProcInfo = 615 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed); 616 switch (distributionMethod) { 617 case DistributionMethod::None: { 618 // Generate a single parallel loop-nest operation for all outermost 619 // parallel loops and recurse. 620 b.create<scf::ParallelOp>( 621 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 622 steps.take_front(numProcessed), 623 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 624 ivStorage.append(localIvs.begin(), localIvs.end()); 625 generateParallelLoopNest( 626 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 627 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 628 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 629 bodyBuilderFn, ivStorage); 630 }); 631 return; 632 } 633 case DistributionMethod::Cyclic: { 634 // Generate a single parallel loop-nest operation for all outermost 635 // parallel loops and recurse. 636 b.create<scf::ParallelOp>( 637 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 638 steps.take_front(numProcessed), 639 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 640 ivStorage.append(localIvs.begin(), localIvs.end()); 641 generateParallelLoopNest( 642 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 643 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 644 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 645 bodyBuilderFn, ivStorage); 646 }); 647 return; 648 } 649 case DistributionMethod::CyclicNumProcsGeNumIters: { 650 // Check (for the processed loops) that the iteration is in-bounds. 651 ArithBuilder ab(b, loc); 652 Value cond = ab.slt(lbs[0], ubs[0]); 653 for (unsigned i = 1; i < numProcessed; ++i) 654 cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); 655 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 656 b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { 657 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), 658 ubs.drop_front(numProcessed), 659 steps.drop_front(numProcessed), 660 iteratorTypes.drop_front(numProcessed), 661 remainderProcInfo, bodyBuilderFn, ivStorage); 662 b.create<scf::YieldOp>(loc, ValueRange{}); 663 }); 664 return; 665 } 666 case DistributionMethod::CyclicNumProcsEqNumIters: 667 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed 668 // with inner loop generation. 669 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 670 generateParallelLoopNest( 671 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), 672 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), 673 remainderProcInfo, bodyBuilderFn, ivStorage); 674 return; 675 } 676 } 677 678 /// Specialization for generating a mix of parallel and sequential scf loops. 679 template <> 680 void GenerateLoopNest<scf::ParallelOp>::doit( 681 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 682 ArrayRef<utils::IteratorType> iteratorTypes, 683 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 684 ValueRange)> 685 bodyBuilderFn, 686 ArrayRef<linalg::ProcInfo> procInfo) { 687 SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics() 688 ? SmallVector<Value>{} 689 : linalgOp.getDpsInitOperands(); 690 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); 691 // This function may be passed more iterator types than ranges. 692 assert(iteratorTypes.size() >= loopRanges.size() && 693 "expected iterator type for all ranges"); 694 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 695 "expected proc information for all loops when present"); 696 iteratorTypes = iteratorTypes.take_front(loopRanges.size()); 697 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; 698 unsigned numLoops = iteratorTypes.size(); 699 ivs.reserve(numLoops); 700 lbsStorage.reserve(numLoops); 701 ubsStorage.reserve(numLoops); 702 stepsStorage.reserve(numLoops); 703 704 // Get the loop lb, ub, and step. 705 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage); 706 707 // Modify the lb, ub, and step based on the distribution options. 708 for (const auto &it : llvm::enumerate(procInfo)) { 709 if (it.value().distributionMethod != linalg::DistributionMethod::None) { 710 updateBoundsForCyclicDistribution( 711 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()], 712 ubsStorage[it.index()], stepsStorage[it.index()]); 713 } 714 } 715 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); 716 generateParallelLoopNest( 717 b, loc, lbs, ubs, steps, iteratorTypes, procInfo, 718 [&](OpBuilder &b, Location loc, ValueRange ivs) { 719 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); 720 }, 721 ivs); 722 723 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); 724 } 725 726 static Value materializeTiledShape(OpBuilder &builder, Location loc, 727 Value valueToTile, 728 const SliceParameters &sliceParams) { 729 auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); 730 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) 731 .Case([&](MemRefType) { 732 return builder.create<memref::SubViewOp>( 733 loc, valueToTile, sliceParams.offsets, 734 sliceParams.sizes, sliceParams.strides); 735 }) 736 .Case([&](RankedTensorType) { 737 return builder.create<tensor::ExtractSliceOp>( 738 loc, valueToTile, sliceParams.offsets, 739 sliceParams.sizes, sliceParams.strides); 740 }) 741 .Default([](ShapedType) -> Operation * { 742 llvm_unreachable("Unexpected shaped type"); 743 }); 744 return sliceOp->getResult(0); 745 } 746 747 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, 748 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 749 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 750 ArrayRef<OpFoldResult> subShapeSizes, 751 bool omitPartialTileCheck) { 752 SliceParameters sliceParams = 753 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, 754 ubs, subShapeSizes, omitPartialTileCheck); 755 return materializeTiledShape(builder, loc, valueToTile, sliceParams); 756 } 757 758 SliceParameters 759 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, 760 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 761 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 762 ArrayRef<OpFoldResult> subShapeSizes, 763 bool omitPartialTileCheck) { 764 auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); 765 assert(shapedType && "only shaped types can be tiled"); 766 ArrayRef<int64_t> shape = shapedType.getShape(); 767 int64_t rank = shapedType.getRank(); 768 769 // Compute offsets/sizes/strides for the tile. 770 SliceParameters sliceParams; 771 sliceParams.offsets.reserve(rank); 772 sliceParams.sizes.reserve(rank); 773 sliceParams.strides.reserve(rank); 774 for (unsigned r = 0; r < rank; ++r) { 775 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); 776 if (!isTiled(map.getSubMap({r}), tileSizes)) { 777 sliceParams.offsets.push_back(builder.getIndexAttr(0)); 778 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r); 779 sliceParams.sizes.push_back(dim); 780 sliceParams.strides.push_back(builder.getIndexAttr(1)); 781 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); 782 continue; 783 } 784 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n"); 785 786 // Tiling creates a new slice at the proper index, the slice step is 1 787 // (i.e. the op does not subsample, stepping occurs in the loop). 788 auto m = map.getSubMap({r}); 789 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n"); 790 IRRewriter rewriter(builder); 791 OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs); 792 sliceParams.offsets.push_back(offset); 793 OpFoldResult closedIntSize = 794 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes); 795 // Resulting size needs to be made half open interval again. 796 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext()); 797 OpFoldResult size = 798 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize); 799 LLVM_DEBUG(llvm::dbgs() 800 << "computeSliceParameters: raw size: " << size << "\n"); 801 LLVM_DEBUG(llvm::dbgs() 802 << "computeSliceParameters: new offset: " << offset << "\n"); 803 sliceParams.strides.push_back(builder.getIndexAttr(1)); 804 805 if (omitPartialTileCheck) { 806 // We statically know that the partial/boundary tile condition is 807 // unnecessary. 808 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 809 sliceParams.sizes.push_back(size); 810 continue; 811 } 812 813 // The size of the subview / extract_slice should be trimmed to avoid 814 // out-of-bounds accesses, unless: 815 // a. We statically know the subshape size divides the shape size evenly. 816 // b. The subshape size is 1. According to the way the loops are set up, 817 // tensors with "0" dimensions would never be constructed. 818 int64_t shapeSize = shape[r]; 819 std::optional<int64_t> sizeCst = getConstantIntValue(size); 820 auto hasTileSizeOne = sizeCst && *sizeCst == 1; 821 auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && 822 ((shapeSize % *sizeCst) == 0); 823 if (!hasTileSizeOne && !dividesEvenly) { 824 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize 825 << ", size: " << size 826 << ": make sure in bound with affine.min\n"); 827 828 AffineExpr dim0, dim1, dim2; 829 bindDims(builder.getContext(), dim0, dim1, dim2); 830 831 // Get the dimension size for this dimension. We need to first calculate 832 // the max index and then plus one. This is important because for 833 // convolution ops, we have its input window dimension's affine map of the 834 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window 835 // dimension and `s0` is stride. Directly use the dimension size of 836 // output/filer window dimensions will cause incorrect calculation. 837 AffineMap minusOneMap = 838 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}}) 839 .front(); 840 AffineMap plusOneMap = 841 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}}) 842 .front(); 843 SmallVector<OpFoldResult> maxIndices = 844 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) { 845 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap, 846 {ub}); 847 })); 848 OpFoldResult maxIndex = 849 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices); 850 OpFoldResult d = 851 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex}); 852 853 // Compute min(dim - offset, size) to avoid out-of-bounds accesses. 854 AffineMap minMap = AffineMap::inferFromExprList( 855 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}) 856 .front(); 857 size = 858 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset}); 859 } 860 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 861 sliceParams.sizes.push_back(size); 862 } 863 return sliceParams; 864 } 865 866 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, 867 ArrayRef<OpFoldResult> ivs, 868 ArrayRef<OpFoldResult> tileSizes) { 869 SmallVector<OpFoldResult> offsets; 870 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { 871 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); 872 bool isTiled = !isZero(tileSizes[idx]); 873 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); 874 LLVM_DEBUG(llvm::dbgs() 875 << "computeTileOffsets: " << offsets.back() << "\n"); 876 } 877 return offsets; 878 } 879 880 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, 881 ArrayRef<OpFoldResult> tileSizes, 882 ArrayRef<OpFoldResult> sizeBounds) { 883 SmallVector<OpFoldResult> sizes; 884 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { 885 bool isTiled = !isZero(tileSizes[idx]); 886 // Before composing, we need to make range a closed interval. 887 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; 888 AffineExpr d0 = getAffineDimExpr(0, b.getContext()); 889 IRRewriter rewriter(b); 890 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size)); 891 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n"); 892 } 893 return sizes; 894 } 895 896 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { 897 if (op.hasBufferSemantics()) 898 return {}; 899 return llvm::to_vector( 900 llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) { 901 return operands[opOperand->getOperandNumber()].getType(); 902 })); 903 } 904 905 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, 906 LinalgOp op, ValueRange operands, 907 ValueRange results) { 908 if (op.hasBufferSemantics()) 909 return {}; 910 SmallVector<Value> tensorResults; 911 tensorResults.reserve(results.size()); 912 // Insert a insert_slice for each output tensor. 913 unsigned resultIdx = 0; 914 for (OpOperand *opOperand : op.getDpsInitOperands()) { 915 // TODO: use an interface/adaptor to avoid leaking position in 916 // `tiledOperands`. 917 Value outputTensor = operands[opOperand->getOperandNumber()]; 918 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { 919 Value inserted = builder.create<tensor::InsertSliceOp>( 920 loc, sliceOp.getSource().getType(), results[resultIdx], 921 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), 922 sliceOp.getStrides(), sliceOp.getStaticOffsets(), 923 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 924 tensorResults.push_back(inserted); 925 } else { 926 tensorResults.push_back(results[resultIdx]); 927 } 928 ++resultIdx; 929 } 930 return tensorResults; 931 } 932 933 SmallVector<std::optional<SliceParameters>> 934 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, 935 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, 936 ArrayRef<OpFoldResult> tileSizes, 937 ArrayRef<OpFoldResult> sizeBounds, 938 bool omitPartialTileCheck) { 939 assert(ivs.size() == static_cast<size_t>(llvm::count_if( 940 llvm::make_range(tileSizes.begin(), tileSizes.end()), 941 [](OpFoldResult v) { return !isZero(v); })) && 942 "expected as many ivs as non-zero sizes"); 943 944 // Construct (potentially temporary) mins and maxes on which to apply maps 945 // that define tile subshapes. 946 SmallVector<OpFoldResult> lbs = 947 computeTileOffsets(builder, loc, ivs, tileSizes); 948 SmallVector<OpFoldResult> subShapeSizes = 949 computeTileSizes(builder, loc, tileSizes, sizeBounds); 950 951 assert(static_cast<int64_t>(valuesToTile.size()) <= 952 linalgOp->getNumOperands() && 953 "more value to tile than operands."); 954 SmallVector<std::optional<SliceParameters>> allSliceParams; 955 allSliceParams.reserve(valuesToTile.size()); 956 for (auto [opOperand, val] : 957 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { 958 Value shapedOp = val; 959 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); 960 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); 961 // Use `opOperand` as is if it is not tiled and not an output tensor. Having 962 // an extract/insert slice pair for all output tensors simplifies follow up 963 // transformations such as padding and bufferization since the 964 // extract/insert slice pairs make the accessed iteration argument 965 // subdomains explicit. 966 967 Type operandType = opOperand.get().getType(); 968 if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() && 969 linalgOp.isDpsInit(&opOperand))) { 970 allSliceParams.push_back(std::nullopt); 971 LLVM_DEBUG(llvm::dbgs() 972 << ": not tiled: use shape: " << operandType << "\n"); 973 continue; 974 } 975 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); 976 977 allSliceParams.push_back(computeSliceParameters( 978 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, 979 omitPartialTileCheck)); 980 } 981 982 return allSliceParams; 983 } 984 985 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, 986 LinalgOp linalgOp, ValueRange valuesToTile, 987 ArrayRef<OpFoldResult> ivs, 988 ArrayRef<OpFoldResult> tileSizes, 989 ArrayRef<OpFoldResult> sizeBounds, 990 bool omitPartialTileCheck) { 991 SmallVector<std::optional<SliceParameters>> allSliceParameter = 992 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, 993 tileSizes, sizeBounds, omitPartialTileCheck); 994 SmallVector<Value> tiledShapes; 995 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { 996 Value valueToTile = std::get<0>(item); 997 std::optional<SliceParameters> sliceParams = std::get<1>(item); 998 tiledShapes.push_back( 999 sliceParams.has_value() 1000 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) 1001 : valueToTile); 1002 } 1003 return tiledShapes; 1004 } 1005 1006 void offsetIndices(OpBuilder &b, LinalgOp linalgOp, 1007 ArrayRef<OpFoldResult> offsets) { 1008 IRRewriter rewriter(b); 1009 offsetIndices(rewriter, linalgOp, offsets); 1010 } 1011 1012 void offsetIndices(RewriterBase &b, LinalgOp linalgOp, 1013 ArrayRef<OpFoldResult> offsets) { 1014 if (!linalgOp.hasIndexSemantics()) 1015 return; 1016 1017 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { 1018 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) 1019 continue; 1020 OpBuilder::InsertionGuard guard(b); 1021 b.setInsertionPointAfter(indexOp); 1022 AffineExpr index, offset; 1023 bindDims(b.getContext(), index, offset); 1024 OpFoldResult applied = makeComposedFoldedAffineApply( 1025 b, indexOp.getLoc(), index + offset, 1026 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); 1027 Value materialized = 1028 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); 1029 b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) { 1030 return use.getOwner() != materialized.getDefiningOp(); 1031 }); 1032 } 1033 } 1034 1035 /// Get the reassociation maps to fold the result of a extract_slice (or source 1036 /// of a insert_slice) operation with given offsets, and sizes to its 1037 /// rank-reduced version. This is only done for the cases where the size is 1 1038 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 1039 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 1040 /// potentially cannot be handled). 1041 std::optional<SmallVector<ReassociationIndices>> 1042 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 1043 SmallVector<ReassociationIndices> reassociation; 1044 ReassociationIndices curr; 1045 for (const auto &it : llvm::enumerate(mixedSizes)) { 1046 auto dim = it.index(); 1047 auto size = it.value(); 1048 curr.push_back(dim); 1049 auto attr = size.dyn_cast<Attribute>(); 1050 if (attr && attr.cast<IntegerAttr>().getInt() == 1) 1051 continue; 1052 reassociation.emplace_back(ReassociationIndices{}); 1053 std::swap(reassociation.back(), curr); 1054 } 1055 // When the reassociations are not empty, then fold the remaining 1056 // unit-dimensions into the last dimension. If the reassociations so far is 1057 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 1058 if (!curr.empty() && !reassociation.empty()) 1059 reassociation.back().append(curr.begin(), curr.end()); 1060 return reassociation; 1061 } 1062 1063 /// Return the identity numeric value associated to the give op. 1064 std::optional<Attribute> getNeutralElement(Operation *op) { 1065 // Builder only used as helper for attribute creation. 1066 OpBuilder b(op->getContext()); 1067 Type resultType = op->getResult(0).getType(); 1068 if (auto floatType = resultType.dyn_cast<FloatType>()) { 1069 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 1070 if (isa<arith::AddFOp>(op)) 1071 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 1072 if (isa<arith::MulFOp>(op)) 1073 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 1074 if (isa<arith::MaxFOp>(op)) 1075 return b.getFloatAttr(resultType, 1076 llvm::APFloat::getInf(semantic, /*Negative=*/true)); 1077 if (isa<arith::MinFOp>(op)) 1078 return b.getFloatAttr( 1079 resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false)); 1080 return std::nullopt; 1081 } 1082 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 1083 return b.getIntegerAttr(resultType, 0); 1084 if (isa<arith::AndIOp>(op)) 1085 return b.getIntegerAttr(resultType, -1); 1086 if (isa<arith::MaxSIOp>(op)) 1087 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 1088 if (isa<arith::MinSIOp>(op)) 1089 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 1090 if (isa<arith::MulIOp>(op)) 1091 return b.getIntegerAttr(resultType, 1); 1092 return std::nullopt; 1093 } 1094 1095 } // namespace linalg 1096 } // namespace mlir 1097