1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4"""DSL for constructing affine expressions and maps. 5 6These python wrappers allow construction of affine expressions in a more 7pythonic fashion that is later instantiated as an IR AffineExpr. Separating the 8AST from construction of the map allows for manipulations of symbols and dims 9beyond the scope of one expression. 10 11Affine expression construction: 12 >>> with _ir.Context(): 13 ... s = AffineBuildState() 14 ... (S.K + S.M).build(s) 15 ... (S.K * S.M).build(s) 16 ... (S.K // S.M).build(s) 17 ... (S.K / S.M).build(s) 18 ... (S.K % 4).build(s) 19 ... (D.i + D.j * 4).build(s) 20 ... s 21 AffineExpr(s0 + s1) 22 AffineExpr(s0 * s1) 23 AffineExpr(s0 floordiv s1) 24 AffineExpr(s0 ceildiv s1) 25 AffineExpr(s0 mod 4) 26 AffineExpr(d0 + d1 * 4) 27 AffineBuildState< 28 symbols={'K': 0, 'M': 1} 29 dims={'i': 0, 'j': 1}> 30 31In the DSL, dimensions and symbols are name-uniqued instances of DimDef and 32SymbolDef. There are shortcut "expando" instances that will create a 33corresponding DimDef/SymbolDef upon accessing an attribute: 34 35Referencing a named dimension: 36 37 >>> D.i 38 Dim(i) 39 >>> D.a is D.b 40 False 41 >>> D.a is D.a 42 True 43 44Referencing a named symbol: 45 46 >>> S.foobar 47 Symbol(foobar) 48 >>> S.a is S.b 49 False 50 >>> S.a is S.a 51 True 52""" 53 54from typing import Callable, Dict, Optional, Tuple, Union 55 56from ..... import ir as _ir 57 58__all__ = [ 59 "AffineBuildState", 60 "AffineExprDef", 61 "D", 62 "DimDef", 63 "S", 64 "SymbolDef", 65] 66 67 68class AffineBuildState: 69 """Internal state for the AffineExprDef._create impls. 70 71 Note that a "local" AffineBuildState can be created relative to a "global" 72 AffineBuildState. In that case, any affine expressions built will inherit 73 symbol and dim bindings from the global state and will update both as new 74 ones are discovered. This allows for building expressions across contexts 75 which share a common symbol and dim space. 76 """ 77 78 def __init__( 79 self, 80 *, 81 global_state: "AffineBuildState" = None, 82 allow_new_symbols: bool = True, 83 allow_new_dims: bool = True, 84 ): 85 if not global_state: 86 self.all_symbols = dict() # type: Dict[str, int] 87 self.all_dims = dict() # type: Dict[str, int] 88 else: 89 # Alias the global dict. 90 self.all_symbols = global_state.all_symbols 91 self.all_dims = global_state.all_dims 92 93 # Map of symbols and dims in the current build. 94 self.local_symbols = dict() # type: Dict[str, int] 95 self.local_dims = dict() # type: Dict[str, int] 96 self.allow_new_symbols = allow_new_symbols 97 self.allow_new_dims = allow_new_dims 98 99 def get_dim(self, dimname: str) -> int: 100 """Gets the dim position given a name.""" 101 pos = self.all_dims.get(dimname) 102 if pos is None: 103 if not self.allow_new_dims: 104 raise ValueError( 105 f"New dimensions not allowed in the current affine expression: " 106 f"Requested '{dimname}', Availble: {self.all_dims}" 107 ) 108 pos = len(self.all_dims) 109 self.all_dims[dimname] = pos 110 self.local_dims[dimname] = pos 111 return pos 112 113 def get_symbol(self, symname: str) -> int: 114 """Geta a symbol position given a name.""" 115 pos = self.all_symbols.get(symname) 116 if pos is None: 117 if not self.allow_new_symbols: 118 raise ValueError( 119 f"New symbols not allowed in the current affine expression: " 120 f"Requested '{symname}', Availble: {self.all_symbols}" 121 ) 122 pos = len(self.all_symbols) 123 self.all_symbols[symname] = pos 124 self.local_symbols[symname] = pos 125 return pos 126 127 @property 128 def local_dim_count(self) -> int: 129 return len(self.local_dims) 130 131 @property 132 def local_symbol_count(self) -> int: 133 return len(self.local_symbols) 134 135 @property 136 def dim_count(self) -> int: 137 return len(self.all_dims) 138 139 @property 140 def symbol_count(self) -> int: 141 return len(self.all_symbols) 142 143 def __repr__(self): 144 lines = [f"AffineBuildState<"] 145 lines.append(f" symbols={self.local_symbols}") 146 lines.append(f" dims={self.local_dims}>") 147 return "\n".join(lines) 148 149 150class AffineExprDef: 151 """Base class for an affine expression being defined.""" 152 153 def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: 154 """Builds the corresponding _ir.AffineExpr from the definitions.""" 155 state = AffineBuildState() if state is None else state 156 expr = self._create(state) 157 return expr 158 159 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 160 raise NotImplementedError() 161 162 @staticmethod 163 def coerce_from(py_value): 164 if isinstance(py_value, int): 165 return AffineConstantExpr(py_value) 166 assert isinstance(py_value, AffineExprDef) 167 return py_value 168 169 def visit_affine_exprs(self, callback): 170 """Visits all AffineExprDefs including self.""" 171 callback(self) 172 173 def __add__(lhs, rhs): 174 rhs = AffineExprDef.coerce_from(rhs) 175 return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) 176 177 def __mul__(lhs, rhs): 178 rhs = AffineExprDef.coerce_from(rhs) 179 return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) 180 181 def __mod__(lhs, rhs): 182 rhs = AffineExprDef.coerce_from(rhs) 183 return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) 184 185 def __floordiv__(lhs, rhs): 186 rhs = AffineExprDef.coerce_from(rhs) 187 return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) 188 189 def __truediv__(lhs, rhs): 190 # TODO: Not really a ceil div - taking liberties for the DSL. 191 rhs = AffineExprDef.coerce_from(rhs) 192 return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) 193 194 195class AffineConstantExpr(AffineExprDef): 196 """An affine constant being defined.""" 197 198 def __init__(self, value: int): 199 assert isinstance(value, int) 200 self.value = value 201 202 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 203 return _ir.AffineConstantExpr.get(self.value) 204 205 def __repr__(self): 206 return f"Const({self.value})" 207 208 209class AffineBinaryExprDef(AffineExprDef): 210 """An affine binary expression being defined.""" 211 212 def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): 213 self.ir_ctor = ir_ctor 214 self.lhs = lhs 215 self.rhs = rhs 216 217 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 218 return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) 219 220 def visit_affine_exprs(self, callback): 221 """Visits all AffineExprDefs including self.""" 222 super().visit_affine_exprs(callback) 223 self.lhs.visit_affine_exprs(callback) 224 self.rhs.visit_affine_exprs(callback) 225 226 def __repr__(self): 227 return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" 228 229 230class DimDef(AffineExprDef): 231 """Represents a named dimension.""" 232 233 ALL_DIMS = dict() # type: Dict[str, "DimDef"] 234 235 def __new__(cls, dimname: str): 236 existing = cls.ALL_DIMS.get(dimname) 237 if existing is not None: 238 return existing 239 new = super().__new__(cls) 240 new.dimname = dimname 241 cls.ALL_DIMS[dimname] = new 242 return new 243 244 def __repr__(self): 245 return f"Dim({self.dimname})" 246 247 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 248 pos = state.get_dim(self.dimname) 249 return _ir.AffineDimExpr.get(position=pos) 250 251 @classmethod 252 def create_expando(cls): 253 """Create an expando class that creates unique symbols based on attr access.""" 254 255 class ExpandoDims: 256 def __getattr__(self, n): 257 return cls(n) 258 259 return ExpandoDims() 260 261 262class SymbolDef(AffineExprDef): 263 """Represents a named symbol. 264 265 >>> s1 = SymbolDef("s1") 266 >>> s1 267 Symbol(s1) 268 >>> s2 = SymbolDef("s2") 269 >>> s1 is s2 270 False 271 >>> s1 is SymbolDef("s1") 272 True 273 """ 274 275 ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] 276 277 def __new__(cls, symname: str): 278 existing = cls.ALL_SYMBOLS.get(symname) 279 if existing is not None: 280 return existing 281 new = super().__new__(cls) 282 new.symname = symname 283 cls.ALL_SYMBOLS[symname] = new 284 return new 285 286 def __repr__(self): 287 return f"Symbol({self.symname})" 288 289 def _create(self, state: AffineBuildState) -> _ir.AffineExpr: 290 pos = state.get_symbol(self.symname) 291 return _ir.AffineSymbolExpr.get(position=pos) 292 293 @classmethod 294 def create_expando(cls): 295 """Create an expando class that creates unique symbols based on attr access.""" 296 297 class ExpandoSymbols: 298 def __getattr__(self, n): 299 return cls(n) 300 301 return ExpandoSymbols() 302 303 304# Global accessor for on-demand dims and symbols. 305D = DimDef.create_expando() 306S = SymbolDef.create_expando() 307