Home
last modified time | relevance | path

Searched refs:getDimSize (Results 1 – 25 of 87) sorted by relevance

1234

/llvm-project/mlir/lib/Dialect/AMX/IR/
H A DAMXDialect.cpp42 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
43 if (tp.getDimSize(0) > kMaxRows) in verifyMultShape()
44 return op->emitOpError("bad row height: ") << tp.getDimSize(0); in verifyMultShape()
54 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
55 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; in verify()
56 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); in verify()
/llvm-project/mlir/lib/Dialect/Tosa/IR/
H A DTosaOps.cpp601 auto dim2 = shape.getDimSize(i); in inferReturnTypeComponents()
636 outShape.push_back(inputShape.getDimSize(i)); in inferReturnTypeComponents()
654 outputShape[0] = inputShape.getDimSize(0); in inferReturnTypeComponents()
655 outputShape[1] = inputShape.getDimSize(1); in inferReturnTypeComponents()
656 int64_t inWidth = inputShape.getDimSize(2); in inferReturnTypeComponents()
703 outputShape[i] = operandShape.getDimSize(i); in isCompatibleReturnTypes()
704 if (outputShape[i] != operandShape.getDimSize(i)) in isCompatibleReturnTypes()
732 concatDimSize += operandShape.getDimSize(axis); in inferReturnTypeComponents()
777 outShape[0] = inputShape.getDimSize(0); in inferReturnTypeComponents()
781 outShape[1] = weightShape.getDimSize( in inferReturnTypeComponents()
[all...]
H A DTosaCanonicalizations.cpp166 auto sz = inputTy.getDimSize(idx); in matchAndRewrite()
179 newShape.push_back(inputTy.getDimSize(permValues[i])); in matchAndRewrite()
474 inputType.getDimSize(axis)) {
486 sliceStarts[axis] -= inputType.getDimSize(axis);
591 if (inputTy.getDimSize(getAxis()) == 1) in mulBinaryFolder()
898 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \ in fold()
1013 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1)) in fold()
/llvm-project/mlir/lib/Dialect/Tosa/Transforms/
H A DTosaDecomposeTransposeConv.cpp58 int64_t kernelHeight = weightTy.getDimSize(1); in createOpAndInfer()
59 int64_t kernelWidth = weightTy.getDimSize(2); in createOpAndInfer()
130 int64_t batch = inputTy.getDimSize(0); in matchAndRewrite()
132 int64_t outputChannels = weightTy.getDimSize(0); in matchAndRewrite()
133 int64_t weightHeight = weightTy.getDimSize(1); in matchAndRewrite()
134 int64_t weightWidth = weightTy.getDimSize(2); in matchAndRewrite()
135 int64_t inputChannels = weightTy.getDimSize(3); in matchAndRewrite()
161 weightHeight = weightTy.getDimSize(1); in matchAndRewrite()
162 weightWidth = weightTy.getDimSize(2); in matchAndRewrite()
200 inputPadding[2] += restridedWeightTy.getDimSize( in matchAndRewrite()
[all...]
H A DTosaDecomposeDepthwise.cpp125 inputType.getDimSize(0), inputType.getDimSize(1), in matchAndRewrite()
126 inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]}; in matchAndRewrite()
/llvm-project/mlir/unittests/Interfaces/
H A DInferTypeOpInterfaceTest.cpp62 EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10); in TEST_F()
64 EXPECT_EQ(range.getShape(1).getDimSize(0), 1); in TEST_F()
80 EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30); in TEST_F()
84 EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10); in TEST_F()
99 EXPECT_EQ(range.getShape(0).getDimSize(0), 10); in TEST_F()
100 EXPECT_EQ(range.getShape(0).getDimSize(1), 20); in TEST_F()
103 EXPECT_EQ(range.getShape(1).getDimSize(0), 1); in TEST_F()
/llvm-project/mlir/lib/Dialect/Linalg/Transforms/
H A DNamedOpConversions.cpp53 if (kernelTy.getDimSize(3) != 1) in matchAndReplaceDepthwiseConv()
60 {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, in matchAndReplaceDepthwiseConv()
70 RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), in matchAndReplaceDepthwiseConv()
71 initTy.getDimSize(2), initTy.getDimSize(3)}, in matchAndReplaceDepthwiseConv()
/llvm-project/mlir/lib/Dialect/Vector/Transforms/
H A DLowerVectorBroadcast.cpp91 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) in matchAndRewrite()
101 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { in matchAndRewrite()
136 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) in matchAndRewrite()
144 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { in matchAndRewrite()
H A DLowerVectorContract.cpp100 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { in reshapeLoad()
122 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { in reshapeStore()
464 int64_t reductionSize = vecType.getDimSize(reductionDim); in getReductionSize()
854 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); in matchAndRewriteMaskableOp()
864 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); in matchAndRewriteMaskableOp()
1054 dimSize = lhsType.getDimSize(lhsIndex); in lowerParallel()
1062 dimSize = rhsType.getDimSize(rhsIndex); in lowerParallel()
1141 int64_t dimSize = lhsType.getDimSize(lhsIndex); in lowerReduction()
1142 if (dimSize != rhsType.getDimSize(rhsIndex)) in lowerReduction()
1247 for (int64_t d = 0, e = resType.getDimSize( in matchAndRewrite()
[all...]
H A DLowerVectorMask.cpp66 int64_t dim = dstType.getDimSize(0); in matchAndRewrite()
125 if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { in matchAndRewrite()
135 SmallVector<bool> values(dstType.getDimSize(0), false); in matchAndRewrite()
H A DVectorDistribute.cpp47 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i)) in calculateImplicitMap()
75 int64_t distributedSize = distributedVectorType.getDimSize(index); in buildDistributedOffset()
456 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); in getDistributedType()
790 int64_t scale = distributedType.getDimSize(vectorPos); in delinearizeLaneId()
1127 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { in matchAndRewrite()
1141 distributedType.getDimSize(i);
1415 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(
[all...]
/llvm-project/mlir/lib/Dialect/NVGPU/IR/
H A DNVGPUDialect.cpp601 if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) { in verify()
603 << matrixB.getDimSize(1) << ", it is not supported"; in verify()
627 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
628 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
629 return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
631 << dstMemrefType.getDimSize(0) << "][" in verify()
632 << dstMemrefType.getDimSize(1) in verify()
645 int64_t sizeM = accType.getFragmented().getDimSize( in verify()
[all...]
/llvm-project/mlir/lib/Dialect/ArmNeon/Transforms/
H A DLowerContractionToSMMLAPattern.cpp59 auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); in matchAndRewrite()
60 auto dimN = rhsType.getDimSize(0); in matchAndRewrite()
61 auto dimK = rhsType.getDimSize(1); in matchAndRewrite()
63 if (lhsType.getDimSize(lhsType.getRank() - 1) != in matchAndRewrite()
64 rhsType.getDimSize(rhsType.getRank() - 1)) { in matchAndRewrite()
/llvm-project/mlir/lib/Dialect/Mesh/IR/
H A DMeshOps.cpp256 using Dim = std::decay_t<decltype(shape.getDimSize(0))>; in maybeInsertSourceShardingAnnotation()
956 auto operandDimSize = DimensionSize(operandType.getDimSize(axis)); in getAsmResultNames()
957 auto resultDimSize = DimensionSize(resultType.getDimSize(axis)); in getAsmResultNames()
976 result.getLoc(), operandType.getDimSize(axis), in verifySymbolUses()
977 resultType.getDimSize(axis), axis))) { in verifySymbolUses()
989 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
990 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1001 resultType.getDimSize(concatAxis), concatAxis))) { in verifySymbolUses()
1006 resultType.getDimSize(splitAxis), splitAxis))) {
1021 result.getLoc(), operandType.getDimSize(axi in verifySymbolUses()
[all...]
/llvm-project/mlir/lib/Dialect/ArmSME/IR/
H A DUtils.cpp81 loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0)); in createLoopOverTileSlices()
107 int64_t vectorRows = vType.getDimSize(0); in isMultipleOfSMETileVectorType()
108 int64_t vectorCols = vType.getDimSize(1); in isMultipleOfSMETileVectorType()
/llvm-project/mlir/lib/Dialect/ArmSME/Transforms/
H A DVectorLegalization.cpp150 {std::min(type.getDimSize(0), smeTileType.getDimSize(0)), in decomposeToSMETiles()
151 std::min(type.getDimSize(1), smeTileType.getDimSize(1))}), in decomposeToSMETiles()
166 int64_t vectorRows = type.getDimSize(0); in getNumberOfSMETilesForVectorType()
167 int64_t vectorCols = type.getDimSize(1); in getNumberOfSMETilesForVectorType()
443 if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 || in matchAndRewrite()
444 vectorType.getDimSize(1) > 16))) in matchAndRewrite()
454 auto minTileSlices = smeTileType.getDimSize(0); in matchAndRewrite()
761 if (sourceType.getRank() != 2 || sourceType.getDimSize( in matchAndRewrite()
[all...]
/llvm-project/mlir/lib/Dialect/Vector/IR/
H A DVectorOps.cpp257 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset); in isDisjointTransferIndices()
694 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1) in matchAndRewrite()
874 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second)) in verifyOutputShape()
900 expectedResultDims.push_back(lhsType.getDimSize(i)); in verifyOutputShape()
907 expectedResultDims.push_back(rhsType.getDimSize(i)); in verifyOutputShape()
1367 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) { in foldExtractOpFromExtractChain()
1849 extractStridedSliceOp.getType().getDimSize(lastOffset) != in foldExtractStridedOpFromInsertChain()
1850 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset)) in foldExtractStridedOpFromInsertChain()
1913 : insertOp.getSourceVectorType().getDimSize(di in foldScalarExtractFromFromElements()
[all...]
/llvm-project/mlir/lib/Dialect/NVGPU/Transforms/
H A DOptimizeSharedMemory.cpp57 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * in permuteVectorOffset()
72 int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim)); in permuteVectorOffset()
168 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); in optimizeSharedMemoryReadsAndWrites()
/llvm-project/mlir/lib/Dialect/MemRef/Transforms/
H A DExpandRealloc.cpp71 cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0); in matchAndRewrite()
82 cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0); in matchAndRewrite()
H A DExpandOps.cpp86 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0); in matchAndRewrite()
105 auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i)); in matchAndRewrite()
122 staticStride *= op.getType().getDimSize(i); in runOnOperation()
/llvm-project/mlir/lib/Dialect/AMX/Transforms/
H A DLegalizeForLLVMExport.cpp34 auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); in getTileSizes()
35 auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); in getTileSizes()
/llvm-project/mlir/include/mlir/Interfaces/
H A DInferTypeOpInterface.h68 int64_t getDimSize(int index) const;
73 return ShapedType::isDynamic(getDimSize(index)); in isDynamicDim()
/llvm-project/mlir/lib/Dialect/Tensor/Utils/
H A DUtils.cpp160 if (resultType.getDimSize(resultDim) != 1) in isCastLikeExtractSliceOp()
184 if (sourceType.getDimSize(dim) != 1)
/llvm-project/mlir/lib/Dialect/Tensor/IR/
H A DTensorOps.cpp63 return builder.getIndexAttr(tensorType.getDimSize(dim)); in getMixedSize()
402 join.push_back(two.getDimSize(i)); in joinShapes()
406 join.push_back(one.getDimSize(i)); in joinShapes()
409 if (one.getDimSize(i) != two.getDimSize(i)) in joinShapes()
411 join.push_back(one.getDimSize(i)); in joinShapes()
538 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i))); in inferResultType()
544 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim)); in inferResultType()
589 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i))); in verify()
600 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(di in verify()
[all...]
/llvm-project/mlir/lib/Dialect/MemRef/Utils/
H A DMemRefUtils.cpp39 runningStride *= type.getDimSize(curDim); in isStaticShapeAndContiguousRowMajor()
44 while (curDim >= 0 && type.getDimSize(curDim) == 1) { in isStaticShapeAndContiguousRowMajor()

1234