Lines Matching +full:non +full:- +full:batch

1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
13 //===----------------------------------------------------------------------===//
41 #define DEBUG_TYPE "linalg-drop-unit-dims"
48 /// blockArgument corresponding to init is used in the region. This is a fix-up
55 /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
56 /// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
57 /// affine_map<(d0) -> (0, d0)>],
64 /// } -> tensor<1x1xf32>
71 /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
73 /// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
74 /// affine_map<(d0) -> (0, d0)>,
75 /// affine_map<(d0) -> (0, d0)>],
82 /// } -> tensor<1x1xf32>
111 newInputOperands.push_back(op->get());
122 rewriter.setInsertionPointAfterValue(op->get());
123 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
125 loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType);
128 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
141 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
145 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
151 block->addArgument(bbarg.getType(), loc);
153 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
156 for (auto &op : genericOp.getBody()->getOperations()) {
166 //===---------------------------------------------------------------------===//
167 // Drop loops that are unit-extents within Linalg operations.
168 //===---------------------------------------------------------------------===//
170 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
175 /// affine_map<(d0, d1) -> (0, d1)>,
176 /// affine_map<(d0, d1) -> (d0, 0)>,
177 /// affine_map<(d0, d1) -> (d0, d1)>
186 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
189 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
191 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
197 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
205 /// affine_map<(d0, d1) -> (d1)>,
206 /// affine_map<(d0, d1) -> (d0)>,
207 /// affine_map<(d0, d1) -> (d0, d1)>
216 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
223 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
233 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
244 indexOp.getDim() - droppedDims);
290 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
310 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
366 assert(!isUnitDim(dim) && "expected non unit-extent");
372 // Fold all following dimensions that are unit-extent.
392 // 1. Check if any of the iteration dimensions are unit-trip count. They will
393 // end up being unit-trip count if they are used to index into a unit-dim
420 // 2. Compute the iterator types of the modified op by dropping the one-trip
441 // - modified affine map to use.
442 // - shape of the operands after the unit-dims are dropped.
443 // - the reassociation indices used to convert from the original
448 // access a unit-extent tensor. Consider moving this out of this specific
449 // transformation as a stand-alone transformation. Kept here right now due
465 for (OpOperand &opOperand : genericOp->getOpOperands()) {
497 // either through use of reshapes or rank-reducing slices as
500 for (OpOperand &opOperand : genericOp->getOpOperands()) {
535 Value origDest = genericOp.getDpsInitOperand(index)->get();
562 rewriter.replaceOp(genericOp, result->replacements);
571 //===---------------------------------------------------------------------===//
572 // Drop dimensions that are unit-extents within tensor operations.
573 //===---------------------------------------------------------------------===//
595 // Fail for non-constant padding values. The body of the pad could
598 // TODO: Support non-constant padding values.
602 padOp, "unimplemented: non-constant padding value");
642 assert(!unitDims.contains(dim) && "expected non unit-extent");
645 // Fold all following dimensions that are unit-extent.
672 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
692 /// Convert `extract_slice` operations to rank-reduced versions.
705 reassociation->size() == static_cast<size_t>(resultType.getRank()))
713 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
725 /// Convert `insert_slice` operations to rank-reduced versions.
739 reassociation->size() == static_cast<size_t>(sourceType.getRank()))
750 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
763 /// Patterns that are used to canonicalize the use of unit-extent dims for
817 /// Pass that removes unit-extent dims within generic ops.
825 MLIRContext *context = op->getContext();
846 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
847 bool lastDim = pos == rank - 1;
849 for (int64_t i = 0; i < rank - 1; i++) {
850 if (i == pos || (lastDim && i == pos - 1))
937 for (auto attr : contractionOp->getAttrs()) {
940 collapsedOp->setAttr(attr.getName(), attr.getValue());
971 /// Look for unit batch dims to collapse.
983 if (contractionDims.batch.size() != 1)
985 auto batchDim = contractionDims.batch[0];
1003 /// Patterns for reducing non-batch dimensions
1020 /// Look for non-batch spatial dims to collapse.
1042 operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1056 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1071 // Unbatching patterns for unit batch size
1082 // Non-batch rank 1 reducing patterns
1087 // Batch rank 1 reducing patterns
1095 // Non-batch rank 0 reducing patterns