1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 10 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Builders.h" 13 14 #include <numeric> 15 #include <optional> 16 17 using namespace mlir; 18 19 std::optional<SmallVector<ReassociationIndices>> 20 mlir::getReassociationIndicesForReshape(ShapedType sourceType, 21 ShapedType targetType) { 22 if (sourceType.getRank() > targetType.getRank()) 23 return getReassociationIndicesForCollapse(sourceType.getShape(), 24 targetType.getShape()); 25 if (sourceType.getRank() < targetType.getRank()) 26 return getReassociationIndicesForCollapse(targetType.getShape(), 27 sourceType.getShape()); 28 return std::nullopt; 29 } 30 31 std::optional<SmallVector<ReassociationIndices>> 32 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, 33 ArrayRef<int64_t> targetShape) { 34 if (sourceShape.size() <= targetShape.size()) 35 return std::nullopt; 36 unsigned sourceDim = 0; 37 SmallVector<ReassociationIndices> reassociationMap; 38 reassociationMap.reserve(targetShape.size()); 39 40 ReassociationIndices currIndices; 41 int64_t prodOfCollapsedDims = 1; 42 while (sourceDim < sourceShape.size()) { 43 unsigned targetDim = reassociationMap.size(); 44 // If we have mapped all the target dimensions stop and handle the remaining 45 // tail of size-1 dimensions explicitly. 46 if (targetDim == targetShape.size()) 47 break; 48 49 int64_t currTargetShape = targetShape[targetDim]; 50 while (sourceDim < (sourceShape.size() - 1) && 51 sourceShape[sourceDim] != ShapedType::kDynamic && 52 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) { 53 prodOfCollapsedDims *= sourceShape[sourceDim]; 54 currIndices.push_back(sourceDim++); 55 } 56 57 // If the current expanded dimension is dynamic, then the collapsed 58 // dimensions should also be dynamic and product of all previous unprocessed 59 // dimensions of the expanded shape should be 1. 60 if (sourceShape[sourceDim] == ShapedType::kDynamic && 61 (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1)) 62 return std::nullopt; 63 64 // If the collapsed dim is dynamic, the current expanded dim should also 65 // be dynamic. 66 if (currTargetShape == ShapedType::kDynamic && 67 sourceShape[sourceDim] != ShapedType::kDynamic) 68 return std::nullopt; 69 70 // For static shapes, if the product of dimensions of the expanded shape 71 // should match the collapsed dimension shape. 72 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) 73 return std::nullopt; 74 75 currIndices.push_back(sourceDim++); 76 reassociationMap.emplace_back(ReassociationIndices{}); 77 std::swap(reassociationMap.back(), currIndices); 78 prodOfCollapsedDims = 1; 79 } 80 // All the dimensions in the target must have been processed. 81 if (reassociationMap.size() != targetShape.size()) 82 return std::nullopt; 83 // Process any remaining entries in the source shape. They all need to be 84 // 1 or dynamic. 85 for (; sourceDim < sourceShape.size(); sourceDim++) { 86 if (sourceShape[sourceDim] != ShapedType::kDynamic && 87 sourceShape[sourceDim] != 1) 88 return std::nullopt; 89 // The map is empty when the target type is a scalar. 90 if (!reassociationMap.empty()) 91 reassociationMap.back().push_back(sourceDim); 92 } 93 return reassociationMap; 94 } 95 96 std::optional<SmallVector<ReassociationIndices>> 97 mlir::composeReassociationIndices( 98 ArrayRef<ReassociationIndices> producerReassociations, 99 ArrayRef<ReassociationIndices> consumerReassociations, 100 MLIRContext *context) { 101 SmallVector<ReassociationIndices> composedIndices; 102 // Make the producer the larger sized vector. If they are of same size, the 103 // resulting reshape is not a supported reshape op. 104 if (producerReassociations.size() == consumerReassociations.size()) 105 return std::nullopt; 106 if (producerReassociations.size() < consumerReassociations.size()) 107 std::swap(producerReassociations, consumerReassociations); 108 109 // Handle the corner case of the result being a rank 0 shaped type. Return an 110 // empty reassociation. 111 if (consumerReassociations.empty()) 112 return composedIndices; 113 114 size_t consumerDims = std::accumulate( 115 consumerReassociations.begin(), consumerReassociations.end(), 0, 116 [](size_t all, ReassociationIndicesRef indices) { 117 return all + indices.size(); 118 }); 119 if (producerReassociations.size() != consumerDims) 120 return std::nullopt; 121 122 for (ReassociationIndicesRef consumerIndices : consumerReassociations) { 123 ReassociationIndices reassociations; 124 for (int64_t consumerIndex : consumerIndices) { 125 llvm::append_range(reassociations, producerReassociations[consumerIndex]); 126 } 127 composedIndices.push_back(std::move(reassociations)); 128 } 129 return composedIndices; 130 } 131 132 SmallVector<SmallVector<AffineExpr, 2>, 2> 133 mlir::convertReassociationIndicesToExprs( 134 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) { 135 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; 136 for (const auto &indices : reassociationIndices) { 137 SmallVector<AffineExpr, 2> reassociationMap; 138 reassociationMap.reserve(indices.size()); 139 for (int64_t index : indices) 140 reassociationMap.push_back(mlir::getAffineDimExpr(index, context)); 141 reassociationMaps.push_back(std::move(reassociationMap)); 142 } 143 return reassociationMaps; 144 } 145 146 template <typename AffineExprTy> 147 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { 148 unsigned pos = 0; 149 for (const auto &exprs : exprArrays) { 150 for (auto expr : exprs) { 151 expr.walk([&pos](AffineExpr e) { 152 if (auto d = dyn_cast<AffineExprTy>(e)) 153 pos = std::max(pos, d.getPosition()); 154 }); 155 } 156 } 157 return pos; 158 } 159 160 ArrayAttr mlir::getReassociationIndicesAttribute( 161 OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) { 162 SmallVector<Attribute, 4> reassociationAttr = 163 llvm::to_vector<4>(llvm::map_range( 164 reassociation, [&](const ReassociationIndices &indices) -> Attribute { 165 return cast<Attribute>(b.getI64ArrayAttr(indices)); 166 })); 167 return b.getArrayAttr(reassociationAttr); 168 } 169 170 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices( 171 ArrayRef<ReassociationExprs> reassociationExprs) { 172 SmallVector<ReassociationIndices, 2> reassociationIndices; 173 for (const auto &exprs : reassociationExprs) { 174 ReassociationIndices indices; 175 indices.reserve(exprs.size()); 176 for (const auto &expr : exprs) 177 indices.push_back(cast<AffineDimExpr>(expr).getPosition()); 178 reassociationIndices.push_back(indices); 179 } 180 return reassociationIndices; 181 } 182 183 SmallVector<AffineMap, 4> 184 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { 185 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); 186 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && 187 "Expected symbol-less expressions"); 188 SmallVector<AffineMap, 4> maps; 189 maps.reserve(reassociation.size()); 190 for (const auto &exprs : reassociation) { 191 assert(!exprs.empty()); 192 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); 193 } 194 return maps; 195 } 196 197 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation, 198 int *invalidIndex) { 199 if (reassociation.empty()) 200 return true; 201 unsigned nDims = reassociation[0].getNumDims(); 202 unsigned nextExpectedDim = 0; 203 for (const auto &it : llvm::enumerate(reassociation)) { 204 auto m = it.value(); 205 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { 206 if (invalidIndex) 207 *invalidIndex = it.index(); 208 return false; 209 } 210 for (auto e : m.getResults()) { 211 auto d = dyn_cast<AffineDimExpr>(e); 212 if (!d || d.getPosition() != nextExpectedDim++) { 213 if (invalidIndex) 214 *invalidIndex = it.index(); 215 return false; 216 } 217 } 218 } 219 if (nextExpectedDim != nDims) { 220 if (invalidIndex) 221 *invalidIndex = reassociation.size() - 1; 222 return false; 223 } 224 return true; 225 } 226 227 LogicalResult mlir::reshapeLikeShapesAreCompatible( 228 function_ref<LogicalResult(const Twine &)> emitError, 229 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, 230 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) { 231 unsigned expandedDimStart = 0; 232 for (const auto &map : llvm::enumerate(reassociationMaps)) { 233 bool foundDynamicShape = false; 234 int64_t linearizedStaticShape = 1; 235 236 for (const auto &dim : llvm::enumerate( 237 expandedShape.slice(expandedDimStart, map.value().size()))) { 238 if (ShapedType::isDynamic(dim.value())) 239 foundDynamicShape = true; 240 else 241 linearizedStaticShape *= dim.value(); 242 } 243 if (foundDynamicShape) { 244 if (!ShapedType::isDynamic(collapsedShape[map.index()])) { 245 return emitError( 246 "expected dimension " + Twine(map.index()) + 247 " of collapsed type to be dynamic since one or more of the " 248 "corresponding dimensions in the expanded type is dynamic"); 249 } 250 } else { 251 if (collapsedShape[map.index()] != linearizedStaticShape) { 252 return emitError("expected dimension " + Twine(map.index()) + 253 " of collapsed type to be static value of " + 254 Twine(linearizedStaticShape)); 255 } 256 } 257 expandedDimStart += map.value().size(); 258 } 259 return success(); 260 } 261 262 bool mlir::hasNonIdentityLayout(Type type) { 263 if (auto memrefType = dyn_cast<MemRefType>(type)) 264 return !memrefType.getLayout().isIdentity(); 265 return false; 266 } 267 268 llvm::SmallBitVector 269 mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape, 270 ArrayRef<Range> sliceParams) { 271 assert(sliceParams.size() == sliceInputShape.size() && 272 "only supports non rank-reducing case"); 273 llvm::SmallBitVector mask(sliceInputShape.size()); 274 unsigned idx = 0; 275 for (const auto &[offset, size, stride] : sliceParams) { 276 std::optional<int64_t> offsetConst = getConstantIntValue(offset); 277 std::optional<int64_t> strideConst = getConstantIntValue(stride); 278 mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) || 279 (!strideConst || *strideConst != 1) || 280 (!offsetConst || *offsetConst != 0); 281 idx++; 282 } 283 return mask; 284 } 285 286 llvm::SmallBitVector mlir::getLinearizedDimensions( 287 ArrayRef<ReassociationIndices> reassociationIndices) { 288 llvm::SmallBitVector result(reassociationIndices.size()); 289 for (const auto &it : llvm::enumerate(reassociationIndices)) 290 result[it.index()] = it.value().size() > 1; 291 return result; 292 } 293 294 SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams( 295 MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) { 296 unsigned loopIdx = 0; 297 auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1); 298 auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0); 299 SmallVector<Range> offsetsSizesAndStrides; 300 offsetsSizesAndStrides.reserve(collapseShapeInputShape.size()); 301 for (const auto &it : llvm::enumerate(reassociationIndices)) { 302 // Case 1: Linearized dimensions that have also been sliced. These 303 // are size of 1 because we are iterating over these dimensions. The 304 // offsets are exactly the de-linearized multi-indices. 305 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) { 306 llvm::append_range( 307 offsetsSizesAndStrides, 308 llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range { 309 return Range{getAsOpFoldResult(v), oneAttr, oneAttr}; 310 })); 311 continue; 312 } 313 314 // Case 2: One or possibly multiple combined input dimensions, but we 315 // have proven that these are not sliced. In this case we just take 316 // the full extent of each dimension in the reassociation list. 317 if (linearizedDimensions[it.index()]) { 318 llvm::append_range( 319 offsetsSizesAndStrides, 320 llvm::map_range(it.value(), [&](int64_t idx) -> Range { 321 return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; 322 })); 323 continue; 324 } 325 326 // Case 3: A single index, but it may be sliced. 327 offsetsSizesAndStrides.push_back(sliceParams[it.index()]); 328 } 329 return offsetsSizesAndStrides; 330 } 331 332 SmallVector<Range> 333 SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx, 334 ValueRange tileIndices) { 335 auto one = IntegerAttr::get(IndexType::get(ctx), 1); 336 auto zero = IntegerAttr::get(IndexType::get(ctx), 0); 337 SmallVector<Range> insertParams; 338 insertParams.reserve(linearizedDimensions.size()); 339 unsigned loopIdx = 0; 340 for (unsigned i = 0; i < linearizedDimensions.size(); i++) { 341 if (linearizedDimensions[i] && slicedDimensions[i]) { 342 insertParams.push_back(Range{tileIndices[loopIdx++], one, one}); 343 continue; 344 } 345 insertParams.push_back(Range{zero, sliceParams[i].size, one}); 346 } 347 return insertParams; 348 } 349 350 /// Returns the index of the only non-unit dimension among `indices` of `shape`, 351 /// if such a dimension exists and `indices` has more than one element. 352 /// Otherwise, return std::nullopt. 353 static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices, 354 ArrayRef<int64_t> shape) { 355 // Return false if more than one of the dimensions in this group are not 1. 356 std::optional<int64_t> dimIndex; 357 if (indices.size() < 2) 358 return std::nullopt; 359 for (int64_t idx : indices) { 360 if (shape[idx] != 1) { 361 if (dimIndex != std::nullopt) 362 return std::nullopt; 363 dimIndex = idx; 364 } 365 } 366 return dimIndex; 367 } 368 369 // For each segment in the reassociation indices, check whether we can 370 // simplify that segment with a rank-reducing extract slice. We can do this if 371 // all but (exactly) one of the corresponding source dims is 1. 372 static SmallVector<std::optional<int64_t>> getCollapseShapeTrivialSegments( 373 RankedTensorType sourceType, 374 ArrayRef<ReassociationIndices> reassociationIndices) { 375 SmallVector<std::optional<int64_t>> trivialSegments; 376 for (const auto &indices : reassociationIndices) 377 trivialSegments.push_back( 378 getUniqueNonUnitDim(indices, sourceType.getShape())); 379 return trivialSegments; 380 } 381 382 /// Returns true if any of the segments of the reassociation indices for a 383 /// collapsing reshape can be simplified using a rank-reducing slice. 384 static FailureOr<SmallVector<std::optional<int64_t>>> 385 canCollapseShapeBeSimplifiedByRankReducingSlice( 386 RankedTensorType sourceType, 387 ArrayRef<ReassociationIndices> reassociationIndices) { 388 SmallVector<std::optional<int64_t>> trivialSegments = 389 getCollapseShapeTrivialSegments(sourceType, reassociationIndices); 390 if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) { 391 return idx.has_value(); 392 })) 393 return failure(); 394 return trivialSegments; 395 } 396 397 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> 398 mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo( 399 RankedTensorType sourceType, 400 ArrayRef<ReassociationIndices> reassociationIndices) { 401 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments = 402 canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType, 403 reassociationIndices); 404 if (failed(trivialSegments)) 405 return failure(); 406 407 // Create the expected result shape of the rank-reducing slice. 408 SmallVector<int64_t> sliceShape; 409 for (const auto &[nonUnitDim, indices] : 410 llvm::zip(*trivialSegments, reassociationIndices)) { 411 if (nonUnitDim) { 412 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim)); 413 continue; 414 } 415 llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) { 416 return sourceType.getDimSize(idx); 417 })); 418 } 419 auto sliceType = 420 RankedTensorType::get(sliceShape, sourceType.getElementType()); 421 422 // If the rank-reducing slice simplified every segment, then we are done. 423 if (sliceShape.size() == reassociationIndices.size()) 424 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType, 425 std::nullopt}; 426 427 // Otherwise, we need to create a new collapse_shape op for the segments that 428 // weren't covered by the slice. By design, the new reassociation indices has 429 // the same number of groups as the old reassociation indices. 430 SmallVector<ReassociationIndices> newReassociationIndices; 431 SmallVector<int64_t, 2> reassociation; 432 int64_t groupIdx = 0; 433 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) { 434 reassociation.push_back(dimIdx); 435 if ((*trivialSegments)[groupIdx] || 436 reassociation.size() == reassociationIndices[groupIdx].size()) { 437 newReassociationIndices.push_back(reassociation); 438 reassociation.clear(); 439 groupIdx++; 440 } 441 } 442 443 return CollapseShapeRankReducingSliceSimplificationInfo{ 444 sliceType, newReassociationIndices}; 445 } 446 447 PackingMetadata mlir::computePackingMetadata(int64_t packedRank, 448 ArrayRef<int64_t> innerDimPos) { 449 PackingMetadata res; 450 res.insertPositions.reserve(innerDimPos.size()); 451 // The pack insert position is the position + the number of previously 452 // inserted positions + offset. 453 // The offset controls whether the packing dimension is the first or last. 454 // 455 // Example 456 // ======= 457 // Consider packing from a hypothetical ABCD layout to ABCDba whose 458 // pack.inner_dims is [1, 0]. The first step consists in undoing the 459 // permutation and producing AaBbCD. This is achieved purely by computing the 460 // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One 461 // possibility, is to produce insert positions [2, 0], this would result in an 462 // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert 463 // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1). 464 // The latter is what we expect from packing. 465 int64_t offset = 1; 466 for (int64_t pos : innerDimPos) { 467 int64_t numInsertedBefore = llvm::count_if( 468 innerDimPos, [&pos](int64_t pos2) { return pos > pos2; }); 469 res.insertPositions.push_back(pos + numInsertedBefore + offset); 470 } 471 472 DenseSet<int64_t> posSet(res.insertPositions.begin(), 473 res.insertPositions.end()); 474 res.reassociations.reserve(packedRank); 475 for (int64_t i = 1; i <= packedRank; ++i) { 476 res.outerPositions.push_back(i - 1); 477 if (!posSet.contains(i)) { 478 res.reassociations.push_back(ReassociationIndices{i - 1}); 479 continue; 480 } 481 res.reassociations.push_back(ReassociationIndices{i - 1, i}); 482 ++i; 483 } 484 return res; 485 } 486