1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from typing import Callable, Optional, Sequence, Union 6 7from ....extras.meta import region_op 8from .... import ir 9from ... import transform 10from .. import ( 11 AnyOpType, 12 AnyParamType, 13 AnyValueType, 14 OperationType, 15 ParamType, 16 NamedSequenceOp, 17 YieldOp, 18 SequenceOp, 19 ApplyPatternsOp, 20) 21from .. import structured 22 23 24class Handle(ir.Value): 25 """ 26 Base class for wrappers around different types of transform handle with 27 methods to chain further transforms. 28 29 The fields `children` and `parent` are used to capture the relation of 30 handles statically in order to enable further analysis. The payload 31 operation of a child handle is nested into a region of the payload operation 32 of the corresponding parent handle. 33 """ 34 35 def __init__( 36 self, 37 v: ir.Value, 38 *, 39 parent: Optional["Handle"] = None, 40 children: Optional[Sequence["Handle"]] = None, 41 ): 42 super().__init__(v) 43 self.parent = parent 44 self.children = children if children is not None else [] 45 46@ir.register_value_caster(AnyOpType.get_static_typeid()) 47@ir.register_value_caster(OperationType.get_static_typeid()) 48class OpHandle(Handle): 49 """ 50 Wrapper around a transform operation handle with methods to chain further 51 transforms. 52 """ 53 54 def __init__( 55 self, 56 v: ir.Value, 57 *, 58 parent: Optional[Handle] = None, 59 children: Optional[Sequence[Handle]] = None, 60 ): 61 super().__init__(v, parent=parent, children=children) 62 63 def get_result(self, indices: Sequence[int] = [0]) -> "ValueHandle": 64 """ 65 Emits a `transform.GetResultOp`. 66 Returns a handle to the result of the payload operation at the given 67 indices. 68 """ 69 get_result_op = transform.GetResultOp( 70 AnyValueType.get(), 71 self, 72 indices, 73 ) 74 return get_result_op.result 75 76 def match_ops( 77 self, 78 ops: Union[ 79 str, 80 ir.OpView, 81 structured.MatchInterfaceEnum, 82 Sequence[Union[str, ir.OpView]], 83 ], 84 ) -> "OpHandle": 85 """ 86 Emits a `transform.structured.MatchOp`. 87 Returns a handle to payload ops that match the given names, types, or 88 interface. If only a single type is given, the value wrapped by the 89 resulting handle is populated with the respective type. 90 """ 91 # Handle interface. 92 if isinstance(ops, structured.MatchInterfaceEnum) or ( 93 isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__ 94 ): 95 if isinstance(ops, str): 96 ops = structured.MatchInterfaceEnum[ops] 97 match_op = structured.MatchOp( 98 AnyOpType.get(), 99 self, 100 interface=ops, 101 ) 102 103 # Handle op name(s), either given directly as string or given as op. 104 else: 105 if isinstance(ops, str): 106 op_type = OperationType.get(ops) 107 op_names = [ops] 108 elif isinstance(ops, Sequence): 109 op_type = AnyOpType.get() 110 op_names = [ 111 op if isinstance(op, str) else op.OPERATION_NAME for op in ops 112 ] 113 else: 114 op_type = OperationType.get(ops.OPERATION_NAME) 115 op_names = [ops.OPERATION_NAME] 116 match_op = structured.MatchOp.match_op_names( 117 op_type, 118 self, 119 op_names, 120 ) 121 122 handle = OpHandle(match_op.results_, parent=self) 123 self.children.append(handle) 124 return handle 125 126 def print(self, name: Optional[str] = None) -> "OpHandle": 127 """ 128 Emits a `transform.PrintOp` to print this handle and an optional message. 129 Returns the existing handle to facilitate further chaining. 130 """ 131 transform.PrintOp(target=self, name=name) 132 return self 133 134 135@ir.register_value_caster(AnyParamType.get_static_typeid()) 136@ir.register_value_caster(ParamType.get_static_typeid()) 137class ParamHandle(Handle): 138 """Wrapper around a transform param handle.""" 139 140 def __init__( 141 self, 142 v: ir.Value, 143 *, 144 parent: Optional[Handle] = None, 145 children: Optional[Sequence[Handle]] = None, 146 ): 147 super().__init__(v, parent=parent, children=children) 148 149 150@ir.register_value_caster(AnyValueType.get_static_typeid()) 151class ValueHandle(Handle): 152 """ 153 Wrapper around a transform value handle with methods to chain further 154 transforms. 155 """ 156 157 def __init__( 158 self, 159 v: ir.Value, 160 *, 161 parent: Optional[Handle] = None, 162 children: Optional[Sequence[Handle]] = None, 163 ): 164 super().__init__(v, parent=parent, children=children) 165 166 def get_defining_op(self) -> OpHandle: 167 """ 168 Emits a `transform.GetDefiningOpOp`. 169 Returns a handle to the defining op of the wrapped value. 170 """ 171 get_defining_op = transform.GetDefiningOp( 172 AnyOpType.get(), 173 self, 174 ) 175 return get_defining_op.result 176 177 178def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle: 179 """ 180 Emits a `transform.ParamConstantOp`. 181 Returns a handle to the newly created parameter. The type of the parameter 182 is `transfrom.any_param` if the value is not an integer, otherwise the type 183 is `transform.param` parametrized with the according integer type. 184 """ 185 if isinstance(value, int): 186 value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) 187 if isinstance(value.type, ir.IntegerType): 188 param_type = ParamType.get(value.type) 189 else: 190 param_type = AnyParamType.get() 191 op = transform.ParamConstantOp(param_type, value) 192 return op.param 193 194 195def insert_transform_script( 196 block_or_insertion_point: Union[ir.Block, ir.InsertionPoint], 197 script: Callable[[OpHandle], None], 198 dump_script: bool = False, 199) -> None: 200 """ 201 Inserts the transform script of the schedule into the module. The script 202 should accept an instance of OpHandle as argument, which will be called with 203 the block arg of the newly created named_sequence op. 204 205 Example: 206 This python code 207 ``` 208 module = ir.Module.create() 209 def test_match_ops_single(module: OpHandle): 210 module.match_ops(scf.ForOp) 211 insert_transform_script(module.body, script) 212 ``` 213 generates the following IR: 214 ``` 215 module { 216 transform.named_sequence @__transform_main(%arg0: !transform.any_op) { 217 ^bb0(%arg0: !transform.any_op): 218 %0 = transform.structured.match ops{["scf.for"]} in %arg0 219 : (!transform.any_op) -> !transform.op<"scf.for"> 220 } 221 } 222 ``` 223 """ 224 if isinstance(block_or_insertion_point, ir.Block): 225 context = block_or_insertion_point.owner.context 226 insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point) 227 else: 228 context = block_or_insertion_point.block.owner.context 229 insertion_point = block_or_insertion_point 230 231 with context, ir.Location.unknown(context): 232 with insertion_point: 233 named_sequence_op = NamedSequenceOp( 234 "__transform_main", [AnyOpType.get()], [] 235 ) 236 with ir.InsertionPoint(named_sequence_op.body): 237 script(named_sequence_op.bodyTarget) 238 YieldOp([]) 239 240 if dump_script: 241 print(named_sequence_op) 242 243 244sequence = region_op(SequenceOp.__base__, terminator=YieldOp) 245named_sequence = region_op(NamedSequenceOp, terminator=YieldOp) 246apply_patterns = region_op(ApplyPatternsOp) 247