xref: /llvm-project/mlir/lib/Dialect/Arith/Utils/Utils.cpp (revision 1f5335c1db5d54b4465677c224b48e0ffc78e6d9)
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/Utils/Utils.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "llvm/ADT/SmallBitVector.h"
19 #include <numeric>
20 
21 using namespace mlir;
22 
23 std::optional<SmallVector<OpFoldResult>>
24 mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
25                                   ShapedType expandedType,
26                                   ArrayRef<ReassociationIndices> reassociation,
27                                   ArrayRef<OpFoldResult> inputShape) {
28 
29   SmallVector<Value> outputShapeValues;
30   SmallVector<int64_t> outputShapeInts;
31   // For zero-rank inputs, all dims in result shape are unit extent.
32   if (inputShape.empty()) {
33     outputShapeInts.resize(expandedType.getRank(), 1);
34     return getMixedValues(outputShapeInts, outputShapeValues, b);
35   }
36 
37   // Check for all static shapes.
38   if (expandedType.hasStaticShape()) {
39     ArrayRef<int64_t> staticShape = expandedType.getShape();
40     outputShapeInts.assign(staticShape.begin(), staticShape.end());
41     return getMixedValues(outputShapeInts, outputShapeValues, b);
42   }
43 
44   outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
45   for (const auto &it : llvm::enumerate(reassociation)) {
46     ReassociationIndices indexGroup = it.value();
47 
48     int64_t indexGroupStaticSizesProductInt = 1;
49     bool foundDynamicShape = false;
50     for (int64_t index : indexGroup) {
51       int64_t outputDimSize = expandedType.getDimSize(index);
52       // Cannot infer expanded shape with multiple dynamic dims in the
53       // same reassociation group!
54       if (ShapedType::isDynamic(outputDimSize)) {
55         if (foundDynamicShape)
56           return std::nullopt;
57         foundDynamicShape = true;
58       } else {
59         outputShapeInts[index] = outputDimSize;
60         indexGroupStaticSizesProductInt *= outputDimSize;
61       }
62     }
63     if (!foundDynamicShape)
64       continue;
65 
66     int64_t inputIndex = it.index();
67     // Call get<Value>() under the assumption that we're not casting
68     // dynamism.
69     Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
70     Value indexGroupStaticSizesProduct =
71         b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
72     Value dynamicDimSize = b.createOrFold<arith::DivSIOp>(
73         loc, indexGroupSize, indexGroupStaticSizesProduct);
74     outputShapeValues.push_back(dynamicDimSize);
75   }
76 
77   if ((int64_t)outputShapeValues.size() !=
78       llvm::count(outputShapeInts, ShapedType::kDynamic))
79     return std::nullopt;
80 
81   return getMixedValues(outputShapeInts, outputShapeValues, b);
82 }
83 
84 /// Matches a ConstantIndexOp.
85 /// TODO: This should probably just be a general matcher that uses matchConstant
86 /// and checks the operation for an index type.
87 detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
88   return detail::op_matcher<arith::ConstantIndexOp>();
89 }
90 
91 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
92                                                   ArrayRef<int64_t> shape) {
93   llvm::SmallBitVector dimsToProject(shape.size());
94   for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
95     if (shape[pos] == 1) {
96       dimsToProject.set(pos);
97       --rank;
98     }
99   }
100   return dimsToProject;
101 }
102 
103 Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
104                                           OpFoldResult ofr) {
105   if (auto value = dyn_cast_if_present<Value>(ofr))
106     return value;
107   auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
108   return b.create<arith::ConstantOp>(
109       loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
110 }
111 
112 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
113                                             OpFoldResult ofr) {
114   if (auto value = dyn_cast_if_present<Value>(ofr))
115     return value;
116   auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
117   return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
118 }
119 
120 Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
121                                             Type targetType, Value value) {
122   if (targetType == value.getType())
123     return value;
124 
125   bool targetIsIndex = targetType.isIndex();
126   bool valueIsIndex = value.getType().isIndex();
127   if (targetIsIndex ^ valueIsIndex)
128     return b.create<arith::IndexCastOp>(loc, targetType, value);
129 
130   auto targetIntegerType = dyn_cast<IntegerType>(targetType);
131   auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
132   assert(targetIntegerType && valueIntegerType &&
133          "unexpected cast between types other than integers and index");
134   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
135 
136   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
137     return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
138   return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
139 }
140 
141 static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand,
142                                      IntegerType toType, bool isUnsigned) {
143   // If operand is floating point, cast directly to the int type.
144   if (isa<FloatType>(operand.getType())) {
145     if (isUnsigned)
146       return b.create<arith::FPToUIOp>(toType, operand);
147     return b.create<arith::FPToSIOp>(toType, operand);
148   }
149   // Cast index operands directly to the int type.
150   if (operand.getType().isIndex())
151     return b.create<arith::IndexCastOp>(toType, operand);
152   if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
153     // Either extend or truncate.
154     if (toType.getWidth() > fromIntType.getWidth()) {
155       if (isUnsigned)
156         return b.create<arith::ExtUIOp>(toType, operand);
157       return b.create<arith::ExtSIOp>(toType, operand);
158     }
159     if (toType.getWidth() < fromIntType.getWidth())
160       return b.create<arith::TruncIOp>(toType, operand);
161     return operand;
162   }
163 
164   return {};
165 }
166 
167 static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand,
168                                     FloatType toType, bool isUnsigned) {
169   // If operand is integer, cast directly to the float type.
170   // Note that it is unclear how to cast from BF16<->FP16.
171   if (isa<IntegerType>(operand.getType())) {
172     if (isUnsigned)
173       return b.create<arith::UIToFPOp>(toType, operand);
174     return b.create<arith::SIToFPOp>(toType, operand);
175   }
176   if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
177     if (toType.getWidth() > fromFpTy.getWidth())
178       return b.create<arith::ExtFOp>(toType, operand);
179     if (toType.getWidth() < fromFpTy.getWidth())
180       return b.create<arith::TruncFOp>(toType, operand);
181     return operand;
182   }
183 
184   return {};
185 }
186 
187 static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
188                                          ComplexType targetType,
189                                          bool isUnsigned) {
190   if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
191     if (isa<FloatType>(targetType.getElementType()) &&
192         isa<FloatType>(fromComplexType.getElementType())) {
193       Value real = b.create<complex::ReOp>(operand);
194       Value imag = b.create<complex::ImOp>(operand);
195       Type targetETy = targetType.getElementType();
196       if (targetType.getElementType().getIntOrFloatBitWidth() <
197           fromComplexType.getElementType().getIntOrFloatBitWidth()) {
198         real = b.create<arith::TruncFOp>(targetETy, real);
199         imag = b.create<arith::TruncFOp>(targetETy, imag);
200       } else {
201         real = b.create<arith::ExtFOp>(targetETy, real);
202         imag = b.create<arith::ExtFOp>(targetETy, imag);
203       }
204       return b.create<complex::CreateOp>(targetType, real, imag);
205     }
206   }
207 
208   if (dyn_cast<FloatType>(operand.getType())) {
209     FloatType toFpTy = cast<FloatType>(targetType.getElementType());
210     auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
211     Value from = operand;
212     if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
213       from = b.create<arith::ExtFOp>(toFpTy, from);
214     }
215     if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
216       from = b.create<arith::TruncFOp>(toFpTy, from);
217     }
218     Value zero = b.create<mlir::arith::ConstantFloatOp>(
219         mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
220     return b.create<complex::CreateOp>(targetType, from, zero);
221   }
222 
223   if (dyn_cast<IntegerType>(operand.getType())) {
224     FloatType toFpTy = cast<FloatType>(targetType.getElementType());
225     Value from = operand;
226     if (isUnsigned) {
227       from = b.create<arith::UIToFPOp>(toFpTy, from);
228     } else {
229       from = b.create<arith::SIToFPOp>(toFpTy, from);
230     }
231     Value zero = b.create<mlir::arith::ConstantFloatOp>(
232         mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
233     return b.create<complex::CreateOp>(targetType, from, zero);
234   }
235 
236   return {};
237 }
238 
239 Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
240                                  Type toType, bool isUnsignedCast) {
241   if (operand.getType() == toType)
242     return operand;
243   ImplicitLocOpBuilder ib(loc, b);
244   Value result;
245   if (auto intTy = dyn_cast<IntegerType>(toType)) {
246     result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
247   } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
248     result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
249   } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
250     result =
251         convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
252   }
253 
254   if (result)
255     return result;
256 
257   emitWarning(loc) << "could not cast operand of type " << operand.getType()
258                    << " to " << toType;
259   return operand;
260 }
261 
262 SmallVector<Value>
263 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
264                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
265   return llvm::to_vector<4>(
266       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
267         return getValueOrCreateConstantIndexOp(b, loc, value);
268       }));
269 }
270 
271 Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
272                                         Type type, const APInt &value) {
273   TypedAttr attr;
274   if (isa<IntegerType>(type)) {
275     attr = builder.getIntegerAttr(type, value);
276   } else {
277     auto vecTy = cast<ShapedType>(type);
278     attr = SplatElementsAttr::get(vecTy, value);
279   }
280 
281   return builder.create<arith::ConstantOp>(loc, attr);
282 }
283 
284 Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
285                                         Type type, int64_t value) {
286   unsigned elementBitWidth = 0;
287   if (auto intTy = dyn_cast<IntegerType>(type))
288     elementBitWidth = intTy.getWidth();
289   else
290     elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
291 
292   return createScalarOrSplatConstant(builder, loc, type,
293                                      APInt(elementBitWidth, value));
294 }
295 
296 Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
297                                         Type type, const APFloat &value) {
298   if (isa<FloatType>(type))
299     return builder.createOrFold<arith::ConstantOp>(
300         loc, type, builder.getFloatAttr(type, value));
301   TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
302   return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
303 }
304 
305 Type mlir::getType(OpFoldResult ofr) {
306   if (auto value = dyn_cast_if_present<Value>(ofr))
307     return value.getType();
308   auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
309   return attr.getType();
310 }
311 
312 Value ArithBuilder::_and(Value lhs, Value rhs) {
313   return b.create<arith::AndIOp>(loc, lhs, rhs);
314 }
315 Value ArithBuilder::add(Value lhs, Value rhs) {
316   if (isa<FloatType>(lhs.getType()))
317     return b.create<arith::AddFOp>(loc, lhs, rhs);
318   return b.create<arith::AddIOp>(loc, lhs, rhs);
319 }
320 Value ArithBuilder::sub(Value lhs, Value rhs) {
321   if (isa<FloatType>(lhs.getType()))
322     return b.create<arith::SubFOp>(loc, lhs, rhs);
323   return b.create<arith::SubIOp>(loc, lhs, rhs);
324 }
325 Value ArithBuilder::mul(Value lhs, Value rhs) {
326   if (isa<FloatType>(lhs.getType()))
327     return b.create<arith::MulFOp>(loc, lhs, rhs);
328   return b.create<arith::MulIOp>(loc, lhs, rhs);
329 }
330 Value ArithBuilder::sgt(Value lhs, Value rhs) {
331   if (isa<FloatType>(lhs.getType()))
332     return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
333   return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
334 }
335 Value ArithBuilder::slt(Value lhs, Value rhs) {
336   if (isa<FloatType>(lhs.getType()))
337     return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
338   return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
339 }
340 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
341   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
342 }
343 
344 namespace mlir::arith {
345 
346 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
347   return createProduct(builder, loc, values, values.front().getType());
348 }
349 
350 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
351                     Type resultType) {
352   Value one = builder.create<ConstantOp>(loc, resultType,
353                                          builder.getOneAttr(resultType));
354   ArithBuilder arithBuilder(builder, loc);
355   return std::accumulate(
356       values.begin(), values.end(), one,
357       [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
358 }
359 
360 /// Map strings to float types.
361 std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
362   Builder b(ctx);
363   return llvm::StringSwitch<std::optional<FloatType>>(name)
364       .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
365       .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
366       .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
367       .Case("f8E5M2", b.getType<Float8E5M2Type>())
368       .Case("f8E4M3", b.getType<Float8E4M3Type>())
369       .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
370       .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
371       .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
372       .Case("f8E3M4", b.getType<Float8E3M4Type>())
373       .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
374       .Case("bf16", b.getType<BFloat16Type>())
375       .Case("f16", b.getType<Float16Type>())
376       .Case("f32", b.getType<Float32Type>())
377       .Case("f64", b.getType<Float64Type>())
378       .Case("f80", b.getType<Float80Type>())
379       .Case("f128", b.getType<Float128Type>())
380       .Default(std::nullopt);
381 }
382 
383 } // namespace mlir::arith
384