1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import sparse_tensor as st, tensor 5import textwrap 6 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 return f 12 13 14# CHECK-LABEL: TEST: testEncodingAttr1D 15@run 16def testEncodingAttr1D(): 17 with Context() as ctx: 18 parsed = Attribute.parse( 19 textwrap.dedent( 20 """\ 21 #sparse_tensor.encoding<{ 22 map = (d0) -> (d0 : compressed), 23 posWidth = 16, 24 crdWidth = 32, 25 explicitVal = 1.0 : f64 26 }>\ 27 """ 28 ) 29 ) 30 # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32, explicitVal = 1.000000e+00 : f64 }> 31 print(parsed) 32 33 casted = st.EncodingAttr(parsed) 34 # CHECK: equal: True 35 print(f"equal: {casted == parsed}") 36 37 # CHECK: lvl_types: [262144] 38 print(f"lvl_types: {casted.lvl_types}") 39 # CHECK: dim_to_lvl: (d0) -> (d0) 40 print(f"dim_to_lvl: {casted.dim_to_lvl}") 41 # CHECK: lvl_to_dim: (d0) -> (d0) 42 print(f"lvl_to_dim: {casted.lvl_to_dim}") 43 # CHECK: pos_width: 16 44 print(f"pos_width: {casted.pos_width}") 45 # CHECK: crd_width: 32 46 print(f"crd_width: {casted.crd_width}") 47 # CHECK: explicit_val: 1.000000e+00 48 print(f"explicit_val: {casted.explicit_val}") 49 # CHECK: implicit_val: None 50 print(f"implicit_val: {casted.implicit_val}") 51 52 new_explicit_val = FloatAttr.get_f64(1.0) 53 created = st.EncodingAttr.get( 54 casted.lvl_types, None, None, 0, 0, new_explicit_val 55 ) 56 # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), explicitVal = 1.000000e+00 : f64 }> 57 print(created) 58 # CHECK: created_equal: False 59 print(f"created_equal: {created == casted}") 60 61 # Verify that the factory creates an instance of the proper type. 62 # CHECK: is_proper_instance: True 63 print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") 64 # CHECK: created_pos_width: 0 65 print(f"created_pos_width: {created.pos_width}") 66 67 68# CHECK-LABEL: TEST: testEncodingAttrStructure 69@run 70def testEncodingAttrStructure(): 71 with Context() as ctx: 72 parsed = Attribute.parse( 73 textwrap.dedent( 74 """\ 75 #sparse_tensor.encoding<{ 76 map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, 77 d1 mod 4 : structured[2, 4]), 78 posWidth = 16, 79 crdWidth = 32, 80 }>\ 81 """ 82 ) 83 ) 84 # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), posWidth = 16, crdWidth = 32 }> 85 print(parsed) 86 87 casted = st.EncodingAttr(parsed) 88 # CHECK: equal: True 89 print(f"equal: {casted == parsed}") 90 91 # CHECK: lvl_types: [65536, 65536, 4406638542848] 92 print(f"lvl_types: {casted.lvl_types}") 93 # CHECK: lvl_formats_enum: [{{65536|LevelFormat.dense}}, {{65536|LevelFormat.dense}}, {{2097152|LevelFormat.n_out_of_m}}] 94 print(f"lvl_formats_enum: {casted.lvl_formats_enum}") 95 # CHECK: structured_n: 2 96 print(f"structured_n: {casted.structured_n}") 97 # CHECK: structured_m: 4 98 print(f"structured_m: {casted.structured_m}") 99 # CHECK: dim_to_lvl: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4) 100 print(f"dim_to_lvl: {casted.dim_to_lvl}") 101 # CHECK: lvl_to_dim: (d0, d1, d2) -> (d0, d1 * 4 + d2) 102 print(f"lvl_to_dim: {casted.lvl_to_dim}") 103 # CHECK: pos_width: 16 104 print(f"pos_width: {casted.pos_width}") 105 # CHECK: crd_width: 32 106 print(f"crd_width: {casted.crd_width}") 107 108 created = st.EncodingAttr.get( 109 casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 0, 0 110 ) 111 # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }> 112 print(created) 113 # CHECK: created_equal: False 114 print(f"created_equal: {created == casted}") 115 116 built_2_4 = st.EncodingAttr.build_level_type( 117 st.LevelFormat.n_out_of_m, [], 2, 4 118 ) 119 built_dense = st.EncodingAttr.build_level_type(st.LevelFormat.dense) 120 dim_to_lvl = AffineMap.get( 121 2, 122 0, 123 [ 124 AffineExpr.get_dim(0), 125 AffineExpr.get_floor_div(AffineExpr.get_dim(1), 4), 126 AffineExpr.get_mod(AffineExpr.get_dim(1), 4), 127 ], 128 ) 129 lvl_to_dim = AffineMap.get( 130 3, 131 0, 132 [ 133 AffineExpr.get_dim(0), 134 AffineExpr.get_add( 135 AffineExpr.get_mul(AffineExpr.get_dim(1), 4), 136 AffineExpr.get_dim(2), 137 ), 138 ], 139 ) 140 built = st.EncodingAttr.get( 141 [built_dense, built_dense, built_2_4], 142 dim_to_lvl, 143 lvl_to_dim, 144 0, 145 0, 146 ) 147 # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }> 148 print(built) 149 # CHECK: built_equal: True 150 print(f"built_equal: {built == created}") 151 152 # Verify that the factory creates an instance of the proper type. 153 # CHECK: is_proper_instance: True 154 print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") 155 # CHECK: created_pos_width: 0 156 print(f"created_pos_width: {created.pos_width}") 157 158 159# CHECK-LABEL: TEST: testEncodingAttr2D 160@run 161def testEncodingAttr2D(): 162 with Context() as ctx: 163 parsed = Attribute.parse( 164 textwrap.dedent( 165 """\ 166 #sparse_tensor.encoding<{ 167 map = (d0, d1) -> (d1 : dense, d0 : compressed), 168 posWidth = 8, 169 crdWidth = 32, 170 }>\ 171 """ 172 ) 173 ) 174 # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> 175 print(parsed) 176 177 casted = st.EncodingAttr(parsed) 178 # CHECK: equal: True 179 print(f"equal: {casted == parsed}") 180 181 # CHECK: lvl_types: [65536, 262144] 182 print(f"lvl_types: {casted.lvl_types}") 183 # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0) 184 print(f"dim_to_lvl: {casted.dim_to_lvl}") 185 # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0) 186 print(f"lvl_to_dim: {casted.lvl_to_dim}") 187 # CHECK: pos_width: 8 188 print(f"pos_width: {casted.pos_width}") 189 # CHECK: crd_width: 32 190 print(f"crd_width: {casted.crd_width}") 191 192 created = st.EncodingAttr.get( 193 casted.lvl_types, 194 casted.dim_to_lvl, 195 casted.lvl_to_dim, 196 8, 197 32, 198 ) 199 # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> 200 print(created) 201 # CHECK: created_equal: True 202 print(f"created_equal: {created == casted}") 203 204 205# CHECK-LABEL: TEST: testEncodingAttrOnTensorType 206@run 207def testEncodingAttrOnTensorType(): 208 with Context() as ctx, Location.unknown(): 209 encoding = st.EncodingAttr( 210 Attribute.parse( 211 textwrap.dedent( 212 """\ 213 #sparse_tensor.encoding<{ 214 map = (d0) -> (d0 : compressed), 215 posWidth = 64, 216 crdWidth = 32, 217 }>\ 218 """ 219 ) 220 ) 221 ) 222 tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding) 223 # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>> 224 print(tt) 225 # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }> 226 print(tt.encoding) 227 assert tt.encoding == encoding 228 229 230# CHECK-LABEL: TEST: testEncodingEmptyTensor 231@run 232def testEncodingEmptyTensor(): 233 with Context(), Location.unknown(): 234 module = Module.create() 235 with InsertionPoint(module.body): 236 levels = [st.LevelFormat.compressed] 237 ordering = AffineMap.get_permutation([0]) 238 encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32) 239 tensor.empty((1024,), F32Type.get(), encoding=encoding) 240 241 # CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }> 242 # CHECK: module { 243 # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse> 244 # CHECK: } 245 print(module) 246