xref: /llvm-project/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""DSL for constructing affine expressions and maps.
5
6These python wrappers allow construction of affine expressions in a more
7pythonic fashion that is later instantiated as an IR AffineExpr. Separating the
8AST from construction of the map allows for manipulations of symbols and dims
9beyond the scope of one expression.
10
11Affine expression construction:
12  >>> with _ir.Context():
13  ...   s = AffineBuildState()
14  ...   (S.K + S.M).build(s)
15  ...   (S.K * S.M).build(s)
16  ...   (S.K // S.M).build(s)
17  ...   (S.K / S.M).build(s)
18  ...   (S.K % 4).build(s)
19  ...   (D.i + D.j * 4).build(s)
20  ...   s
21  AffineExpr(s0 + s1)
22  AffineExpr(s0 * s1)
23  AffineExpr(s0 floordiv s1)
24  AffineExpr(s0 ceildiv s1)
25  AffineExpr(s0 mod 4)
26  AffineExpr(d0 + d1 * 4)
27  AffineBuildState<
28    symbols={'K': 0, 'M': 1}
29    dims={'i': 0, 'j': 1}>
30
31In the DSL, dimensions and symbols are name-uniqued instances of DimDef and
32SymbolDef. There are shortcut "expando" instances that will create a
33corresponding DimDef/SymbolDef upon accessing an attribute:
34
35Referencing a named dimension:
36
37  >>> D.i
38  Dim(i)
39  >>> D.a is D.b
40  False
41  >>> D.a is D.a
42  True
43
44Referencing a named symbol:
45
46  >>> S.foobar
47  Symbol(foobar)
48  >>> S.a is S.b
49  False
50  >>> S.a is S.a
51  True
52"""
53
54from typing import Callable, Dict, Optional, Tuple, Union
55
56from ..... import ir as _ir
57
58__all__ = [
59    "AffineBuildState",
60    "AffineExprDef",
61    "D",
62    "DimDef",
63    "S",
64    "SymbolDef",
65]
66
67
68class AffineBuildState:
69    """Internal state for the AffineExprDef._create impls.
70
71    Note that a "local" AffineBuildState can be created relative to a "global"
72    AffineBuildState. In that case, any affine expressions built will inherit
73    symbol and dim bindings from the global state and will update both as new
74    ones are discovered. This allows for building expressions across contexts
75    which share a common symbol and dim space.
76    """
77
78    def __init__(
79        self,
80        *,
81        global_state: "AffineBuildState" = None,
82        allow_new_symbols: bool = True,
83        allow_new_dims: bool = True,
84    ):
85        if not global_state:
86            self.all_symbols = dict()  # type: Dict[str, int]
87            self.all_dims = dict()  # type: Dict[str, int]
88        else:
89            # Alias the global dict.
90            self.all_symbols = global_state.all_symbols
91            self.all_dims = global_state.all_dims
92
93        # Map of symbols and dims in the current build.
94        self.local_symbols = dict()  # type: Dict[str, int]
95        self.local_dims = dict()  # type: Dict[str, int]
96        self.allow_new_symbols = allow_new_symbols
97        self.allow_new_dims = allow_new_dims
98
99    def get_dim(self, dimname: str) -> int:
100        """Gets the dim position given a name."""
101        pos = self.all_dims.get(dimname)
102        if pos is None:
103            if not self.allow_new_dims:
104                raise ValueError(
105                    f"New dimensions not allowed in the current affine expression: "
106                    f"Requested '{dimname}', Availble: {self.all_dims}"
107                )
108            pos = len(self.all_dims)
109            self.all_dims[dimname] = pos
110        self.local_dims[dimname] = pos
111        return pos
112
113    def get_symbol(self, symname: str) -> int:
114        """Geta a symbol position given a name."""
115        pos = self.all_symbols.get(symname)
116        if pos is None:
117            if not self.allow_new_symbols:
118                raise ValueError(
119                    f"New symbols not allowed in the current affine expression: "
120                    f"Requested '{symname}', Availble: {self.all_symbols}"
121                )
122            pos = len(self.all_symbols)
123            self.all_symbols[symname] = pos
124        self.local_symbols[symname] = pos
125        return pos
126
127    @property
128    def local_dim_count(self) -> int:
129        return len(self.local_dims)
130
131    @property
132    def local_symbol_count(self) -> int:
133        return len(self.local_symbols)
134
135    @property
136    def dim_count(self) -> int:
137        return len(self.all_dims)
138
139    @property
140    def symbol_count(self) -> int:
141        return len(self.all_symbols)
142
143    def __repr__(self):
144        lines = [f"AffineBuildState<"]
145        lines.append(f"  symbols={self.local_symbols}")
146        lines.append(f"  dims={self.local_dims}>")
147        return "\n".join(lines)
148
149
150class AffineExprDef:
151    """Base class for an affine expression being defined."""
152
153    def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
154        """Builds the corresponding _ir.AffineExpr from the definitions."""
155        state = AffineBuildState() if state is None else state
156        expr = self._create(state)
157        return expr
158
159    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
160        raise NotImplementedError()
161
162    @staticmethod
163    def coerce_from(py_value):
164        if isinstance(py_value, int):
165            return AffineConstantExpr(py_value)
166        assert isinstance(py_value, AffineExprDef)
167        return py_value
168
169    def visit_affine_exprs(self, callback):
170        """Visits all AffineExprDefs including self."""
171        callback(self)
172
173    def __add__(lhs, rhs):
174        rhs = AffineExprDef.coerce_from(rhs)
175        return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
176
177    def __mul__(lhs, rhs):
178        rhs = AffineExprDef.coerce_from(rhs)
179        return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
180
181    def __mod__(lhs, rhs):
182        rhs = AffineExprDef.coerce_from(rhs)
183        return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
184
185    def __floordiv__(lhs, rhs):
186        rhs = AffineExprDef.coerce_from(rhs)
187        return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
188
189    def __truediv__(lhs, rhs):
190        # TODO: Not really a ceil div - taking liberties for the DSL.
191        rhs = AffineExprDef.coerce_from(rhs)
192        return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
193
194
195class AffineConstantExpr(AffineExprDef):
196    """An affine constant being defined."""
197
198    def __init__(self, value: int):
199        assert isinstance(value, int)
200        self.value = value
201
202    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
203        return _ir.AffineConstantExpr.get(self.value)
204
205    def __repr__(self):
206        return f"Const({self.value})"
207
208
209class AffineBinaryExprDef(AffineExprDef):
210    """An affine binary expression being defined."""
211
212    def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
213        self.ir_ctor = ir_ctor
214        self.lhs = lhs
215        self.rhs = rhs
216
217    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
218        return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
219
220    def visit_affine_exprs(self, callback):
221        """Visits all AffineExprDefs including self."""
222        super().visit_affine_exprs(callback)
223        self.lhs.visit_affine_exprs(callback)
224        self.rhs.visit_affine_exprs(callback)
225
226    def __repr__(self):
227        return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
228
229
230class DimDef(AffineExprDef):
231    """Represents a named dimension."""
232
233    ALL_DIMS = dict()  # type: Dict[str, "DimDef"]
234
235    def __new__(cls, dimname: str):
236        existing = cls.ALL_DIMS.get(dimname)
237        if existing is not None:
238            return existing
239        new = super().__new__(cls)
240        new.dimname = dimname
241        cls.ALL_DIMS[dimname] = new
242        return new
243
244    def __repr__(self):
245        return f"Dim({self.dimname})"
246
247    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
248        pos = state.get_dim(self.dimname)
249        return _ir.AffineDimExpr.get(position=pos)
250
251    @classmethod
252    def create_expando(cls):
253        """Create an expando class that creates unique symbols based on attr access."""
254
255        class ExpandoDims:
256            def __getattr__(self, n):
257                return cls(n)
258
259        return ExpandoDims()
260
261
262class SymbolDef(AffineExprDef):
263    """Represents a named symbol.
264
265    >>> s1 = SymbolDef("s1")
266    >>> s1
267    Symbol(s1)
268    >>> s2 = SymbolDef("s2")
269    >>> s1 is s2
270    False
271    >>> s1 is SymbolDef("s1")
272    True
273    """
274
275    ALL_SYMBOLS = dict()  # type: Dict[str, "SymbolDef"]
276
277    def __new__(cls, symname: str):
278        existing = cls.ALL_SYMBOLS.get(symname)
279        if existing is not None:
280            return existing
281        new = super().__new__(cls)
282        new.symname = symname
283        cls.ALL_SYMBOLS[symname] = new
284        return new
285
286    def __repr__(self):
287        return f"Symbol({self.symname})"
288
289    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
290        pos = state.get_symbol(self.symname)
291        return _ir.AffineSymbolExpr.get(position=pos)
292
293    @classmethod
294    def create_expando(cls):
295        """Create an expando class that creates unique symbols based on attr access."""
296
297        class ExpandoSymbols:
298            def __getattr__(self, n):
299                return cls(n)
300
301        return ExpandoSymbols()
302
303
304# Global accessor for on-demand dims and symbols.
305D = DimDef.create_expando()
306S = SymbolDef.create_expando()
307