1bcfa7baeSStella Laurenzo //===- sparse_tensor.c - Test of sparse_tensor APIs -----------------------===//
2bcfa7baeSStella Laurenzo //
3bcfa7baeSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4bcfa7baeSStella Laurenzo // Exceptions.
5bcfa7baeSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
6bcfa7baeSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7bcfa7baeSStella Laurenzo //
8bcfa7baeSStella Laurenzo //===----------------------------------------------------------------------===//
9bcfa7baeSStella Laurenzo
10bcfa7baeSStella Laurenzo // RUN: mlir-capi-sparse-tensor-test 2>&1 | FileCheck %s
11bcfa7baeSStella Laurenzo
12bcfa7baeSStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
13bcfa7baeSStella Laurenzo #include "mlir-c/IR.h"
145e83a5b4SStella Laurenzo #include "mlir-c/RegisterEverything.h"
15bcfa7baeSStella Laurenzo
16bcfa7baeSStella Laurenzo #include <assert.h>
1787ff65b0SJie Fu #include <inttypes.h>
18bcfa7baeSStella Laurenzo #include <math.h>
19bcfa7baeSStella Laurenzo #include <stdio.h>
20bcfa7baeSStella Laurenzo #include <stdlib.h>
21bcfa7baeSStella Laurenzo #include <string.h>
22bcfa7baeSStella Laurenzo
23bcfa7baeSStella Laurenzo // CHECK-LABEL: testRoundtripEncoding()
testRoundtripEncoding(MlirContext ctx)24bcfa7baeSStella Laurenzo static int testRoundtripEncoding(MlirContext ctx) {
25bcfa7baeSStella Laurenzo fprintf(stderr, "testRoundtripEncoding()\n");
26bcfa7baeSStella Laurenzo // clang-format off
27bcfa7baeSStella Laurenzo const char *originalAsm =
28bcfa7baeSStella Laurenzo "#sparse_tensor.encoding<{ "
29256ac461SYinying Li "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
30*a10d67f9SYinying Li "posWidth = 32, crdWidth = 64, explicitVal = 1 : i64}>";
31bcfa7baeSStella Laurenzo // clang-format on
32bcfa7baeSStella Laurenzo MlirAttribute originalAttr =
33bcfa7baeSStella Laurenzo mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(originalAsm));
34bcfa7baeSStella Laurenzo // CHECK: isa: 1
35bcfa7baeSStella Laurenzo fprintf(stderr, "isa: %d\n",
36bcfa7baeSStella Laurenzo mlirAttributeIsASparseTensorEncodingAttr(originalAttr));
3776647fceSwren romano MlirAffineMap dimToLvl =
3876647fceSwren romano mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr);
39c48e9087SAart Bik // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
4076647fceSwren romano mlirAffineMapDump(dimToLvl);
41e5924d64SYinying Li // CHECK: level_type: 65536
4256d58295SPeiming Liu // CHECK: level_type: 262144
4356d58295SPeiming Liu // CHECK: level_type: 262144
44d4088e7dSYinying Li MlirAffineMap lvlToDim =
45d4088e7dSYinying Li mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
4684cd51bbSwren romano int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
47cd481fa8SYinying Li MlirSparseTensorLevelType *lvlTypes =
48cd481fa8SYinying Li malloc(sizeof(MlirSparseTensorLevelType) * lvlRank);
4984cd51bbSwren romano for (int l = 0; l < lvlRank; ++l) {
50a0615d02Swren romano lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
5187ff65b0SJie Fu fprintf(stderr, "level_type: %" PRIu64 "\n", lvlTypes[l]);
52bcfa7baeSStella Laurenzo }
5384cd51bbSwren romano // CHECK: posWidth: 32
5484cd51bbSwren romano int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);
5584cd51bbSwren romano fprintf(stderr, "posWidth: %d\n", posWidth);
5684cd51bbSwren romano // CHECK: crdWidth: 64
5784cd51bbSwren romano int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
5884cd51bbSwren romano fprintf(stderr, "crdWidth: %d\n", crdWidth);
59*a10d67f9SYinying Li
60*a10d67f9SYinying Li // CHECK: explicitVal: 1 : i64
61*a10d67f9SYinying Li MlirAttribute explicitVal =
62*a10d67f9SYinying Li mlirSparseTensorEncodingAttrGetExplicitVal(originalAttr);
63*a10d67f9SYinying Li fprintf(stderr, "explicitVal: ");
64*a10d67f9SYinying Li mlirAttributeDump(explicitVal);
65*a10d67f9SYinying Li // CHECK: implicitVal: <<NULL ATTRIBUTE>>
66*a10d67f9SYinying Li MlirAttribute implicitVal =
67*a10d67f9SYinying Li mlirSparseTensorEncodingAttrGetImplicitVal(originalAttr);
68*a10d67f9SYinying Li fprintf(stderr, "implicitVal: ");
69*a10d67f9SYinying Li mlirAttributeDump(implicitVal);
70*a10d67f9SYinying Li
71a0615d02Swren romano MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
72*a10d67f9SYinying Li ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
73*a10d67f9SYinying Li explicitVal, implicitVal);
74bcfa7baeSStella Laurenzo mlirAttributeDump(newAttr); // For debugging filecheck output.
75bcfa7baeSStella Laurenzo // CHECK: equal: 1
76bcfa7baeSStella Laurenzo fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
77a0615d02Swren romano free(lvlTypes);
78bcfa7baeSStella Laurenzo return 0;
79bcfa7baeSStella Laurenzo }
80bcfa7baeSStella Laurenzo
main(void)815d91f79fSTom Eccles int main(void) {
82bcfa7baeSStella Laurenzo MlirContext ctx = mlirContextCreate();
83bcfa7baeSStella Laurenzo mlirDialectHandleRegisterDialect(mlirGetDialectHandle__sparse_tensor__(),
84bcfa7baeSStella Laurenzo ctx);
85bcfa7baeSStella Laurenzo if (testRoundtripEncoding(ctx))
86bcfa7baeSStella Laurenzo return 1;
87bcfa7baeSStella Laurenzo
88bcfa7baeSStella Laurenzo mlirContextDestroy(ctx);
89bcfa7baeSStella Laurenzo return 0;
90bcfa7baeSStella Laurenzo }
91