xref: /llvm-project/mlir/lib/IR/BuiltinAttributeInterfaces.cpp (revision c1fa60b4cde512964544ab66404dea79dbc5dcb4)
1d80d3a35SRiver Riddle //===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
2d80d3a35SRiver Riddle //
3d80d3a35SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d80d3a35SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5d80d3a35SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d80d3a35SRiver Riddle //
7d80d3a35SRiver Riddle //===----------------------------------------------------------------------===//
8d80d3a35SRiver Riddle 
9d80d3a35SRiver Riddle #include "mlir/IR/BuiltinAttributeInterfaces.h"
10d80d3a35SRiver Riddle #include "mlir/IR/BuiltinTypes.h"
11e41ebbecSVladislav Vinogradov #include "mlir/IR/Diagnostics.h"
12d80d3a35SRiver Riddle #include "llvm/ADT/Sequence.h"
13d80d3a35SRiver Riddle 
14d80d3a35SRiver Riddle using namespace mlir;
15d80d3a35SRiver Riddle using namespace mlir::detail;
16d80d3a35SRiver Riddle 
17d80d3a35SRiver Riddle //===----------------------------------------------------------------------===//
18d80d3a35SRiver Riddle /// Tablegen Interface Definitions
19d80d3a35SRiver Riddle //===----------------------------------------------------------------------===//
20d80d3a35SRiver Riddle 
21d80d3a35SRiver Riddle #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
22d80d3a35SRiver Riddle 
23d80d3a35SRiver Riddle //===----------------------------------------------------------------------===//
24d80d3a35SRiver Riddle // ElementsAttr
25d80d3a35SRiver Riddle //===----------------------------------------------------------------------===//
26d80d3a35SRiver Riddle 
getElementType(ElementsAttr elementsAttr)27e1795322SJeff Niu Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
288db947daSRahul Kayaith   return elementsAttr.getShapedType().getElementType();
29d80d3a35SRiver Riddle }
30d80d3a35SRiver Riddle 
getNumElements(ElementsAttr elementsAttr)31e1795322SJeff Niu int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
328db947daSRahul Kayaith   return elementsAttr.getShapedType().getNumElements();
33d80d3a35SRiver Riddle }
34d80d3a35SRiver Riddle 
isValidIndex(ShapedType type,ArrayRef<uint64_t> index)35d80d3a35SRiver Riddle bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
36d80d3a35SRiver Riddle   // Verify that the rank of the indices matches the held type.
37d80d3a35SRiver Riddle   int64_t rank = type.getRank();
38d80d3a35SRiver Riddle   if (rank == 0 && index.size() == 1 && index[0] == 0)
39d80d3a35SRiver Riddle     return true;
40d80d3a35SRiver Riddle   if (rank != static_cast<int64_t>(index.size()))
41d80d3a35SRiver Riddle     return false;
42d80d3a35SRiver Riddle 
43d80d3a35SRiver Riddle   // Verify that all of the indices are within the shape dimensions.
44d80d3a35SRiver Riddle   ArrayRef<int64_t> shape = type.getShape();
45d80d3a35SRiver Riddle   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
46d80d3a35SRiver Riddle     int64_t dim = static_cast<int64_t>(index[i]);
47d80d3a35SRiver Riddle     return 0 <= dim && dim < shape[i];
48d80d3a35SRiver Riddle   });
49d80d3a35SRiver Riddle }
isValidIndex(ElementsAttr elementsAttr,ArrayRef<uint64_t> index)50e1795322SJeff Niu bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
51d80d3a35SRiver Riddle                                 ArrayRef<uint64_t> index) {
528db947daSRahul Kayaith   return isValidIndex(elementsAttr.getShapedType(), index);
53d80d3a35SRiver Riddle }
54d80d3a35SRiver Riddle 
getFlattenedIndex(Type type,ArrayRef<uint64_t> index)55ae40d625SRiver Riddle uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
56*c1fa60b4STres Popp   ShapedType shapeType = llvm::cast<ShapedType>(type);
57ae40d625SRiver Riddle   assert(isValidIndex(shapeType, index) &&
58ae40d625SRiver Riddle          "expected valid multi-dimensional index");
59d80d3a35SRiver Riddle 
60d80d3a35SRiver Riddle   // Reduce the provided multidimensional index into a flattended 1D row-major
61d80d3a35SRiver Riddle   // index.
62ae40d625SRiver Riddle   auto rank = shapeType.getRank();
63ae40d625SRiver Riddle   ArrayRef<int64_t> shape = shapeType.getShape();
64d80d3a35SRiver Riddle   uint64_t valueIndex = 0;
65d80d3a35SRiver Riddle   uint64_t dimMultiplier = 1;
66d80d3a35SRiver Riddle   for (int i = rank - 1; i >= 0; --i) {
67d80d3a35SRiver Riddle     valueIndex += index[i] * dimMultiplier;
68d80d3a35SRiver Riddle     dimMultiplier *= shape[i];
69d80d3a35SRiver Riddle   }
70d80d3a35SRiver Riddle   return valueIndex;
71d80d3a35SRiver Riddle }
72e41ebbecSVladislav Vinogradov 
73e41ebbecSVladislav Vinogradov //===----------------------------------------------------------------------===//
74e41ebbecSVladislav Vinogradov // MemRefLayoutAttrInterface
75e41ebbecSVladislav Vinogradov //===----------------------------------------------------------------------===//
76e41ebbecSVladislav Vinogradov 
verifyAffineMapAsLayout(AffineMap m,ArrayRef<int64_t> shape,function_ref<InFlightDiagnostic ()> emitError)77e41ebbecSVladislav Vinogradov LogicalResult mlir::detail::verifyAffineMapAsLayout(
78e41ebbecSVladislav Vinogradov     AffineMap m, ArrayRef<int64_t> shape,
79e41ebbecSVladislav Vinogradov     function_ref<InFlightDiagnostic()> emitError) {
80e41ebbecSVladislav Vinogradov   if (m.getNumDims() != shape.size())
81e41ebbecSVladislav Vinogradov     return emitError() << "memref layout mismatch between rank and affine map: "
82e41ebbecSVladislav Vinogradov                        << shape.size() << " != " << m.getNumDims();
83e41ebbecSVladislav Vinogradov 
84e41ebbecSVladislav Vinogradov   return success();
85e41ebbecSVladislav Vinogradov }
86