xref: /llvm-project/mlir/lib/CAPI/IR/BuiltinAttributes.cpp (revision 5d3ae5161210c068d01ffba36c8e0761e9971179)
1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir-c/Support.h"
11 #include "mlir/CAPI/AffineMap.h"
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/IntegerSet.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/IR/AsmState.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 
20 using namespace mlir;
21 
22 MlirAttribute mlirAttributeGetNull() { return {nullptr}; }
23 
24 //===----------------------------------------------------------------------===//
25 // Location attribute.
26 //===----------------------------------------------------------------------===//
27 
28 bool mlirAttributeIsALocation(MlirAttribute attr) {
29   return llvm::isa<LocationAttr>(unwrap(attr));
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // Affine map attribute.
34 //===----------------------------------------------------------------------===//
35 
36 bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
37   return llvm::isa<AffineMapAttr>(unwrap(attr));
38 }
39 
40 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
41   return wrap(AffineMapAttr::get(unwrap(map)));
42 }
43 
44 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
45   return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue());
46 }
47 
48 MlirTypeID mlirAffineMapAttrGetTypeID(void) {
49   return wrap(AffineMapAttr::getTypeID());
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // Array attribute.
54 //===----------------------------------------------------------------------===//
55 
56 bool mlirAttributeIsAArray(MlirAttribute attr) {
57   return llvm::isa<ArrayAttr>(unwrap(attr));
58 }
59 
60 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
61                                MlirAttribute const *elements) {
62   SmallVector<Attribute, 8> attrs;
63   return wrap(
64       ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
65                                              elements, attrs)));
66 }
67 
68 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
69   return static_cast<intptr_t>(llvm::cast<ArrayAttr>(unwrap(attr)).size());
70 }
71 
72 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
73   return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]);
74 }
75 
76 MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); }
77 
78 //===----------------------------------------------------------------------===//
79 // Dictionary attribute.
80 //===----------------------------------------------------------------------===//
81 
82 bool mlirAttributeIsADictionary(MlirAttribute attr) {
83   return llvm::isa<DictionaryAttr>(unwrap(attr));
84 }
85 
86 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
87                                     MlirNamedAttribute const *elements) {
88   SmallVector<NamedAttribute, 8> attributes;
89   attributes.reserve(numElements);
90   for (intptr_t i = 0; i < numElements; ++i)
91     attributes.emplace_back(unwrap(elements[i].name),
92                             unwrap(elements[i].attribute));
93   return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
94 }
95 
96 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
97   return static_cast<intptr_t>(llvm::cast<DictionaryAttr>(unwrap(attr)).size());
98 }
99 
100 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
101                                                 intptr_t pos) {
102   NamedAttribute attribute =
103       llvm::cast<DictionaryAttr>(unwrap(attr)).getValue()[pos];
104   return {wrap(attribute.getName()), wrap(attribute.getValue())};
105 }
106 
107 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
108                                                  MlirStringRef name) {
109   return wrap(llvm::cast<DictionaryAttr>(unwrap(attr)).get(unwrap(name)));
110 }
111 
112 MlirTypeID mlirDictionaryAttrGetTypeID(void) {
113   return wrap(DictionaryAttr::getTypeID());
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // Floating point attribute.
118 //===----------------------------------------------------------------------===//
119 
120 bool mlirAttributeIsAFloat(MlirAttribute attr) {
121   return llvm::isa<FloatAttr>(unwrap(attr));
122 }
123 
124 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
125                                      double value) {
126   return wrap(FloatAttr::get(unwrap(type), value));
127 }
128 
129 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
130                                             double value) {
131   return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
132 }
133 
134 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
135   return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble();
136 }
137 
138 MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); }
139 
140 //===----------------------------------------------------------------------===//
141 // Integer attribute.
142 //===----------------------------------------------------------------------===//
143 
144 bool mlirAttributeIsAInteger(MlirAttribute attr) {
145   return llvm::isa<IntegerAttr>(unwrap(attr));
146 }
147 
148 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
149   return wrap(IntegerAttr::get(unwrap(type), value));
150 }
151 
152 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
153   return llvm::cast<IntegerAttr>(unwrap(attr)).getInt();
154 }
155 
156 int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
157   return llvm::cast<IntegerAttr>(unwrap(attr)).getSInt();
158 }
159 
160 uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
161   return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
162 }
163 
164 MlirTypeID mlirIntegerAttrGetTypeID(void) {
165   return wrap(IntegerAttr::getTypeID());
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // Bool attribute.
170 //===----------------------------------------------------------------------===//
171 
172 bool mlirAttributeIsABool(MlirAttribute attr) {
173   return llvm::isa<BoolAttr>(unwrap(attr));
174 }
175 
176 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
177   return wrap(BoolAttr::get(unwrap(ctx), value));
178 }
179 
180 bool mlirBoolAttrGetValue(MlirAttribute attr) {
181   return llvm::cast<BoolAttr>(unwrap(attr)).getValue();
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // Integer set attribute.
186 //===----------------------------------------------------------------------===//
187 
188 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
189   return llvm::isa<IntegerSetAttr>(unwrap(attr));
190 }
191 
192 MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
193   return wrap(IntegerSetAttr::getTypeID());
194 }
195 
196 MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
197   return wrap(IntegerSetAttr::get(unwrap(set)));
198 }
199 
200 MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
201   return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // Opaque attribute.
206 //===----------------------------------------------------------------------===//
207 
208 bool mlirAttributeIsAOpaque(MlirAttribute attr) {
209   return llvm::isa<OpaqueAttr>(unwrap(attr));
210 }
211 
212 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
213                                 intptr_t dataLength, const char *data,
214                                 MlirType type) {
215   return wrap(
216       OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
217                       StringRef(data, dataLength), unwrap(type)));
218 }
219 
220 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
221   return wrap(
222       llvm::cast<OpaqueAttr>(unwrap(attr)).getDialectNamespace().strref());
223 }
224 
225 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
226   return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData());
227 }
228 
229 MlirTypeID mlirOpaqueAttrGetTypeID(void) {
230   return wrap(OpaqueAttr::getTypeID());
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // String attribute.
235 //===----------------------------------------------------------------------===//
236 
237 bool mlirAttributeIsAString(MlirAttribute attr) {
238   return llvm::isa<StringAttr>(unwrap(attr));
239 }
240 
241 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
242   return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
243 }
244 
245 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
246   return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
247 }
248 
249 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
250   return wrap(llvm::cast<StringAttr>(unwrap(attr)).getValue());
251 }
252 
253 MlirTypeID mlirStringAttrGetTypeID(void) {
254   return wrap(StringAttr::getTypeID());
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // SymbolRef attribute.
259 //===----------------------------------------------------------------------===//
260 
261 bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
262   return llvm::isa<SymbolRefAttr>(unwrap(attr));
263 }
264 
265 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
266                                    intptr_t numReferences,
267                                    MlirAttribute const *references) {
268   SmallVector<FlatSymbolRefAttr, 4> refs;
269   refs.reserve(numReferences);
270   for (intptr_t i = 0; i < numReferences; ++i)
271     refs.push_back(llvm::cast<FlatSymbolRefAttr>(unwrap(references[i])));
272   auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
273   return wrap(SymbolRefAttr::get(symbolAttr, refs));
274 }
275 
276 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
277   return wrap(
278       llvm::cast<SymbolRefAttr>(unwrap(attr)).getRootReference().getValue());
279 }
280 
281 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
282   return wrap(
283       llvm::cast<SymbolRefAttr>(unwrap(attr)).getLeafReference().getValue());
284 }
285 
286 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
287   return static_cast<intptr_t>(
288       llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences().size());
289 }
290 
291 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
292                                                   intptr_t pos) {
293   return wrap(
294       llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]);
295 }
296 
297 MlirTypeID mlirSymbolRefAttrGetTypeID(void) {
298   return wrap(SymbolRefAttr::getTypeID());
299 }
300 
301 MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) {
302   return wrap(mlir::DistinctAttr::create(unwrap(referencedAttr)));
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // Flat SymbolRef attribute.
307 //===----------------------------------------------------------------------===//
308 
309 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
310   return llvm::isa<FlatSymbolRefAttr>(unwrap(attr));
311 }
312 
313 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
314   return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
315 }
316 
317 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
318   return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // Type attribute.
323 //===----------------------------------------------------------------------===//
324 
325 bool mlirAttributeIsAType(MlirAttribute attr) {
326   return llvm::isa<TypeAttr>(unwrap(attr));
327 }
328 
329 MlirAttribute mlirTypeAttrGet(MlirType type) {
330   return wrap(TypeAttr::get(unwrap(type)));
331 }
332 
333 MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
334   return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue());
335 }
336 
337 MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); }
338 
339 //===----------------------------------------------------------------------===//
340 // Unit attribute.
341 //===----------------------------------------------------------------------===//
342 
343 bool mlirAttributeIsAUnit(MlirAttribute attr) {
344   return llvm::isa<UnitAttr>(unwrap(attr));
345 }
346 
347 MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
348   return wrap(UnitAttr::get(unwrap(ctx)));
349 }
350 
351 MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); }
352 
353 //===----------------------------------------------------------------------===//
354 // Elements attributes.
355 //===----------------------------------------------------------------------===//
356 
357 bool mlirAttributeIsAElements(MlirAttribute attr) {
358   return llvm::isa<ElementsAttr>(unwrap(attr));
359 }
360 
361 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
362                                        uint64_t *idxs) {
363   return wrap(llvm::cast<ElementsAttr>(unwrap(attr))
364                   .getValues<Attribute>()[llvm::ArrayRef(idxs, rank)]);
365 }
366 
367 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
368                                   uint64_t *idxs) {
369   return llvm::cast<ElementsAttr>(unwrap(attr))
370       .isValidIndex(llvm::ArrayRef(idxs, rank));
371 }
372 
373 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
374   return llvm::cast<ElementsAttr>(unwrap(attr)).getNumElements();
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // Dense array attribute.
379 //===----------------------------------------------------------------------===//
380 
381 MlirTypeID mlirDenseArrayAttrGetTypeID() {
382   return wrap(DenseArrayAttr::getTypeID());
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // IsA support.
387 //===----------------------------------------------------------------------===//
388 
389 bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
390   return llvm::isa<DenseBoolArrayAttr>(unwrap(attr));
391 }
392 bool mlirAttributeIsADenseI8Array(MlirAttribute attr) {
393   return llvm::isa<DenseI8ArrayAttr>(unwrap(attr));
394 }
395 bool mlirAttributeIsADenseI16Array(MlirAttribute attr) {
396   return llvm::isa<DenseI16ArrayAttr>(unwrap(attr));
397 }
398 bool mlirAttributeIsADenseI32Array(MlirAttribute attr) {
399   return llvm::isa<DenseI32ArrayAttr>(unwrap(attr));
400 }
401 bool mlirAttributeIsADenseI64Array(MlirAttribute attr) {
402   return llvm::isa<DenseI64ArrayAttr>(unwrap(attr));
403 }
404 bool mlirAttributeIsADenseF32Array(MlirAttribute attr) {
405   return llvm::isa<DenseF32ArrayAttr>(unwrap(attr));
406 }
407 bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
408   return llvm::isa<DenseF64ArrayAttr>(unwrap(attr));
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // Constructors.
413 //===----------------------------------------------------------------------===//
414 
415 MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
416                                     int const *values) {
417   SmallVector<bool, 4> elements(values, values + size);
418   return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements));
419 }
420 MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size,
421                                   int8_t const *values) {
422   return wrap(
423       DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size)));
424 }
425 MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size,
426                                    int16_t const *values) {
427   return wrap(
428       DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size)));
429 }
430 MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size,
431                                    int32_t const *values) {
432   return wrap(
433       DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size)));
434 }
435 MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size,
436                                    int64_t const *values) {
437   return wrap(
438       DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size)));
439 }
440 MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size,
441                                    float const *values) {
442   return wrap(
443       DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size)));
444 }
445 MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
446                                    double const *values) {
447   return wrap(
448       DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size)));
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // Accessors.
453 //===----------------------------------------------------------------------===//
454 
455 intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
456   return llvm::cast<DenseArrayAttr>(unwrap(attr)).size();
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // Indexed accessors.
461 //===----------------------------------------------------------------------===//
462 
463 bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
464   return llvm::cast<DenseBoolArrayAttr>(unwrap(attr))[pos];
465 }
466 int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) {
467   return llvm::cast<DenseI8ArrayAttr>(unwrap(attr))[pos];
468 }
469 int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) {
470   return llvm::cast<DenseI16ArrayAttr>(unwrap(attr))[pos];
471 }
472 int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
473   return llvm::cast<DenseI32ArrayAttr>(unwrap(attr))[pos];
474 }
475 int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
476   return llvm::cast<DenseI64ArrayAttr>(unwrap(attr))[pos];
477 }
478 float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
479   return llvm::cast<DenseF32ArrayAttr>(unwrap(attr))[pos];
480 }
481 double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
482   return llvm::cast<DenseF64ArrayAttr>(unwrap(attr))[pos];
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // Dense elements attribute.
487 //===----------------------------------------------------------------------===//
488 
489 //===----------------------------------------------------------------------===//
490 // IsA support.
491 //===----------------------------------------------------------------------===//
492 
493 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
494   return llvm::isa<DenseElementsAttr>(unwrap(attr));
495 }
496 
497 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
498   return llvm::isa<DenseIntElementsAttr>(unwrap(attr));
499 }
500 
501 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
502   return llvm::isa<DenseFPElementsAttr>(unwrap(attr));
503 }
504 
505 MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) {
506   return wrap(DenseIntOrFPElementsAttr::getTypeID());
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // Constructors.
511 //===----------------------------------------------------------------------===//
512 
513 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
514                                        intptr_t numElements,
515                                        MlirAttribute const *elements) {
516   SmallVector<Attribute, 8> attributes;
517   return wrap(
518       DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
519                              unwrapList(numElements, elements, attributes)));
520 }
521 
522 MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
523                                                 size_t rawBufferSize,
524                                                 const void *rawBuffer) {
525   auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType));
526   ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
527                               rawBufferSize);
528   bool isSplat = false;
529   if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
530                                            isSplat))
531     return mlirAttributeGetNull();
532   return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
533 }
534 
535 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
536                                             MlirAttribute element) {
537   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
538                                      unwrap(element)));
539 }
540 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
541                                                 bool element) {
542   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
543                                      element));
544 }
545 MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
546                                                  uint8_t element) {
547   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
548                                      element));
549 }
550 MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
551                                                 int8_t element) {
552   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
553                                      element));
554 }
555 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
556                                                   uint32_t element) {
557   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
558                                      element));
559 }
560 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
561                                                  int32_t element) {
562   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
563                                      element));
564 }
565 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
566                                                   uint64_t element) {
567   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
568                                      element));
569 }
570 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
571                                                  int64_t element) {
572   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
573                                      element));
574 }
575 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
576                                                  float element) {
577   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
578                                      element));
579 }
580 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
581                                                   double element) {
582   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
583                                      element));
584 }
585 
586 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
587                                            intptr_t numElements,
588                                            const int *elements) {
589   SmallVector<bool, 8> values(elements, elements + numElements);
590   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
591                                      values));
592 }
593 
594 /// Creates a dense attribute with elements of the type deduced by templates.
595 template <typename T>
596 static MlirAttribute getDenseAttribute(MlirType shapedType,
597                                        intptr_t numElements,
598                                        const T *elements) {
599   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
600                                      llvm::ArrayRef(elements, numElements)));
601 }
602 
603 MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType,
604                                             intptr_t numElements,
605                                             const uint8_t *elements) {
606   return getDenseAttribute(shapedType, numElements, elements);
607 }
608 MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
609                                            intptr_t numElements,
610                                            const int8_t *elements) {
611   return getDenseAttribute(shapedType, numElements, elements);
612 }
613 MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
614                                              intptr_t numElements,
615                                              const uint16_t *elements) {
616   return getDenseAttribute(shapedType, numElements, elements);
617 }
618 MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
619                                             intptr_t numElements,
620                                             const int16_t *elements) {
621   return getDenseAttribute(shapedType, numElements, elements);
622 }
623 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
624                                              intptr_t numElements,
625                                              const uint32_t *elements) {
626   return getDenseAttribute(shapedType, numElements, elements);
627 }
628 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
629                                             intptr_t numElements,
630                                             const int32_t *elements) {
631   return getDenseAttribute(shapedType, numElements, elements);
632 }
633 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
634                                              intptr_t numElements,
635                                              const uint64_t *elements) {
636   return getDenseAttribute(shapedType, numElements, elements);
637 }
638 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
639                                             intptr_t numElements,
640                                             const int64_t *elements) {
641   return getDenseAttribute(shapedType, numElements, elements);
642 }
643 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
644                                             intptr_t numElements,
645                                             const float *elements) {
646   return getDenseAttribute(shapedType, numElements, elements);
647 }
648 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
649                                              intptr_t numElements,
650                                              const double *elements) {
651   return getDenseAttribute(shapedType, numElements, elements);
652 }
653 MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
654                                                intptr_t numElements,
655                                                const uint16_t *elements) {
656   size_t bufferSize = numElements * 2;
657   const void *buffer = static_cast<const void *>(elements);
658   return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
659 }
660 MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType,
661                                               intptr_t numElements,
662                                               const uint16_t *elements) {
663   size_t bufferSize = numElements * 2;
664   const void *buffer = static_cast<const void *>(elements);
665   return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
666 }
667 
668 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
669                                              intptr_t numElements,
670                                              MlirStringRef *strs) {
671   SmallVector<StringRef, 8> values;
672   values.reserve(numElements);
673   for (intptr_t i = 0; i < numElements; ++i)
674     values.push_back(unwrap(strs[i]));
675 
676   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
677                                      values));
678 }
679 
680 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
681                                               MlirType shapedType) {
682   return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
683                   .reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // Splat accessors.
688 //===----------------------------------------------------------------------===//
689 
690 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
691   return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
692 }
693 
694 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
695   return wrap(
696       llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>());
697 }
698 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
699   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>();
700 }
701 int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
702   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int8_t>();
703 }
704 uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
705   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint8_t>();
706 }
707 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
708   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int32_t>();
709 }
710 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
711   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint32_t>();
712 }
713 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
714   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int64_t>();
715 }
716 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
717   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint64_t>();
718 }
719 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
720   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<float>();
721 }
722 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
723   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>();
724 }
725 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
726   return wrap(
727       llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>());
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // Indexed accessors.
732 //===----------------------------------------------------------------------===//
733 
734 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
735   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<bool>()[pos];
736 }
737 int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
738   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int8_t>()[pos];
739 }
740 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
741   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint8_t>()[pos];
742 }
743 int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
744   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int16_t>()[pos];
745 }
746 uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
747   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint16_t>()[pos];
748 }
749 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
750   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int32_t>()[pos];
751 }
752 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
753   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint32_t>()[pos];
754 }
755 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
756   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int64_t>()[pos];
757 }
758 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
759   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
760 }
761 uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) {
762   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
763 }
764 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
765   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<float>()[pos];
766 }
767 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
768   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<double>()[pos];
769 }
770 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
771                                                   intptr_t pos) {
772   return wrap(
773       llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]);
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // Raw data accessors.
778 //===----------------------------------------------------------------------===//
779 
780 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
781   return static_cast<const void *>(
782       llvm::cast<DenseElementsAttr>(unwrap(attr)).getRawData().data());
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // Resource blob attributes.
787 //===----------------------------------------------------------------------===//
788 
789 bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) {
790   return llvm::isa<DenseResourceElementsAttr>(unwrap(attr));
791 }
792 
793 MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
794     MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
795     size_t dataAlignment, bool dataIsMutable,
796     void (*deleter)(void *userData, const void *data, size_t size,
797                     size_t align),
798     void *userData) {
799   AsmResourceBlob::DeleterFn cppDeleter = {};
800   if (deleter) {
801     cppDeleter = [deleter, userData](void *data, size_t size, size_t align) {
802       deleter(userData, data, size, align);
803     };
804   }
805   AsmResourceBlob blob(
806       llvm::ArrayRef(static_cast<const char *>(data), dataLength),
807       dataAlignment, std::move(cppDeleter), dataIsMutable);
808   return wrap(
809       DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
810                                      unwrap(name), std::move(blob)));
811 }
812 
813 template <typename U, typename T>
814 static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
815                                       intptr_t numElements, const T *elements) {
816   return wrap(U::get(llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name),
817                      UnmanagedAsmResourceBlob::allocateInferAlign(
818                          llvm::ArrayRef(elements, numElements))));
819 }
820 
821 MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
822     MlirType shapedType, MlirStringRef name, intptr_t numElements,
823     const int *elements) {
824   return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name,
825                                                          numElements, elements);
826 }
827 MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
828     MlirType shapedType, MlirStringRef name, intptr_t numElements,
829     const uint8_t *elements) {
830   return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
831                                                         numElements, elements);
832 }
833 MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
834     MlirType shapedType, MlirStringRef name, intptr_t numElements,
835     const uint16_t *elements) {
836   return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
837                                                          numElements, elements);
838 }
839 MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
840     MlirType shapedType, MlirStringRef name, intptr_t numElements,
841     const uint32_t *elements) {
842   return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
843                                                          numElements, elements);
844 }
845 MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
846     MlirType shapedType, MlirStringRef name, intptr_t numElements,
847     const uint64_t *elements) {
848   return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
849                                                          numElements, elements);
850 }
851 MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
852     MlirType shapedType, MlirStringRef name, intptr_t numElements,
853     const int8_t *elements) {
854   return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
855                                                         numElements, elements);
856 }
857 MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
858     MlirType shapedType, MlirStringRef name, intptr_t numElements,
859     const int16_t *elements) {
860   return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
861                                                          numElements, elements);
862 }
863 MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
864     MlirType shapedType, MlirStringRef name, intptr_t numElements,
865     const int32_t *elements) {
866   return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
867                                                          numElements, elements);
868 }
869 MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
870     MlirType shapedType, MlirStringRef name, intptr_t numElements,
871     const int64_t *elements) {
872   return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
873                                                          numElements, elements);
874 }
875 MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
876     MlirType shapedType, MlirStringRef name, intptr_t numElements,
877     const float *elements) {
878   return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name,
879                                                         numElements, elements);
880 }
881 MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet(
882     MlirType shapedType, MlirStringRef name, intptr_t numElements,
883     const double *elements) {
884   return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name,
885                                                         numElements, elements);
886 }
887 template <typename U, typename T>
888 static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) {
889   return (*llvm::cast<U>(unwrap(attr)).tryGetAsArrayRef())[pos];
890 }
891 
892 bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr,
893                                                intptr_t pos) {
894   return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos);
895 }
896 uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr,
897                                                    intptr_t pos) {
898   return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos);
899 }
900 uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr,
901                                                      intptr_t pos) {
902   return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr,
903                                                                       pos);
904 }
905 uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr,
906                                                      intptr_t pos) {
907   return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr,
908                                                                       pos);
909 }
910 uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr,
911                                                      intptr_t pos) {
912   return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr,
913                                                                       pos);
914 }
915 int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr,
916                                                  intptr_t pos) {
917   return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos);
918 }
919 int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr,
920                                                    intptr_t pos) {
921   return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos);
922 }
923 int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr,
924                                                    intptr_t pos) {
925   return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos);
926 }
927 int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr,
928                                                    intptr_t pos) {
929   return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos);
930 }
931 float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr,
932                                                  intptr_t pos) {
933   return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos);
934 }
935 double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr,
936                                                    intptr_t pos) {
937   return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos);
938 }
939 
940 //===----------------------------------------------------------------------===//
941 // Sparse elements attribute.
942 //===----------------------------------------------------------------------===//
943 
944 bool mlirAttributeIsASparseElements(MlirAttribute attr) {
945   return llvm::isa<SparseElementsAttr>(unwrap(attr));
946 }
947 
948 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
949                                           MlirAttribute denseIndices,
950                                           MlirAttribute denseValues) {
951   return wrap(SparseElementsAttr::get(
952       llvm::cast<ShapedType>(unwrap(shapedType)),
953       llvm::cast<DenseElementsAttr>(unwrap(denseIndices)),
954       llvm::cast<DenseElementsAttr>(unwrap(denseValues))));
955 }
956 
957 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
958   return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getIndices());
959 }
960 
961 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
962   return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues());
963 }
964 
965 MlirTypeID mlirSparseElementsAttrGetTypeID(void) {
966   return wrap(SparseElementsAttr::getTypeID());
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // Strided layout attribute.
971 //===----------------------------------------------------------------------===//
972 
973 bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
974   return llvm::isa<StridedLayoutAttr>(unwrap(attr));
975 }
976 
977 MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
978                                        intptr_t numStrides,
979                                        const int64_t *strides) {
980   return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
981                                      ArrayRef<int64_t>(strides, numStrides)));
982 }
983 
984 int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) {
985   return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getOffset();
986 }
987 
988 intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
989   return static_cast<intptr_t>(
990       llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides().size());
991 }
992 
993 int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
994   return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos];
995 }
996 
997 MlirTypeID mlirStridedLayoutAttrGetTypeID(void) {
998   return wrap(StridedLayoutAttr::getTypeID());
999 }
1000