1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/Utils/Utils.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/Utils/StaticValueUtils.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "llvm/ADT/SmallBitVector.h" 19 #include <numeric> 20 21 using namespace mlir; 22 23 std::optional<SmallVector<OpFoldResult>> 24 mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, 25 ShapedType expandedType, 26 ArrayRef<ReassociationIndices> reassociation, 27 ArrayRef<OpFoldResult> inputShape) { 28 29 SmallVector<Value> outputShapeValues; 30 SmallVector<int64_t> outputShapeInts; 31 // For zero-rank inputs, all dims in result shape are unit extent. 32 if (inputShape.empty()) { 33 outputShapeInts.resize(expandedType.getRank(), 1); 34 return getMixedValues(outputShapeInts, outputShapeValues, b); 35 } 36 37 // Check for all static shapes. 38 if (expandedType.hasStaticShape()) { 39 ArrayRef<int64_t> staticShape = expandedType.getShape(); 40 outputShapeInts.assign(staticShape.begin(), staticShape.end()); 41 return getMixedValues(outputShapeInts, outputShapeValues, b); 42 } 43 44 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic); 45 for (const auto &it : llvm::enumerate(reassociation)) { 46 ReassociationIndices indexGroup = it.value(); 47 48 int64_t indexGroupStaticSizesProductInt = 1; 49 bool foundDynamicShape = false; 50 for (int64_t index : indexGroup) { 51 int64_t outputDimSize = expandedType.getDimSize(index); 52 // Cannot infer expanded shape with multiple dynamic dims in the 53 // same reassociation group! 54 if (ShapedType::isDynamic(outputDimSize)) { 55 if (foundDynamicShape) 56 return std::nullopt; 57 foundDynamicShape = true; 58 } else { 59 outputShapeInts[index] = outputDimSize; 60 indexGroupStaticSizesProductInt *= outputDimSize; 61 } 62 } 63 if (!foundDynamicShape) 64 continue; 65 66 int64_t inputIndex = it.index(); 67 // Call get<Value>() under the assumption that we're not casting 68 // dynamism. 69 Value indexGroupSize = cast<Value>(inputShape[inputIndex]); 70 Value indexGroupStaticSizesProduct = 71 b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt); 72 Value dynamicDimSize = b.createOrFold<arith::DivSIOp>( 73 loc, indexGroupSize, indexGroupStaticSizesProduct); 74 outputShapeValues.push_back(dynamicDimSize); 75 } 76 77 if ((int64_t)outputShapeValues.size() != 78 llvm::count(outputShapeInts, ShapedType::kDynamic)) 79 return std::nullopt; 80 81 return getMixedValues(outputShapeInts, outputShapeValues, b); 82 } 83 84 /// Matches a ConstantIndexOp. 85 /// TODO: This should probably just be a general matcher that uses matchConstant 86 /// and checks the operation for an index type. 87 detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() { 88 return detail::op_matcher<arith::ConstantIndexOp>(); 89 } 90 91 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, 92 ArrayRef<int64_t> shape) { 93 llvm::SmallBitVector dimsToProject(shape.size()); 94 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { 95 if (shape[pos] == 1) { 96 dimsToProject.set(pos); 97 --rank; 98 } 99 } 100 return dimsToProject; 101 } 102 103 Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, 104 OpFoldResult ofr) { 105 if (auto value = dyn_cast_if_present<Value>(ofr)) 106 return value; 107 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); 108 return b.create<arith::ConstantOp>( 109 loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue())); 110 } 111 112 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 113 OpFoldResult ofr) { 114 if (auto value = dyn_cast_if_present<Value>(ofr)) 115 return value; 116 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); 117 return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue()); 118 } 119 120 Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, 121 Type targetType, Value value) { 122 if (targetType == value.getType()) 123 return value; 124 125 bool targetIsIndex = targetType.isIndex(); 126 bool valueIsIndex = value.getType().isIndex(); 127 if (targetIsIndex ^ valueIsIndex) 128 return b.create<arith::IndexCastOp>(loc, targetType, value); 129 130 auto targetIntegerType = dyn_cast<IntegerType>(targetType); 131 auto valueIntegerType = dyn_cast<IntegerType>(value.getType()); 132 assert(targetIntegerType && valueIntegerType && 133 "unexpected cast between types other than integers and index"); 134 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 135 136 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 137 return b.create<arith::ExtSIOp>(loc, targetIntegerType, value); 138 return b.create<arith::TruncIOp>(loc, targetIntegerType, value); 139 } 140 141 static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, 142 IntegerType toType, bool isUnsigned) { 143 // If operand is floating point, cast directly to the int type. 144 if (isa<FloatType>(operand.getType())) { 145 if (isUnsigned) 146 return b.create<arith::FPToUIOp>(toType, operand); 147 return b.create<arith::FPToSIOp>(toType, operand); 148 } 149 // Cast index operands directly to the int type. 150 if (operand.getType().isIndex()) 151 return b.create<arith::IndexCastOp>(toType, operand); 152 if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) { 153 // Either extend or truncate. 154 if (toType.getWidth() > fromIntType.getWidth()) { 155 if (isUnsigned) 156 return b.create<arith::ExtUIOp>(toType, operand); 157 return b.create<arith::ExtSIOp>(toType, operand); 158 } 159 if (toType.getWidth() < fromIntType.getWidth()) 160 return b.create<arith::TruncIOp>(toType, operand); 161 return operand; 162 } 163 164 return {}; 165 } 166 167 static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, 168 FloatType toType, bool isUnsigned) { 169 // If operand is integer, cast directly to the float type. 170 // Note that it is unclear how to cast from BF16<->FP16. 171 if (isa<IntegerType>(operand.getType())) { 172 if (isUnsigned) 173 return b.create<arith::UIToFPOp>(toType, operand); 174 return b.create<arith::SIToFPOp>(toType, operand); 175 } 176 if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) { 177 if (toType.getWidth() > fromFpTy.getWidth()) 178 return b.create<arith::ExtFOp>(toType, operand); 179 if (toType.getWidth() < fromFpTy.getWidth()) 180 return b.create<arith::TruncFOp>(toType, operand); 181 return operand; 182 } 183 184 return {}; 185 } 186 187 static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, 188 ComplexType targetType, 189 bool isUnsigned) { 190 if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) { 191 if (isa<FloatType>(targetType.getElementType()) && 192 isa<FloatType>(fromComplexType.getElementType())) { 193 Value real = b.create<complex::ReOp>(operand); 194 Value imag = b.create<complex::ImOp>(operand); 195 Type targetETy = targetType.getElementType(); 196 if (targetType.getElementType().getIntOrFloatBitWidth() < 197 fromComplexType.getElementType().getIntOrFloatBitWidth()) { 198 real = b.create<arith::TruncFOp>(targetETy, real); 199 imag = b.create<arith::TruncFOp>(targetETy, imag); 200 } else { 201 real = b.create<arith::ExtFOp>(targetETy, real); 202 imag = b.create<arith::ExtFOp>(targetETy, imag); 203 } 204 return b.create<complex::CreateOp>(targetType, real, imag); 205 } 206 } 207 208 if (dyn_cast<FloatType>(operand.getType())) { 209 FloatType toFpTy = cast<FloatType>(targetType.getElementType()); 210 auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); 211 Value from = operand; 212 if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { 213 from = b.create<arith::ExtFOp>(toFpTy, from); 214 } 215 if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { 216 from = b.create<arith::TruncFOp>(toFpTy, from); 217 } 218 Value zero = b.create<mlir::arith::ConstantFloatOp>( 219 mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); 220 return b.create<complex::CreateOp>(targetType, from, zero); 221 } 222 223 if (dyn_cast<IntegerType>(operand.getType())) { 224 FloatType toFpTy = cast<FloatType>(targetType.getElementType()); 225 Value from = operand; 226 if (isUnsigned) { 227 from = b.create<arith::UIToFPOp>(toFpTy, from); 228 } else { 229 from = b.create<arith::SIToFPOp>(toFpTy, from); 230 } 231 Value zero = b.create<mlir::arith::ConstantFloatOp>( 232 mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); 233 return b.create<complex::CreateOp>(targetType, from, zero); 234 } 235 236 return {}; 237 } 238 239 Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, 240 Type toType, bool isUnsignedCast) { 241 if (operand.getType() == toType) 242 return operand; 243 ImplicitLocOpBuilder ib(loc, b); 244 Value result; 245 if (auto intTy = dyn_cast<IntegerType>(toType)) { 246 result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast); 247 } else if (auto floatTy = dyn_cast<FloatType>(toType)) { 248 result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast); 249 } else if (auto complexTy = dyn_cast<ComplexType>(toType)) { 250 result = 251 convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast); 252 } 253 254 if (result) 255 return result; 256 257 emitWarning(loc) << "could not cast operand of type " << operand.getType() 258 << " to " << toType; 259 return operand; 260 } 261 262 SmallVector<Value> 263 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 264 ArrayRef<OpFoldResult> valueOrAttrVec) { 265 return llvm::to_vector<4>( 266 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 267 return getValueOrCreateConstantIndexOp(b, loc, value); 268 })); 269 } 270 271 Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, 272 Type type, const APInt &value) { 273 TypedAttr attr; 274 if (isa<IntegerType>(type)) { 275 attr = builder.getIntegerAttr(type, value); 276 } else { 277 auto vecTy = cast<ShapedType>(type); 278 attr = SplatElementsAttr::get(vecTy, value); 279 } 280 281 return builder.create<arith::ConstantOp>(loc, attr); 282 } 283 284 Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, 285 Type type, int64_t value) { 286 unsigned elementBitWidth = 0; 287 if (auto intTy = dyn_cast<IntegerType>(type)) 288 elementBitWidth = intTy.getWidth(); 289 else 290 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth(); 291 292 return createScalarOrSplatConstant(builder, loc, type, 293 APInt(elementBitWidth, value)); 294 } 295 296 Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, 297 Type type, const APFloat &value) { 298 if (isa<FloatType>(type)) 299 return builder.createOrFold<arith::ConstantOp>( 300 loc, type, builder.getFloatAttr(type, value)); 301 TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value); 302 return builder.createOrFold<arith::ConstantOp>(loc, type, splat); 303 } 304 305 Type mlir::getType(OpFoldResult ofr) { 306 if (auto value = dyn_cast_if_present<Value>(ofr)) 307 return value.getType(); 308 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); 309 return attr.getType(); 310 } 311 312 Value ArithBuilder::_and(Value lhs, Value rhs) { 313 return b.create<arith::AndIOp>(loc, lhs, rhs); 314 } 315 Value ArithBuilder::add(Value lhs, Value rhs) { 316 if (isa<FloatType>(lhs.getType())) 317 return b.create<arith::AddFOp>(loc, lhs, rhs); 318 return b.create<arith::AddIOp>(loc, lhs, rhs); 319 } 320 Value ArithBuilder::sub(Value lhs, Value rhs) { 321 if (isa<FloatType>(lhs.getType())) 322 return b.create<arith::SubFOp>(loc, lhs, rhs); 323 return b.create<arith::SubIOp>(loc, lhs, rhs); 324 } 325 Value ArithBuilder::mul(Value lhs, Value rhs) { 326 if (isa<FloatType>(lhs.getType())) 327 return b.create<arith::MulFOp>(loc, lhs, rhs); 328 return b.create<arith::MulIOp>(loc, lhs, rhs); 329 } 330 Value ArithBuilder::sgt(Value lhs, Value rhs) { 331 if (isa<FloatType>(lhs.getType())) 332 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs); 333 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs); 334 } 335 Value ArithBuilder::slt(Value lhs, Value rhs) { 336 if (isa<FloatType>(lhs.getType())) 337 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs); 338 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs); 339 } 340 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { 341 return b.create<arith::SelectOp>(loc, cmp, lhs, rhs); 342 } 343 344 namespace mlir::arith { 345 346 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) { 347 return createProduct(builder, loc, values, values.front().getType()); 348 } 349 350 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values, 351 Type resultType) { 352 Value one = builder.create<ConstantOp>(loc, resultType, 353 builder.getOneAttr(resultType)); 354 ArithBuilder arithBuilder(builder, loc); 355 return std::accumulate( 356 values.begin(), values.end(), one, 357 [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); }); 358 } 359 360 /// Map strings to float types. 361 std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) { 362 Builder b(ctx); 363 return llvm::StringSwitch<std::optional<FloatType>>(name) 364 .Case("f4E2M1FN", b.getType<Float4E2M1FNType>()) 365 .Case("f6E2M3FN", b.getType<Float6E2M3FNType>()) 366 .Case("f6E3M2FN", b.getType<Float6E3M2FNType>()) 367 .Case("f8E5M2", b.getType<Float8E5M2Type>()) 368 .Case("f8E4M3", b.getType<Float8E4M3Type>()) 369 .Case("f8E4M3FN", b.getType<Float8E4M3FNType>()) 370 .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>()) 371 .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>()) 372 .Case("f8E3M4", b.getType<Float8E3M4Type>()) 373 .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>()) 374 .Case("bf16", b.getType<BFloat16Type>()) 375 .Case("f16", b.getType<Float16Type>()) 376 .Case("f32", b.getType<Float32Type>()) 377 .Case("f64", b.getType<Float64Type>()) 378 .Case("f80", b.getType<Float80Type>()) 379 .Case("f128", b.getType<Float128Type>()) 380 .Default(std::nullopt); 381 } 382 383 } // namespace mlir::arith 384