1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4"""Represents configured ops as emitted for code generation. 5 6Classes in this module generally are directly serializable to YAML for use 7by the code generator. 8 9TODO: These should just be dumb containers or serialization code but they 10currently encode too many details of how the language is interpreted. Move this 11to helpers on the comprehension objects themselves. 12""" 13 14from typing import Dict, Optional 15 16from ..... import ir as _ir 17from .comprehension import * 18from .yaml_helper import * 19 20__all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"] 21 22 23def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: 24 with affine_map.context: 25 # Affine map printing/parsing is via an AffineMap attr. 26 attr = _ir.AffineMapAttr.get(affine_map) 27 return str(attr) 28 29 30class TensorUseConfig: 31 """Wrapper around a TensorUse with additional context-bound state.""" 32 33 def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): 34 self.tensor_use = tensor_use 35 self.indexing_map = indexing_map 36 37 def __repr__(self): 38 return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" 39 40 41class OperandDefConfig(YAMLObject): 42 """Wrapper containing an operand definition with additional state.""" 43 44 yaml_tag = "!LinalgOperandDefConfig" 45 46 def __init__( 47 self, 48 operand_def: OperandDef, 49 shape_map: Optional[_ir.AffineMap] = None, 50 index_attr_map: Optional[_ir.AffineMap] = None, 51 ): 52 self.operand_def = operand_def 53 self.shape_map = shape_map # type: Optional[_ir.AffineMap] 54 self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] 55 self.indexing_map = None # type: Optional[_ir.AffineMap] 56 57 @property 58 def name(self) -> str: 59 return self.operand_def.name 60 61 @property 62 def kind(self) -> OperandKind: 63 return self.operand_def.kind 64 65 @property 66 def type_var(self) -> TypeVar: 67 return self.operand_def.type_var 68 69 def to_yaml_custom_dict(self): 70 self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower()) 71 if self.type_var: 72 self_dict["type_var"] = self.type_var.name 73 if self.shape_map: 74 self_dict["shape_map"] = _serialize_affine_map(self.shape_map) 75 if self.index_attr_map: 76 self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) 77 if self.operand_def.default_indices: 78 self_dict["default_indices"] = self.operand_def.default_indices 79 if self.operand_def.default_fn: 80 self_dict["default_fn"] = self.operand_def.default_fn 81 return self_dict 82 83 def __repr__(self): 84 return ( 85 f"OperandDefConfig({self.operand_def}, " 86 f"shape_map={self.shape_map}, " 87 f"index_attr_map={self.index_attr_map}, " 88 f"indexing_map={self.indexing_map})" 89 ) 90 91 92class LinalgIndexingMapsConfig(YAMLObject): 93 """Abstracts the style of indexing maps that the op exports. 94 95 Presently only static (tied to the op name) indexing maps are supported. In 96 the future, it is expected that we will have additional variants: 97 - Dynamic based on attributes 98 - Dynamic based on operands 99 Each is expected to require a different variant of specification. 100 """ 101 102 yaml_tag = "!LinalgIndexingMapsConfig" 103 104 def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): 105 self.static_indexing_maps = static_indexing_maps 106 107 def to_yaml_custom_dict(self): 108 if self.static_indexing_maps is not None: 109 return dict( 110 static_indexing_maps=[ 111 _serialize_affine_map(m) for m in self.static_indexing_maps 112 ] 113 ) 114 raise ValueError( 115 f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)" 116 ) 117 118 119class LinalgStructuredOpConfig(YAMLObject): 120 """Configuration for metadata sufficient to construct a linalg named op.""" 121 122 yaml_tag = "!LinalgStructuredOpConfig" 123 124 def __init__( 125 self, 126 comprehension: Comprehension, 127 domain: Sequence[DimDef], 128 registered_operands: Sequence[OperandDef], 129 context: Optional[_ir.Context] = None, 130 ): 131 self.context = context if context is not None else _ir.Context() 132 self.affine_state = AffineBuildState() 133 self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] 134 self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] 135 self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] 136 137 # Compute the ordered set of writes and collect the tensor, capture, dims, 138 # and index uses. 139 collected_tensor_uses = set() 140 collected_scalar_uses = set() 141 collected_dim_uses = set() 142 collected_indices = set() 143 for write_use, read_use in zip(comprehension.definitions, comprehension.values): 144 self.writes.append((write_use, read_use)) 145 146 for write_use, read_use in self.writes: 147 collected_tensor_uses.add(write_use) 148 read_use.collect_tensor_uses(collected_tensor_uses) 149 read_use.collect_scalar_uses(collected_scalar_uses) 150 read_use.collect_dim_uses(collected_dim_uses) 151 write_use.collect_dim_uses(collected_dim_uses) 152 read_use.collect_indices(collected_indices) 153 154 # Set domain to the sorted list of uses if no domain annotation is given. 155 if not domain: 156 domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) 157 158 # Verify the domain dimensions match the used dimensions. 159 if len(domain) != len(collected_dim_uses) or any( 160 dim not in collected_dim_uses for dim in domain 161 ): 162 raise ValueError( 163 f"Expected the annotated domain dimensions {domain} to " 164 f"match the set of dimension used by the tensor " 165 f"comprehension {collected_dim_uses}" 166 ) 167 168 # Instantiate the dimensions in the given order. 169 with self.context: 170 local_state = AffineBuildState( 171 global_state=self.affine_state, allow_new_symbols=False 172 ) 173 for dim in domain: 174 dim.build(state=local_state) 175 176 # Collect all attribute definitions. 177 collected_attr_defs = list() 178 for operand in registered_operands: 179 if operand.is_attribute(): 180 collected_attr_defs.append(operand) 181 182 # Collect all tensors with manual indexing annotation. 183 collected_index_defs = list() 184 for operand in registered_operands: 185 if operand.index_dims: 186 if any(dim not in collected_dim_uses for dim in operand.index_dims): 187 raise ValueError( 188 f"Expected all index dims {operand.index_dims} of " 189 f"operand {operand.name} to have uses." 190 ) 191 collected_index_defs.append(operand) 192 193 # Collect the operand definitions of all tensor/scalar uses, attributes, and 194 # shape-only tensors. 195 all_operand_defs = list() 196 for use in collected_tensor_uses: 197 all_operand_defs.append(use.operand_def) 198 for use in collected_scalar_uses: 199 all_operand_defs.append(use.operand_def) 200 for definition in collected_attr_defs: 201 all_operand_defs.append(definition) 202 for definition in collected_index_defs: 203 all_operand_defs.append(definition) 204 205 # Add all operands in registration order to ensure the symbols are 206 # registered in the order they appear. 207 all_operand_defs = sorted( 208 all_operand_defs, key=lambda operand_def: operand_def.registered_index 209 ) 210 for operand_def in all_operand_defs: 211 self.add_operand(operand_def) 212 213 # Add all shape-only tensor index_dim annotations and all tensor uses. 214 for definition in collected_index_defs: 215 self.add_indexed_operand(definition) 216 for use in collected_tensor_uses: 217 self.add_tensor_use(use) 218 219 # Normalize all shape and indexing maps now that full count of dims and 220 # symbols are known. 221 for cuse in self.uses.values(): 222 cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) 223 for definition in collected_index_defs: 224 self.operands[definition].indexing_map = self._normalize_affine_map( 225 self.operands[definition].indexing_map 226 ) 227 for operand_config in self.operands.values(): 228 if operand_config.shape_map: 229 operand_config.shape_map = self._normalize_affine_map( 230 operand_config.shape_map, with_dims=False 231 ) 232 if operand_config.index_attr_map: 233 operand_config.index_attr_map = self._normalize_affine_map( 234 operand_config.index_attr_map, with_dims=False 235 ) 236 237 # Now for each write use, propagate the indexing maps from the use to the 238 # tensor, ensuring that there are not conflicts. 239 for write_use, _ in self.writes: 240 write_tensor_config = self.operands[write_use.operand_def] 241 if write_tensor_config.indexing_map: 242 raise ValueError( 243 f"Unexpected multi-write to a single tensor: {write_tensor_config}" 244 ) 245 write_tensor_config.indexing_map = self.uses[write_use].indexing_map 246 247 # For each read use, propagate the indexing maps from the use to the 248 # tensor, ensuring that there are not conflicts. 249 for _, read_expr in self.writes: 250 read_uses = set() # type: Set[TensorUse] 251 read_expr.collect_tensor_uses(read_uses) 252 for read_use in read_uses: 253 read_operand_config = self.operands[read_use.operand_def] 254 if ( 255 read_operand_config.indexing_map 256 and read_operand_config.indexing_map 257 != self.uses[read_use].indexing_map 258 ): 259 raise ValueError( 260 f"Unexpected multi-read of a tensor with different accesses:" 261 f"{read_operand_config} vs {read_use}" 262 ) 263 read_operand_config.indexing_map = self.uses[read_use].indexing_map 264 265 # Set the indexing map of all scalar uses to the empty map. 266 for operand_config in self.operands.values(): 267 if operand_config.operand_def.kind == OperandKind.SCALAR: 268 operand_config.indexing_map = self._get_scalar_map() 269 270 # Check all registered tensor and scalar operands have an indexing map. 271 for operand in registered_operands: 272 if operand.is_attribute(): 273 continue 274 if not (operand in self.operands and self.operands[operand].indexing_map): 275 raise ValueError( 276 f"Failed to compute an indexing map for operand " f"{operand.name}" 277 ) 278 279 # Collect reduction dims and ensure all the same. 280 all_reduction_dims = set(comprehension.all_reduction_dims) 281 if len(all_reduction_dims) != 1: 282 raise ValueError( 283 f"All writes within a generic must have the same reduction " 284 f"dims. Got: {all_reduction_dims}" 285 ) 286 self.reduction_dims = next(iter(all_reduction_dims)) 287 288 # Check the index dimension exists and resolve. 289 for index in collected_indices: 290 if index.dim_def.dimname not in self.affine_state.all_dims: 291 raise ValueError( 292 f"The dimension {index.dim_def.dimname} is not part of the " 293 f"iteration domain {self.affine_state.all_dims}" 294 ) 295 index.resolve_dimension_name(self.affine_state) 296 297 # Generate the scalar assignments (used to build a body). 298 self.assignments = [ 299 ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) 300 for write_use, read_expr in self.writes 301 ] 302 303 @property 304 def ordered_operands(self) -> Sequence[OperandDefConfig]: 305 return sorted( 306 self.operands.values(), 307 key=lambda operand: operand.operand_def.registered_index, 308 ) 309 310 @property 311 def ordered_dims(self) -> Sequence[Tuple[str, int]]: 312 """Gets the ordered list of dim bindings (symbolic name, position). 313 314 TODO: The original parser relies on parse ordering to arrive at the 315 iterator types, but that ordering is not defined on the Python side, so 316 this may be ambiguous. 317 """ 318 return list(self.affine_state.all_dims.items()) 319 320 @property 321 def indexing_maps(self) -> Sequence[_ir.AffineMap]: 322 return [o.indexing_map for o in self.ordered_operands if o.indexing_map] 323 324 @property 325 def iterator_types(self) -> Sequence[str]: 326 def get_type(symbolic_name, position): 327 for reduction_dim_expr in self.reduction_dims: 328 if reduction_dim_expr.dimname == symbolic_name: 329 return "reduction" 330 return "parallel" 331 332 return [get_type(*dim) for dim in self.ordered_dims] 333 334 def add_operand(self, operand_def: OperandDef): 335 if operand_def in self.operands: 336 return 337 if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR): 338 self.operands[operand_def] = OperandDefConfig(operand_def) 339 return 340 with self.context: 341 local_state = AffineBuildState( 342 global_state=self.affine_state, allow_new_dims=False 343 ) 344 exprs = [] 345 for expr in operand_def.size_exprs: 346 exprs.append(expr.build(state=local_state)) 347 assert local_state.local_dim_count == 0 348 affine_map = _ir.AffineMap.get( 349 dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs 350 ) 351 if operand_def.kind == OperandKind.INDEX_ATTR: 352 self.operands[operand_def] = OperandDefConfig( 353 operand_def, index_attr_map=affine_map 354 ) 355 else: 356 self.operands[operand_def] = OperandDefConfig( 357 operand_def, shape_map=affine_map 358 ) 359 360 def add_indexed_operand(self, operand_def: OperandDef): 361 with self.context: 362 local_state = AffineBuildState( 363 global_state=self.affine_state, allow_new_symbols=False 364 ) 365 exprs = [] 366 for expr in operand_def.index_dims: 367 exprs.append(expr.build(state=local_state)) 368 self.operands[operand_def].indexing_map = _ir.AffineMap.get( 369 dim_count=local_state.dim_count, 370 symbol_count=local_state.symbol_count, 371 exprs=exprs, 372 ) 373 374 def add_tensor_use(self, tensor_use: TensorUse): 375 if tensor_use in self.uses: 376 return 377 with self.context: 378 local_state = AffineBuildState( 379 global_state=self.affine_state, allow_new_symbols=False 380 ) 381 exprs = [] 382 for expr in tensor_use.indices: 383 exprs.append(expr.build(state=local_state)) 384 indexing_map = _ir.AffineMap.get( 385 dim_count=local_state.dim_count, 386 symbol_count=local_state.symbol_count, 387 exprs=exprs, 388 ) 389 390 use_config = TensorUseConfig(tensor_use, indexing_map) 391 self.uses[tensor_use] = use_config 392 393 def _get_scalar_map(self) -> _ir.AffineMap: 394 """Create an empty affine map used to index a scalar.""" 395 with self.context: 396 return _ir.AffineMap.get( 397 dim_count=self.affine_state.dim_count, 398 symbol_count=self.affine_state.symbol_count, 399 exprs=list(), 400 ) 401 402 def _normalize_affine_map( 403 self, affine_map: _ir.AffineMap, with_dims: bool = True 404 ) -> _ir.AffineMap: 405 """Normalizes an indexing map to have the max known symbols and dims.""" 406 with self.context: 407 return _ir.AffineMap.get( 408 dim_count=self.affine_state.dim_count if with_dims else 0, 409 symbol_count=self.affine_state.symbol_count, 410 exprs=list(affine_map.results), 411 ) 412 413 def to_yaml_custom_dict(self): 414 self_dict = dict(args=self.ordered_operands) 415 # TODO: Refactor the hierarchy internally when supporting more 416 # than static (preserving this serialized form). 417 self_dict["indexing_maps"] = LinalgIndexingMapsConfig( 418 static_indexing_maps=self.indexing_maps 419 ) 420 self_dict["iterator_types"] = self.iterator_types 421 self_dict["assignments"] = self.assignments 422 return self_dict 423 424 def __repr__(self): 425 lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] 426 lines.append("operands=[") 427 for def_config in self.ordered_operands: 428 lines.append(f" {repr(def_config)}") 429 lines.append("], indexing_maps=[") 430 for m in self.indexing_maps: 431 lines.append(f" {repr(m)}") 432 lines.append(f"], iterator_types=[") 433 for t in self.iterator_types: 434 lines.append(f" {t}") 435 lines.append("])") 436 return "\n".join(lines) 437 438 439class LinalgOpConfig(YAMLObject): 440 """Container for any supported linalg op type. 441 442 This includes the concrete type by name for ease of parsing by systems 443 that ignore tags. 444 """ 445 446 yaml_tag = "!LinalgOpConfig" 447 448 def __init__( 449 self, 450 metadata: OpMetadataDef, 451 *, 452 structured_op: Optional[LinalgStructuredOpConfig] = None, 453 ): 454 self.metadata = metadata 455 self.structured_op = structured_op 456 457 def to_yaml_custom_dict(self): 458 self_dict = dict( 459 metadata=self.metadata, 460 ) 461 if self.structured_op: 462 self_dict["structured_op"] = self.structured_op 463 return self_dict 464 465 @staticmethod 466 def from_linalg_op_def( 467 op_def: LinalgOpDef, context: Optional[_ir.Context] = None 468 ) -> Sequence["LinalgOpConfig"]: 469 """Expands a LinalgOpDef into corresponding Linalg configured ops.""" 470 # TODO: Many LinalgOpDef patterns need to expand to multiple generics. 471 assert len(op_def.comprehensions) == 1, "Only one comprehension supported" 472 return [ 473 LinalgOpConfig( 474 op_def.metadata, 475 structured_op=LinalgStructuredOpConfig( 476 op_def.comprehensions[0], 477 op_def.domain, 478 op_def.registered_operands.values(), 479 context, 480 ), 481 ), 482 ] 483 484 def __repr__(self): 485 return ( 486 f"LinalgOpConfig(metadata={self.metadata},\n" 487 f"structured_op={self.structured_op})" 488 ) 489