xref: /llvm-project/mlir/test/CAPI/sparse_tensor.c (revision a10d67f9fb559d0c35a12b2d26974636bbf642c0)
1bcfa7baeSStella Laurenzo //===- sparse_tensor.c - Test of sparse_tensor APIs -----------------------===//
2bcfa7baeSStella Laurenzo //
3bcfa7baeSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4bcfa7baeSStella Laurenzo // Exceptions.
5bcfa7baeSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
6bcfa7baeSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7bcfa7baeSStella Laurenzo //
8bcfa7baeSStella Laurenzo //===----------------------------------------------------------------------===//
9bcfa7baeSStella Laurenzo 
10bcfa7baeSStella Laurenzo // RUN: mlir-capi-sparse-tensor-test 2>&1 | FileCheck %s
11bcfa7baeSStella Laurenzo 
12bcfa7baeSStella Laurenzo #include "mlir-c/Dialect/SparseTensor.h"
13bcfa7baeSStella Laurenzo #include "mlir-c/IR.h"
145e83a5b4SStella Laurenzo #include "mlir-c/RegisterEverything.h"
15bcfa7baeSStella Laurenzo 
16bcfa7baeSStella Laurenzo #include <assert.h>
1787ff65b0SJie Fu #include <inttypes.h>
18bcfa7baeSStella Laurenzo #include <math.h>
19bcfa7baeSStella Laurenzo #include <stdio.h>
20bcfa7baeSStella Laurenzo #include <stdlib.h>
21bcfa7baeSStella Laurenzo #include <string.h>
22bcfa7baeSStella Laurenzo 
23bcfa7baeSStella Laurenzo // CHECK-LABEL: testRoundtripEncoding()
testRoundtripEncoding(MlirContext ctx)24bcfa7baeSStella Laurenzo static int testRoundtripEncoding(MlirContext ctx) {
25bcfa7baeSStella Laurenzo   fprintf(stderr, "testRoundtripEncoding()\n");
26bcfa7baeSStella Laurenzo   // clang-format off
27bcfa7baeSStella Laurenzo   const char *originalAsm =
28bcfa7baeSStella Laurenzo     "#sparse_tensor.encoding<{ "
29256ac461SYinying Li     "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
30*a10d67f9SYinying Li     "posWidth = 32, crdWidth = 64, explicitVal = 1 : i64}>";
31bcfa7baeSStella Laurenzo   // clang-format on
32bcfa7baeSStella Laurenzo   MlirAttribute originalAttr =
33bcfa7baeSStella Laurenzo       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(originalAsm));
34bcfa7baeSStella Laurenzo   // CHECK: isa: 1
35bcfa7baeSStella Laurenzo   fprintf(stderr, "isa: %d\n",
36bcfa7baeSStella Laurenzo           mlirAttributeIsASparseTensorEncodingAttr(originalAttr));
3776647fceSwren romano   MlirAffineMap dimToLvl =
3876647fceSwren romano       mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr);
39c48e9087SAart Bik   // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
4076647fceSwren romano   mlirAffineMapDump(dimToLvl);
41e5924d64SYinying Li   // CHECK: level_type: 65536
4256d58295SPeiming Liu   // CHECK: level_type: 262144
4356d58295SPeiming Liu   // CHECK: level_type: 262144
44d4088e7dSYinying Li   MlirAffineMap lvlToDim =
45d4088e7dSYinying Li       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
4684cd51bbSwren romano   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
47cd481fa8SYinying Li   MlirSparseTensorLevelType *lvlTypes =
48cd481fa8SYinying Li       malloc(sizeof(MlirSparseTensorLevelType) * lvlRank);
4984cd51bbSwren romano   for (int l = 0; l < lvlRank; ++l) {
50a0615d02Swren romano     lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
5187ff65b0SJie Fu     fprintf(stderr, "level_type: %" PRIu64 "\n", lvlTypes[l]);
52bcfa7baeSStella Laurenzo   }
5384cd51bbSwren romano   // CHECK: posWidth: 32
5484cd51bbSwren romano   int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);
5584cd51bbSwren romano   fprintf(stderr, "posWidth: %d\n", posWidth);
5684cd51bbSwren romano   // CHECK: crdWidth: 64
5784cd51bbSwren romano   int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
5884cd51bbSwren romano   fprintf(stderr, "crdWidth: %d\n", crdWidth);
59*a10d67f9SYinying Li 
60*a10d67f9SYinying Li   // CHECK: explicitVal: 1 : i64
61*a10d67f9SYinying Li   MlirAttribute explicitVal =
62*a10d67f9SYinying Li       mlirSparseTensorEncodingAttrGetExplicitVal(originalAttr);
63*a10d67f9SYinying Li   fprintf(stderr, "explicitVal: ");
64*a10d67f9SYinying Li   mlirAttributeDump(explicitVal);
65*a10d67f9SYinying Li   // CHECK: implicitVal: <<NULL ATTRIBUTE>>
66*a10d67f9SYinying Li   MlirAttribute implicitVal =
67*a10d67f9SYinying Li       mlirSparseTensorEncodingAttrGetImplicitVal(originalAttr);
68*a10d67f9SYinying Li   fprintf(stderr, "implicitVal: ");
69*a10d67f9SYinying Li   mlirAttributeDump(implicitVal);
70*a10d67f9SYinying Li 
71a0615d02Swren romano   MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
72*a10d67f9SYinying Li       ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
73*a10d67f9SYinying Li       explicitVal, implicitVal);
74bcfa7baeSStella Laurenzo   mlirAttributeDump(newAttr); // For debugging filecheck output.
75bcfa7baeSStella Laurenzo   // CHECK: equal: 1
76bcfa7baeSStella Laurenzo   fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
77a0615d02Swren romano   free(lvlTypes);
78bcfa7baeSStella Laurenzo   return 0;
79bcfa7baeSStella Laurenzo }
80bcfa7baeSStella Laurenzo 
main(void)815d91f79fSTom Eccles int main(void) {
82bcfa7baeSStella Laurenzo   MlirContext ctx = mlirContextCreate();
83bcfa7baeSStella Laurenzo   mlirDialectHandleRegisterDialect(mlirGetDialectHandle__sparse_tensor__(),
84bcfa7baeSStella Laurenzo                                    ctx);
85bcfa7baeSStella Laurenzo   if (testRoundtripEncoding(ctx))
86bcfa7baeSStella Laurenzo     return 1;
87bcfa7baeSStella Laurenzo 
88bcfa7baeSStella Laurenzo   mlirContextDestroy(ctx);
89bcfa7baeSStella Laurenzo   return 0;
90bcfa7baeSStella Laurenzo }
91