xref: /netbsd-src/external/mit/isl/dist/isl_stride.c (revision 5971e316fdea024efff6be8f03536623db06833e)
1*5971e316Smrg /*
2*5971e316Smrg  * Copyright 2012-2013 Ecole Normale Superieure
3*5971e316Smrg  *
4*5971e316Smrg  * Use of this software is governed by the MIT license
5*5971e316Smrg  *
6*5971e316Smrg  * Written by Sven Verdoolaege,
7*5971e316Smrg  * Ecole Normale Superieure, 45 rue d'Ulm, 75230 Paris, France
8*5971e316Smrg  */
9*5971e316Smrg 
10*5971e316Smrg #include <isl/val.h>
11*5971e316Smrg #include <isl_map_private.h>
12*5971e316Smrg #include <isl_aff_private.h>
13*5971e316Smrg #include <isl/constraint.h>
14*5971e316Smrg #include <isl/set.h>
15*5971e316Smrg 
16*5971e316Smrg /* Stride information about a specific set dimension.
17*5971e316Smrg  * The values of the set dimension are equal to
18*5971e316Smrg  * "offset" plus a multiple of "stride".
19*5971e316Smrg  */
20*5971e316Smrg struct isl_stride_info {
21*5971e316Smrg 	isl_val *stride;
22*5971e316Smrg 	isl_aff *offset;
23*5971e316Smrg };
24*5971e316Smrg 
25*5971e316Smrg /* Return the ctx to which "si" belongs.
26*5971e316Smrg  */
isl_stride_info_get_ctx(__isl_keep isl_stride_info * si)27*5971e316Smrg isl_ctx *isl_stride_info_get_ctx(__isl_keep isl_stride_info *si)
28*5971e316Smrg {
29*5971e316Smrg 	if (!si)
30*5971e316Smrg 		return NULL;
31*5971e316Smrg 
32*5971e316Smrg 	return isl_val_get_ctx(si->stride);
33*5971e316Smrg }
34*5971e316Smrg 
35*5971e316Smrg /* Free "si" and return NULL.
36*5971e316Smrg  */
isl_stride_info_free(__isl_take isl_stride_info * si)37*5971e316Smrg __isl_null isl_stride_info *isl_stride_info_free(
38*5971e316Smrg 	__isl_take isl_stride_info *si)
39*5971e316Smrg {
40*5971e316Smrg 	if (!si)
41*5971e316Smrg 		return NULL;
42*5971e316Smrg 	isl_val_free(si->stride);
43*5971e316Smrg 	isl_aff_free(si->offset);
44*5971e316Smrg 	free(si);
45*5971e316Smrg 	return NULL;
46*5971e316Smrg }
47*5971e316Smrg 
48*5971e316Smrg /* Construct an isl_stride_info object with given offset and stride.
49*5971e316Smrg  */
isl_stride_info_alloc(__isl_take isl_val * stride,__isl_take isl_aff * offset)50*5971e316Smrg __isl_give isl_stride_info *isl_stride_info_alloc(
51*5971e316Smrg 	__isl_take isl_val *stride, __isl_take isl_aff *offset)
52*5971e316Smrg {
53*5971e316Smrg 	struct isl_stride_info *si;
54*5971e316Smrg 
55*5971e316Smrg 	if (!stride || !offset)
56*5971e316Smrg 		goto error;
57*5971e316Smrg 	si = isl_alloc_type(isl_val_get_ctx(stride), struct isl_stride_info);
58*5971e316Smrg 	if (!si)
59*5971e316Smrg 		goto error;
60*5971e316Smrg 	si->stride = stride;
61*5971e316Smrg 	si->offset = offset;
62*5971e316Smrg 	return si;
63*5971e316Smrg error:
64*5971e316Smrg 	isl_val_free(stride);
65*5971e316Smrg 	isl_aff_free(offset);
66*5971e316Smrg 	return NULL;
67*5971e316Smrg }
68*5971e316Smrg 
69*5971e316Smrg /* Make a copy of "si" and return it.
70*5971e316Smrg  */
isl_stride_info_copy(__isl_keep isl_stride_info * si)71*5971e316Smrg __isl_give isl_stride_info *isl_stride_info_copy(
72*5971e316Smrg 	__isl_keep isl_stride_info *si)
73*5971e316Smrg {
74*5971e316Smrg 	if (!si)
75*5971e316Smrg 		return NULL;
76*5971e316Smrg 
77*5971e316Smrg 	return isl_stride_info_alloc(isl_val_copy(si->stride),
78*5971e316Smrg 		isl_aff_copy(si->offset));
79*5971e316Smrg }
80*5971e316Smrg 
81*5971e316Smrg /* Return the stride of "si".
82*5971e316Smrg  */
isl_stride_info_get_stride(__isl_keep isl_stride_info * si)83*5971e316Smrg __isl_give isl_val *isl_stride_info_get_stride(__isl_keep isl_stride_info *si)
84*5971e316Smrg {
85*5971e316Smrg 	if (!si)
86*5971e316Smrg 		return NULL;
87*5971e316Smrg 	return isl_val_copy(si->stride);
88*5971e316Smrg }
89*5971e316Smrg 
90*5971e316Smrg /* Return the offset of "si".
91*5971e316Smrg  */
isl_stride_info_get_offset(__isl_keep isl_stride_info * si)92*5971e316Smrg __isl_give isl_aff *isl_stride_info_get_offset(__isl_keep isl_stride_info *si)
93*5971e316Smrg {
94*5971e316Smrg 	if (!si)
95*5971e316Smrg 		return NULL;
96*5971e316Smrg 	return isl_aff_copy(si->offset);
97*5971e316Smrg }
98*5971e316Smrg 
99*5971e316Smrg /* Information used inside detect_stride.
100*5971e316Smrg  *
101*5971e316Smrg  * "pos" is the set dimension at which the stride is being determined.
102*5971e316Smrg  * "want_offset" is set if the offset should be computed.
103*5971e316Smrg  * "found" is set if some stride was found already.
104*5971e316Smrg  * "stride" and "offset" contain the (combined) stride and offset
105*5971e316Smrg  * found so far and are NULL when "found" is not set.
106*5971e316Smrg  * If "want_offset" is not set, then "offset" remains NULL.
107*5971e316Smrg  */
108*5971e316Smrg struct isl_detect_stride_data {
109*5971e316Smrg 	int pos;
110*5971e316Smrg 	int want_offset;
111*5971e316Smrg 	int found;
112*5971e316Smrg 	isl_val *stride;
113*5971e316Smrg 	isl_aff *offset;
114*5971e316Smrg };
115*5971e316Smrg 
116*5971e316Smrg /* Set the stride and offset of data->pos to the given
117*5971e316Smrg  * value and expression.
118*5971e316Smrg  *
119*5971e316Smrg  * If we had already found a stride before, then the two strides
120*5971e316Smrg  * are combined into a single stride.
121*5971e316Smrg  *
122*5971e316Smrg  * In particular, if the new stride information is of the form
123*5971e316Smrg  *
124*5971e316Smrg  *	i = f + s (...)
125*5971e316Smrg  *
126*5971e316Smrg  * and the old stride information is of the form
127*5971e316Smrg  *
128*5971e316Smrg  *	i = f2 + s2 (...)
129*5971e316Smrg  *
130*5971e316Smrg  * then we compute the extended gcd of s and s2
131*5971e316Smrg  *
132*5971e316Smrg  *	a s + b s2 = g,
133*5971e316Smrg  *
134*5971e316Smrg  * with g = gcd(s,s2), multiply the first equation with t1 = b s2/g
135*5971e316Smrg  * and the second with t2 = a s1/g.
136*5971e316Smrg  * This results in
137*5971e316Smrg  *
138*5971e316Smrg  *	i = (b s2 + a s1)/g i = t1 f + t2 f2 + (s s2)/g (...)
139*5971e316Smrg  *
140*5971e316Smrg  * so that t1 f + t2 f2 is the combined offset and (s s2)/g = lcm(s,s2)
141*5971e316Smrg  * is the combined stride.
142*5971e316Smrg  */
set_stride(struct isl_detect_stride_data * data,__isl_take isl_val * stride,__isl_take isl_aff * offset)143*5971e316Smrg static isl_stat set_stride(struct isl_detect_stride_data *data,
144*5971e316Smrg 	__isl_take isl_val *stride, __isl_take isl_aff *offset)
145*5971e316Smrg {
146*5971e316Smrg 	if (!stride || !offset)
147*5971e316Smrg 		goto error;
148*5971e316Smrg 
149*5971e316Smrg 	if (data->found) {
150*5971e316Smrg 		isl_val *stride2, *a, *b, *g;
151*5971e316Smrg 		isl_aff *offset2;
152*5971e316Smrg 
153*5971e316Smrg 		stride2 = data->stride;
154*5971e316Smrg 		g = isl_val_gcdext(isl_val_copy(stride), isl_val_copy(stride2),
155*5971e316Smrg 					&a, &b);
156*5971e316Smrg 		a = isl_val_mul(a, isl_val_copy(stride));
157*5971e316Smrg 		a = isl_val_div(a, isl_val_copy(g));
158*5971e316Smrg 		stride2 = isl_val_div(stride2, g);
159*5971e316Smrg 		b = isl_val_mul(b, isl_val_copy(stride2));
160*5971e316Smrg 		stride = isl_val_mul(stride, stride2);
161*5971e316Smrg 
162*5971e316Smrg 		if (!data->want_offset) {
163*5971e316Smrg 			isl_val_free(a);
164*5971e316Smrg 			isl_val_free(b);
165*5971e316Smrg 		} else {
166*5971e316Smrg 			offset2 = data->offset;
167*5971e316Smrg 			offset2 = isl_aff_scale_val(offset2, a);
168*5971e316Smrg 			offset = isl_aff_scale_val(offset, b);
169*5971e316Smrg 			offset = isl_aff_add(offset, offset2);
170*5971e316Smrg 		}
171*5971e316Smrg 	}
172*5971e316Smrg 
173*5971e316Smrg 	data->found = 1;
174*5971e316Smrg 	data->stride = stride;
175*5971e316Smrg 	if (data->want_offset)
176*5971e316Smrg 		data->offset = offset;
177*5971e316Smrg 	else
178*5971e316Smrg 		isl_aff_free(offset);
179*5971e316Smrg 	if (!data->stride || (data->want_offset && !data->offset))
180*5971e316Smrg 		return isl_stat_error;
181*5971e316Smrg 
182*5971e316Smrg 	return isl_stat_ok;
183*5971e316Smrg error:
184*5971e316Smrg 	isl_val_free(stride);
185*5971e316Smrg 	isl_aff_free(offset);
186*5971e316Smrg 	return isl_stat_error;
187*5971e316Smrg }
188*5971e316Smrg 
189*5971e316Smrg /* Check if constraint "c" imposes any stride on dimension data->pos
190*5971e316Smrg  * and, if so, update the stride information in "data".
191*5971e316Smrg  *
192*5971e316Smrg  * In order to impose a stride on the dimension, "c" needs to be an equality
193*5971e316Smrg  * and it needs to involve the dimension.  Note that "c" may also be
194*5971e316Smrg  * a div constraint and thus an inequality that we cannot use.
195*5971e316Smrg  *
196*5971e316Smrg  * Let c be of the form
197*5971e316Smrg  *
198*5971e316Smrg  *	h(p) + g * v * i + g * stride * f(alpha) = 0
199*5971e316Smrg  *
200*5971e316Smrg  * with h(p) an expression in terms of the parameters and other dimensions
201*5971e316Smrg  * and f(alpha) an expression in terms of the existentially quantified
202*5971e316Smrg  * variables.
203*5971e316Smrg  *
204*5971e316Smrg  * If "stride" is not zero and not one, then it represents a non-trivial stride
205*5971e316Smrg  * on "i".  We compute a and b such that
206*5971e316Smrg  *
207*5971e316Smrg  *	a v + b stride = 1
208*5971e316Smrg  *
209*5971e316Smrg  * We have
210*5971e316Smrg  *
211*5971e316Smrg  *	g v i = -h(p) + g stride f(alpha)
212*5971e316Smrg  *
213*5971e316Smrg  *	a g v i = -a h(p) + g stride f(alpha)
214*5971e316Smrg  *
215*5971e316Smrg  *	a g v i + b g stride i = -a h(p) + g stride * (...)
216*5971e316Smrg  *
217*5971e316Smrg  *	g i = -a h(p) + g stride * (...)
218*5971e316Smrg  *
219*5971e316Smrg  *	i = -a h(p)/g + stride * (...)
220*5971e316Smrg  *
221*5971e316Smrg  * The expression "-a h(p)/g" can therefore be used as offset.
222*5971e316Smrg  */
detect_stride(__isl_take isl_constraint * c,void * user)223*5971e316Smrg static isl_stat detect_stride(__isl_take isl_constraint *c, void *user)
224*5971e316Smrg {
225*5971e316Smrg 	struct isl_detect_stride_data *data = user;
226*5971e316Smrg 	int i;
227*5971e316Smrg 	isl_size n_div;
228*5971e316Smrg 	isl_ctx *ctx;
229*5971e316Smrg 	isl_stat r = isl_stat_ok;
230*5971e316Smrg 	isl_val *v, *stride, *m;
231*5971e316Smrg 	isl_bool is_eq, relevant, has_stride;
232*5971e316Smrg 
233*5971e316Smrg 	is_eq = isl_constraint_is_equality(c);
234*5971e316Smrg 	relevant = isl_constraint_involves_dims(c, isl_dim_set, data->pos, 1);
235*5971e316Smrg 	if (is_eq < 0 || relevant < 0)
236*5971e316Smrg 		goto error;
237*5971e316Smrg 	if (!is_eq || !relevant) {
238*5971e316Smrg 		isl_constraint_free(c);
239*5971e316Smrg 		return isl_stat_ok;
240*5971e316Smrg 	}
241*5971e316Smrg 
242*5971e316Smrg 	n_div = isl_constraint_dim(c, isl_dim_div);
243*5971e316Smrg 	if (n_div < 0)
244*5971e316Smrg 		goto error;
245*5971e316Smrg 	ctx = isl_constraint_get_ctx(c);
246*5971e316Smrg 	stride = isl_val_zero(ctx);
247*5971e316Smrg 	for (i = 0; i < n_div; ++i) {
248*5971e316Smrg 		v = isl_constraint_get_coefficient_val(c, isl_dim_div, i);
249*5971e316Smrg 		stride = isl_val_gcd(stride, v);
250*5971e316Smrg 	}
251*5971e316Smrg 
252*5971e316Smrg 	v = isl_constraint_get_coefficient_val(c, isl_dim_set, data->pos);
253*5971e316Smrg 	m = isl_val_gcd(isl_val_copy(stride), isl_val_copy(v));
254*5971e316Smrg 	stride = isl_val_div(stride, isl_val_copy(m));
255*5971e316Smrg 	v = isl_val_div(v, isl_val_copy(m));
256*5971e316Smrg 
257*5971e316Smrg 	has_stride = isl_val_gt_si(stride, 1);
258*5971e316Smrg 	if (has_stride >= 0 && has_stride) {
259*5971e316Smrg 		isl_aff *aff;
260*5971e316Smrg 		isl_val *gcd, *a, *b;
261*5971e316Smrg 
262*5971e316Smrg 		gcd = isl_val_gcdext(v, isl_val_copy(stride), &a, &b);
263*5971e316Smrg 		isl_val_free(gcd);
264*5971e316Smrg 		isl_val_free(b);
265*5971e316Smrg 
266*5971e316Smrg 		aff = isl_constraint_get_aff(c);
267*5971e316Smrg 		for (i = 0; i < n_div; ++i)
268*5971e316Smrg 			aff = isl_aff_set_coefficient_si(aff,
269*5971e316Smrg 							 isl_dim_div, i, 0);
270*5971e316Smrg 		aff = isl_aff_set_coefficient_si(aff, isl_dim_in, data->pos, 0);
271*5971e316Smrg 		aff = isl_aff_remove_unused_divs(aff);
272*5971e316Smrg 		a = isl_val_neg(a);
273*5971e316Smrg 		aff = isl_aff_scale_val(aff, a);
274*5971e316Smrg 		aff = isl_aff_scale_down_val(aff, m);
275*5971e316Smrg 		r = set_stride(data, stride, aff);
276*5971e316Smrg 	} else {
277*5971e316Smrg 		isl_val_free(stride);
278*5971e316Smrg 		isl_val_free(m);
279*5971e316Smrg 		isl_val_free(v);
280*5971e316Smrg 	}
281*5971e316Smrg 
282*5971e316Smrg 	isl_constraint_free(c);
283*5971e316Smrg 	if (has_stride < 0)
284*5971e316Smrg 		return isl_stat_error;
285*5971e316Smrg 	return r;
286*5971e316Smrg error:
287*5971e316Smrg 	isl_constraint_free(c);
288*5971e316Smrg 	return isl_stat_error;
289*5971e316Smrg }
290*5971e316Smrg 
291*5971e316Smrg /* Check if the constraints in "set" imply any stride on set dimension "pos" and
292*5971e316Smrg  * store the results in data->stride and data->offset.
293*5971e316Smrg  *
294*5971e316Smrg  * In particular, compute the affine hull and then check if
295*5971e316Smrg  * any of the constraints in the hull impose any stride on the dimension.
296*5971e316Smrg  * If no such constraint can be found, then the offset is taken
297*5971e316Smrg  * to be the zero expression and the stride is taken to be one.
298*5971e316Smrg  */
set_detect_stride(__isl_keep isl_set * set,int pos,struct isl_detect_stride_data * data)299*5971e316Smrg static void set_detect_stride(__isl_keep isl_set *set, int pos,
300*5971e316Smrg 	struct isl_detect_stride_data *data)
301*5971e316Smrg {
302*5971e316Smrg 	isl_basic_set *hull;
303*5971e316Smrg 
304*5971e316Smrg 	hull = isl_set_affine_hull(isl_set_copy(set));
305*5971e316Smrg 
306*5971e316Smrg 	data->pos = pos;
307*5971e316Smrg 	data->found = 0;
308*5971e316Smrg 	data->stride = NULL;
309*5971e316Smrg 	data->offset = NULL;
310*5971e316Smrg 	if (isl_basic_set_foreach_constraint(hull, &detect_stride, data) < 0)
311*5971e316Smrg 		goto error;
312*5971e316Smrg 
313*5971e316Smrg 	if (!data->found) {
314*5971e316Smrg 		data->stride = isl_val_one(isl_set_get_ctx(set));
315*5971e316Smrg 		if (data->want_offset) {
316*5971e316Smrg 			isl_space *space;
317*5971e316Smrg 			isl_local_space *ls;
318*5971e316Smrg 
319*5971e316Smrg 			space = isl_set_get_space(set);
320*5971e316Smrg 			ls = isl_local_space_from_space(space);
321*5971e316Smrg 			data->offset = isl_aff_zero_on_domain(ls);
322*5971e316Smrg 		}
323*5971e316Smrg 	}
324*5971e316Smrg 	isl_basic_set_free(hull);
325*5971e316Smrg 	return;
326*5971e316Smrg error:
327*5971e316Smrg 	isl_basic_set_free(hull);
328*5971e316Smrg 	data->stride = isl_val_free(data->stride);
329*5971e316Smrg 	data->offset = isl_aff_free(data->offset);
330*5971e316Smrg }
331*5971e316Smrg 
332*5971e316Smrg /* Check if the constraints in "set" imply any stride on set dimension "pos" and
333*5971e316Smrg  * return the results in the form of an offset and a stride.
334*5971e316Smrg  */
isl_set_get_stride_info(__isl_keep isl_set * set,int pos)335*5971e316Smrg __isl_give isl_stride_info *isl_set_get_stride_info(__isl_keep isl_set *set,
336*5971e316Smrg 	int pos)
337*5971e316Smrg {
338*5971e316Smrg 	struct isl_detect_stride_data data;
339*5971e316Smrg 
340*5971e316Smrg 	data.want_offset = 1;
341*5971e316Smrg 	set_detect_stride(set, pos, &data);
342*5971e316Smrg 
343*5971e316Smrg 	return isl_stride_info_alloc(data.stride, data.offset);
344*5971e316Smrg }
345*5971e316Smrg 
346*5971e316Smrg /* Check if the constraints in "set" imply any stride on set dimension "pos" and
347*5971e316Smrg  * return this stride.
348*5971e316Smrg  */
isl_set_get_stride(__isl_keep isl_set * set,int pos)349*5971e316Smrg __isl_give isl_val *isl_set_get_stride(__isl_keep isl_set *set, int pos)
350*5971e316Smrg {
351*5971e316Smrg 	struct isl_detect_stride_data data;
352*5971e316Smrg 
353*5971e316Smrg 	data.want_offset = 0;
354*5971e316Smrg 	set_detect_stride(set, pos, &data);
355*5971e316Smrg 
356*5971e316Smrg 	return data.stride;
357*5971e316Smrg }
358*5971e316Smrg 
359*5971e316Smrg /* Check if the constraints in "map" imply any stride on output dimension "pos",
360*5971e316Smrg  * independently of any other output dimensions, and
361*5971e316Smrg  * return the results in the form of an offset and a stride.
362*5971e316Smrg  *
363*5971e316Smrg  * Convert the input to a set with only the input dimensions and
364*5971e316Smrg  * the single output dimension such that it be passed to
365*5971e316Smrg  * isl_set_get_stride_info and convert the result back to
366*5971e316Smrg  * an expression defined over the domain of "map".
367*5971e316Smrg  */
isl_map_get_range_stride_info(__isl_keep isl_map * map,int pos)368*5971e316Smrg __isl_give isl_stride_info *isl_map_get_range_stride_info(
369*5971e316Smrg 	__isl_keep isl_map *map, int pos)
370*5971e316Smrg {
371*5971e316Smrg 	isl_stride_info *si;
372*5971e316Smrg 	isl_set *set;
373*5971e316Smrg 	isl_size n_in;
374*5971e316Smrg 
375*5971e316Smrg 	n_in = isl_map_dim(map, isl_dim_in);
376*5971e316Smrg 	if (n_in < 0)
377*5971e316Smrg 		return NULL;
378*5971e316Smrg 	map = isl_map_copy(map);
379*5971e316Smrg 	map = isl_map_project_onto(map, isl_dim_out, pos, 1);
380*5971e316Smrg 	set = isl_map_wrap(map);
381*5971e316Smrg 	si = isl_set_get_stride_info(set, n_in);
382*5971e316Smrg 	isl_set_free(set);
383*5971e316Smrg 	if (!si)
384*5971e316Smrg 		return NULL;
385*5971e316Smrg 	si->offset = isl_aff_domain_factor_domain(si->offset);
386*5971e316Smrg 	if (!si->offset)
387*5971e316Smrg 		return isl_stride_info_free(si);
388*5971e316Smrg 	return si;
389*5971e316Smrg }
390