xref: /llvm-project/mlir/lib/IR/BuiltinTypes.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 namespace mlir {
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37 } // namespace mlir
38 
39 //===----------------------------------------------------------------------===//
40 // BuiltinDialect
41 //===----------------------------------------------------------------------===//
42 
43 void BuiltinDialect::registerTypes() {
44   addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
47       >();
48 }
49 
50 //===----------------------------------------------------------------------===//
51 /// ComplexType
52 //===----------------------------------------------------------------------===//
53 
54 /// Verify the construction of an integer type.
55 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
56                                   Type elementType) {
57   if (!elementType.isIntOrFloat())
58     return emitError() << "invalid element type for complex";
59   return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Integer Type
64 //===----------------------------------------------------------------------===//
65 
66 /// Verify the construction of an integer type.
67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
68                                   unsigned width,
69                                   SignednessSemantics signedness) {
70   if (width > IntegerType::kMaxWidth) {
71     return emitError() << "integer bitwidth is limited to "
72                        << IntegerType::kMaxWidth << " bits";
73   }
74   return success();
75 }
76 
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80   return getImpl()->signedness;
81 }
82 
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84   if (!scale)
85     return IntegerType();
86   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Types
91 //===----------------------------------------------------------------------===//
92 
93 // Mapping from MLIR FloatType to APFloat semantics.
94 #define FLOAT_TYPE_SEMANTICS(TYPE, SEM)                                        \
95   const llvm::fltSemantics &TYPE::getFloatSemantics() const {                  \
96     return APFloat::SEM();                                                     \
97   }
98 FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
99 FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
100 FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
101 FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
102 FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
103 FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
104 FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
105 FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
106 FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
107 FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
108 FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
109 FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
110 FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
111 FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
112 FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
113 FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
114 FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
115 FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
116 #undef FLOAT_TYPE_SEMANTICS
117 
118 FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
119   if (scale == 2)
120     return Float32Type::get(getContext());
121   if (scale == 4)
122     return Float64Type::get(getContext());
123   return FloatType();
124 }
125 
126 FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
127   if (scale == 2)
128     return Float32Type::get(getContext());
129   if (scale == 4)
130     return Float64Type::get(getContext());
131   return FloatType();
132 }
133 
134 FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
135   if (scale == 2)
136     return Float64Type::get(getContext());
137   return FloatType();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // FunctionType
142 //===----------------------------------------------------------------------===//
143 
144 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
145 
146 ArrayRef<Type> FunctionType::getInputs() const {
147   return getImpl()->getInputs();
148 }
149 
150 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
151 
152 ArrayRef<Type> FunctionType::getResults() const {
153   return getImpl()->getResults();
154 }
155 
156 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
157   return get(getContext(), inputs, results);
158 }
159 
160 /// Returns a new function type with the specified arguments and results
161 /// inserted.
162 FunctionType FunctionType::getWithArgsAndResults(
163     ArrayRef<unsigned> argIndices, TypeRange argTypes,
164     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
165   SmallVector<Type> argStorage, resultStorage;
166   TypeRange newArgTypes =
167       insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
168   TypeRange newResultTypes =
169       insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
170   return clone(newArgTypes, newResultTypes);
171 }
172 
173 /// Returns a new function type without the specified arguments and results.
174 FunctionType
175 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
176                                        const BitVector &resultIndices) {
177   SmallVector<Type> argStorage, resultStorage;
178   TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
179   TypeRange newResultTypes =
180       filterTypesOut(getResults(), resultIndices, resultStorage);
181   return clone(newArgTypes, newResultTypes);
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // OpaqueType
186 //===----------------------------------------------------------------------===//
187 
188 /// Verify the construction of an opaque type.
189 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
190                                  StringAttr dialect, StringRef typeData) {
191   if (!Dialect::isValidNamespace(dialect.strref()))
192     return emitError() << "invalid dialect namespace '" << dialect << "'";
193 
194   // Check that the dialect is actually registered.
195   MLIRContext *context = dialect.getContext();
196   if (!context->allowsUnregisteredDialects() &&
197       !context->getLoadedDialect(dialect.strref())) {
198     return emitError()
199            << "`!" << dialect << "<\"" << typeData << "\">"
200            << "` type created with unregistered dialect. If this is "
201               "intended, please call allowUnregisteredDialects() on the "
202               "MLIRContext, or use -allow-unregistered-dialect with "
203               "the MLIR opt tool used";
204   }
205 
206   return success();
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // VectorType
211 //===----------------------------------------------------------------------===//
212 
213 bool VectorType::isValidElementType(Type t) {
214   return isValidVectorTypeElementType(t);
215 }
216 
217 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
218                                  ArrayRef<int64_t> shape, Type elementType,
219                                  ArrayRef<bool> scalableDims) {
220   if (!isValidElementType(elementType))
221     return emitError()
222            << "vector elements must be int/index/float type but got "
223            << elementType;
224 
225   if (any_of(shape, [](int64_t i) { return i <= 0; }))
226     return emitError()
227            << "vector types must have positive constant sizes but got "
228            << shape;
229 
230   if (scalableDims.size() != shape.size())
231     return emitError() << "number of dims must match, got "
232                        << scalableDims.size() << " and " << shape.size();
233 
234   return success();
235 }
236 
237 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
238   if (!scale)
239     return VectorType();
240   if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
241     if (auto scaledEt = et.scaleElementBitwidth(scale))
242       return VectorType::get(getShape(), scaledEt, getScalableDims());
243   if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
244     if (auto scaledEt = et.scaleElementBitwidth(scale))
245       return VectorType::get(getShape(), scaledEt, getScalableDims());
246   return VectorType();
247 }
248 
249 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
250                                  Type elementType) const {
251   return VectorType::get(shape.value_or(getShape()), elementType,
252                          getScalableDims());
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // TensorType
257 //===----------------------------------------------------------------------===//
258 
259 Type TensorType::getElementType() const {
260   return llvm::TypeSwitch<TensorType, Type>(*this)
261       .Case<RankedTensorType, UnrankedTensorType>(
262           [](auto type) { return type.getElementType(); });
263 }
264 
265 bool TensorType::hasRank() const {
266   return !llvm::isa<UnrankedTensorType>(*this);
267 }
268 
269 ArrayRef<int64_t> TensorType::getShape() const {
270   return llvm::cast<RankedTensorType>(*this).getShape();
271 }
272 
273 TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
274                                  Type elementType) const {
275   if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
276     if (shape)
277       return RankedTensorType::get(*shape, elementType);
278     return UnrankedTensorType::get(elementType);
279   }
280 
281   auto rankedTy = llvm::cast<RankedTensorType>(*this);
282   if (!shape)
283     return RankedTensorType::get(rankedTy.getShape(), elementType,
284                                  rankedTy.getEncoding());
285   return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
286                                rankedTy.getEncoding());
287 }
288 
289 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
290                                    Type elementType) const {
291   return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
292 }
293 
294 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
295   return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
296 }
297 
298 // Check if "elementType" can be an element type of a tensor.
299 static LogicalResult
300 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
301                        Type elementType) {
302   if (!TensorType::isValidElementType(elementType))
303     return emitError() << "invalid tensor element type: " << elementType;
304   return success();
305 }
306 
307 /// Return true if the specified element type is ok in a tensor.
308 bool TensorType::isValidElementType(Type type) {
309   // Note: Non standard/builtin types are allowed to exist within tensor
310   // types. Dialects are expected to verify that tensor types have a valid
311   // element type within that dialect.
312   return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
313                    IndexType>(type) ||
314          !llvm::isa<BuiltinDialect>(type.getDialect());
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // RankedTensorType
319 //===----------------------------------------------------------------------===//
320 
321 LogicalResult
322 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
323                          ArrayRef<int64_t> shape, Type elementType,
324                          Attribute encoding) {
325   for (int64_t s : shape)
326     if (s < 0 && !ShapedType::isDynamic(s))
327       return emitError() << "invalid tensor dimension size";
328   if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
329     if (failed(v.verifyEncoding(shape, elementType, emitError)))
330       return failure();
331   return checkTensorElementType(emitError, elementType);
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // UnrankedTensorType
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult
339 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
340                            Type elementType) {
341   return checkTensorElementType(emitError, elementType);
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // BaseMemRefType
346 //===----------------------------------------------------------------------===//
347 
348 Type BaseMemRefType::getElementType() const {
349   return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
350       .Case<MemRefType, UnrankedMemRefType>(
351           [](auto type) { return type.getElementType(); });
352 }
353 
354 bool BaseMemRefType::hasRank() const {
355   return !llvm::isa<UnrankedMemRefType>(*this);
356 }
357 
358 ArrayRef<int64_t> BaseMemRefType::getShape() const {
359   return llvm::cast<MemRefType>(*this).getShape();
360 }
361 
362 BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
363                                          Type elementType) const {
364   if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
365     if (!shape)
366       return UnrankedMemRefType::get(elementType, getMemorySpace());
367     MemRefType::Builder builder(*shape, elementType);
368     builder.setMemorySpace(getMemorySpace());
369     return builder;
370   }
371 
372   MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
373   if (shape)
374     builder.setShape(*shape);
375   builder.setElementType(elementType);
376   return builder;
377 }
378 
379 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
380                                  Type elementType) const {
381   return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
382 }
383 
384 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
385   return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
386 }
387 
388 Attribute BaseMemRefType::getMemorySpace() const {
389   if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
390     return rankedMemRefTy.getMemorySpace();
391   return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
392 }
393 
394 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
395   if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
396     return rankedMemRefTy.getMemorySpaceAsInt();
397   return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // MemRefType
402 //===----------------------------------------------------------------------===//
403 
404 std::optional<llvm::SmallDenseSet<unsigned>>
405 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
406                                ArrayRef<int64_t> reducedShape,
407                                bool matchDynamic) {
408   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
409   llvm::SmallDenseSet<unsigned> unusedDims;
410   unsigned reducedIdx = 0;
411   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
412     // Greedily insert `originalIdx` if match.
413     int64_t origSize = originalShape[originalIdx];
414     // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
415     if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
416         (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
417          ShapedType::isDynamic(origSize))) {
418       reducedIdx++;
419       continue;
420     }
421     if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
422       reducedIdx++;
423       continue;
424     }
425 
426     unusedDims.insert(originalIdx);
427     // If no match on `originalIdx`, the `originalShape` at this dimension
428     // must be 1, otherwise we bail.
429     if (origSize != 1)
430       return std::nullopt;
431   }
432   // The whole reducedShape must be scanned, otherwise we bail.
433   if (reducedIdx != reducedRank)
434     return std::nullopt;
435   return unusedDims;
436 }
437 
438 SliceVerificationResult
439 mlir::isRankReducedType(ShapedType originalType,
440                         ShapedType candidateReducedType) {
441   if (originalType == candidateReducedType)
442     return SliceVerificationResult::Success;
443 
444   ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
445   ShapedType candidateReducedShapedType =
446       llvm::cast<ShapedType>(candidateReducedType);
447 
448   // Rank and size logic is valid for all ShapedTypes.
449   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
450   ArrayRef<int64_t> candidateReducedShape =
451       candidateReducedShapedType.getShape();
452   unsigned originalRank = originalShape.size(),
453            candidateReducedRank = candidateReducedShape.size();
454   if (candidateReducedRank > originalRank)
455     return SliceVerificationResult::RankTooLarge;
456 
457   auto optionalUnusedDimsMask =
458       computeRankReductionMask(originalShape, candidateReducedShape);
459 
460   // Sizes cannot be matched in case empty vector is returned.
461   if (!optionalUnusedDimsMask)
462     return SliceVerificationResult::SizeMismatch;
463 
464   if (originalShapedType.getElementType() !=
465       candidateReducedShapedType.getElementType())
466     return SliceVerificationResult::ElemTypeMismatch;
467 
468   return SliceVerificationResult::Success;
469 }
470 
471 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
472   // Empty attribute is allowed as default memory space.
473   if (!memorySpace)
474     return true;
475 
476   // Supported built-in attributes.
477   if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
478     return true;
479 
480   // Allow custom dialect attributes.
481   if (!isa<BuiltinDialect>(memorySpace.getDialect()))
482     return true;
483 
484   return false;
485 }
486 
487 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
488                                                MLIRContext *ctx) {
489   if (memorySpace == 0)
490     return nullptr;
491 
492   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
493 }
494 
495 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
496   IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
497   if (intMemorySpace && intMemorySpace.getValue() == 0)
498     return nullptr;
499 
500   return memorySpace;
501 }
502 
503 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
504   if (!memorySpace)
505     return 0;
506 
507   assert(llvm::isa<IntegerAttr>(memorySpace) &&
508          "Using `getMemorySpaceInteger` with non-Integer attribute");
509 
510   return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
511 }
512 
513 unsigned MemRefType::getMemorySpaceAsInt() const {
514   return detail::getMemorySpaceAsInt(getMemorySpace());
515 }
516 
517 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
518                            MemRefLayoutAttrInterface layout,
519                            Attribute memorySpace) {
520   // Use default layout for empty attribute.
521   if (!layout)
522     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
523         shape.size(), elementType.getContext()));
524 
525   // Drop default memory space value and replace it with empty attribute.
526   memorySpace = skipDefaultMemorySpace(memorySpace);
527 
528   return Base::get(elementType.getContext(), shape, elementType, layout,
529                    memorySpace);
530 }
531 
532 MemRefType MemRefType::getChecked(
533     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
534     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
535 
536   // Use default layout for empty attribute.
537   if (!layout)
538     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
539         shape.size(), elementType.getContext()));
540 
541   // Drop default memory space value and replace it with empty attribute.
542   memorySpace = skipDefaultMemorySpace(memorySpace);
543 
544   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
545                           elementType, layout, memorySpace);
546 }
547 
548 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
549                            AffineMap map, Attribute memorySpace) {
550 
551   // Use default layout for empty map.
552   if (!map)
553     map = AffineMap::getMultiDimIdentityMap(shape.size(),
554                                             elementType.getContext());
555 
556   // Wrap AffineMap into Attribute.
557   auto layout = AffineMapAttr::get(map);
558 
559   // Drop default memory space value and replace it with empty attribute.
560   memorySpace = skipDefaultMemorySpace(memorySpace);
561 
562   return Base::get(elementType.getContext(), shape, elementType, layout,
563                    memorySpace);
564 }
565 
566 MemRefType
567 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
568                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
569                        Attribute memorySpace) {
570 
571   // Use default layout for empty map.
572   if (!map)
573     map = AffineMap::getMultiDimIdentityMap(shape.size(),
574                                             elementType.getContext());
575 
576   // Wrap AffineMap into Attribute.
577   auto layout = AffineMapAttr::get(map);
578 
579   // Drop default memory space value and replace it with empty attribute.
580   memorySpace = skipDefaultMemorySpace(memorySpace);
581 
582   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
583                           elementType, layout, memorySpace);
584 }
585 
586 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
587                            AffineMap map, unsigned memorySpaceInd) {
588 
589   // Use default layout for empty map.
590   if (!map)
591     map = AffineMap::getMultiDimIdentityMap(shape.size(),
592                                             elementType.getContext());
593 
594   // Wrap AffineMap into Attribute.
595   auto layout = AffineMapAttr::get(map);
596 
597   // Convert deprecated integer-like memory space to Attribute.
598   Attribute memorySpace =
599       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
600 
601   return Base::get(elementType.getContext(), shape, elementType, layout,
602                    memorySpace);
603 }
604 
605 MemRefType
606 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
607                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
608                        unsigned memorySpaceInd) {
609 
610   // Use default layout for empty map.
611   if (!map)
612     map = AffineMap::getMultiDimIdentityMap(shape.size(),
613                                             elementType.getContext());
614 
615   // Wrap AffineMap into Attribute.
616   auto layout = AffineMapAttr::get(map);
617 
618   // Convert deprecated integer-like memory space to Attribute.
619   Attribute memorySpace =
620       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
621 
622   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
623                           elementType, layout, memorySpace);
624 }
625 
626 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
627                                  ArrayRef<int64_t> shape, Type elementType,
628                                  MemRefLayoutAttrInterface layout,
629                                  Attribute memorySpace) {
630   if (!BaseMemRefType::isValidElementType(elementType))
631     return emitError() << "invalid memref element type";
632 
633   // Negative sizes are not allowed except for `kDynamic`.
634   for (int64_t s : shape)
635     if (s < 0 && !ShapedType::isDynamic(s))
636       return emitError() << "invalid memref size";
637 
638   assert(layout && "missing layout specification");
639   if (failed(layout.verifyLayout(shape, emitError)))
640     return failure();
641 
642   if (!isSupportedMemorySpace(memorySpace))
643     return emitError() << "unsupported memory space Attribute";
644 
645   return success();
646 }
647 
648 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649   if (!isLastDimUnitStride())
650     return false;
651 
652   auto memrefShape = getShape().take_back(n);
653   if (ShapedType::isDynamicShape(memrefShape))
654     return false;
655 
656   if (getLayout().isIdentity())
657     return true;
658 
659   int64_t offset;
660   SmallVector<int64_t> stridesFull;
661   if (!succeeded(getStridesAndOffset(stridesFull, offset)))
662     return false;
663   auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
664 
665   if (strides.empty())
666     return true;
667 
668   // Check whether strides match "flattened" dims.
669   SmallVector<int64_t> flattenedDims;
670   auto dimProduct = 1;
671   for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
672     dimProduct *= dim;
673     flattenedDims.push_back(dimProduct);
674   }
675 
676   strides = strides.drop_back(1);
677   return llvm::equal(strides, llvm::reverse(flattenedDims));
678 }
679 
680 MemRefType MemRefType::canonicalizeStridedLayout() {
681   AffineMap m = getLayout().getAffineMap();
682 
683   // Already in canonical form.
684   if (m.isIdentity())
685     return *this;
686 
687   // Can't reduce to canonical identity form, return in canonical form.
688   if (m.getNumResults() > 1)
689     return *this;
690 
691   // Corner-case for 0-D affine maps.
692   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
693     if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
694       if (cst.getValue() == 0)
695         return MemRefType::Builder(*this).setLayout({});
696     return *this;
697   }
698 
699   // 0-D corner case for empty shape that still have an affine map. Example:
700   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
701   // offset needs to remain, just return t.
702   if (getShape().empty())
703     return *this;
704 
705   // If the canonical strided layout for the sizes of `t` is equal to the
706   // simplified layout of `t` we can just return an empty layout. Otherwise,
707   // just simplify the existing layout.
708   AffineExpr expr = makeCanonicalStridedLayoutExpr(getShape(), getContext());
709   auto simplifiedLayoutExpr =
710       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
711   if (expr != simplifiedLayoutExpr)
712     return MemRefType::Builder(*this).setLayout(
713         AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(),
714                                           simplifiedLayoutExpr)));
715   return MemRefType::Builder(*this).setLayout({});
716 }
717 
718 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
719 // i.e. single term). Accumulate the AffineExpr into the existing one.
720 static void extractStridesFromTerm(AffineExpr e,
721                                    AffineExpr multiplicativeFactor,
722                                    MutableArrayRef<AffineExpr> strides,
723                                    AffineExpr &offset) {
724   if (auto dim = dyn_cast<AffineDimExpr>(e))
725     strides[dim.getPosition()] =
726         strides[dim.getPosition()] + multiplicativeFactor;
727   else
728     offset = offset + e * multiplicativeFactor;
729 }
730 
731 /// Takes a single AffineExpr `e` and populates the `strides` array with the
732 /// strides expressions for each dim position.
733 /// The convention is that the strides for dimensions d0, .. dn appear in
734 /// order to make indexing intuitive into the result.
735 static LogicalResult extractStrides(AffineExpr e,
736                                     AffineExpr multiplicativeFactor,
737                                     MutableArrayRef<AffineExpr> strides,
738                                     AffineExpr &offset) {
739   auto bin = dyn_cast<AffineBinaryOpExpr>(e);
740   if (!bin) {
741     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
742     return success();
743   }
744 
745   if (bin.getKind() == AffineExprKind::CeilDiv ||
746       bin.getKind() == AffineExprKind::FloorDiv ||
747       bin.getKind() == AffineExprKind::Mod)
748     return failure();
749 
750   if (bin.getKind() == AffineExprKind::Mul) {
751     auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
752     if (dim) {
753       strides[dim.getPosition()] =
754           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
755       return success();
756     }
757     // LHS and RHS may both contain complex expressions of dims. Try one path
758     // and if it fails try the other. This is guaranteed to succeed because
759     // only one path may have a `dim`, otherwise this is not an AffineExpr in
760     // the first place.
761     if (bin.getLHS().isSymbolicOrConstant())
762       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
763                             strides, offset);
764     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
765                           strides, offset);
766   }
767 
768   if (bin.getKind() == AffineExprKind::Add) {
769     auto res1 =
770         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
771     auto res2 =
772         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
773     return success(succeeded(res1) && succeeded(res2));
774   }
775 
776   llvm_unreachable("unexpected binary operation");
777 }
778 
779 /// A stride specification is a list of integer values that are either static
780 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
781 /// the distance in the number of elements between successive entries along a
782 /// particular dimension.
783 ///
784 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
785 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
786 /// distance between two consecutive elements along the outer dimension is `1`
787 /// and the distance between two consecutive elements along the inner dimension
788 /// is `64`.
789 ///
790 /// The convention is that the strides for dimensions d0, .. dn appear in
791 /// order to make indexing intuitive into the result.
792 static LogicalResult getStridesAndOffset(MemRefType t,
793                                          SmallVectorImpl<AffineExpr> &strides,
794                                          AffineExpr &offset) {
795   AffineMap m = t.getLayout().getAffineMap();
796 
797   if (m.getNumResults() != 1 && !m.isIdentity())
798     return failure();
799 
800   auto zero = getAffineConstantExpr(0, t.getContext());
801   auto one = getAffineConstantExpr(1, t.getContext());
802   offset = zero;
803   strides.assign(t.getRank(), zero);
804 
805   // Canonical case for empty map.
806   if (m.isIdentity()) {
807     // 0-D corner case, offset is already 0.
808     if (t.getRank() == 0)
809       return success();
810     auto stridedExpr =
811         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
812     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
813       return success();
814     assert(false && "unexpected failure: extract strides in canonical layout");
815   }
816 
817   // Non-canonical case requires more work.
818   auto stridedExpr =
819       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
820   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
821     offset = AffineExpr();
822     strides.clear();
823     return failure();
824   }
825 
826   // Simplify results to allow folding to constants and simple checks.
827   unsigned numDims = m.getNumDims();
828   unsigned numSymbols = m.getNumSymbols();
829   offset = simplifyAffineExpr(offset, numDims, numSymbols);
830   for (auto &stride : strides)
831     stride = simplifyAffineExpr(stride, numDims, numSymbols);
832 
833   return success();
834 }
835 
836 LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
837                                               int64_t &offset) {
838   // Happy path: the type uses the strided layout directly.
839   if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
840     llvm::append_range(strides, strided.getStrides());
841     offset = strided.getOffset();
842     return success();
843   }
844 
845   // Otherwise, defer to the affine fallback as layouts are supposed to be
846   // convertible to affine maps.
847   AffineExpr offsetExpr;
848   SmallVector<AffineExpr, 4> strideExprs;
849   if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr)))
850     return failure();
851   if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
852     offset = cst.getValue();
853   else
854     offset = ShapedType::kDynamic;
855   for (auto e : strideExprs) {
856     if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
857       strides.push_back(c.getValue());
858     else
859       strides.push_back(ShapedType::kDynamic);
860   }
861   return success();
862 }
863 
864 std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {
865   SmallVector<int64_t> strides;
866   int64_t offset;
867   LogicalResult status = getStridesAndOffset(strides, offset);
868   (void)status;
869   assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
870   return {strides, offset};
871 }
872 
873 bool MemRefType::isStrided() {
874   int64_t offset;
875   SmallVector<int64_t, 4> strides;
876   auto res = getStridesAndOffset(strides, offset);
877   return succeeded(res);
878 }
879 
880 bool MemRefType::isLastDimUnitStride() {
881   int64_t offset;
882   SmallVector<int64_t> strides;
883   auto successStrides = getStridesAndOffset(strides, offset);
884   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // UnrankedMemRefType
889 //===----------------------------------------------------------------------===//
890 
891 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
892   return detail::getMemorySpaceAsInt(getMemorySpace());
893 }
894 
895 LogicalResult
896 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
897                            Type elementType, Attribute memorySpace) {
898   if (!BaseMemRefType::isValidElementType(elementType))
899     return emitError() << "invalid memref element type";
900 
901   if (!isSupportedMemorySpace(memorySpace))
902     return emitError() << "unsupported memory space Attribute";
903 
904   return success();
905 }
906 
907 //===----------------------------------------------------------------------===//
908 /// TupleType
909 //===----------------------------------------------------------------------===//
910 
911 /// Return the elements types for this tuple.
912 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
913 
914 /// Accumulate the types contained in this tuple and tuples nested within it.
915 /// Note that this only flattens nested tuples, not any other container type,
916 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
917 /// (i32, tensor<i32>, f32, i64)
918 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
919   for (Type type : getTypes()) {
920     if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
921       nestedTuple.getFlattenedTypes(types);
922     else
923       types.push_back(type);
924   }
925 }
926 
927 /// Return the number of element types.
928 size_t TupleType::size() const { return getImpl()->size(); }
929 
930 //===----------------------------------------------------------------------===//
931 // Type Utilities
932 //===----------------------------------------------------------------------===//
933 
934 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
935                                                 ArrayRef<AffineExpr> exprs,
936                                                 MLIRContext *context) {
937   // Size 0 corner case is useful for canonicalizations.
938   if (sizes.empty())
939     return getAffineConstantExpr(0, context);
940 
941   assert(!exprs.empty() && "expected exprs");
942   auto maps = AffineMap::inferFromExprList(exprs, context);
943   assert(!maps.empty() && "Expected one non-empty map");
944   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
945 
946   AffineExpr expr;
947   bool dynamicPoisonBit = false;
948   int64_t runningSize = 1;
949   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
950     int64_t size = std::get<1>(en);
951     AffineExpr dimExpr = std::get<0>(en);
952     AffineExpr stride = dynamicPoisonBit
953                             ? getAffineSymbolExpr(nSymbols++, context)
954                             : getAffineConstantExpr(runningSize, context);
955     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
956     if (size > 0) {
957       runningSize *= size;
958       assert(runningSize > 0 && "integer overflow in size computation");
959     } else {
960       dynamicPoisonBit = true;
961     }
962   }
963   return simplifyAffineExpr(expr, numDims, nSymbols);
964 }
965 
966 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
967                                                 MLIRContext *context) {
968   SmallVector<AffineExpr, 4> exprs;
969   exprs.reserve(sizes.size());
970   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
971     exprs.push_back(getAffineDimExpr(dim, context));
972   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
973 }
974