1181254a7Smrg // Written in the D programming language.
2181254a7Smrg
3181254a7Smrg /**
4181254a7Smrg This module is a port of a growing fragment of the $(D_PARAM numeric)
5*b1e83836Smrg header in Alexander Stepanov's $(LINK2 https://en.wikipedia.org/wiki/Standard_Template_Library,
6181254a7Smrg Standard Template Library), with a few additions.
7181254a7Smrg
8181254a7Smrg Macros:
9181254a7Smrg Copyright: Copyright Andrei Alexandrescu 2008 - 2009.
10181254a7Smrg License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0).
11181254a7Smrg Authors: $(HTTP erdani.org, Andrei Alexandrescu),
12181254a7Smrg Don Clugston, Robert Jacques, Ilya Yaroshenko
13*b1e83836Smrg Source: $(PHOBOSSRC std/numeric.d)
14181254a7Smrg */
15181254a7Smrg /*
16181254a7Smrg Copyright Andrei Alexandrescu 2008 - 2009.
17181254a7Smrg Distributed under the Boost Software License, Version 1.0.
18181254a7Smrg (See accompanying file LICENSE_1_0.txt or copy at
19181254a7Smrg http://www.boost.org/LICENSE_1_0.txt)
20181254a7Smrg */
21181254a7Smrg module std.numeric;
22181254a7Smrg
23181254a7Smrg import std.complex;
24181254a7Smrg import std.math;
25*b1e83836Smrg import core.math : fabs, ldexp, sin, sqrt;
26181254a7Smrg import std.range.primitives;
27181254a7Smrg import std.traits;
28181254a7Smrg import std.typecons;
29181254a7Smrg
30181254a7Smrg /// Format flags for CustomFloat.
31181254a7Smrg public enum CustomFloatFlags
32181254a7Smrg {
33181254a7Smrg /// Adds a sign bit to allow for signed numbers.
34181254a7Smrg signed = 1,
35181254a7Smrg
36181254a7Smrg /**
37181254a7Smrg * Store values in normalized form by default. The actual precision of the
38181254a7Smrg * significand is extended by 1 bit by assuming an implicit leading bit of 1
39*b1e83836Smrg * instead of 0. i.e. `1.nnnn` instead of `0.nnnn`.
40181254a7Smrg * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types
41181254a7Smrg */
42181254a7Smrg storeNormalized = 2,
43181254a7Smrg
44181254a7Smrg /**
45181254a7Smrg * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
46181254a7Smrg * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0.
47181254a7Smrg */
48181254a7Smrg allowDenorm = 4,
49181254a7Smrg
50181254a7Smrg /**
51181254a7Smrg * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity,
52181254a7Smrg * IEEE754 _infinity) values.
53181254a7Smrg */
54181254a7Smrg infinity = 8,
55181254a7Smrg
56181254a7Smrg /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values.
57181254a7Smrg nan = 16,
58181254a7Smrg
59181254a7Smrg /**
60181254a7Smrg * If set, select an exponent bias such that max_exp = 1.
61181254a7Smrg * i.e. so that the maximum value is >= 1.0 and < 2.0.
62181254a7Smrg * Ignored if the exponent bias is manually specified.
63181254a7Smrg */
64181254a7Smrg probability = 32,
65181254a7Smrg
66181254a7Smrg /// If set, unsigned custom floats are assumed to be negative.
67181254a7Smrg negativeUnsigned = 64,
68181254a7Smrg
69181254a7Smrg /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
70181254a7Smrg * IEEE754 denormalized) number.
71181254a7Smrg * Requires allowDenorm and storeNormalized.
72181254a7Smrg */
73181254a7Smrg allowDenormZeroOnly = 128 | allowDenorm | storeNormalized,
74181254a7Smrg
75181254a7Smrg /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options.
76181254a7Smrg ieee = signed | storeNormalized | allowDenorm | infinity | nan ,
77181254a7Smrg
78181254a7Smrg /// Include none of the above options.
79181254a7Smrg none = 0
80181254a7Smrg }
81181254a7Smrg
CustomFloatParams(uint bits)82181254a7Smrg private template CustomFloatParams(uint bits)
83181254a7Smrg {
84181254a7Smrg enum CustomFloatFlags flags = CustomFloatFlags.ieee
85181254a7Smrg ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none);
86181254a7Smrg static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags);
87181254a7Smrg static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags);
88181254a7Smrg static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags);
89181254a7Smrg static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags);
90181254a7Smrg static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags);
91181254a7Smrg }
92181254a7Smrg
CustomFloatParams(uint precision,uint exponentWidth,CustomFloatFlags flags)93181254a7Smrg private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags)
94181254a7Smrg {
95181254a7Smrg import std.meta : AliasSeq;
96181254a7Smrg alias CustomFloatParams =
97181254a7Smrg AliasSeq!(
98181254a7Smrg precision,
99181254a7Smrg exponentWidth,
100181254a7Smrg flags,
101181254a7Smrg (1 << (exponentWidth - ((flags & flags.probability) == 0)))
102181254a7Smrg - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0)
103181254a7Smrg ); // ((flags & CustomFloatFlags.probability) == 0)
104181254a7Smrg }
105181254a7Smrg
106181254a7Smrg /**
107181254a7Smrg * Allows user code to define custom floating-point formats. These formats are
108181254a7Smrg * for storage only; all operations on them are performed by first implicitly
109*b1e83836Smrg * extracting them to `real` first. After the operation is completed the
110181254a7Smrg * result can be stored in a custom floating-point value via assignment.
111181254a7Smrg */
112181254a7Smrg template CustomFloat(uint bits)
113181254a7Smrg if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80)
114181254a7Smrg {
115181254a7Smrg alias CustomFloat = CustomFloat!(CustomFloatParams!(bits));
116181254a7Smrg }
117181254a7Smrg
118181254a7Smrg /// ditto
119181254a7Smrg template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee)
120181254a7Smrg if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0)
121181254a7Smrg {
122181254a7Smrg alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags));
123181254a7Smrg }
124181254a7Smrg
125181254a7Smrg ///
126181254a7Smrg @safe unittest
127181254a7Smrg {
128*b1e83836Smrg import std.math.trigonometry : sin, cos;
129181254a7Smrg
130181254a7Smrg // Define a 16-bit floating point values
131181254a7Smrg CustomFloat!16 x; // Using the number of bits
132181254a7Smrg CustomFloat!(10, 5) y; // Using the precision and exponent width
133181254a7Smrg CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags
134181254a7Smrg CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias
135181254a7Smrg
136181254a7Smrg // Use the 16-bit floats mostly like normal numbers
137181254a7Smrg w = x*y - 1;
138181254a7Smrg
139181254a7Smrg // Functions calls require conversion
140181254a7Smrg z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real
141181254a7Smrg z = sin(x.get!float) + cos(y.get!float); // Or use get!T
142181254a7Smrg z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert
143181254a7Smrg
144181254a7Smrg // Define a 8-bit custom float for storing probabilities
145181254a7Smrg alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed );
146181254a7Smrg auto p = Probability(0.5);
147181254a7Smrg }
148181254a7Smrg
149*b1e83836Smrg // Facilitate converting numeric types to custom float
150*b1e83836Smrg private union ToBinary(F)
151*b1e83836Smrg if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real))
152*b1e83836Smrg {
153*b1e83836Smrg F set;
154*b1e83836Smrg
155*b1e83836Smrg // If on Linux or Mac, where 80-bit reals are padded, ignore the
156*b1e83836Smrg // padding.
157*b1e83836Smrg import std.algorithm.comparison : min;
158*b1e83836Smrg CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get;
159*b1e83836Smrg
160*b1e83836Smrg // Convert F to the correct binary type.
opCall(F value)161*b1e83836Smrg static typeof(get) opCall(F value)
162*b1e83836Smrg {
163*b1e83836Smrg ToBinary r;
164*b1e83836Smrg r.set = value;
165*b1e83836Smrg return r.get;
166*b1e83836Smrg }
167*b1e83836Smrg alias get this;
168*b1e83836Smrg }
169*b1e83836Smrg
170181254a7Smrg /// ditto
171181254a7Smrg struct CustomFloat(uint precision, // fraction bits (23 for float)
172181254a7Smrg uint exponentWidth, // exponent bits (8 for float) Exponent width
173181254a7Smrg CustomFloatFlags flags,
174181254a7Smrg uint bias)
175*b1e83836Smrg if (isCorrectCustomFloat(precision, exponentWidth, flags))
176181254a7Smrg {
177181254a7Smrg import std.bitmanip : bitfields;
178181254a7Smrg import std.meta : staticIndexOf;
179181254a7Smrg private:
180181254a7Smrg // get the correct unsigned bitfield type to support > 32 bits
uType(uint bits)181181254a7Smrg template uType(uint bits)
182181254a7Smrg {
183181254a7Smrg static if (bits <= size_t.sizeof*8) alias uType = size_t;
184181254a7Smrg else alias uType = ulong ;
185181254a7Smrg }
186181254a7Smrg
187181254a7Smrg // get the correct signed bitfield type to support > 32 bits
sType(uint bits)188181254a7Smrg template sType(uint bits)
189181254a7Smrg {
190181254a7Smrg static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t;
191181254a7Smrg else alias sType = long;
192181254a7Smrg }
193181254a7Smrg
194181254a7Smrg alias T_sig = uType!precision;
195181254a7Smrg alias T_exp = uType!exponentWidth;
196181254a7Smrg alias T_signed_exp = sType!exponentWidth;
197181254a7Smrg
198181254a7Smrg alias Flags = CustomFloatFlags;
199181254a7Smrg
200181254a7Smrg // Perform IEEE rounding with round to nearest detection
roundedShift(T,U)201181254a7Smrg void roundedShift(T,U)(ref T sig, U shift)
202181254a7Smrg {
203*b1e83836Smrg if (shift >= T.sizeof*8)
204*b1e83836Smrg {
205*b1e83836Smrg // avoid illegal shift
206*b1e83836Smrg sig = 0;
207*b1e83836Smrg }
208*b1e83836Smrg else if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1))
209181254a7Smrg {
210181254a7Smrg // round to even
211181254a7Smrg sig >>= shift;
212181254a7Smrg sig += sig & 1;
213181254a7Smrg }
214181254a7Smrg else
215181254a7Smrg {
216181254a7Smrg sig >>= shift - 1;
217181254a7Smrg sig += sig & 1;
218181254a7Smrg // Perform standard rounding
219181254a7Smrg sig >>= 1;
220181254a7Smrg }
221181254a7Smrg }
222181254a7Smrg
223181254a7Smrg // Convert the current value to signed exponent, normalized form
toNormalized(T,U)224181254a7Smrg void toNormalized(T,U)(ref T sig, ref U exp)
225181254a7Smrg {
226181254a7Smrg sig = significand;
227181254a7Smrg auto shift = (T.sizeof*8) - precision;
228181254a7Smrg exp = exponent;
229181254a7Smrg static if (flags&(Flags.infinity|Flags.nan))
230181254a7Smrg {
231181254a7Smrg // Handle inf or nan
232181254a7Smrg if (exp == exponent_max)
233181254a7Smrg {
234181254a7Smrg exp = exp.max;
235181254a7Smrg sig <<= shift;
236181254a7Smrg static if (flags&Flags.storeNormalized)
237181254a7Smrg {
238181254a7Smrg // Save inf/nan in denormalized format
239181254a7Smrg sig >>= 1;
240181254a7Smrg sig += cast(T) 1uL << (T.sizeof*8 - 1);
241181254a7Smrg }
242181254a7Smrg return;
243181254a7Smrg }
244181254a7Smrg }
245181254a7Smrg if ((~flags&Flags.storeNormalized) ||
246181254a7Smrg // Convert denormalized form to normalized form
247181254a7Smrg ((flags&Flags.allowDenorm) && exp == 0))
248181254a7Smrg {
249181254a7Smrg if (sig > 0)
250181254a7Smrg {
251181254a7Smrg import core.bitop : bsr;
252181254a7Smrg auto shift2 = precision - bsr(sig);
253181254a7Smrg exp -= shift2-1;
254181254a7Smrg shift += shift2;
255181254a7Smrg }
256181254a7Smrg else // value = 0.0
257181254a7Smrg {
258181254a7Smrg exp = exp.min;
259181254a7Smrg return;
260181254a7Smrg }
261181254a7Smrg }
262181254a7Smrg sig <<= shift;
263181254a7Smrg exp -= bias;
264181254a7Smrg }
265181254a7Smrg
266181254a7Smrg // Set the current value from signed exponent, normalized form
fromNormalized(T,U)267181254a7Smrg void fromNormalized(T,U)(ref T sig, ref U exp)
268181254a7Smrg {
269181254a7Smrg auto shift = (T.sizeof*8) - precision;
270181254a7Smrg if (exp == exp.max)
271181254a7Smrg {
272181254a7Smrg // infinity or nan
273181254a7Smrg exp = exponent_max;
274181254a7Smrg static if (flags & Flags.storeNormalized)
275181254a7Smrg sig <<= 1;
276181254a7Smrg
277181254a7Smrg // convert back to normalized form
278181254a7Smrg static if (~flags & Flags.infinity)
279181254a7Smrg // No infinity support?
280181254a7Smrg assert(sig != 0, "Infinity floating point value assigned to a "
281181254a7Smrg ~ typeof(this).stringof ~ " (no infinity support).");
282181254a7Smrg
283181254a7Smrg static if (~flags & Flags.nan) // No NaN support?
284181254a7Smrg assert(sig == 0, "NaN floating point value assigned to a " ~
285181254a7Smrg typeof(this).stringof ~ " (no nan support).");
286181254a7Smrg sig >>= shift;
287181254a7Smrg return;
288181254a7Smrg }
289181254a7Smrg if (exp == exp.min) // 0.0
290181254a7Smrg {
291181254a7Smrg exp = 0;
292181254a7Smrg sig = 0;
293181254a7Smrg return;
294181254a7Smrg }
295181254a7Smrg
296181254a7Smrg exp += bias;
297181254a7Smrg if (exp <= 0)
298181254a7Smrg {
299181254a7Smrg static if ((flags&Flags.allowDenorm) ||
300181254a7Smrg // Convert from normalized form to denormalized
301181254a7Smrg (~flags&Flags.storeNormalized))
302181254a7Smrg {
303181254a7Smrg shift += -exp;
304181254a7Smrg roundedShift(sig,1);
305181254a7Smrg sig += cast(T) 1uL << (T.sizeof*8 - 1);
306181254a7Smrg // Add the leading 1
307181254a7Smrg exp = 0;
308181254a7Smrg }
309181254a7Smrg else
310181254a7Smrg assert((flags&Flags.storeNormalized) && exp == 0,
311181254a7Smrg "Underflow occured assigning to a " ~
312181254a7Smrg typeof(this).stringof ~ " (no denormal support).");
313181254a7Smrg }
314181254a7Smrg else
315181254a7Smrg {
316181254a7Smrg static if (~flags&Flags.storeNormalized)
317181254a7Smrg {
318181254a7Smrg // Convert from normalized form to denormalized
319181254a7Smrg roundedShift(sig,1);
320181254a7Smrg sig += cast(T) 1uL << (T.sizeof*8 - 1);
321181254a7Smrg // Add the leading 1
322181254a7Smrg }
323181254a7Smrg }
324181254a7Smrg
325181254a7Smrg if (shift > 0)
326181254a7Smrg roundedShift(sig,shift);
327181254a7Smrg if (sig > significand_max)
328181254a7Smrg {
329181254a7Smrg // handle significand overflow (should only be 1 bit)
330181254a7Smrg static if (~flags&Flags.storeNormalized)
331181254a7Smrg {
332181254a7Smrg sig >>= 1;
333181254a7Smrg }
334181254a7Smrg else
335181254a7Smrg sig &= significand_max;
336181254a7Smrg exp++;
337181254a7Smrg }
338181254a7Smrg static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly)
339181254a7Smrg {
340181254a7Smrg // disallow non-zero denormals
341181254a7Smrg if (exp == 0)
342181254a7Smrg {
343181254a7Smrg sig <<= 1;
344181254a7Smrg if (sig > significand_max && (sig&significand_max) > 0)
345181254a7Smrg // Check and round to even
346181254a7Smrg exp++;
347181254a7Smrg sig = 0;
348181254a7Smrg }
349181254a7Smrg }
350181254a7Smrg
351181254a7Smrg if (exp >= exponent_max)
352181254a7Smrg {
353181254a7Smrg static if (flags&(Flags.infinity|Flags.nan))
354181254a7Smrg {
355181254a7Smrg sig = 0;
356181254a7Smrg exp = exponent_max;
357181254a7Smrg static if (~flags&(Flags.infinity))
358181254a7Smrg assert(0, "Overflow occured assigning to a " ~
359181254a7Smrg typeof(this).stringof ~ " (no infinity support).");
360181254a7Smrg }
361181254a7Smrg else
362181254a7Smrg assert(exp == exponent_max, "Overflow occured assigning to a "
363181254a7Smrg ~ typeof(this).stringof ~ " (no infinity support).");
364181254a7Smrg }
365181254a7Smrg }
366181254a7Smrg
367181254a7Smrg public:
368181254a7Smrg static if (precision == 64) // CustomFloat!80 support hack
369181254a7Smrg {
370181254a7Smrg ulong significand;
371181254a7Smrg enum ulong significand_max = ulong.max;
372181254a7Smrg mixin(bitfields!(
373181254a7Smrg T_exp , "exponent", exponentWidth,
374181254a7Smrg bool , "sign" , flags & flags.signed ));
375181254a7Smrg }
376181254a7Smrg else
377181254a7Smrg {
378181254a7Smrg mixin(bitfields!(
379181254a7Smrg T_sig, "significand", precision,
380181254a7Smrg T_exp, "exponent" , exponentWidth,
381181254a7Smrg bool , "sign" , flags & flags.signed ));
382181254a7Smrg }
383181254a7Smrg
384181254a7Smrg /// Returns: infinity value
385181254a7Smrg static if (flags & Flags.infinity)
infinity()386181254a7Smrg static @property CustomFloat infinity()
387181254a7Smrg {
388181254a7Smrg CustomFloat value;
389181254a7Smrg static if (flags & Flags.signed)
390181254a7Smrg value.sign = 0;
391181254a7Smrg value.significand = 0;
392181254a7Smrg value.exponent = exponent_max;
393181254a7Smrg return value;
394181254a7Smrg }
395181254a7Smrg
396181254a7Smrg /// Returns: NaN value
397181254a7Smrg static if (flags & Flags.nan)
nan()398181254a7Smrg static @property CustomFloat nan()
399181254a7Smrg {
400181254a7Smrg CustomFloat value;
401181254a7Smrg static if (flags & Flags.signed)
402181254a7Smrg value.sign = 0;
403181254a7Smrg value.significand = cast(typeof(significand_max)) 1L << (precision-1);
404181254a7Smrg value.exponent = exponent_max;
405181254a7Smrg return value;
406181254a7Smrg }
407181254a7Smrg
408181254a7Smrg /// Returns: number of decimal digits of precision
dig()409181254a7Smrg static @property size_t dig()
410181254a7Smrg {
411*b1e83836Smrg auto shiftcnt = precision - ((flags&Flags.storeNormalized) == 0);
412*b1e83836Smrg return shiftcnt == 64 ? 19 : cast(size_t) log10(1uL << shiftcnt);
413181254a7Smrg }
414181254a7Smrg
415181254a7Smrg /// Returns: smallest increment to the value 1
epsilon()416181254a7Smrg static @property CustomFloat epsilon()
417181254a7Smrg {
418*b1e83836Smrg CustomFloat one = CustomFloat(1);
419*b1e83836Smrg CustomFloat onePlusEpsilon = one;
420*b1e83836Smrg onePlusEpsilon.significand = onePlusEpsilon.significand | 1; // |= does not work here
421181254a7Smrg
422*b1e83836Smrg return CustomFloat(onePlusEpsilon - one);
423181254a7Smrg }
424181254a7Smrg
425181254a7Smrg /// the number of bits in mantissa
426181254a7Smrg enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0);
427181254a7Smrg
428181254a7Smrg /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable
max_10_exp()429181254a7Smrg static @property int max_10_exp(){ return cast(int) log10( +max ); }
430181254a7Smrg
431181254a7Smrg /// maximum int value such that 2<sup>max_exp-1</sup> is representable
432*b1e83836Smrg enum max_exp = exponent_max - bias - ((flags & (Flags.infinity | Flags.nan)) != 0) + 1;
433181254a7Smrg
434181254a7Smrg /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable
min_10_exp()435181254a7Smrg static @property int min_10_exp(){ return cast(int) log10( +min_normal ); }
436181254a7Smrg
437181254a7Smrg /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value
438*b1e83836Smrg enum min_exp = cast(T_signed_exp) -(cast(long) bias) + 1 + ((flags & Flags.allowDenorm) != 0);
439181254a7Smrg
440181254a7Smrg /// Returns: largest representable value that's not infinity
max()441181254a7Smrg static @property CustomFloat max()
442181254a7Smrg {
443181254a7Smrg CustomFloat value;
444181254a7Smrg static if (flags & Flags.signed)
445181254a7Smrg value.sign = 0;
446181254a7Smrg value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0);
447181254a7Smrg value.significand = significand_max;
448181254a7Smrg return value;
449181254a7Smrg }
450181254a7Smrg
451181254a7Smrg /// Returns: smallest representable normalized value that's not 0
min_normal()452*b1e83836Smrg static @property CustomFloat min_normal()
453*b1e83836Smrg {
454181254a7Smrg CustomFloat value;
455181254a7Smrg static if (flags & Flags.signed)
456181254a7Smrg value.sign = 0;
457*b1e83836Smrg value.exponent = (flags & Flags.allowDenorm) != 0;
458181254a7Smrg static if (flags & Flags.storeNormalized)
459181254a7Smrg value.significand = 0;
460181254a7Smrg else
461181254a7Smrg value.significand = cast(T_sig) 1uL << (precision - 1);
462181254a7Smrg return value;
463181254a7Smrg }
464181254a7Smrg
465181254a7Smrg /// Returns: real part
re()466181254a7Smrg @property CustomFloat re() { return this; }
467181254a7Smrg
468181254a7Smrg /// Returns: imaginary part
im()469181254a7Smrg static @property CustomFloat im() { return CustomFloat(0.0f); }
470181254a7Smrg
471*b1e83836Smrg /// Initialize from any `real` compatible type.
472181254a7Smrg this(F)(F input) if (__traits(compiles, cast(real) input ))
473181254a7Smrg {
474181254a7Smrg this = input;
475181254a7Smrg }
476181254a7Smrg
477181254a7Smrg /// Self assignment
478181254a7Smrg void opAssign(F:CustomFloat)(F input)
479181254a7Smrg {
480181254a7Smrg static if (flags & Flags.signed)
481181254a7Smrg sign = input.sign;
482181254a7Smrg exponent = input.exponent;
483181254a7Smrg significand = input.significand;
484181254a7Smrg }
485181254a7Smrg
486*b1e83836Smrg /// Assigns from any `real` compatible type.
487181254a7Smrg void opAssign(F)(F input)
488181254a7Smrg if (__traits(compiles, cast(real) input))
489181254a7Smrg {
490181254a7Smrg import std.conv : text;
491181254a7Smrg
492*b1e83836Smrg static if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0)
493181254a7Smrg auto value = ToBinary!(Unqual!F)(input);
494181254a7Smrg else
495181254a7Smrg auto value = ToBinary!(real )(input);
496181254a7Smrg
497181254a7Smrg // Assign the sign bit
498181254a7Smrg static if (~flags & Flags.signed)
499181254a7Smrg assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0),
500181254a7Smrg "Incorrectly signed floating point value assigned to a " ~
501181254a7Smrg typeof(this).stringof ~ " (no sign support).");
502181254a7Smrg else
503181254a7Smrg sign = value.sign;
504181254a7Smrg
505181254a7Smrg CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent;
506181254a7Smrg CommonType!(T_sig, value.T_sig ) sig = value.significand;
507181254a7Smrg
508181254a7Smrg value.toNormalized(sig,exp);
509181254a7Smrg fromNormalized(sig,exp);
510181254a7Smrg
511181254a7Smrg assert(exp <= exponent_max, text(typeof(this).stringof ~
512181254a7Smrg " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig));
513181254a7Smrg assert(sig <= significand_max, text(typeof(this).stringof ~
514181254a7Smrg " significand too large: ",sig," > ",significand_max,
515181254a7Smrg "\t",input,"\t",exp," ",exponent_max));
516181254a7Smrg exponent = cast(T_exp) exp;
517181254a7Smrg significand = cast(T_sig) sig;
518181254a7Smrg }
519181254a7Smrg
520*b1e83836Smrg /// Fetches the stored value either as a `float`, `double` or `real`.
521181254a7Smrg @property F get(F)()
522*b1e83836Smrg if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0)
523181254a7Smrg {
524181254a7Smrg import std.conv : text;
525181254a7Smrg
526181254a7Smrg ToBinary!F result;
527181254a7Smrg
528181254a7Smrg static if (flags&Flags.signed)
529181254a7Smrg result.sign = sign;
530181254a7Smrg else
531181254a7Smrg result.sign = (flags&flags.negativeUnsigned) > 0;
532181254a7Smrg
533181254a7Smrg CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction
534181254a7Smrg CommonType!(T_sig, result.get.T_sig ) sig = significand;
535181254a7Smrg
536181254a7Smrg toNormalized(sig,exp);
537181254a7Smrg result.fromNormalized(sig,exp);
538181254a7Smrg assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) );
539181254a7Smrg assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) );
540181254a7Smrg result.exponent = cast(result.get.T_exp) exp;
541181254a7Smrg result.significand = cast(result.get.T_sig) sig;
542181254a7Smrg return result.set;
543181254a7Smrg }
544181254a7Smrg
545181254a7Smrg ///ditto
546*b1e83836Smrg alias opCast = get;
547181254a7Smrg
548*b1e83836Smrg /// Convert the CustomFloat to a real and perform the relevant operator on the result
549181254a7Smrg real opUnary(string op)()
550181254a7Smrg if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--")
551181254a7Smrg {
552181254a7Smrg static if (op=="++" || op=="--")
553181254a7Smrg {
554181254a7Smrg auto result = get!real;
555181254a7Smrg this = mixin(op~`result`);
556181254a7Smrg return result;
557181254a7Smrg }
558181254a7Smrg else
559181254a7Smrg return mixin(op~`get!real`);
560181254a7Smrg }
561181254a7Smrg
562181254a7Smrg /// ditto
563*b1e83836Smrg // Define an opBinary `CustomFloat op CustomFloat` so that those below
564*b1e83836Smrg // do not match equally, which is disallowed by the spec:
565*b1e83836Smrg // https://dlang.org/spec/operatoroverloading.html#binary
566181254a7Smrg real opBinary(string op,T)(T b)
567*b1e83836Smrg if (__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
568*b1e83836Smrg {
569*b1e83836Smrg return mixin(`get!real`~op~`b.get!real`);
570*b1e83836Smrg }
571*b1e83836Smrg
572*b1e83836Smrg /// ditto
573*b1e83836Smrg real opBinary(string op,T)(T b)
574*b1e83836Smrg if ( __traits(compiles, mixin(`get!real`~op~`b`)) &&
575*b1e83836Smrg !__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
576181254a7Smrg {
577181254a7Smrg return mixin(`get!real`~op~`b`);
578181254a7Smrg }
579181254a7Smrg
580181254a7Smrg /// ditto
581181254a7Smrg real opBinaryRight(string op,T)(T a)
582181254a7Smrg if ( __traits(compiles, mixin(`a`~op~`get!real`)) &&
583*b1e83836Smrg !__traits(compiles, mixin(`get!real`~op~`b`)) &&
584*b1e83836Smrg !__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
585181254a7Smrg {
586181254a7Smrg return mixin(`a`~op~`get!real`);
587181254a7Smrg }
588181254a7Smrg
589181254a7Smrg /// ditto
590181254a7Smrg int opCmp(T)(auto ref T b)
591181254a7Smrg if (__traits(compiles, cast(real) b))
592181254a7Smrg {
593181254a7Smrg auto x = get!real;
594181254a7Smrg auto y = cast(real) b;
595181254a7Smrg return (x >= y)-(x <= y);
596181254a7Smrg }
597181254a7Smrg
598181254a7Smrg /// ditto
599181254a7Smrg void opOpAssign(string op, T)(auto ref T b)
600181254a7Smrg if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`)))
601181254a7Smrg {
602181254a7Smrg return mixin(`this = this `~op~` cast(real) b`);
603181254a7Smrg }
604181254a7Smrg
605181254a7Smrg /// ditto
toString()606181254a7Smrg template toString()
607181254a7Smrg {
608*b1e83836Smrg import std.format.spec : FormatSpec;
609*b1e83836Smrg import std.format.write : formatValue;
610*b1e83836Smrg // Needs to be a template because of https://issues.dlang.org/show_bug.cgi?id=13737.
611*b1e83836Smrg void toString()(scope void delegate(const(char)[]) sink, scope const ref FormatSpec!char fmt)
612181254a7Smrg {
613181254a7Smrg sink.formatValue(get!real, fmt);
614181254a7Smrg }
615181254a7Smrg }
616181254a7Smrg }
617181254a7Smrg
618181254a7Smrg @safe unittest
619181254a7Smrg {
620181254a7Smrg import std.meta;
621181254a7Smrg alias FPTypes =
622181254a7Smrg AliasSeq!(
623181254a7Smrg CustomFloat!(5, 10),
624181254a7Smrg CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
625*b1e83836Smrg CustomFloat!(1, 7, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
626181254a7Smrg CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed)
627181254a7Smrg );
628181254a7Smrg
foreach(F;FPTypes)629181254a7Smrg foreach (F; FPTypes)
630181254a7Smrg {
631181254a7Smrg auto x = F(0.125);
632181254a7Smrg assert(x.get!float == 0.125F);
633181254a7Smrg assert(x.get!double == 0.125);
634181254a7Smrg
635181254a7Smrg x -= 0.0625;
636181254a7Smrg assert(x.get!float == 0.0625F);
637181254a7Smrg assert(x.get!double == 0.0625);
638181254a7Smrg
639181254a7Smrg x *= 2;
640181254a7Smrg assert(x.get!float == 0.125F);
641181254a7Smrg assert(x.get!double == 0.125);
642181254a7Smrg
643181254a7Smrg x /= 4;
644181254a7Smrg assert(x.get!float == 0.03125);
645181254a7Smrg assert(x.get!double == 0.03125);
646181254a7Smrg
647181254a7Smrg x = 0.5;
648181254a7Smrg x ^^= 4;
649181254a7Smrg assert(x.get!float == 1 / 16.0F);
650181254a7Smrg assert(x.get!double == 1 / 16.0);
651181254a7Smrg }
652181254a7Smrg }
653181254a7Smrg
654181254a7Smrg @system unittest
655181254a7Smrg {
656181254a7Smrg // @system due to to!string(CustomFloat)
657181254a7Smrg import std.conv;
658181254a7Smrg CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125);
659181254a7Smrg assert(y.to!string == "0.125");
660181254a7Smrg }
661181254a7Smrg
662*b1e83836Smrg @safe unittest
663*b1e83836Smrg {
664*b1e83836Smrg alias cf = CustomFloat!(5, 2);
665*b1e83836Smrg
666*b1e83836Smrg auto a = cf.infinity;
667*b1e83836Smrg assert(a.sign == 0);
668*b1e83836Smrg assert(a.exponent == 3);
669*b1e83836Smrg assert(a.significand == 0);
670*b1e83836Smrg
671*b1e83836Smrg auto b = cf.nan;
672*b1e83836Smrg assert(b.exponent == 3);
673*b1e83836Smrg assert(b.significand != 0);
674*b1e83836Smrg
675*b1e83836Smrg assert(cf.dig == 1);
676*b1e83836Smrg
677*b1e83836Smrg auto c = cf.epsilon;
678*b1e83836Smrg assert(c.sign == 0);
679*b1e83836Smrg assert(c.exponent == 0);
680*b1e83836Smrg assert(c.significand == 1);
681*b1e83836Smrg
682*b1e83836Smrg assert(cf.mant_dig == 6);
683*b1e83836Smrg
684*b1e83836Smrg assert(cf.max_10_exp == 0);
685*b1e83836Smrg assert(cf.max_exp == 2);
686*b1e83836Smrg assert(cf.min_10_exp == 0);
687*b1e83836Smrg assert(cf.min_exp == 1);
688*b1e83836Smrg
689*b1e83836Smrg auto d = cf.max;
690*b1e83836Smrg assert(d.sign == 0);
691*b1e83836Smrg assert(d.exponent == 2);
692*b1e83836Smrg assert(d.significand == 31);
693*b1e83836Smrg
694*b1e83836Smrg auto e = cf.min_normal;
695*b1e83836Smrg assert(e.sign == 0);
696*b1e83836Smrg assert(e.exponent == 1);
697*b1e83836Smrg assert(e.significand == 0);
698*b1e83836Smrg
699*b1e83836Smrg assert(e.re == e);
700*b1e83836Smrg assert(e.im == cf(0.0));
701*b1e83836Smrg }
702*b1e83836Smrg
703*b1e83836Smrg // check whether CustomFloats identical to float/double behave like float/double
704*b1e83836Smrg @safe unittest
705*b1e83836Smrg {
706*b1e83836Smrg import std.conv : to;
707*b1e83836Smrg
708*b1e83836Smrg alias myFloat = CustomFloat!(23, 8);
709*b1e83836Smrg
710*b1e83836Smrg static assert(myFloat.dig == float.dig);
711*b1e83836Smrg static assert(myFloat.mant_dig == float.mant_dig);
712*b1e83836Smrg assert(myFloat.max_10_exp == float.max_10_exp);
713*b1e83836Smrg static assert(myFloat.max_exp == float.max_exp);
714*b1e83836Smrg assert(myFloat.min_10_exp == float.min_10_exp);
715*b1e83836Smrg static assert(myFloat.min_exp == float.min_exp);
716*b1e83836Smrg assert(to!float(myFloat.epsilon) == float.epsilon);
717*b1e83836Smrg assert(to!float(myFloat.max) == float.max);
718*b1e83836Smrg assert(to!float(myFloat.min_normal) == float.min_normal);
719*b1e83836Smrg
720*b1e83836Smrg alias myDouble = CustomFloat!(52, 11);
721*b1e83836Smrg
722*b1e83836Smrg static assert(myDouble.dig == double.dig);
723*b1e83836Smrg static assert(myDouble.mant_dig == double.mant_dig);
724*b1e83836Smrg assert(myDouble.max_10_exp == double.max_10_exp);
725*b1e83836Smrg static assert(myDouble.max_exp == double.max_exp);
726*b1e83836Smrg assert(myDouble.min_10_exp == double.min_10_exp);
727*b1e83836Smrg static assert(myDouble.min_exp == double.min_exp);
728*b1e83836Smrg assert(to!double(myDouble.epsilon) == double.epsilon);
729*b1e83836Smrg assert(to!double(myDouble.max) == double.max);
730*b1e83836Smrg assert(to!double(myDouble.min_normal) == double.min_normal);
731*b1e83836Smrg }
732*b1e83836Smrg
733*b1e83836Smrg // testing .dig
734*b1e83836Smrg @safe unittest
735*b1e83836Smrg {
736*b1e83836Smrg static assert(CustomFloat!(1, 6).dig == 0);
737*b1e83836Smrg static assert(CustomFloat!(9, 6).dig == 2);
738*b1e83836Smrg static assert(CustomFloat!(10, 5).dig == 3);
739*b1e83836Smrg static assert(CustomFloat!(10, 6, CustomFloatFlags.none).dig == 2);
740*b1e83836Smrg static assert(CustomFloat!(11, 5, CustomFloatFlags.none).dig == 3);
741*b1e83836Smrg static assert(CustomFloat!(64, 7).dig == 19);
742*b1e83836Smrg }
743*b1e83836Smrg
744*b1e83836Smrg // testing .mant_dig
745*b1e83836Smrg @safe unittest
746*b1e83836Smrg {
747*b1e83836Smrg static assert(CustomFloat!(10, 5).mant_dig == 11);
748*b1e83836Smrg static assert(CustomFloat!(10, 6, CustomFloatFlags.none).mant_dig == 10);
749*b1e83836Smrg }
750*b1e83836Smrg
751*b1e83836Smrg // testing .max_exp
752*b1e83836Smrg @safe unittest
753*b1e83836Smrg {
754*b1e83836Smrg static assert(CustomFloat!(1, 6).max_exp == 2^^5);
755*b1e83836Smrg static assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_exp == 2^^5);
756*b1e83836Smrg static assert(CustomFloat!(5, 10).max_exp == 2^^9);
757*b1e83836Smrg static assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_exp == 2^^9);
758*b1e83836Smrg static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_exp == 2^^5);
759*b1e83836Smrg static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_exp == 2^^9);
760*b1e83836Smrg }
761*b1e83836Smrg
762*b1e83836Smrg // testing .min_exp
763*b1e83836Smrg @safe unittest
764*b1e83836Smrg {
765*b1e83836Smrg static assert(CustomFloat!(1, 6).min_exp == -2^^5+3);
766*b1e83836Smrg static assert(CustomFloat!(5, 10).min_exp == -2^^9+3);
767*b1e83836Smrg static assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_exp == -2^^5+1);
768*b1e83836Smrg static assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_exp == -2^^9+1);
769*b1e83836Smrg static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_exp == -2^^5+2);
770*b1e83836Smrg static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_exp == -2^^9+2);
771*b1e83836Smrg static assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_exp == -2^^5+2);
772*b1e83836Smrg static assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_exp == -2^^9+2);
773*b1e83836Smrg }
774*b1e83836Smrg
775*b1e83836Smrg // testing .max_10_exp
776*b1e83836Smrg @safe unittest
777*b1e83836Smrg {
778*b1e83836Smrg assert(CustomFloat!(1, 6).max_10_exp == 9);
779*b1e83836Smrg assert(CustomFloat!(5, 10).max_10_exp == 154);
780*b1e83836Smrg assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_10_exp == 9);
781*b1e83836Smrg assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_10_exp == 154);
782*b1e83836Smrg assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_10_exp == 9);
783*b1e83836Smrg assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_10_exp == 154);
784*b1e83836Smrg }
785*b1e83836Smrg
786*b1e83836Smrg // testing .min_10_exp
787*b1e83836Smrg @safe unittest
788*b1e83836Smrg {
789*b1e83836Smrg assert(CustomFloat!(1, 6).min_10_exp == -9);
790*b1e83836Smrg assert(CustomFloat!(5, 10).min_10_exp == -153);
791*b1e83836Smrg assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_10_exp == -9);
792*b1e83836Smrg assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_10_exp == -154);
793*b1e83836Smrg assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_10_exp == -9);
794*b1e83836Smrg assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_10_exp == -153);
795*b1e83836Smrg assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_10_exp == -9);
796*b1e83836Smrg assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_10_exp == -153);
797*b1e83836Smrg }
798*b1e83836Smrg
799*b1e83836Smrg // testing .epsilon
800*b1e83836Smrg @safe unittest
801*b1e83836Smrg {
802*b1e83836Smrg assert(CustomFloat!(1,6).epsilon.sign == 0);
803*b1e83836Smrg assert(CustomFloat!(1,6).epsilon.exponent == 30);
804*b1e83836Smrg assert(CustomFloat!(1,6).epsilon.significand == 0);
805*b1e83836Smrg assert(CustomFloat!(2,5).epsilon.sign == 0);
806*b1e83836Smrg assert(CustomFloat!(2,5).epsilon.exponent == 13);
807*b1e83836Smrg assert(CustomFloat!(2,5).epsilon.significand == 0);
808*b1e83836Smrg assert(CustomFloat!(3,4).epsilon.sign == 0);
809*b1e83836Smrg assert(CustomFloat!(3,4).epsilon.exponent == 4);
810*b1e83836Smrg assert(CustomFloat!(3,4).epsilon.significand == 0);
811*b1e83836Smrg // the following epsilons are only available, when denormalized numbers are allowed:
812*b1e83836Smrg assert(CustomFloat!(4,3).epsilon.sign == 0);
813*b1e83836Smrg assert(CustomFloat!(4,3).epsilon.exponent == 0);
814*b1e83836Smrg assert(CustomFloat!(4,3).epsilon.significand == 4);
815*b1e83836Smrg assert(CustomFloat!(5,2).epsilon.sign == 0);
816*b1e83836Smrg assert(CustomFloat!(5,2).epsilon.exponent == 0);
817*b1e83836Smrg assert(CustomFloat!(5,2).epsilon.significand == 1);
818*b1e83836Smrg }
819*b1e83836Smrg
820*b1e83836Smrg // testing .max
821*b1e83836Smrg @safe unittest
822*b1e83836Smrg {
823*b1e83836Smrg static assert(CustomFloat!(5,2).max.sign == 0);
824*b1e83836Smrg static assert(CustomFloat!(5,2).max.exponent == 2);
825*b1e83836Smrg static assert(CustomFloat!(5,2).max.significand == 31);
826*b1e83836Smrg static assert(CustomFloat!(4,3).max.sign == 0);
827*b1e83836Smrg static assert(CustomFloat!(4,3).max.exponent == 6);
828*b1e83836Smrg static assert(CustomFloat!(4,3).max.significand == 15);
829*b1e83836Smrg static assert(CustomFloat!(3,4).max.sign == 0);
830*b1e83836Smrg static assert(CustomFloat!(3,4).max.exponent == 14);
831*b1e83836Smrg static assert(CustomFloat!(3,4).max.significand == 7);
832*b1e83836Smrg static assert(CustomFloat!(2,5).max.sign == 0);
833*b1e83836Smrg static assert(CustomFloat!(2,5).max.exponent == 30);
834*b1e83836Smrg static assert(CustomFloat!(2,5).max.significand == 3);
835*b1e83836Smrg static assert(CustomFloat!(1,6).max.sign == 0);
836*b1e83836Smrg static assert(CustomFloat!(1,6).max.exponent == 62);
837*b1e83836Smrg static assert(CustomFloat!(1,6).max.significand == 1);
838*b1e83836Smrg static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.exponent == 31);
839*b1e83836Smrg static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.significand == 7);
840*b1e83836Smrg }
841*b1e83836Smrg
842*b1e83836Smrg // testing .min_normal
843*b1e83836Smrg @safe unittest
844*b1e83836Smrg {
845*b1e83836Smrg static assert(CustomFloat!(5,2).min_normal.sign == 0);
846*b1e83836Smrg static assert(CustomFloat!(5,2).min_normal.exponent == 1);
847*b1e83836Smrg static assert(CustomFloat!(5,2).min_normal.significand == 0);
848*b1e83836Smrg static assert(CustomFloat!(4,3).min_normal.sign == 0);
849*b1e83836Smrg static assert(CustomFloat!(4,3).min_normal.exponent == 1);
850*b1e83836Smrg static assert(CustomFloat!(4,3).min_normal.significand == 0);
851*b1e83836Smrg static assert(CustomFloat!(3,4).min_normal.sign == 0);
852*b1e83836Smrg static assert(CustomFloat!(3,4).min_normal.exponent == 1);
853*b1e83836Smrg static assert(CustomFloat!(3,4).min_normal.significand == 0);
854*b1e83836Smrg static assert(CustomFloat!(2,5).min_normal.sign == 0);
855*b1e83836Smrg static assert(CustomFloat!(2,5).min_normal.exponent == 1);
856*b1e83836Smrg static assert(CustomFloat!(2,5).min_normal.significand == 0);
857*b1e83836Smrg static assert(CustomFloat!(1,6).min_normal.sign == 0);
858*b1e83836Smrg static assert(CustomFloat!(1,6).min_normal.exponent == 1);
859*b1e83836Smrg static assert(CustomFloat!(1,6).min_normal.significand == 0);
860*b1e83836Smrg static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.exponent == 0);
861*b1e83836Smrg static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.significand == 4);
862*b1e83836Smrg }
863*b1e83836Smrg
864*b1e83836Smrg @safe unittest
865*b1e83836Smrg {
866*b1e83836Smrg import std.math.traits : isNaN;
867*b1e83836Smrg
868*b1e83836Smrg alias cf = CustomFloat!(5, 2);
869*b1e83836Smrg
870*b1e83836Smrg auto f = cf.nan.get!float();
871*b1e83836Smrg assert(isNaN(f));
872*b1e83836Smrg
873*b1e83836Smrg cf a;
874*b1e83836Smrg a = real.max;
875*b1e83836Smrg assert(a == cf.infinity);
876*b1e83836Smrg
877*b1e83836Smrg a = 0.015625;
878*b1e83836Smrg assert(a.exponent == 0);
879*b1e83836Smrg assert(a.significand == 0);
880*b1e83836Smrg
881*b1e83836Smrg a = 0.984375;
882*b1e83836Smrg assert(a.exponent == 1);
883*b1e83836Smrg assert(a.significand == 0);
884*b1e83836Smrg }
885*b1e83836Smrg
886*b1e83836Smrg @system unittest
887*b1e83836Smrg {
888*b1e83836Smrg import std.exception : assertThrown;
889*b1e83836Smrg import core.exception : AssertError;
890*b1e83836Smrg
891*b1e83836Smrg alias cf = CustomFloat!(3, 5, CustomFloatFlags.none);
892*b1e83836Smrg
893*b1e83836Smrg cf a;
894*b1e83836Smrg assertThrown!AssertError(a = real.max);
895*b1e83836Smrg }
896*b1e83836Smrg
897*b1e83836Smrg @system unittest
898*b1e83836Smrg {
899*b1e83836Smrg import std.exception : assertThrown;
900*b1e83836Smrg import core.exception : AssertError;
901*b1e83836Smrg
902*b1e83836Smrg alias cf = CustomFloat!(3, 5, CustomFloatFlags.nan);
903*b1e83836Smrg
904*b1e83836Smrg cf a;
905*b1e83836Smrg assertThrown!AssertError(a = real.max);
906*b1e83836Smrg }
907*b1e83836Smrg
908*b1e83836Smrg @system unittest
909*b1e83836Smrg {
910*b1e83836Smrg import std.exception : assertThrown;
911*b1e83836Smrg import core.exception : AssertError;
912*b1e83836Smrg
913*b1e83836Smrg alias cf = CustomFloat!(24, 8, CustomFloatFlags.none);
914*b1e83836Smrg
915*b1e83836Smrg cf a;
916*b1e83836Smrg assertThrown!AssertError(a = float.infinity);
917*b1e83836Smrg }
918*b1e83836Smrg
isCorrectCustomFloat(uint precision,uint exponentWidth,CustomFloatFlags flags)919*b1e83836Smrg private bool isCorrectCustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags) @safe pure nothrow @nogc
920*b1e83836Smrg {
921*b1e83836Smrg // Restrictions from bitfield
922*b1e83836Smrg // due to CustomFloat!80 support hack precision with 64 bits is handled specially
923*b1e83836Smrg auto length = (flags & flags.signed) + exponentWidth + ((precision == 64) ? 0 : precision);
924*b1e83836Smrg if (length != 8 && length != 16 && length != 32 && length != 64) return false;
925*b1e83836Smrg
926*b1e83836Smrg // mantissa needs to fit into real mantissa
927*b1e83836Smrg if (precision > real.mant_dig - 1 && precision != 64) return false;
928*b1e83836Smrg
929*b1e83836Smrg // exponent needs to fit into real exponent
930*b1e83836Smrg if (1L << exponentWidth - 1 > real.max_exp) return false;
931*b1e83836Smrg
932*b1e83836Smrg // mantissa should have at least one bit
933*b1e83836Smrg if (precision == 0) return false;
934*b1e83836Smrg
935*b1e83836Smrg // exponent should have at least one bit, in some cases two
936*b1e83836Smrg if (exponentWidth <= ((flags & (flags.allowDenorm | flags.infinity | flags.nan)) != 0)) return false;
937*b1e83836Smrg
938*b1e83836Smrg return true;
939*b1e83836Smrg }
940*b1e83836Smrg
941*b1e83836Smrg @safe pure nothrow @nogc unittest
942*b1e83836Smrg {
943*b1e83836Smrg assert(isCorrectCustomFloat(3,4,CustomFloatFlags.ieee));
944*b1e83836Smrg assert(isCorrectCustomFloat(3,5,CustomFloatFlags.none));
945*b1e83836Smrg assert(!isCorrectCustomFloat(3,3,CustomFloatFlags.ieee));
946*b1e83836Smrg assert(isCorrectCustomFloat(64,7,CustomFloatFlags.ieee));
947*b1e83836Smrg assert(!isCorrectCustomFloat(64,4,CustomFloatFlags.ieee));
948*b1e83836Smrg assert(!isCorrectCustomFloat(508,3,CustomFloatFlags.ieee));
949*b1e83836Smrg assert(!isCorrectCustomFloat(3,100,CustomFloatFlags.ieee));
950*b1e83836Smrg assert(!isCorrectCustomFloat(0,7,CustomFloatFlags.ieee));
951*b1e83836Smrg assert(!isCorrectCustomFloat(6,1,CustomFloatFlags.ieee));
952*b1e83836Smrg assert(isCorrectCustomFloat(7,1,CustomFloatFlags.none));
953*b1e83836Smrg assert(!isCorrectCustomFloat(8,0,CustomFloatFlags.none));
954*b1e83836Smrg }
955*b1e83836Smrg
956181254a7Smrg /**
957181254a7Smrg Defines the fastest type to use when storing temporaries of a
958*b1e83836Smrg calculation intended to ultimately yield a result of type `F`
959*b1e83836Smrg (where `F` must be one of `float`, `double`, or $(D
960181254a7Smrg real)). When doing a multi-step computation, you may want to store
961*b1e83836Smrg intermediate results as `FPTemporary!F`.
962181254a7Smrg
963*b1e83836Smrg The necessity of `FPTemporary` stems from the optimized
964181254a7Smrg floating-point operations and registers present in virtually all
965181254a7Smrg processors. When adding numbers in the example above, the addition may
966*b1e83836Smrg in fact be done in `real` precision internally. In that case,
967*b1e83836Smrg storing the intermediate `result` in $(D double format) is not only
968181254a7Smrg less precise, it is also (surprisingly) slower, because a conversion
969*b1e83836Smrg from `real` to `double` is performed every pass through the
970*b1e83836Smrg loop. This being a lose-lose situation, `FPTemporary!F` has been
971181254a7Smrg defined as the $(I fastest) type to use for calculations at precision
972*b1e83836Smrg `F`. There is no need to define a type for the $(I most accurate)
973*b1e83836Smrg calculations, as that is always `real`.
974181254a7Smrg
975*b1e83836Smrg Finally, there is no guarantee that using `FPTemporary!F` will
976181254a7Smrg always be fastest, as the speed of floating-point calculations depends
977181254a7Smrg on very many factors.
978181254a7Smrg */
979181254a7Smrg template FPTemporary(F)
980181254a7Smrg if (isFloatingPoint!F)
981181254a7Smrg {
982181254a7Smrg version (X86)
983181254a7Smrg alias FPTemporary = real;
984181254a7Smrg else
985181254a7Smrg alias FPTemporary = Unqual!F;
986181254a7Smrg }
987181254a7Smrg
988181254a7Smrg ///
989181254a7Smrg @safe unittest
990181254a7Smrg {
991*b1e83836Smrg import std.math.operations : isClose;
992181254a7Smrg
993181254a7Smrg // Average numbers in an array
avg(in double[]a)994181254a7Smrg double avg(in double[] a)
995181254a7Smrg {
996181254a7Smrg if (a.length == 0) return 0;
997181254a7Smrg FPTemporary!double result = 0;
998181254a7Smrg foreach (e; a) result += e;
999181254a7Smrg return result / a.length;
1000181254a7Smrg }
1001181254a7Smrg
1002181254a7Smrg auto a = [1.0, 2.0, 3.0];
1003*b1e83836Smrg assert(isClose(avg(a), 2));
1004181254a7Smrg }
1005181254a7Smrg
1006181254a7Smrg /**
1007181254a7Smrg Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a
1008*b1e83836Smrg root of the function `fun` starting from points $(D [xn_1, x_n])
1009*b1e83836Smrg (ideally close to the root). `Num` may be `float`, `double`,
1010*b1e83836Smrg or `real`.
1011181254a7Smrg */
secantMethod(alias fun)1012181254a7Smrg template secantMethod(alias fun)
1013181254a7Smrg {
1014181254a7Smrg import std.functional : unaryFun;
1015181254a7Smrg Num secantMethod(Num)(Num xn_1, Num xn)
1016181254a7Smrg {
1017181254a7Smrg auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn;
1018181254a7Smrg typeof(fxn) fxn_1;
1019181254a7Smrg
1020181254a7Smrg xn = xn_1;
1021*b1e83836Smrg while (!isClose(d, 0, 0.0, 1e-5) && isFinite(d))
1022181254a7Smrg {
1023181254a7Smrg xn_1 = xn;
1024181254a7Smrg xn -= d;
1025181254a7Smrg fxn_1 = fxn;
1026181254a7Smrg fxn = unaryFun!(fun)(xn);
1027181254a7Smrg d *= -fxn / (fxn - fxn_1);
1028181254a7Smrg }
1029181254a7Smrg return xn;
1030181254a7Smrg }
1031181254a7Smrg }
1032181254a7Smrg
1033181254a7Smrg ///
1034181254a7Smrg @safe unittest
1035181254a7Smrg {
1036*b1e83836Smrg import std.math.operations : isClose;
1037*b1e83836Smrg import std.math.trigonometry : cos;
1038181254a7Smrg
f(float x)1039181254a7Smrg float f(float x)
1040181254a7Smrg {
1041181254a7Smrg return cos(x) - x*x*x;
1042181254a7Smrg }
1043181254a7Smrg auto x = secantMethod!(f)(0f, 1f);
1044*b1e83836Smrg assert(isClose(x, 0.865474));
1045181254a7Smrg }
1046181254a7Smrg
1047181254a7Smrg @system unittest
1048181254a7Smrg {
1049181254a7Smrg // @system because of __gshared stderr
1050*b1e83836Smrg import std.stdio;
1051181254a7Smrg scope(failure) stderr.writeln("Failure testing secantMethod");
f(float x)1052181254a7Smrg float f(float x)
1053181254a7Smrg {
1054181254a7Smrg return cos(x) - x*x*x;
1055181254a7Smrg }
1056181254a7Smrg immutable x = secantMethod!(f)(0f, 1f);
1057*b1e83836Smrg assert(isClose(x, 0.865474));
1058181254a7Smrg auto d = &f;
1059181254a7Smrg immutable y = secantMethod!(d)(0f, 1f);
1060*b1e83836Smrg assert(isClose(y, 0.865474));
1061181254a7Smrg }
1062181254a7Smrg
1063181254a7Smrg
1064181254a7Smrg /**
1065181254a7Smrg * Return true if a and b have opposite sign.
1066181254a7Smrg */
oppositeSigns(T1,T2)1067181254a7Smrg private bool oppositeSigns(T1, T2)(T1 a, T2 b)
1068181254a7Smrg {
1069181254a7Smrg return signbit(a) != signbit(b);
1070181254a7Smrg }
1071181254a7Smrg
1072181254a7Smrg public:
1073181254a7Smrg
1074181254a7Smrg /** Find a real root of a real function f(x) via bracketing.
1075181254a7Smrg *
1076181254a7Smrg * Given a function `f` and a range `[a .. b]` such that `f(a)`
1077181254a7Smrg * and `f(b)` have opposite signs or at least one of them equals ±0,
1078181254a7Smrg * returns the value of `x` in
1079181254a7Smrg * the range which is closest to a root of `f(x)`. If `f(x)`
1080181254a7Smrg * has more than one root in the range, one will be chosen
1081181254a7Smrg * arbitrarily. If `f(x)` returns NaN, NaN will be returned;
1082181254a7Smrg * otherwise, this algorithm is guaranteed to succeed.
1083181254a7Smrg *
1084181254a7Smrg * Uses an algorithm based on TOMS748, which uses inverse cubic
1085181254a7Smrg * interpolation whenever possible, otherwise reverting to parabolic
1086181254a7Smrg * or secant interpolation. Compared to TOMS748, this implementation
1087181254a7Smrg * improves worst-case performance by a factor of more than 100, and
1088181254a7Smrg * typical performance by a factor of 2. For 80-bit reals, most
1089181254a7Smrg * problems require 8 to 15 calls to `f(x)` to achieve full machine
1090181254a7Smrg * precision. The worst-case performance (pathological cases) is
1091181254a7Smrg * approximately twice the number of bits.
1092181254a7Smrg *
1093181254a7Smrg * References: "On Enclosing Simple Roots of Nonlinear Equations",
1094181254a7Smrg * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61,
1095181254a7Smrg * pp733-744 (1993). Fortran code available from $(HTTP
1096181254a7Smrg * www.netlib.org,www.netlib.org) as algorithm TOMS478.
1097181254a7Smrg *
1098181254a7Smrg */
1099*b1e83836Smrg T findRoot(T, DF, DT)(scope DF f, const T a, const T b,
1100181254a7Smrg scope DT tolerance) //= (T a, T b) => false)
1101181254a7Smrg if (
1102181254a7Smrg isFloatingPoint!T &&
1103181254a7Smrg is(typeof(tolerance(T.init, T.init)) : bool) &&
1104181254a7Smrg is(typeof(f(T.init)) == R, R) && isFloatingPoint!R
1105181254a7Smrg )
1106181254a7Smrg {
1107181254a7Smrg immutable fa = f(a);
1108181254a7Smrg if (fa == 0)
1109181254a7Smrg return a;
1110181254a7Smrg immutable fb = f(b);
1111181254a7Smrg if (fb == 0)
1112181254a7Smrg return b;
1113181254a7Smrg immutable r = findRoot(f, a, b, fa, fb, tolerance);
1114181254a7Smrg // Return the first value if it is smaller or NaN
1115181254a7Smrg return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1];
1116181254a7Smrg }
1117181254a7Smrg
1118181254a7Smrg ///ditto
findRoot(T,DF)1119*b1e83836Smrg T findRoot(T, DF)(scope DF f, const T a, const T b)
1120181254a7Smrg {
1121181254a7Smrg return findRoot(f, a, b, (T a, T b) => false);
1122181254a7Smrg }
1123181254a7Smrg
1124181254a7Smrg /** Find root of a real function f(x) by bracketing, allowing the
1125181254a7Smrg * termination condition to be specified.
1126181254a7Smrg *
1127181254a7Smrg * Params:
1128181254a7Smrg *
1129181254a7Smrg * f = Function to be analyzed
1130181254a7Smrg *
1131181254a7Smrg * ax = Left bound of initial range of `f` known to contain the
1132181254a7Smrg * root.
1133181254a7Smrg *
1134181254a7Smrg * bx = Right bound of initial range of `f` known to contain the
1135181254a7Smrg * root.
1136181254a7Smrg *
1137*b1e83836Smrg * fax = Value of `f(ax)`.
1138181254a7Smrg *
1139*b1e83836Smrg * fbx = Value of `f(bx)`. `fax` and `fbx` should have opposite signs.
1140*b1e83836Smrg * (`f(ax)` and `f(bx)` are commonly known in advance.)
1141181254a7Smrg *
1142181254a7Smrg *
1143181254a7Smrg * tolerance = Defines an early termination condition. Receives the
1144181254a7Smrg * current upper and lower bounds on the root. The
1145*b1e83836Smrg * delegate must return `true` when these bounds are
1146*b1e83836Smrg * acceptable. If this function always returns `false`,
1147181254a7Smrg * full machine precision will be achieved.
1148181254a7Smrg *
1149181254a7Smrg * Returns:
1150181254a7Smrg *
1151181254a7Smrg * A tuple consisting of two ranges. The first two elements are the
1152181254a7Smrg * range (in `x`) of the root, while the second pair of elements
1153181254a7Smrg * are the corresponding function values at those points. If an exact
1154181254a7Smrg * root was found, both of the first two elements will contain the
1155181254a7Smrg * root, and the second pair of elements will be 0.
1156181254a7Smrg */
1157*b1e83836Smrg Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f,
1158*b1e83836Smrg const T ax, const T bx, const R fax, const R fbx,
1159181254a7Smrg scope DT tolerance) // = (T a, T b) => false)
1160181254a7Smrg if (
1161181254a7Smrg isFloatingPoint!T &&
1162181254a7Smrg is(typeof(tolerance(T.init, T.init)) : bool) &&
1163181254a7Smrg is(typeof(f(T.init)) == R) && isFloatingPoint!R
1164181254a7Smrg )
1165181254a7Smrg in
1166181254a7Smrg {
1167181254a7Smrg assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN");
1168181254a7Smrg assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root.");
1169181254a7Smrg }
1170*b1e83836Smrg do
1171181254a7Smrg {
1172181254a7Smrg // Author: Don Clugston. This code is (heavily) modified from TOMS748
1173181254a7Smrg // (www.netlib.org). The changes to improve the worst-cast performance are
1174181254a7Smrg // entirely original.
1175181254a7Smrg
1176181254a7Smrg T a, b, d; // [a .. b] is our current bracket. d is the third best guess.
1177181254a7Smrg R fa, fb, fd; // Values of f at a, b, d.
1178181254a7Smrg bool done = false; // Has a root been found?
1179181254a7Smrg
1180181254a7Smrg // Allow ax and bx to be provided in reverse order
1181181254a7Smrg if (ax <= bx)
1182181254a7Smrg {
1183181254a7Smrg a = ax; fa = fax;
1184181254a7Smrg b = bx; fb = fbx;
1185181254a7Smrg }
1186181254a7Smrg else
1187181254a7Smrg {
1188181254a7Smrg a = bx; fa = fbx;
1189181254a7Smrg b = ax; fb = fax;
1190181254a7Smrg }
1191181254a7Smrg
1192181254a7Smrg // Test the function at point c; update brackets accordingly
bracket(T c)1193181254a7Smrg void bracket(T c)
1194181254a7Smrg {
1195181254a7Smrg R fc = f(c);
1196181254a7Smrg if (fc == 0 || fc.isNaN()) // Exact solution, or NaN
1197181254a7Smrg {
1198181254a7Smrg a = c;
1199181254a7Smrg fa = fc;
1200181254a7Smrg d = c;
1201181254a7Smrg fd = fc;
1202181254a7Smrg done = true;
1203181254a7Smrg return;
1204181254a7Smrg }
1205181254a7Smrg
1206181254a7Smrg // Determine new enclosing interval
1207181254a7Smrg if (signbit(fa) != signbit(fc))
1208181254a7Smrg {
1209181254a7Smrg d = b;
1210181254a7Smrg fd = fb;
1211181254a7Smrg b = c;
1212181254a7Smrg fb = fc;
1213181254a7Smrg }
1214181254a7Smrg else
1215181254a7Smrg {
1216181254a7Smrg d = a;
1217181254a7Smrg fd = fa;
1218181254a7Smrg a = c;
1219181254a7Smrg fa = fc;
1220181254a7Smrg }
1221181254a7Smrg }
1222181254a7Smrg
1223181254a7Smrg /* Perform a secant interpolation. If the result would lie on a or b, or if
1224181254a7Smrg a and b differ so wildly in magnitude that the result would be meaningless,
1225181254a7Smrg perform a bisection instead.
1226181254a7Smrg */
secant_interpolate(T a,T b,R fa,R fb)1227181254a7Smrg static T secant_interpolate(T a, T b, R fa, R fb)
1228181254a7Smrg {
1229181254a7Smrg if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b)))
1230181254a7Smrg {
1231181254a7Smrg // Catastrophic cancellation
1232181254a7Smrg if (a == 0)
1233181254a7Smrg a = copysign(T(0), b);
1234181254a7Smrg else if (b == 0)
1235181254a7Smrg b = copysign(T(0), a);
1236181254a7Smrg else if (signbit(a) != signbit(b))
1237181254a7Smrg return 0;
1238181254a7Smrg T c = ieeeMean(a, b);
1239181254a7Smrg return c;
1240181254a7Smrg }
1241181254a7Smrg // avoid overflow
1242181254a7Smrg if (b - a > T.max)
1243181254a7Smrg return b / 2 + a / 2;
1244181254a7Smrg if (fb - fa > R.max)
1245181254a7Smrg return a - (b - a) / 2;
1246181254a7Smrg T c = a - (fa / (fb - fa)) * (b - a);
1247181254a7Smrg if (c == a || c == b)
1248181254a7Smrg return (a + b) / 2;
1249181254a7Smrg return c;
1250181254a7Smrg }
1251181254a7Smrg
1252181254a7Smrg /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the
1253181254a7Smrg quadratic polynomial interpolating f(x) at a, b, and d.
1254181254a7Smrg Returns:
1255181254a7Smrg The approximate zero in [a .. b] of the quadratic polynomial.
1256181254a7Smrg */
newtonQuadratic(int numsteps)1257181254a7Smrg T newtonQuadratic(int numsteps)
1258181254a7Smrg {
1259181254a7Smrg // Find the coefficients of the quadratic polynomial.
1260181254a7Smrg immutable T a0 = fa;
1261181254a7Smrg immutable T a1 = (fb - fa)/(b - a);
1262181254a7Smrg immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a);
1263181254a7Smrg
1264181254a7Smrg // Determine the starting point of newton steps.
1265181254a7Smrg T c = oppositeSigns(a2, fa) ? a : b;
1266181254a7Smrg
1267181254a7Smrg // start the safeguarded newton steps.
1268181254a7Smrg foreach (int i; 0 .. numsteps)
1269181254a7Smrg {
1270181254a7Smrg immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a);
1271181254a7Smrg immutable T pdc = a1 + a2*((2 * c) - (a + b));
1272181254a7Smrg if (pdc == 0)
1273181254a7Smrg return a - a0 / a1;
1274181254a7Smrg else
1275181254a7Smrg c = c - pc / pdc;
1276181254a7Smrg }
1277181254a7Smrg return c;
1278181254a7Smrg }
1279181254a7Smrg
1280181254a7Smrg // On the first iteration we take a secant step:
1281181254a7Smrg if (fa == 0 || fa.isNaN())
1282181254a7Smrg {
1283181254a7Smrg done = true;
1284181254a7Smrg b = a;
1285181254a7Smrg fb = fa;
1286181254a7Smrg }
1287181254a7Smrg else if (fb == 0 || fb.isNaN())
1288181254a7Smrg {
1289181254a7Smrg done = true;
1290181254a7Smrg a = b;
1291181254a7Smrg fa = fb;
1292181254a7Smrg }
1293181254a7Smrg else
1294181254a7Smrg {
1295181254a7Smrg bracket(secant_interpolate(a, b, fa, fb));
1296181254a7Smrg }
1297181254a7Smrg
1298181254a7Smrg // Starting with the second iteration, higher-order interpolation can
1299181254a7Smrg // be used.
1300181254a7Smrg int itnum = 1; // Iteration number
1301181254a7Smrg int baditer = 1; // Num bisections to take if an iteration is bad.
1302181254a7Smrg T c, e; // e is our fourth best guess
1303181254a7Smrg R fe;
1304181254a7Smrg
1305181254a7Smrg whileloop:
1306181254a7Smrg while (!done && (b != nextUp(a)) && !tolerance(a, b))
1307181254a7Smrg {
1308181254a7Smrg T a0 = a, b0 = b; // record the brackets
1309181254a7Smrg
1310181254a7Smrg // Do two higher-order (cubic or parabolic) interpolation steps.
1311181254a7Smrg foreach (int QQ; 0 .. 2)
1312181254a7Smrg {
1313181254a7Smrg // Cubic inverse interpolation requires that
1314181254a7Smrg // all four function values fa, fb, fd, and fe are distinct;
1315181254a7Smrg // otherwise use quadratic interpolation.
1316181254a7Smrg bool distinct = (fa != fb) && (fa != fd) && (fa != fe)
1317181254a7Smrg && (fb != fd) && (fb != fe) && (fd != fe);
1318181254a7Smrg // The first time, cubic interpolation is impossible.
1319181254a7Smrg if (itnum<2) distinct = false;
1320181254a7Smrg bool ok = distinct;
1321181254a7Smrg if (distinct)
1322181254a7Smrg {
1323181254a7Smrg // Cubic inverse interpolation of f(x) at a, b, d, and e
1324181254a7Smrg immutable q11 = (d - e) * fd / (fe - fd);
1325181254a7Smrg immutable q21 = (b - d) * fb / (fd - fb);
1326181254a7Smrg immutable q31 = (a - b) * fa / (fb - fa);
1327181254a7Smrg immutable d21 = (b - d) * fd / (fd - fb);
1328181254a7Smrg immutable d31 = (a - b) * fb / (fb - fa);
1329181254a7Smrg
1330181254a7Smrg immutable q22 = (d21 - q11) * fb / (fe - fb);
1331181254a7Smrg immutable q32 = (d31 - q21) * fa / (fd - fa);
1332181254a7Smrg immutable d32 = (d31 - q21) * fd / (fd - fa);
1333181254a7Smrg immutable q33 = (d32 - q22) * fa / (fe - fa);
1334181254a7Smrg c = a + (q31 + q32 + q33);
1335181254a7Smrg if (c.isNaN() || (c <= a) || (c >= b))
1336181254a7Smrg {
1337181254a7Smrg // DAC: If the interpolation predicts a or b, it's
1338181254a7Smrg // probable that it's the actual root. Only allow this if
1339181254a7Smrg // we're already close to the root.
1340181254a7Smrg if (c == a && a - b != a)
1341181254a7Smrg {
1342181254a7Smrg c = nextUp(a);
1343181254a7Smrg }
1344181254a7Smrg else if (c == b && a - b != -b)
1345181254a7Smrg {
1346181254a7Smrg c = nextDown(b);
1347181254a7Smrg }
1348181254a7Smrg else
1349181254a7Smrg {
1350181254a7Smrg ok = false;
1351181254a7Smrg }
1352181254a7Smrg }
1353181254a7Smrg }
1354181254a7Smrg if (!ok)
1355181254a7Smrg {
1356181254a7Smrg // DAC: Alefeld doesn't explain why the number of newton steps
1357181254a7Smrg // should vary.
1358181254a7Smrg c = newtonQuadratic(distinct ? 3 : 2);
1359181254a7Smrg if (c.isNaN() || (c <= a) || (c >= b))
1360181254a7Smrg {
1361181254a7Smrg // Failure, try a secant step:
1362181254a7Smrg c = secant_interpolate(a, b, fa, fb);
1363181254a7Smrg }
1364181254a7Smrg }
1365181254a7Smrg ++itnum;
1366181254a7Smrg e = d;
1367181254a7Smrg fe = fd;
1368181254a7Smrg bracket(c);
1369181254a7Smrg if (done || ( b == nextUp(a)) || tolerance(a, b))
1370181254a7Smrg break whileloop;
1371181254a7Smrg if (itnum == 2)
1372181254a7Smrg continue whileloop;
1373181254a7Smrg }
1374181254a7Smrg
1375181254a7Smrg // Now we take a double-length secant step:
1376181254a7Smrg T u;
1377181254a7Smrg R fu;
1378181254a7Smrg if (fabs(fa) < fabs(fb))
1379181254a7Smrg {
1380181254a7Smrg u = a;
1381181254a7Smrg fu = fa;
1382181254a7Smrg }
1383181254a7Smrg else
1384181254a7Smrg {
1385181254a7Smrg u = b;
1386181254a7Smrg fu = fb;
1387181254a7Smrg }
1388181254a7Smrg c = u - 2 * (fu / (fb - fa)) * (b - a);
1389181254a7Smrg
1390181254a7Smrg // DAC: If the secant predicts a value equal to an endpoint, it's
1391181254a7Smrg // probably false.
1392181254a7Smrg if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2)
1393181254a7Smrg {
1394181254a7Smrg if ((a-b) == a || (b-a) == b)
1395181254a7Smrg {
1396181254a7Smrg if ((a>0 && b<0) || (a<0 && b>0))
1397181254a7Smrg c = 0;
1398181254a7Smrg else
1399181254a7Smrg {
1400181254a7Smrg if (a == 0)
1401181254a7Smrg c = ieeeMean(copysign(T(0), b), b);
1402181254a7Smrg else if (b == 0)
1403181254a7Smrg c = ieeeMean(copysign(T(0), a), a);
1404181254a7Smrg else
1405181254a7Smrg c = ieeeMean(a, b);
1406181254a7Smrg }
1407181254a7Smrg }
1408181254a7Smrg else
1409181254a7Smrg {
1410181254a7Smrg c = a + (b - a) / 2;
1411181254a7Smrg }
1412181254a7Smrg }
1413181254a7Smrg e = d;
1414181254a7Smrg fe = fd;
1415181254a7Smrg bracket(c);
1416181254a7Smrg if (done || (b == nextUp(a)) || tolerance(a, b))
1417181254a7Smrg break;
1418181254a7Smrg
1419181254a7Smrg // IMPROVE THE WORST-CASE PERFORMANCE
1420181254a7Smrg // We must ensure that the bounds reduce by a factor of 2
1421181254a7Smrg // in binary space! every iteration. If we haven't achieved this
1422181254a7Smrg // yet, or if we don't yet know what the exponent is,
1423181254a7Smrg // perform a binary chop.
1424181254a7Smrg
1425181254a7Smrg if ((a == 0 || b == 0 ||
1426181254a7Smrg (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a)))
1427181254a7Smrg && (b - a) < T(0.25) * (b0 - a0))
1428181254a7Smrg {
1429181254a7Smrg baditer = 1;
1430181254a7Smrg continue;
1431181254a7Smrg }
1432181254a7Smrg
1433181254a7Smrg // DAC: If this happens on consecutive iterations, we probably have a
1434181254a7Smrg // pathological function. Perform a number of bisections equal to the
1435181254a7Smrg // total number of consecutive bad iterations.
1436181254a7Smrg
1437181254a7Smrg if ((b - a) < T(0.25) * (b0 - a0))
1438181254a7Smrg baditer = 1;
1439181254a7Smrg foreach (int QQ; 0 .. baditer)
1440181254a7Smrg {
1441181254a7Smrg e = d;
1442181254a7Smrg fe = fd;
1443181254a7Smrg
1444181254a7Smrg T w;
1445181254a7Smrg if ((a>0 && b<0) || (a<0 && b>0))
1446181254a7Smrg w = 0;
1447181254a7Smrg else
1448181254a7Smrg {
1449181254a7Smrg T usea = a;
1450181254a7Smrg T useb = b;
1451181254a7Smrg if (a == 0)
1452181254a7Smrg usea = copysign(T(0), b);
1453181254a7Smrg else if (b == 0)
1454181254a7Smrg useb = copysign(T(0), a);
1455181254a7Smrg w = ieeeMean(usea, useb);
1456181254a7Smrg }
1457181254a7Smrg bracket(w);
1458181254a7Smrg }
1459181254a7Smrg ++baditer;
1460181254a7Smrg }
1461181254a7Smrg return Tuple!(T, T, R, R)(a, b, fa, fb);
1462181254a7Smrg }
1463181254a7Smrg
1464181254a7Smrg ///ditto
1465*b1e83836Smrg Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f,
1466*b1e83836Smrg const T ax, const T bx, const R fax, const R fbx)
1467181254a7Smrg {
1468181254a7Smrg return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false);
1469181254a7Smrg }
1470181254a7Smrg
1471181254a7Smrg ///ditto
findRoot(T,R)1472*b1e83836Smrg T findRoot(T, R)(scope R delegate(T) f, const T a, const T b,
1473181254a7Smrg scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false)
1474181254a7Smrg {
1475181254a7Smrg return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance);
1476181254a7Smrg }
1477181254a7Smrg
1478181254a7Smrg @safe nothrow unittest
1479181254a7Smrg {
1480181254a7Smrg int numProblems = 0;
1481181254a7Smrg int numCalls;
1482181254a7Smrg
testFindRoot(real delegate (real)@nogc@safe nothrow pure f,real x1,real x2)1483181254a7Smrg void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure
1484181254a7Smrg {
1485181254a7Smrg //numCalls=0;
1486181254a7Smrg //++numProblems;
1487181254a7Smrg assert(!x1.isNaN() && !x2.isNaN());
1488*b1e83836Smrg assert(signbit(f(x1)) != signbit(f(x2)));
1489181254a7Smrg auto result = findRoot(f, x1, x2, f(x1), f(x2),
1490181254a7Smrg (real lo, real hi) { return false; });
1491181254a7Smrg
1492181254a7Smrg auto flo = f(result[0]);
1493181254a7Smrg auto fhi = f(result[1]);
1494181254a7Smrg if (flo != 0)
1495181254a7Smrg {
1496181254a7Smrg assert(oppositeSigns(flo, fhi));
1497181254a7Smrg }
1498181254a7Smrg }
1499181254a7Smrg
1500181254a7Smrg // Test functions
cubicfn(real x)1501181254a7Smrg real cubicfn(real x) @nogc @safe nothrow pure
1502181254a7Smrg {
1503181254a7Smrg //++numCalls;
1504181254a7Smrg if (x>float.max)
1505181254a7Smrg x = float.max;
1506*b1e83836Smrg if (x<-float.max)
1507*b1e83836Smrg x = -float.max;
1508181254a7Smrg // This has a single real root at -59.286543284815
1509181254a7Smrg return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2;
1510181254a7Smrg }
1511181254a7Smrg // Test a function with more than one root.
multisine(real x)1512181254a7Smrg real multisine(real x) { ++numCalls; return sin(x); }
1513*b1e83836Smrg testFindRoot( &multisine, 6, 90);
1514*b1e83836Smrg testFindRoot(&cubicfn, -100, 100);
1515*b1e83836Smrg testFindRoot( &cubicfn, -double.max, real.max);
1516181254a7Smrg
1517181254a7Smrg
1518181254a7Smrg /* Tests from the paper:
1519181254a7Smrg * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra,
1520181254a7Smrg * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993).
1521181254a7Smrg */
1522181254a7Smrg // Parameters common to many alefeld tests.
1523181254a7Smrg int n;
1524181254a7Smrg real ale_a, ale_b;
1525181254a7Smrg
1526181254a7Smrg int powercalls = 0;
1527181254a7Smrg
power(real x)1528181254a7Smrg real power(real x)
1529181254a7Smrg {
1530181254a7Smrg ++powercalls;
1531181254a7Smrg ++numCalls;
1532181254a7Smrg return pow(x, n) + double.min_normal;
1533181254a7Smrg }
1534181254a7Smrg int [] power_nvals = [3, 5, 7, 9, 19, 25];
1535181254a7Smrg // Alefeld paper states that pow(x,n) is a very poor case, where bisection
1536181254a7Smrg // outperforms his method, and gives total numcalls =
1537181254a7Smrg // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit),
1538181254a7Smrg // 2624 for brent (6.8/bit)
1539181254a7Smrg // ... but that is for double, not real80.
1540181254a7Smrg // This poor performance seems mainly due to catastrophic cancellation,
1541181254a7Smrg // which is avoided here by the use of ieeeMean().
1542181254a7Smrg // I get: 231 (0.48/bit).
1543181254a7Smrg // IE this is 10X faster in Alefeld's worst case
1544181254a7Smrg numProblems=0;
foreach(k;power_nvals)1545181254a7Smrg foreach (k; power_nvals)
1546181254a7Smrg {
1547181254a7Smrg n = k;
1548*b1e83836Smrg testFindRoot(&power, -1, 10);
1549181254a7Smrg }
1550181254a7Smrg
1551181254a7Smrg int powerProblems = numProblems;
1552181254a7Smrg
1553181254a7Smrg // Tests from Alefeld paper
1554181254a7Smrg
1555181254a7Smrg int [9] alefeldSums;
alefeld0(real x)1556181254a7Smrg real alefeld0(real x)
1557181254a7Smrg {
1558181254a7Smrg ++alefeldSums[0];
1559181254a7Smrg ++numCalls;
1560181254a7Smrg real q = sin(x) - x/2;
1561181254a7Smrg for (int i=1; i<20; ++i)
1562181254a7Smrg q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i));
1563181254a7Smrg return q;
1564181254a7Smrg }
alefeld1(real x)1565181254a7Smrg real alefeld1(real x)
1566181254a7Smrg {
1567181254a7Smrg ++numCalls;
1568181254a7Smrg ++alefeldSums[1];
1569181254a7Smrg return ale_a*x + exp(ale_b * x);
1570181254a7Smrg }
alefeld2(real x)1571181254a7Smrg real alefeld2(real x)
1572181254a7Smrg {
1573181254a7Smrg ++numCalls;
1574181254a7Smrg ++alefeldSums[2];
1575181254a7Smrg return pow(x, n) - ale_a;
1576181254a7Smrg }
alefeld3(real x)1577181254a7Smrg real alefeld3(real x)
1578181254a7Smrg {
1579181254a7Smrg ++numCalls;
1580181254a7Smrg ++alefeldSums[3];
1581181254a7Smrg return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2);
1582181254a7Smrg }
alefeld4(real x)1583181254a7Smrg real alefeld4(real x)
1584181254a7Smrg {
1585181254a7Smrg ++numCalls;
1586181254a7Smrg ++alefeldSums[4];
1587181254a7Smrg return x*x - pow(1-x, n);
1588181254a7Smrg }
alefeld5(real x)1589181254a7Smrg real alefeld5(real x)
1590181254a7Smrg {
1591181254a7Smrg ++numCalls;
1592181254a7Smrg ++alefeldSums[5];
1593181254a7Smrg return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4);
1594181254a7Smrg }
alefeld6(real x)1595181254a7Smrg real alefeld6(real x)
1596181254a7Smrg {
1597181254a7Smrg ++numCalls;
1598181254a7Smrg ++alefeldSums[6];
1599181254a7Smrg return exp(-n*x)*(x-1.01L) + pow(x, n);
1600181254a7Smrg }
alefeld7(real x)1601181254a7Smrg real alefeld7(real x)
1602181254a7Smrg {
1603181254a7Smrg ++numCalls;
1604181254a7Smrg ++alefeldSums[7];
1605181254a7Smrg return (n*x-1)/((n-1)*x);
1606181254a7Smrg }
1607181254a7Smrg
1608181254a7Smrg numProblems=0;
1609*b1e83836Smrg testFindRoot(&alefeld0, PI_2, PI);
1610181254a7Smrg for (n=1; n <= 10; ++n)
1611181254a7Smrg {
1612*b1e83836Smrg testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L);
1613181254a7Smrg }
1614181254a7Smrg ale_a = -40; ale_b = -1;
1615*b1e83836Smrg testFindRoot(&alefeld1, -9, 31);
1616181254a7Smrg ale_a = -100; ale_b = -2;
1617*b1e83836Smrg testFindRoot(&alefeld1, -9, 31);
1618181254a7Smrg ale_a = -200; ale_b = -3;
1619*b1e83836Smrg testFindRoot(&alefeld1, -9, 31);
1620181254a7Smrg int [] nvals_3 = [1, 2, 5, 10, 15, 20];
1621181254a7Smrg int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20];
1622181254a7Smrg int [] nvals_6 = [1, 5, 10, 15, 20];
1623181254a7Smrg int [] nvals_7 = [2, 5, 15, 20];
1624181254a7Smrg
1625181254a7Smrg for (int i=4; i<12; i+=2)
1626181254a7Smrg {
1627181254a7Smrg n = i;
1628181254a7Smrg ale_a = 0.2;
1629*b1e83836Smrg testFindRoot(&alefeld2, 0, 5);
1630181254a7Smrg ale_a=1;
1631*b1e83836Smrg testFindRoot(&alefeld2, 0.95, 4.05);
1632*b1e83836Smrg testFindRoot(&alefeld2, 0, 1.5);
1633181254a7Smrg }
foreach(i;nvals_3)1634181254a7Smrg foreach (i; nvals_3)
1635181254a7Smrg {
1636181254a7Smrg n=i;
1637*b1e83836Smrg testFindRoot(&alefeld3, 0, 1);
1638181254a7Smrg }
foreach(i;nvals_3)1639181254a7Smrg foreach (i; nvals_3)
1640181254a7Smrg {
1641181254a7Smrg n=i;
1642*b1e83836Smrg testFindRoot(&alefeld4, 0, 1);
1643181254a7Smrg }
foreach(i;nvals_5)1644181254a7Smrg foreach (i; nvals_5)
1645181254a7Smrg {
1646181254a7Smrg n=i;
1647*b1e83836Smrg testFindRoot(&alefeld5, 0, 1);
1648181254a7Smrg }
foreach(i;nvals_6)1649181254a7Smrg foreach (i; nvals_6)
1650181254a7Smrg {
1651181254a7Smrg n=i;
1652*b1e83836Smrg testFindRoot(&alefeld6, 0, 1);
1653181254a7Smrg }
foreach(i;nvals_7)1654181254a7Smrg foreach (i; nvals_7)
1655181254a7Smrg {
1656181254a7Smrg n=i;
1657*b1e83836Smrg testFindRoot(&alefeld7, 0.01L, 1);
1658181254a7Smrg }
worstcase(real x)1659181254a7Smrg real worstcase(real x)
1660181254a7Smrg {
1661181254a7Smrg ++numCalls;
1662181254a7Smrg return x<0.3*real.max? -0.999e-3 : 1.0;
1663181254a7Smrg }
1664*b1e83836Smrg testFindRoot(&worstcase, -real.max, real.max);
1665181254a7Smrg
1666181254a7Smrg // just check that the double + float cases compile
1667*b1e83836Smrg findRoot((double x){ return 0.0; }, -double.max, double.max);
1668*b1e83836Smrg findRoot((float x){ return 0.0f; }, -float.max, float.max);
1669181254a7Smrg
1670181254a7Smrg /*
1671181254a7Smrg int grandtotal=0;
1672181254a7Smrg foreach (calls; alefeldSums)
1673181254a7Smrg {
1674181254a7Smrg grandtotal+=calls;
1675181254a7Smrg }
1676181254a7Smrg grandtotal-=2*numProblems;
1677181254a7Smrg printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n",
1678181254a7Smrg grandtotal, (1.0*grandtotal)/numProblems);
1679181254a7Smrg powercalls -= 2*powerProblems;
1680181254a7Smrg printf("POWER TOTAL = %d avg = %f ", powercalls,
1681181254a7Smrg (1.0*powercalls)/powerProblems);
1682181254a7Smrg */
1683*b1e83836Smrg // https://issues.dlang.org/show_bug.cgi?id=14231
1684181254a7Smrg auto xp = findRoot((float x) => x, 0f, 1f);
1685181254a7Smrg auto xn = findRoot((float x) => x, -1f, -0f);
1686181254a7Smrg }
1687181254a7Smrg
1688181254a7Smrg //regression control
1689181254a7Smrg @system unittest
1690181254a7Smrg {
1691181254a7Smrg // @system due to the case in the 2nd line
1692181254a7Smrg static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init)));
1693181254a7Smrg static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init)));
1694181254a7Smrg static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init)));
1695181254a7Smrg }
1696181254a7Smrg
1697181254a7Smrg /++
1698181254a7Smrg Find a real minimum of a real function `f(x)` via bracketing.
1699181254a7Smrg Given a function `f` and a range `(ax .. bx)`,
1700181254a7Smrg returns the value of `x` in the range which is closest to a minimum of `f(x)`.
1701181254a7Smrg `f` is never evaluted at the endpoints of `ax` and `bx`.
1702181254a7Smrg If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily.
1703181254a7Smrg If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned;
1704181254a7Smrg otherwise, this algorithm is guaranteed to succeed.
1705181254a7Smrg
1706181254a7Smrg Params:
1707181254a7Smrg f = Function to be analyzed
1708181254a7Smrg ax = Left bound of initial range of f known to contain the minimum.
1709181254a7Smrg bx = Right bound of initial range of f known to contain the minimum.
1710181254a7Smrg relTolerance = Relative tolerance.
1711181254a7Smrg absTolerance = Absolute tolerance.
1712181254a7Smrg
1713181254a7Smrg Preconditions:
1714181254a7Smrg `ax` and `bx` shall be finite reals. $(BR)
1715*b1e83836Smrg `relTolerance` shall be normal positive real. $(BR)
1716*b1e83836Smrg `absTolerance` shall be normal positive real no less then `T.epsilon*2`.
1717181254a7Smrg
1718181254a7Smrg Returns:
1719181254a7Smrg A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`.
1720181254a7Smrg
1721181254a7Smrg The method used is a combination of golden section search and
1722181254a7Smrg successive parabolic interpolation. Convergence is never much slower
1723181254a7Smrg than that for a Fibonacci search.
1724181254a7Smrg
1725181254a7Smrg References:
1726181254a7Smrg "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973)
1727181254a7Smrg
1728181254a7Smrg See_Also: $(LREF findRoot), $(REF isNormal, std,math)
1729181254a7Smrg +/
1730181254a7Smrg Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error")
1731181254a7Smrg findLocalMin(T, DF)(
1732181254a7Smrg scope DF f,
1733*b1e83836Smrg const T ax,
1734*b1e83836Smrg const T bx,
1735*b1e83836Smrg const T relTolerance = sqrt(T.epsilon),
1736*b1e83836Smrg const T absTolerance = sqrt(T.epsilon),
1737181254a7Smrg )
1738181254a7Smrg if (isFloatingPoint!T
1739181254a7Smrg && __traits(compiles, {T _ = DF.init(T.init);}))
1740181254a7Smrg in
1741181254a7Smrg {
1742181254a7Smrg assert(isFinite(ax), "ax is not finite");
1743181254a7Smrg assert(isFinite(bx), "bx is not finite");
1744181254a7Smrg assert(isNormal(relTolerance), "relTolerance is not normal floating point number");
1745181254a7Smrg assert(isNormal(absTolerance), "absTolerance is not normal floating point number");
1746181254a7Smrg assert(relTolerance >= 0, "absTolerance is not positive");
1747181254a7Smrg assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`");
1748181254a7Smrg }
out(result)1749181254a7Smrg out (result)
1750181254a7Smrg {
1751181254a7Smrg assert(isFinite(result.x));
1752181254a7Smrg }
1753*b1e83836Smrg do
1754181254a7Smrg {
1755181254a7Smrg alias R = Unqual!(CommonType!(ReturnType!DF, T));
1756181254a7Smrg // c is the squared inverse of the golden ratio
1757181254a7Smrg // (3 - sqrt(5))/2
1758181254a7Smrg // Value obtained from Wolfram Alpha.
1759181254a7Smrg enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L;
1760181254a7Smrg enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L;
1761181254a7Smrg R tolerance;
1762181254a7Smrg T a = ax > bx ? bx : ax;
1763181254a7Smrg T b = ax > bx ? ax : bx;
1764181254a7Smrg // sequence of declarations suitable for SIMD instructions
1765181254a7Smrg T v = a * cm1 + b * c;
1766181254a7Smrg assert(isFinite(v));
1767181254a7Smrg R fv = f(v);
1768181254a7Smrg if (isNaN(fv) || fv == -T.infinity)
1769181254a7Smrg {
1770181254a7Smrg return typeof(return)(v, fv, T.init);
1771181254a7Smrg }
1772181254a7Smrg T w = v;
1773181254a7Smrg R fw = fv;
1774181254a7Smrg T x = v;
1775181254a7Smrg R fx = fv;
1776181254a7Smrg size_t i;
1777181254a7Smrg for (R d = 0, e = 0;;)
1778181254a7Smrg {
1779181254a7Smrg i++;
1780181254a7Smrg T m = (a + b) / 2;
1781181254a7Smrg // This fix is not part of the original algorithm
1782181254a7Smrg if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R.
1783181254a7Smrg {
1784181254a7Smrg m = a / 2 + b / 2;
1785181254a7Smrg if (!isFinite(m)) // fast-math compiler switch is enabled
1786181254a7Smrg {
1787181254a7Smrg //SIMD instructions can be used by compiler, do not reduce declarations
1788181254a7Smrg int a_exp = void;
1789181254a7Smrg int b_exp = void;
1790181254a7Smrg immutable an = frexp(a, a_exp);
1791181254a7Smrg immutable bn = frexp(b, b_exp);
1792181254a7Smrg immutable am = ldexp(an, a_exp-1);
1793181254a7Smrg immutable bm = ldexp(bn, b_exp-1);
1794181254a7Smrg m = am + bm;
1795181254a7Smrg if (!isFinite(m)) // wrong input: constraints are disabled in release mode
1796181254a7Smrg {
1797181254a7Smrg return typeof(return).init;
1798181254a7Smrg }
1799181254a7Smrg }
1800181254a7Smrg }
1801181254a7Smrg tolerance = absTolerance * fabs(x) + relTolerance;
1802181254a7Smrg immutable t2 = tolerance * 2;
1803181254a7Smrg // check stopping criterion
1804181254a7Smrg if (!(fabs(x - m) > t2 - (b - a) / 2))
1805181254a7Smrg {
1806181254a7Smrg break;
1807181254a7Smrg }
1808181254a7Smrg R p = 0;
1809181254a7Smrg R q = 0;
1810181254a7Smrg R r = 0;
1811181254a7Smrg // fit parabola
1812181254a7Smrg if (fabs(e) > tolerance)
1813181254a7Smrg {
1814181254a7Smrg immutable xw = x - w;
1815181254a7Smrg immutable fxw = fx - fw;
1816181254a7Smrg immutable xv = x - v;
1817181254a7Smrg immutable fxv = fx - fv;
1818181254a7Smrg immutable xwfxv = xw * fxv;
1819181254a7Smrg immutable xvfxw = xv * fxw;
1820181254a7Smrg p = xv * xvfxw - xw * xwfxv;
1821181254a7Smrg q = (xvfxw - xwfxv) * 2;
1822181254a7Smrg if (q > 0)
1823181254a7Smrg p = -p;
1824181254a7Smrg else
1825181254a7Smrg q = -q;
1826181254a7Smrg r = e;
1827181254a7Smrg e = d;
1828181254a7Smrg }
1829181254a7Smrg T u;
1830181254a7Smrg // a parabolic-interpolation step
1831181254a7Smrg if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x))
1832181254a7Smrg {
1833181254a7Smrg d = p / q;
1834181254a7Smrg u = x + d;
1835181254a7Smrg // f must not be evaluated too close to a or b
1836181254a7Smrg if (u - a < t2 || b - u < t2)
1837181254a7Smrg d = x < m ? tolerance : -tolerance;
1838181254a7Smrg }
1839181254a7Smrg // a golden-section step
1840181254a7Smrg else
1841181254a7Smrg {
1842181254a7Smrg e = (x < m ? b : a) - x;
1843181254a7Smrg d = c * e;
1844181254a7Smrg }
1845181254a7Smrg // f must not be evaluated too close to x
1846181254a7Smrg u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance);
1847181254a7Smrg immutable fu = f(u);
1848181254a7Smrg if (isNaN(fu) || fu == -T.infinity)
1849181254a7Smrg {
1850181254a7Smrg return typeof(return)(u, fu, T.init);
1851181254a7Smrg }
1852181254a7Smrg // update a, b, v, w, and x
1853181254a7Smrg if (fu <= fx)
1854181254a7Smrg {
1855*b1e83836Smrg (u < x ? b : a) = x;
1856181254a7Smrg v = w; fv = fw;
1857181254a7Smrg w = x; fw = fx;
1858181254a7Smrg x = u; fx = fu;
1859181254a7Smrg }
1860181254a7Smrg else
1861181254a7Smrg {
1862*b1e83836Smrg (u < x ? a : b) = u;
1863181254a7Smrg if (fu <= fw || w == x)
1864181254a7Smrg {
1865181254a7Smrg v = w; fv = fw;
1866181254a7Smrg w = u; fw = fu;
1867181254a7Smrg }
1868181254a7Smrg else if (fu <= fv || v == x || v == w)
1869181254a7Smrg { // do not remove this braces
1870181254a7Smrg v = u; fv = fu;
1871181254a7Smrg }
1872181254a7Smrg }
1873181254a7Smrg }
1874181254a7Smrg return typeof(return)(x, fx, tolerance * 3);
1875181254a7Smrg }
1876181254a7Smrg
1877181254a7Smrg ///
1878181254a7Smrg @safe unittest
1879181254a7Smrg {
1880*b1e83836Smrg import std.math.operations : isClose;
1881181254a7Smrg
1882181254a7Smrg auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7);
1883*b1e83836Smrg assert(ret.x.isClose(4.0));
1884*b1e83836Smrg assert(ret.y.isClose(0.0, 0.0, 1e-10));
1885181254a7Smrg }
1886181254a7Smrg
1887181254a7Smrg @safe unittest
1888181254a7Smrg {
1889181254a7Smrg import std.meta : AliasSeq;
1890*b1e83836Smrg static foreach (T; AliasSeq!(double, float, real))
1891181254a7Smrg {
1892181254a7Smrg {
1893181254a7Smrg auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7);
1894*b1e83836Smrg assert(ret.x.isClose(T(4)));
1895*b1e83836Smrg assert(ret.y.isClose(T(0), 0.0, T.epsilon));
1896181254a7Smrg }
1897181254a7Smrg {
1898181254a7Smrg auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon);
1899*b1e83836Smrg assert(isClose(ret.x, T(1)));
1900*b1e83836Smrg assert(isClose(ret.y, T(0), 0.0, T.epsilon));
1901181254a7Smrg assert(ret.error <= 10 * T.epsilon);
1902181254a7Smrg }
1903181254a7Smrg {
1904181254a7Smrg auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon);
1905181254a7Smrg assert(!ret.x.isNaN);
1906181254a7Smrg assert(ret.y.isNaN);
1907181254a7Smrg assert(ret.error.isNaN);
1908181254a7Smrg }
1909181254a7Smrg {
1910181254a7Smrg auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon);
1911181254a7Smrg assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal));
1912181254a7Smrg assert(ret.x >= 0 && ret.x <= ret.error);
1913181254a7Smrg }
1914181254a7Smrg {
1915181254a7Smrg auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon);
1916181254a7Smrg assert(ret.y < -18);
1917181254a7Smrg assert(ret.error < 5e-08);
1918181254a7Smrg assert(ret.x >= 0 && ret.x <= ret.error);
1919181254a7Smrg }
1920181254a7Smrg {
1921181254a7Smrg auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon);
1922*b1e83836Smrg assert(ret.x.fabs.isClose(T(1)));
1923*b1e83836Smrg assert(ret.y.fabs.isClose(T(1)));
1924*b1e83836Smrg assert(ret.error.isClose(T(0), 0.0, 100*T.epsilon));
1925181254a7Smrg }
1926181254a7Smrg }
1927181254a7Smrg }
1928181254a7Smrg
1929181254a7Smrg /**
1930181254a7Smrg Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance,
1931*b1e83836Smrg Euclidean distance) between input ranges `a` and
1932*b1e83836Smrg `b`. The two ranges must have the same length. The three-parameter
1933181254a7Smrg version stops computation as soon as the distance is greater than or
1934*b1e83836Smrg equal to `limit` (this is useful to save computation if a small
1935181254a7Smrg distance is sought).
1936181254a7Smrg */
1937181254a7Smrg CommonType!(ElementType!(Range1), ElementType!(Range2))
1938181254a7Smrg euclideanDistance(Range1, Range2)(Range1 a, Range2 b)
1939181254a7Smrg if (isInputRange!(Range1) && isInputRange!(Range2))
1940181254a7Smrg {
1941181254a7Smrg enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1942181254a7Smrg static if (haveLen) assert(a.length == b.length);
1943181254a7Smrg Unqual!(typeof(return)) result = 0;
1944181254a7Smrg for (; !a.empty; a.popFront(), b.popFront())
1945181254a7Smrg {
1946181254a7Smrg immutable t = a.front - b.front;
1947181254a7Smrg result += t * t;
1948181254a7Smrg }
1949181254a7Smrg static if (!haveLen) assert(b.empty);
1950181254a7Smrg return sqrt(result);
1951181254a7Smrg }
1952181254a7Smrg
1953181254a7Smrg /// Ditto
1954181254a7Smrg CommonType!(ElementType!(Range1), ElementType!(Range2))
1955181254a7Smrg euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit)
1956181254a7Smrg if (isInputRange!(Range1) && isInputRange!(Range2))
1957181254a7Smrg {
1958181254a7Smrg limit *= limit;
1959181254a7Smrg enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1960181254a7Smrg static if (haveLen) assert(a.length == b.length);
1961181254a7Smrg Unqual!(typeof(return)) result = 0;
1962181254a7Smrg for (; ; a.popFront(), b.popFront())
1963181254a7Smrg {
1964181254a7Smrg if (a.empty)
1965181254a7Smrg {
1966181254a7Smrg static if (!haveLen) assert(b.empty);
1967181254a7Smrg break;
1968181254a7Smrg }
1969181254a7Smrg immutable t = a.front - b.front;
1970181254a7Smrg result += t * t;
1971181254a7Smrg if (result >= limit) break;
1972181254a7Smrg }
1973181254a7Smrg return sqrt(result);
1974181254a7Smrg }
1975181254a7Smrg
1976181254a7Smrg @safe unittest
1977181254a7Smrg {
1978181254a7Smrg import std.meta : AliasSeq;
1979*b1e83836Smrg static foreach (T; AliasSeq!(double, const double, immutable double))
1980*b1e83836Smrg {{
1981181254a7Smrg T[] a = [ 1.0, 2.0, ];
1982181254a7Smrg T[] b = [ 4.0, 6.0, ];
1983181254a7Smrg assert(euclideanDistance(a, b) == 5);
1984*b1e83836Smrg assert(euclideanDistance(a, b, 6) == 5);
1985181254a7Smrg assert(euclideanDistance(a, b, 5) == 5);
1986181254a7Smrg assert(euclideanDistance(a, b, 4) == 5);
1987181254a7Smrg assert(euclideanDistance(a, b, 2) == 3);
1988*b1e83836Smrg }}
1989181254a7Smrg }
1990181254a7Smrg
1991181254a7Smrg /**
1992181254a7Smrg Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product,
1993*b1e83836Smrg dot product) of input ranges `a` and $(D
1994181254a7Smrg b). The two ranges must have the same length. If both ranges define
1995181254a7Smrg length, the check is done once; otherwise, it is done at each
1996181254a7Smrg iteration.
1997181254a7Smrg */
1998181254a7Smrg CommonType!(ElementType!(Range1), ElementType!(Range2))
1999181254a7Smrg dotProduct(Range1, Range2)(Range1 a, Range2 b)
2000181254a7Smrg if (isInputRange!(Range1) && isInputRange!(Range2) &&
2001181254a7Smrg !(isArray!(Range1) && isArray!(Range2)))
2002181254a7Smrg {
2003181254a7Smrg enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2004181254a7Smrg static if (haveLen) assert(a.length == b.length);
2005181254a7Smrg Unqual!(typeof(return)) result = 0;
2006181254a7Smrg for (; !a.empty; a.popFront(), b.popFront())
2007181254a7Smrg {
2008181254a7Smrg result += a.front * b.front;
2009181254a7Smrg }
2010181254a7Smrg static if (!haveLen) assert(b.empty);
2011181254a7Smrg return result;
2012181254a7Smrg }
2013181254a7Smrg
2014181254a7Smrg /// Ditto
2015181254a7Smrg CommonType!(F1, F2)
2016181254a7Smrg dotProduct(F1, F2)(in F1[] avector, in F2[] bvector)
2017181254a7Smrg {
2018181254a7Smrg immutable n = avector.length;
2019181254a7Smrg assert(n == bvector.length);
2020181254a7Smrg auto avec = avector.ptr, bvec = bvector.ptr;
2021181254a7Smrg Unqual!(typeof(return)) sum0 = 0, sum1 = 0;
2022181254a7Smrg
2023181254a7Smrg const all_endp = avec + n;
2024181254a7Smrg const smallblock_endp = avec + (n & ~3);
2025181254a7Smrg const bigblock_endp = avec + (n & ~15);
2026181254a7Smrg
2027181254a7Smrg for (; avec != bigblock_endp; avec += 16, bvec += 16)
2028181254a7Smrg {
2029181254a7Smrg sum0 += avec[0] * bvec[0];
2030181254a7Smrg sum1 += avec[1] * bvec[1];
2031181254a7Smrg sum0 += avec[2] * bvec[2];
2032181254a7Smrg sum1 += avec[3] * bvec[3];
2033181254a7Smrg sum0 += avec[4] * bvec[4];
2034181254a7Smrg sum1 += avec[5] * bvec[5];
2035181254a7Smrg sum0 += avec[6] * bvec[6];
2036181254a7Smrg sum1 += avec[7] * bvec[7];
2037181254a7Smrg sum0 += avec[8] * bvec[8];
2038181254a7Smrg sum1 += avec[9] * bvec[9];
2039181254a7Smrg sum0 += avec[10] * bvec[10];
2040181254a7Smrg sum1 += avec[11] * bvec[11];
2041181254a7Smrg sum0 += avec[12] * bvec[12];
2042181254a7Smrg sum1 += avec[13] * bvec[13];
2043181254a7Smrg sum0 += avec[14] * bvec[14];
2044181254a7Smrg sum1 += avec[15] * bvec[15];
2045181254a7Smrg }
2046181254a7Smrg
2047181254a7Smrg for (; avec != smallblock_endp; avec += 4, bvec += 4)
2048181254a7Smrg {
2049181254a7Smrg sum0 += avec[0] * bvec[0];
2050181254a7Smrg sum1 += avec[1] * bvec[1];
2051181254a7Smrg sum0 += avec[2] * bvec[2];
2052181254a7Smrg sum1 += avec[3] * bvec[3];
2053181254a7Smrg }
2054181254a7Smrg
2055181254a7Smrg sum0 += sum1;
2056181254a7Smrg
2057181254a7Smrg /* Do trailing portion in naive loop. */
2058181254a7Smrg while (avec != all_endp)
2059181254a7Smrg {
2060181254a7Smrg sum0 += *avec * *bvec;
2061181254a7Smrg ++avec;
2062181254a7Smrg ++bvec;
2063181254a7Smrg }
2064181254a7Smrg
2065181254a7Smrg return sum0;
2066181254a7Smrg }
2067181254a7Smrg
2068*b1e83836Smrg /// ditto
2069*b1e83836Smrg F dotProduct(F, uint N)(const ref scope F[N] a, const ref scope F[N] b)
2070*b1e83836Smrg if (N <= 16)
2071*b1e83836Smrg {
2072*b1e83836Smrg F sum0 = 0;
2073*b1e83836Smrg F sum1 = 0;
2074*b1e83836Smrg static foreach (i; 0 .. N / 2)
2075*b1e83836Smrg {
2076*b1e83836Smrg sum0 += a[i*2] * b[i*2];
2077*b1e83836Smrg sum1 += a[i*2+1] * b[i*2+1];
2078*b1e83836Smrg }
2079*b1e83836Smrg static if (N % 2 == 1)
2080*b1e83836Smrg {
2081*b1e83836Smrg sum0 += a[N-1] * b[N-1];
2082*b1e83836Smrg }
2083*b1e83836Smrg return sum0 + sum1;
2084*b1e83836Smrg }
2085*b1e83836Smrg
2086181254a7Smrg @system unittest
2087181254a7Smrg {
2088181254a7Smrg // @system due to dotProduct and assertCTFEable
2089181254a7Smrg import std.exception : assertCTFEable;
2090181254a7Smrg import std.meta : AliasSeq;
2091*b1e83836Smrg static foreach (T; AliasSeq!(double, const double, immutable double))
2092*b1e83836Smrg {{
2093181254a7Smrg T[] a = [ 1.0, 2.0, ];
2094181254a7Smrg T[] b = [ 4.0, 6.0, ];
2095181254a7Smrg assert(dotProduct(a, b) == 16);
2096181254a7Smrg assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3);
2097*b1e83836Smrg // Test with fixed-length arrays.
2098*b1e83836Smrg T[2] c = [ 1.0, 2.0, ];
2099*b1e83836Smrg T[2] d = [ 4.0, 6.0, ];
2100*b1e83836Smrg assert(dotProduct(c, d) == 16);
2101*b1e83836Smrg T[3] e = [1, 3, -5];
2102*b1e83836Smrg T[3] f = [4, -2, -1];
2103*b1e83836Smrg assert(dotProduct(e, f) == 3);
2104*b1e83836Smrg }}
2105181254a7Smrg
2106181254a7Smrg // Make sure the unrolled loop codepath gets tested.
2107181254a7Smrg static const x =
2108*b1e83836Smrg [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22];
2109181254a7Smrg static const y =
2110*b1e83836Smrg [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23];
2111*b1e83836Smrg assertCTFEable!({ assert(dotProduct(x, y) == 4048); });
2112181254a7Smrg }
2113181254a7Smrg
2114181254a7Smrg /**
2115181254a7Smrg Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity,
2116*b1e83836Smrg cosine similarity) of input ranges `a` and $(D
2117181254a7Smrg b). The two ranges must have the same length. If both ranges define
2118181254a7Smrg length, the check is done once; otherwise, it is done at each
2119181254a7Smrg iteration. If either range has all-zero elements, return 0.
2120181254a7Smrg */
2121181254a7Smrg CommonType!(ElementType!(Range1), ElementType!(Range2))
2122181254a7Smrg cosineSimilarity(Range1, Range2)(Range1 a, Range2 b)
2123181254a7Smrg if (isInputRange!(Range1) && isInputRange!(Range2))
2124181254a7Smrg {
2125181254a7Smrg enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2126181254a7Smrg static if (haveLen) assert(a.length == b.length);
2127181254a7Smrg Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0;
2128181254a7Smrg for (; !a.empty; a.popFront(), b.popFront())
2129181254a7Smrg {
2130181254a7Smrg immutable t1 = a.front, t2 = b.front;
2131181254a7Smrg norma += t1 * t1;
2132181254a7Smrg normb += t2 * t2;
2133181254a7Smrg dotprod += t1 * t2;
2134181254a7Smrg }
2135181254a7Smrg static if (!haveLen) assert(b.empty);
2136181254a7Smrg if (norma == 0 || normb == 0) return 0;
2137181254a7Smrg return dotprod / sqrt(norma * normb);
2138181254a7Smrg }
2139181254a7Smrg
2140181254a7Smrg @safe unittest
2141181254a7Smrg {
2142181254a7Smrg import std.meta : AliasSeq;
2143*b1e83836Smrg static foreach (T; AliasSeq!(double, const double, immutable double))
2144*b1e83836Smrg {{
2145181254a7Smrg T[] a = [ 1.0, 2.0, ];
2146181254a7Smrg T[] b = [ 4.0, 3.0, ];
2147*b1e83836Smrg assert(isClose(
2148181254a7Smrg cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25),
2149181254a7Smrg 0.01));
2150*b1e83836Smrg }}
2151181254a7Smrg }
2152181254a7Smrg
2153181254a7Smrg /**
2154*b1e83836Smrg Normalizes values in `range` by multiplying each element with a
2155*b1e83836Smrg number chosen such that values sum up to `sum`. If elements in $(D
2156181254a7Smrg range) sum to zero, assigns $(D sum / range.length) to
2157*b1e83836Smrg all. Normalization makes sense only if all elements in `range` are
2158*b1e83836Smrg positive. `normalize` assumes that is the case without checking it.
2159181254a7Smrg
2160*b1e83836Smrg Returns: `true` if normalization completed normally, `false` if
2161*b1e83836Smrg all elements in `range` were zero or if `range` is empty.
2162181254a7Smrg */
2163181254a7Smrg bool normalize(R)(R range, ElementType!(R) sum = 1)
2164181254a7Smrg if (isForwardRange!(R))
2165181254a7Smrg {
2166181254a7Smrg ElementType!(R) s = 0;
2167181254a7Smrg // Step 1: Compute sum and length of the range
2168181254a7Smrg static if (hasLength!(R))
2169181254a7Smrg {
2170181254a7Smrg const length = range.length;
foreach(e;range)2171181254a7Smrg foreach (e; range)
2172181254a7Smrg {
2173181254a7Smrg s += e;
2174181254a7Smrg }
2175181254a7Smrg }
2176181254a7Smrg else
2177181254a7Smrg {
2178181254a7Smrg uint length = 0;
foreach(e;range)2179181254a7Smrg foreach (e; range)
2180181254a7Smrg {
2181181254a7Smrg s += e;
2182181254a7Smrg ++length;
2183181254a7Smrg }
2184181254a7Smrg }
2185181254a7Smrg // Step 2: perform normalization
2186181254a7Smrg if (s == 0)
2187181254a7Smrg {
2188181254a7Smrg if (length)
2189181254a7Smrg {
2190181254a7Smrg immutable f = sum / range.length;
2191181254a7Smrg foreach (ref e; range) e = f;
2192181254a7Smrg }
2193181254a7Smrg return false;
2194181254a7Smrg }
2195181254a7Smrg // The path most traveled
2196181254a7Smrg assert(s >= 0);
2197181254a7Smrg immutable f = sum / s;
2198181254a7Smrg foreach (ref e; range)
2199181254a7Smrg e *= f;
2200181254a7Smrg return true;
2201181254a7Smrg }
2202181254a7Smrg
2203181254a7Smrg ///
2204181254a7Smrg @safe unittest
2205181254a7Smrg {
2206181254a7Smrg double[] a = [];
2207181254a7Smrg assert(!normalize(a));
2208181254a7Smrg a = [ 1.0, 3.0 ];
2209181254a7Smrg assert(normalize(a));
2210181254a7Smrg assert(a == [ 0.25, 0.75 ]);
2211*b1e83836Smrg assert(normalize!(typeof(a))(a, 50)); // a = [12.5, 37.5]
2212181254a7Smrg a = [ 0.0, 0.0 ];
2213181254a7Smrg assert(!normalize(a));
2214181254a7Smrg assert(a == [ 0.5, 0.5 ]);
2215181254a7Smrg }
2216181254a7Smrg
2217181254a7Smrg /**
2218*b1e83836Smrg Compute the sum of binary logarithms of the input range `r`.
2219181254a7Smrg The error of this method is much smaller than with a naive sum of log2.
2220181254a7Smrg */
2221181254a7Smrg ElementType!Range sumOfLog2s(Range)(Range r)
2222181254a7Smrg if (isInputRange!Range && isFloatingPoint!(ElementType!Range))
2223181254a7Smrg {
2224181254a7Smrg long exp = 0;
2225181254a7Smrg Unqual!(typeof(return)) x = 1;
foreach(e;r)2226181254a7Smrg foreach (e; r)
2227181254a7Smrg {
2228181254a7Smrg if (e < 0)
2229181254a7Smrg return typeof(return).nan;
2230181254a7Smrg int lexp = void;
2231181254a7Smrg x *= frexp(e, lexp);
2232181254a7Smrg exp += lexp;
2233181254a7Smrg if (x < 0.5)
2234181254a7Smrg {
2235181254a7Smrg x *= 2;
2236181254a7Smrg exp--;
2237181254a7Smrg }
2238181254a7Smrg }
2239181254a7Smrg return exp + log2(x);
2240181254a7Smrg }
2241181254a7Smrg
2242181254a7Smrg ///
2243181254a7Smrg @safe unittest
2244181254a7Smrg {
2245*b1e83836Smrg import std.math.traits : isNaN;
2246181254a7Smrg
2247181254a7Smrg assert(sumOfLog2s(new double[0]) == 0);
2248181254a7Smrg assert(sumOfLog2s([0.0L]) == -real.infinity);
2249181254a7Smrg assert(sumOfLog2s([-0.0L]) == -real.infinity);
2250181254a7Smrg assert(sumOfLog2s([2.0L]) == 1);
2251181254a7Smrg assert(sumOfLog2s([-2.0L]).isNaN());
2252181254a7Smrg assert(sumOfLog2s([real.nan]).isNaN());
2253181254a7Smrg assert(sumOfLog2s([-real.nan]).isNaN());
2254181254a7Smrg assert(sumOfLog2s([real.infinity]) == real.infinity);
2255181254a7Smrg assert(sumOfLog2s([-real.infinity]).isNaN());
2256181254a7Smrg assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9);
2257181254a7Smrg }
2258181254a7Smrg
2259181254a7Smrg /**
2260181254a7Smrg Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory),
2261*b1e83836Smrg _entropy) of input range `r` in bits. This
2262*b1e83836Smrg function assumes (without checking) that the values in `r` are all
2263*b1e83836Smrg in $(D [0, 1]). For the entropy to be meaningful, often `r` should
2264181254a7Smrg be normalized too (i.e., its values should sum to 1). The
2265181254a7Smrg two-parameter version stops evaluating as soon as the intermediate
2266*b1e83836Smrg result is greater than or equal to `max`.
2267181254a7Smrg */
2268181254a7Smrg ElementType!Range entropy(Range)(Range r)
2269181254a7Smrg if (isInputRange!Range)
2270181254a7Smrg {
2271181254a7Smrg Unqual!(typeof(return)) result = 0.0;
2272181254a7Smrg for (;!r.empty; r.popFront)
2273181254a7Smrg {
2274181254a7Smrg if (!r.front) continue;
2275181254a7Smrg result -= r.front * log2(r.front);
2276181254a7Smrg }
2277181254a7Smrg return result;
2278181254a7Smrg }
2279181254a7Smrg
2280181254a7Smrg /// Ditto
2281181254a7Smrg ElementType!Range entropy(Range, F)(Range r, F max)
2282181254a7Smrg if (isInputRange!Range &&
2283181254a7Smrg !is(CommonType!(ElementType!Range, F) == void))
2284181254a7Smrg {
2285181254a7Smrg Unqual!(typeof(return)) result = 0.0;
2286181254a7Smrg for (;!r.empty; r.popFront)
2287181254a7Smrg {
2288181254a7Smrg if (!r.front) continue;
2289181254a7Smrg result -= r.front * log2(r.front);
2290181254a7Smrg if (result >= max) break;
2291181254a7Smrg }
2292181254a7Smrg return result;
2293181254a7Smrg }
2294181254a7Smrg
2295181254a7Smrg @safe unittest
2296181254a7Smrg {
2297181254a7Smrg import std.meta : AliasSeq;
2298*b1e83836Smrg static foreach (T; AliasSeq!(double, const double, immutable double))
2299*b1e83836Smrg {{
2300181254a7Smrg T[] p = [ 0.0, 0, 0, 1 ];
2301181254a7Smrg assert(entropy(p) == 0);
2302181254a7Smrg p = [ 0.25, 0.25, 0.25, 0.25 ];
2303181254a7Smrg assert(entropy(p) == 2);
2304181254a7Smrg assert(entropy(p, 1) == 1);
2305*b1e83836Smrg }}
2306181254a7Smrg }
2307181254a7Smrg
2308181254a7Smrg /**
2309181254a7Smrg Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence,
2310181254a7Smrg Kullback-Leibler divergence) between input ranges
2311*b1e83836Smrg `a` and `b`, which is the sum $(D ai * log(ai / bi)). The base
2312181254a7Smrg of logarithm is 2. The ranges are assumed to contain elements in $(D
2313181254a7Smrg [0, 1]). Usually the ranges are normalized probability distributions,
2314181254a7Smrg but this is not required or checked by $(D
2315*b1e83836Smrg kullbackLeiblerDivergence). If any element `bi` is zero and the
2316*b1e83836Smrg corresponding element `ai` nonzero, returns infinity. (Otherwise,
2317181254a7Smrg if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is
2318181254a7Smrg considered zero.) If the inputs are normalized, the result is
2319181254a7Smrg positive.
2320181254a7Smrg */
2321181254a7Smrg CommonType!(ElementType!Range1, ElementType!Range2)
2322181254a7Smrg kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b)
2323181254a7Smrg if (isInputRange!(Range1) && isInputRange!(Range2))
2324181254a7Smrg {
2325181254a7Smrg enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2326181254a7Smrg static if (haveLen) assert(a.length == b.length);
2327181254a7Smrg Unqual!(typeof(return)) result = 0;
2328181254a7Smrg for (; !a.empty; a.popFront(), b.popFront())
2329181254a7Smrg {
2330181254a7Smrg immutable t1 = a.front;
2331181254a7Smrg if (t1 == 0) continue;
2332181254a7Smrg immutable t2 = b.front;
2333181254a7Smrg if (t2 == 0) return result.infinity;
2334181254a7Smrg assert(t1 > 0 && t2 > 0);
2335181254a7Smrg result += t1 * log2(t1 / t2);
2336181254a7Smrg }
2337181254a7Smrg static if (!haveLen) assert(b.empty);
2338181254a7Smrg return result;
2339181254a7Smrg }
2340181254a7Smrg
2341181254a7Smrg ///
2342181254a7Smrg @safe unittest
2343181254a7Smrg {
2344*b1e83836Smrg import std.math.operations : isClose;
2345181254a7Smrg
2346181254a7Smrg double[] p = [ 0.0, 0, 0, 1 ];
2347181254a7Smrg assert(kullbackLeiblerDivergence(p, p) == 0);
2348181254a7Smrg double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2349181254a7Smrg assert(kullbackLeiblerDivergence(p1, p1) == 0);
2350181254a7Smrg assert(kullbackLeiblerDivergence(p, p1) == 2);
2351181254a7Smrg assert(kullbackLeiblerDivergence(p1, p) == double.infinity);
2352181254a7Smrg double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2353*b1e83836Smrg assert(isClose(kullbackLeiblerDivergence(p1, p2), 0.0719281, 1e-5));
2354*b1e83836Smrg assert(isClose(kullbackLeiblerDivergence(p2, p1), 0.0780719, 1e-5));
2355181254a7Smrg }
2356181254a7Smrg
2357181254a7Smrg /**
2358181254a7Smrg Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence,
2359*b1e83836Smrg Jensen-Shannon divergence) between `a` and $(D
2360181254a7Smrg b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 *
2361181254a7Smrg bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are
2362181254a7Smrg assumed to contain elements in $(D [0, 1]). Usually the ranges are
2363181254a7Smrg normalized probability distributions, but this is not required or
2364*b1e83836Smrg checked by `jensenShannonDivergence`. If the inputs are normalized,
2365181254a7Smrg the result is bounded within $(D [0, 1]). The three-parameter version
2366181254a7Smrg stops evaluations as soon as the intermediate result is greater than
2367*b1e83836Smrg or equal to `limit`.
2368181254a7Smrg */
2369181254a7Smrg CommonType!(ElementType!Range1, ElementType!Range2)
2370181254a7Smrg jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b)
2371181254a7Smrg if (isInputRange!Range1 && isInputRange!Range2 &&
2372181254a7Smrg is(CommonType!(ElementType!Range1, ElementType!Range2)))
2373181254a7Smrg {
2374181254a7Smrg enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2375181254a7Smrg static if (haveLen) assert(a.length == b.length);
2376181254a7Smrg Unqual!(typeof(return)) result = 0;
2377181254a7Smrg for (; !a.empty; a.popFront(), b.popFront())
2378181254a7Smrg {
2379181254a7Smrg immutable t1 = a.front;
2380181254a7Smrg immutable t2 = b.front;
2381181254a7Smrg immutable avg = (t1 + t2) / 2;
2382181254a7Smrg if (t1 != 0)
2383181254a7Smrg {
2384181254a7Smrg result += t1 * log2(t1 / avg);
2385181254a7Smrg }
2386181254a7Smrg if (t2 != 0)
2387181254a7Smrg {
2388181254a7Smrg result += t2 * log2(t2 / avg);
2389181254a7Smrg }
2390181254a7Smrg }
2391181254a7Smrg static if (!haveLen) assert(b.empty);
2392181254a7Smrg return result / 2;
2393181254a7Smrg }
2394181254a7Smrg
2395181254a7Smrg /// Ditto
2396181254a7Smrg CommonType!(ElementType!Range1, ElementType!Range2)
2397181254a7Smrg jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit)
2398181254a7Smrg if (isInputRange!Range1 && isInputRange!Range2 &&
2399181254a7Smrg is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init
2400181254a7Smrg >= F.init) : bool))
2401181254a7Smrg {
2402181254a7Smrg enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2403181254a7Smrg static if (haveLen) assert(a.length == b.length);
2404181254a7Smrg Unqual!(typeof(return)) result = 0;
2405181254a7Smrg limit *= 2;
2406181254a7Smrg for (; !a.empty; a.popFront(), b.popFront())
2407181254a7Smrg {
2408181254a7Smrg immutable t1 = a.front;
2409181254a7Smrg immutable t2 = b.front;
2410181254a7Smrg immutable avg = (t1 + t2) / 2;
2411181254a7Smrg if (t1 != 0)
2412181254a7Smrg {
2413181254a7Smrg result += t1 * log2(t1 / avg);
2414181254a7Smrg }
2415181254a7Smrg if (t2 != 0)
2416181254a7Smrg {
2417181254a7Smrg result += t2 * log2(t2 / avg);
2418181254a7Smrg }
2419181254a7Smrg if (result >= limit) break;
2420181254a7Smrg }
2421181254a7Smrg static if (!haveLen) assert(b.empty);
2422181254a7Smrg return result / 2;
2423181254a7Smrg }
2424181254a7Smrg
2425181254a7Smrg ///
2426181254a7Smrg @safe unittest
2427181254a7Smrg {
2428*b1e83836Smrg import std.math.operations : isClose;
2429181254a7Smrg
2430181254a7Smrg double[] p = [ 0.0, 0, 0, 1 ];
2431181254a7Smrg assert(jensenShannonDivergence(p, p) == 0);
2432181254a7Smrg double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2433181254a7Smrg assert(jensenShannonDivergence(p1, p1) == 0);
2434*b1e83836Smrg assert(isClose(jensenShannonDivergence(p1, p), 0.548795, 1e-5));
2435181254a7Smrg double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2436*b1e83836Smrg assert(isClose(jensenShannonDivergence(p1, p2), 0.0186218, 1e-5));
2437*b1e83836Smrg assert(isClose(jensenShannonDivergence(p2, p1), 0.0186218, 1e-5));
2438*b1e83836Smrg assert(isClose(jensenShannonDivergence(p2, p1, 0.005), 0.00602366, 1e-5));
2439181254a7Smrg }
2440181254a7Smrg
2441181254a7Smrg /**
2442181254a7Smrg The so-called "all-lengths gap-weighted string kernel" computes a
2443*b1e83836Smrg similarity measure between `s` and `t` based on all of their
2444181254a7Smrg common subsequences of all lengths. Gapped subsequences are also
2445181254a7Smrg included.
2446181254a7Smrg
2447181254a7Smrg To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes,
2448181254a7Smrg consider first the case $(D lambda = 1) and the strings $(D s =
2449181254a7Smrg ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new",
2450*b1e83836Smrg "world"]). In that case, `gapWeightedSimilarity` counts the
2451181254a7Smrg following matches:
2452181254a7Smrg
2453*b1e83836Smrg $(OL $(LI three matches of length 1, namely `"Hello"`, `"new"`,
2454*b1e83836Smrg and `"world"`;) $(LI three matches of length 2, namely ($(D
2455181254a7Smrg "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));)
2456181254a7Smrg $(LI one match of length 3, namely ($(D "Hello", "new", "world")).))
2457181254a7Smrg
2458181254a7Smrg The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of
2459181254a7Smrg these matches and adds them up, returning 7.
2460181254a7Smrg
2461181254a7Smrg ----
2462181254a7Smrg string[] s = ["Hello", "brave", "new", "world"];
2463181254a7Smrg string[] t = ["Hello", "new", "world"];
2464181254a7Smrg assert(gapWeightedSimilarity(s, t, 1) == 7);
2465181254a7Smrg ----
2466181254a7Smrg
2467181254a7Smrg Note how the gaps in matching are simply ignored, for example ($(D
2468181254a7Smrg "Hello", "new")) is deemed as good a match as ($(D "new",
2469181254a7Smrg "world")). This may be too permissive for some applications. To
2470181254a7Smrg eliminate gapped matches entirely, use $(D lambda = 0):
2471181254a7Smrg
2472181254a7Smrg ----
2473181254a7Smrg string[] s = ["Hello", "brave", "new", "world"];
2474181254a7Smrg string[] t = ["Hello", "new", "world"];
2475181254a7Smrg assert(gapWeightedSimilarity(s, t, 0) == 4);
2476181254a7Smrg ----
2477181254a7Smrg
2478181254a7Smrg The call above eliminated the gapped matches ($(D "Hello", "new")),
2479181254a7Smrg ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the
2480181254a7Smrg tally. That leaves only 4 matches.
2481181254a7Smrg
2482181254a7Smrg The most interesting case is when gapped matches still participate in
2483181254a7Smrg the result, but not as strongly as ungapped matches. The result will
2484181254a7Smrg be a smooth, fine-grained similarity measure between the input
2485*b1e83836Smrg strings. This is where values of `lambda` between 0 and 1 enter
2486181254a7Smrg into play: gapped matches are $(I exponentially penalized with the
2487*b1e83836Smrg number of gaps) with base `lambda`. This means that an ungapped
2488181254a7Smrg match adds 1 to the return value; a match with one gap in either
2489*b1e83836Smrg string adds `lambda` to the return value; ...; a match with a total
2490*b1e83836Smrg of `n` gaps in both strings adds $(D pow(lambda, n)) to the return
2491181254a7Smrg value. In the example above, we have 4 matches without gaps, 2 matches
2492181254a7Smrg with one gap, and 1 match with three gaps. The latter match is ($(D
2493181254a7Smrg "Hello", "world")), which has two gaps in the first string and one gap
2494181254a7Smrg in the second string, totaling to three gaps. Summing these up we get
2495181254a7Smrg $(D 4 + 2 * lambda + pow(lambda, 3)).
2496181254a7Smrg
2497181254a7Smrg ----
2498181254a7Smrg string[] s = ["Hello", "brave", "new", "world"];
2499181254a7Smrg string[] t = ["Hello", "new", "world"];
2500181254a7Smrg assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125);
2501181254a7Smrg ----
2502181254a7Smrg
2503*b1e83836Smrg `gapWeightedSimilarity` is useful wherever a smooth similarity
2504181254a7Smrg measure between sequences allowing for approximate matches is
2505181254a7Smrg needed. The examples above are given with words, but any sequences
2506181254a7Smrg with elements comparable for equality are allowed, e.g. characters or
2507*b1e83836Smrg numbers. `gapWeightedSimilarity` uses a highly optimized dynamic
2508181254a7Smrg programming implementation that needs $(D 16 * min(s.length,
2509181254a7Smrg t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time
2510181254a7Smrg to complete.
2511181254a7Smrg */
2512181254a7Smrg F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda)
2513181254a7Smrg if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2514181254a7Smrg isRandomAccessRange!(R2) && hasLength!(R2))
2515181254a7Smrg {
2516181254a7Smrg import core.exception : onOutOfMemoryError;
2517181254a7Smrg import core.stdc.stdlib : malloc, free;
2518181254a7Smrg import std.algorithm.mutation : swap;
2519181254a7Smrg import std.functional : binaryFun;
2520181254a7Smrg
2521181254a7Smrg if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda);
2522181254a7Smrg if (!t.length) return 0;
2523181254a7Smrg
2524181254a7Smrg auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length);
2525181254a7Smrg if (!dpvi)
2526181254a7Smrg onOutOfMemoryError();
2527181254a7Smrg
2528181254a7Smrg auto dpvi1 = dpvi + t.length;
2529181254a7Smrg scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1);
2530181254a7Smrg dpvi[0 .. t.length] = 0;
2531181254a7Smrg dpvi1[0] = 0;
2532181254a7Smrg immutable lambda2 = lambda * lambda;
2533181254a7Smrg
2534181254a7Smrg F result = 0;
2535181254a7Smrg foreach (i; 0 .. s.length)
2536181254a7Smrg {
2537181254a7Smrg const si = s[i];
2538181254a7Smrg for (size_t j = 0;;)
2539181254a7Smrg {
2540181254a7Smrg F dpsij = void;
2541181254a7Smrg if (binaryFun!(comp)(si, t[j]))
2542181254a7Smrg {
2543181254a7Smrg dpsij = 1 + dpvi[j];
2544181254a7Smrg result += dpsij;
2545181254a7Smrg }
2546181254a7Smrg else
2547181254a7Smrg {
2548181254a7Smrg dpsij = 0;
2549181254a7Smrg }
2550181254a7Smrg immutable j1 = j + 1;
2551181254a7Smrg if (j1 == t.length) break;
2552181254a7Smrg dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) -
2553181254a7Smrg lambda2 * dpvi[j];
2554181254a7Smrg j = j1;
2555181254a7Smrg }
2556181254a7Smrg swap(dpvi, dpvi1);
2557181254a7Smrg }
2558181254a7Smrg return result;
2559181254a7Smrg }
2560181254a7Smrg
2561181254a7Smrg @system unittest
2562181254a7Smrg {
2563181254a7Smrg string[] s = ["Hello", "brave", "new", "world"];
2564181254a7Smrg string[] t = ["Hello", "new", "world"];
2565181254a7Smrg assert(gapWeightedSimilarity(s, t, 1) == 7);
2566181254a7Smrg assert(gapWeightedSimilarity(s, t, 0) == 4);
2567181254a7Smrg assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125);
2568181254a7Smrg }
2569181254a7Smrg
2570181254a7Smrg /**
2571*b1e83836Smrg The similarity per `gapWeightedSimilarity` has an issue in that it
2572181254a7Smrg grows with the lengths of the two strings, even though the strings are
2573181254a7Smrg not actually very similar. For example, the range $(D ["Hello",
2574181254a7Smrg "world"]) is increasingly similar with the range $(D ["Hello",
2575*b1e83836Smrg "world", "world", "world",...]) as more instances of `"world"` are
2576*b1e83836Smrg appended. To prevent that, `gapWeightedSimilarityNormalized`
2577181254a7Smrg computes a normalized version of the similarity that is computed as
2578181254a7Smrg $(D gapWeightedSimilarity(s, t, lambda) /
2579181254a7Smrg sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t,
2580*b1e83836Smrg lambda))). The function `gapWeightedSimilarityNormalized` (a
2581*b1e83836Smrg so-called normalized kernel) is bounded in $(D [0, 1]), reaches `0`
2582*b1e83836Smrg only for ranges that don't match in any position, and `1` only for
2583181254a7Smrg identical ranges.
2584181254a7Smrg
2585*b1e83836Smrg The optional parameters `sSelfSim` and `tSelfSim` are meant for
2586181254a7Smrg avoiding duplicate computation. Many applications may have already
2587181254a7Smrg computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D
2588181254a7Smrg gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed
2589*b1e83836Smrg as `sSelfSim` and `tSelfSim`, respectively.
2590181254a7Smrg */
2591181254a7Smrg Select!(isFloatingPoint!(F), F, double)
2592181254a7Smrg gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F)
2593181254a7Smrg (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init)
2594181254a7Smrg if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2595181254a7Smrg isRandomAccessRange!(R2) && hasLength!(R2))
2596181254a7Smrg {
uncomputed(F n)2597181254a7Smrg static bool uncomputed(F n)
2598181254a7Smrg {
2599181254a7Smrg static if (isFloatingPoint!(F))
2600181254a7Smrg return isNaN(n);
2601181254a7Smrg else
2602181254a7Smrg return n == n.init;
2603181254a7Smrg }
2604181254a7Smrg if (uncomputed(sSelfSim))
2605181254a7Smrg sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda);
2606181254a7Smrg if (sSelfSim == 0) return 0;
2607181254a7Smrg if (uncomputed(tSelfSim))
2608181254a7Smrg tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda);
2609181254a7Smrg if (tSelfSim == 0) return 0;
2610181254a7Smrg
2611181254a7Smrg return gapWeightedSimilarity!(comp)(s, t, lambda) /
2612181254a7Smrg sqrt(cast(typeof(return)) sSelfSim * tSelfSim);
2613181254a7Smrg }
2614181254a7Smrg
2615181254a7Smrg ///
2616181254a7Smrg @system unittest
2617181254a7Smrg {
2618*b1e83836Smrg import std.math.operations : isClose;
2619*b1e83836Smrg import std.math.algebraic : sqrt;
2620181254a7Smrg
2621181254a7Smrg string[] s = ["Hello", "brave", "new", "world"];
2622181254a7Smrg string[] t = ["Hello", "new", "world"];
2623181254a7Smrg assert(gapWeightedSimilarity(s, s, 1) == 15);
2624181254a7Smrg assert(gapWeightedSimilarity(t, t, 1) == 7);
2625181254a7Smrg assert(gapWeightedSimilarity(s, t, 1) == 7);
2626*b1e83836Smrg assert(isClose(gapWeightedSimilarityNormalized(s, t, 1),
2627181254a7Smrg 7.0 / sqrt(15.0 * 7), 0.01));
2628181254a7Smrg }
2629181254a7Smrg
2630181254a7Smrg /**
2631*b1e83836Smrg Similar to `gapWeightedSimilarity`, just works in an incremental
2632181254a7Smrg manner by first revealing the matches of length 1, then gapped matches
2633181254a7Smrg of length 2, and so on. The memory requirement is $(BIGOH s.length *
2634181254a7Smrg t.length). The time complexity is $(BIGOH s.length * t.length) time
2635181254a7Smrg for computing each step. Continuing on the previous example:
2636181254a7Smrg
2637181254a7Smrg The implementation is based on the pseudocode in Fig. 4 of the paper
2638181254a7Smrg $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf,
2639181254a7Smrg "Efficient Computation of Gapped Substring Kernels on Large Alphabets")
2640181254a7Smrg by Rousu et al., with additional algorithmic and systems-level
2641181254a7Smrg optimizations.
2642181254a7Smrg */
2643181254a7Smrg struct GapWeightedSimilarityIncremental(Range, F = double)
2644181254a7Smrg if (isRandomAccessRange!(Range) && hasLength!(Range))
2645181254a7Smrg {
2646181254a7Smrg import core.stdc.stdlib : malloc, realloc, alloca, free;
2647181254a7Smrg
2648181254a7Smrg private:
2649181254a7Smrg Range s, t;
2650181254a7Smrg F currentValue = 0;
2651181254a7Smrg F* kl;
2652181254a7Smrg size_t gram = void;
2653181254a7Smrg F lambda = void, lambda2 = void;
2654181254a7Smrg
2655181254a7Smrg public:
2656181254a7Smrg /**
2657*b1e83836Smrg Constructs an object given two ranges `s` and `t` and a penalty
2658*b1e83836Smrg `lambda`. Constructor completes in $(BIGOH s.length * t.length)
2659181254a7Smrg time and computes all matches of length 1.
2660181254a7Smrg */
thisGapWeightedSimilarityIncremental2661181254a7Smrg this(Range s, Range t, F lambda)
2662181254a7Smrg {
2663181254a7Smrg import core.exception : onOutOfMemoryError;
2664181254a7Smrg
2665181254a7Smrg assert(lambda > 0);
2666181254a7Smrg this.gram = 0;
2667181254a7Smrg this.lambda = lambda;
2668181254a7Smrg this.lambda2 = lambda * lambda; // for efficiency only
2669181254a7Smrg
2670181254a7Smrg size_t iMin = size_t.max, jMin = size_t.max,
2671181254a7Smrg iMax = 0, jMax = 0;
2672181254a7Smrg /* initialize */
2673181254a7Smrg Tuple!(size_t, size_t) * k0;
2674181254a7Smrg size_t k0len;
2675181254a7Smrg scope(exit) free(k0);
2676181254a7Smrg currentValue = 0;
2677181254a7Smrg foreach (i, si; s)
2678181254a7Smrg {
2679181254a7Smrg foreach (j; 0 .. t.length)
2680181254a7Smrg {
2681181254a7Smrg if (si != t[j]) continue;
2682181254a7Smrg k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof);
2683181254a7Smrg with (k0[k0len - 1])
2684181254a7Smrg {
2685181254a7Smrg field[0] = i;
2686181254a7Smrg field[1] = j;
2687181254a7Smrg }
2688181254a7Smrg // Maintain the minimum and maximum i and j
2689181254a7Smrg if (iMin > i) iMin = i;
2690181254a7Smrg if (iMax < i) iMax = i;
2691181254a7Smrg if (jMin > j) jMin = j;
2692181254a7Smrg if (jMax < j) jMax = j;
2693181254a7Smrg }
2694181254a7Smrg }
2695181254a7Smrg
2696181254a7Smrg if (iMin > iMax) return;
2697181254a7Smrg assert(k0len);
2698181254a7Smrg
2699181254a7Smrg currentValue = k0len;
2700181254a7Smrg // Chop strings down to the useful sizes
2701181254a7Smrg s = s[iMin .. iMax + 1];
2702181254a7Smrg t = t[jMin .. jMax + 1];
2703181254a7Smrg this.s = s;
2704181254a7Smrg this.t = t;
2705181254a7Smrg
2706181254a7Smrg kl = cast(F*) malloc(s.length * t.length * F.sizeof);
2707181254a7Smrg if (!kl)
2708181254a7Smrg onOutOfMemoryError();
2709181254a7Smrg
2710181254a7Smrg kl[0 .. s.length * t.length] = 0;
2711181254a7Smrg foreach (pos; 0 .. k0len)
2712181254a7Smrg {
2713181254a7Smrg with (k0[pos])
2714181254a7Smrg {
2715181254a7Smrg kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2;
2716181254a7Smrg }
2717181254a7Smrg }
2718181254a7Smrg }
2719181254a7Smrg
2720181254a7Smrg /**
2721*b1e83836Smrg Returns: `this`.
2722181254a7Smrg */
opSliceGapWeightedSimilarityIncremental2723181254a7Smrg ref GapWeightedSimilarityIncremental opSlice()
2724181254a7Smrg {
2725181254a7Smrg return this;
2726181254a7Smrg }
2727181254a7Smrg
2728181254a7Smrg /**
2729181254a7Smrg Computes the match of the popFront length. Completes in $(BIGOH s.length *
2730181254a7Smrg t.length) time.
2731181254a7Smrg */
popFrontGapWeightedSimilarityIncremental2732181254a7Smrg void popFront()
2733181254a7Smrg {
2734181254a7Smrg import std.algorithm.mutation : swap;
2735181254a7Smrg
2736181254a7Smrg // This is a large source of optimization: if similarity at
2737181254a7Smrg // the gram-1 level was 0, then we can safely assume
2738181254a7Smrg // similarity at the gram level is 0 as well.
2739181254a7Smrg if (empty) return;
2740181254a7Smrg
2741181254a7Smrg // Now attempt to match gapped substrings of length `gram'
2742181254a7Smrg ++gram;
2743181254a7Smrg currentValue = 0;
2744181254a7Smrg
2745181254a7Smrg auto Si = cast(F*) alloca(t.length * F.sizeof);
2746181254a7Smrg Si[0 .. t.length] = 0;
2747181254a7Smrg foreach (i; 0 .. s.length)
2748181254a7Smrg {
2749181254a7Smrg const si = s[i];
2750181254a7Smrg F Sij_1 = 0;
2751181254a7Smrg F Si_1j_1 = 0;
2752181254a7Smrg auto kli = kl + i * t.length;
2753181254a7Smrg for (size_t j = 0;;)
2754181254a7Smrg {
2755181254a7Smrg const klij = kli[j];
2756181254a7Smrg const Si_1j = Si[j];
2757181254a7Smrg const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1;
2758181254a7Smrg // now update kl and currentValue
2759181254a7Smrg if (si == t[j])
2760181254a7Smrg currentValue += kli[j] = lambda2 * Si_1j_1;
2761181254a7Smrg else
2762181254a7Smrg kli[j] = 0;
2763181254a7Smrg // commit to Si
2764181254a7Smrg Si[j] = tmp;
2765181254a7Smrg if (++j == t.length) break;
2766181254a7Smrg // get ready for the popFront step; virtually increment j,
2767181254a7Smrg // so essentially stuffj_1 <-- stuffj
2768181254a7Smrg Si_1j_1 = Si_1j;
2769181254a7Smrg Sij_1 = tmp;
2770181254a7Smrg }
2771181254a7Smrg }
2772181254a7Smrg currentValue /= pow(lambda, 2 * (gram + 1));
2773181254a7Smrg
2774181254a7Smrg version (none)
2775181254a7Smrg {
2776181254a7Smrg Si_1[0 .. t.length] = 0;
2777181254a7Smrg kl[0 .. min(t.length, maxPerimeter + 1)] = 0;
2778181254a7Smrg foreach (i; 1 .. min(s.length, maxPerimeter + 1))
2779181254a7Smrg {
2780181254a7Smrg auto kli = kl + i * t.length;
2781181254a7Smrg assert(s.length > i);
2782181254a7Smrg const si = s[i];
2783181254a7Smrg auto kl_1i_1 = kl_1 + (i - 1) * t.length;
2784181254a7Smrg kli[0] = 0;
2785181254a7Smrg F lastS = 0;
2786181254a7Smrg foreach (j; 1 .. min(maxPerimeter - i + 1, t.length))
2787181254a7Smrg {
2788181254a7Smrg immutable j_1 = j - 1;
2789181254a7Smrg immutable tmp = kl_1i_1[j_1]
2790181254a7Smrg + lambda * (Si_1[j] + lastS)
2791181254a7Smrg - lambda2 * Si_1[j_1];
2792181254a7Smrg kl_1i_1[j_1] = float.nan;
2793181254a7Smrg Si_1[j_1] = lastS;
2794181254a7Smrg lastS = tmp;
2795181254a7Smrg if (si == t[j])
2796181254a7Smrg {
2797181254a7Smrg currentValue += kli[j] = lambda2 * lastS;
2798181254a7Smrg }
2799181254a7Smrg else
2800181254a7Smrg {
2801181254a7Smrg kli[j] = 0;
2802181254a7Smrg }
2803181254a7Smrg }
2804181254a7Smrg Si_1[t.length - 1] = lastS;
2805181254a7Smrg }
2806181254a7Smrg currentValue /= pow(lambda, 2 * (gram + 1));
2807181254a7Smrg // get ready for the popFront computation
2808181254a7Smrg swap(kl, kl_1);
2809181254a7Smrg }
2810181254a7Smrg }
2811181254a7Smrg
2812181254a7Smrg /**
2813181254a7Smrg Returns: The gapped similarity at the current match length (initially
2814*b1e83836Smrg 1, grows with each call to `popFront`).
2815181254a7Smrg */
frontGapWeightedSimilarityIncremental2816181254a7Smrg @property F front() { return currentValue; }
2817181254a7Smrg
2818181254a7Smrg /**
2819181254a7Smrg Returns: Whether there are more matches.
2820181254a7Smrg */
emptyGapWeightedSimilarityIncremental2821181254a7Smrg @property bool empty()
2822181254a7Smrg {
2823181254a7Smrg if (currentValue) return false;
2824181254a7Smrg if (kl)
2825181254a7Smrg {
2826181254a7Smrg free(kl);
2827181254a7Smrg kl = null;
2828181254a7Smrg }
2829181254a7Smrg return true;
2830181254a7Smrg }
2831181254a7Smrg }
2832181254a7Smrg
2833181254a7Smrg /**
2834181254a7Smrg Ditto
2835181254a7Smrg */
2836181254a7Smrg GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F)
2837181254a7Smrg (R r1, R r2, F penalty)
2838181254a7Smrg {
2839181254a7Smrg return typeof(return)(r1, r2, penalty);
2840181254a7Smrg }
2841181254a7Smrg
2842181254a7Smrg ///
2843181254a7Smrg @system unittest
2844181254a7Smrg {
2845181254a7Smrg string[] s = ["Hello", "brave", "new", "world"];
2846181254a7Smrg string[] t = ["Hello", "new", "world"];
2847181254a7Smrg auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2848181254a7Smrg assert(simIter.front == 3); // three 1-length matches
2849181254a7Smrg simIter.popFront();
2850181254a7Smrg assert(simIter.front == 3); // three 2-length matches
2851181254a7Smrg simIter.popFront();
2852181254a7Smrg assert(simIter.front == 1); // one 3-length match
2853181254a7Smrg simIter.popFront();
2854181254a7Smrg assert(simIter.empty); // no more match
2855181254a7Smrg }
2856181254a7Smrg
2857181254a7Smrg @system unittest
2858181254a7Smrg {
2859181254a7Smrg import std.conv : text;
2860181254a7Smrg string[] s = ["Hello", "brave", "new", "world"];
2861181254a7Smrg string[] t = ["Hello", "new", "world"];
2862181254a7Smrg auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2863181254a7Smrg //foreach (e; simIter) writeln(e);
2864181254a7Smrg assert(simIter.front == 3); // three 1-length matches
2865181254a7Smrg simIter.popFront();
2866181254a7Smrg assert(simIter.front == 3, text(simIter.front)); // three 2-length matches
2867181254a7Smrg simIter.popFront();
2868181254a7Smrg assert(simIter.front == 1); // one 3-length matches
2869181254a7Smrg simIter.popFront();
2870181254a7Smrg assert(simIter.empty); // no more match
2871181254a7Smrg
2872181254a7Smrg s = ["Hello"];
2873181254a7Smrg t = ["bye"];
2874181254a7Smrg simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2875181254a7Smrg assert(simIter.empty);
2876181254a7Smrg
2877181254a7Smrg s = ["Hello"];
2878181254a7Smrg t = ["Hello"];
2879181254a7Smrg simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2880181254a7Smrg assert(simIter.front == 1); // one match
2881181254a7Smrg simIter.popFront();
2882181254a7Smrg assert(simIter.empty);
2883181254a7Smrg
2884181254a7Smrg s = ["Hello", "world"];
2885181254a7Smrg t = ["Hello"];
2886181254a7Smrg simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2887181254a7Smrg assert(simIter.front == 1); // one match
2888181254a7Smrg simIter.popFront();
2889181254a7Smrg assert(simIter.empty);
2890181254a7Smrg
2891181254a7Smrg s = ["Hello", "world"];
2892181254a7Smrg t = ["Hello", "yah", "world"];
2893181254a7Smrg simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2894181254a7Smrg assert(simIter.front == 2); // two 1-gram matches
2895181254a7Smrg simIter.popFront();
2896181254a7Smrg assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap
2897181254a7Smrg }
2898181254a7Smrg
2899181254a7Smrg @system unittest
2900181254a7Smrg {
2901181254a7Smrg GapWeightedSimilarityIncremental!(string[]) sim =
2902181254a7Smrg GapWeightedSimilarityIncremental!(string[])(
2903181254a7Smrg ["nyuk", "I", "have", "no", "chocolate", "giba"],
2904181254a7Smrg ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"],
2905181254a7Smrg 0.5);
2906181254a7Smrg double[] witness = [ 7.0, 4.03125, 0, 0 ];
foreach(e;sim)2907181254a7Smrg foreach (e; sim)
2908181254a7Smrg {
2909181254a7Smrg //writeln(e);
2910181254a7Smrg assert(e == witness.front);
2911181254a7Smrg witness.popFront();
2912181254a7Smrg }
2913181254a7Smrg witness = [ 3.0, 1.3125, 0.25 ];
2914181254a7Smrg sim = GapWeightedSimilarityIncremental!(string[])(
2915181254a7Smrg ["I", "have", "no", "chocolate"],
2916181254a7Smrg ["I", "have", "some", "chocolate"],
2917181254a7Smrg 0.5);
foreach(e;sim)2918181254a7Smrg foreach (e; sim)
2919181254a7Smrg {
2920181254a7Smrg //writeln(e);
2921181254a7Smrg assert(e == witness.front);
2922181254a7Smrg witness.popFront();
2923181254a7Smrg }
2924181254a7Smrg assert(witness.empty);
2925181254a7Smrg }
2926181254a7Smrg
2927181254a7Smrg /**
2928*b1e83836Smrg Computes the greatest common divisor of `a` and `b` by using
2929181254a7Smrg an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's)
2930181254a7Smrg or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm.
2931181254a7Smrg
2932181254a7Smrg Params:
2933*b1e83836Smrg a = Integer value of any numerical type that supports the modulo operator `%`.
2934*b1e83836Smrg If bit-shifting `<<` and `>>` are also supported, Stein's algorithm will
2935181254a7Smrg be used; otherwise, Euclid's algorithm is used as _a fallback.
2936*b1e83836Smrg b = Integer value of any equivalent numerical type.
2937*b1e83836Smrg
2938181254a7Smrg Returns:
2939181254a7Smrg The greatest common divisor of the given arguments.
2940181254a7Smrg */
2941*b1e83836Smrg typeof(Unqual!(T).init % Unqual!(U).init) gcd(T, U)(T a, U b)
2942*b1e83836Smrg if (isIntegral!T && isIntegral!U)
2943*b1e83836Smrg {
2944*b1e83836Smrg // Operate on a common type between the two arguments.
2945*b1e83836Smrg alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U));
2946*b1e83836Smrg
2947*b1e83836Smrg // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined.
2948*b1e83836Smrg static if (is(T : immutable short) || is(T : immutable byte))
2949*b1e83836Smrg UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a);
2950*b1e83836Smrg else
2951*b1e83836Smrg UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a);
2952*b1e83836Smrg
2953*b1e83836Smrg static if (is(U : immutable short) || is(U : immutable byte))
2954*b1e83836Smrg UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b);
2955*b1e83836Smrg else
2956*b1e83836Smrg UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b);
2957*b1e83836Smrg
2958*b1e83836Smrg // Special cases.
2959*b1e83836Smrg if (ax == 0)
2960*b1e83836Smrg return bx;
2961*b1e83836Smrg if (bx == 0)
2962*b1e83836Smrg return ax;
2963*b1e83836Smrg
2964*b1e83836Smrg return gcdImpl(ax, bx);
2965*b1e83836Smrg }
2966*b1e83836Smrg
2967*b1e83836Smrg private typeof(T.init % T.init) gcdImpl(T)(T a, T b)
2968181254a7Smrg if (isIntegral!T)
2969181254a7Smrg {
2970*b1e83836Smrg pragma(inline, true);
2971181254a7Smrg import core.bitop : bsf;
2972181254a7Smrg import std.algorithm.mutation : swap;
2973181254a7Smrg
2974181254a7Smrg immutable uint shift = bsf(a | b);
2975181254a7Smrg a >>= a.bsf;
2976181254a7Smrg do
2977181254a7Smrg {
2978181254a7Smrg b >>= b.bsf;
2979181254a7Smrg if (a > b)
2980181254a7Smrg swap(a, b);
2981181254a7Smrg b -= a;
2982181254a7Smrg } while (b);
2983181254a7Smrg
2984181254a7Smrg return a << shift;
2985181254a7Smrg }
2986181254a7Smrg
2987181254a7Smrg ///
2988181254a7Smrg @safe unittest
2989181254a7Smrg {
2990181254a7Smrg assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7);
2991181254a7Smrg const int a = 5 * 13 * 23 * 23, b = 13 * 59;
2992181254a7Smrg assert(gcd(a, b) == 13);
2993181254a7Smrg }
2994181254a7Smrg
2995*b1e83836Smrg @safe unittest
2996*b1e83836Smrg {
2997*b1e83836Smrg import std.meta : AliasSeq;
2998*b1e83836Smrg static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
2999*b1e83836Smrg const byte, const short, const int, const long,
3000*b1e83836Smrg immutable ubyte, immutable ushort, immutable uint, immutable ulong))
3001*b1e83836Smrg {
3002*b1e83836Smrg static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3003*b1e83836Smrg const ubyte, const ushort, const uint, const ulong,
3004*b1e83836Smrg immutable byte, immutable short, immutable int, immutable long))
3005*b1e83836Smrg {
3006*b1e83836Smrg // Signed and unsigned tests.
3007*b1e83836Smrg static if (T.max > byte.max && U.max > byte.max)
3008*b1e83836Smrg assert(gcd(T(200), U(200)) == 200);
3009*b1e83836Smrg static if (T.max > ubyte.max)
3010*b1e83836Smrg {
3011*b1e83836Smrg assert(gcd(T(2000), U(20)) == 20);
3012*b1e83836Smrg assert(gcd(T(2011), U(17)) == 1);
3013*b1e83836Smrg }
3014*b1e83836Smrg static if (T.max > ubyte.max && U.max > ubyte.max)
3015*b1e83836Smrg assert(gcd(T(1071), U(462)) == 21);
3016*b1e83836Smrg
3017*b1e83836Smrg assert(gcd(T(0), U(13)) == 13);
3018*b1e83836Smrg assert(gcd(T(29), U(0)) == 29);
3019*b1e83836Smrg assert(gcd(T(0), U(0)) == 0);
3020*b1e83836Smrg assert(gcd(T(1), U(2)) == 1);
3021*b1e83836Smrg assert(gcd(T(9), U(6)) == 3);
3022*b1e83836Smrg assert(gcd(T(3), U(4)) == 1);
3023*b1e83836Smrg assert(gcd(T(32), U(24)) == 8);
3024*b1e83836Smrg assert(gcd(T(5), U(6)) == 1);
3025*b1e83836Smrg assert(gcd(T(54), U(36)) == 18);
3026*b1e83836Smrg
3027*b1e83836Smrg // Int and Long tests.
3028*b1e83836Smrg static if (T.max > short.max && U.max > short.max)
3029*b1e83836Smrg assert(gcd(T(46391), U(62527)) == 2017);
3030*b1e83836Smrg static if (T.max > ushort.max && U.max > ushort.max)
3031*b1e83836Smrg assert(gcd(T(63245986), U(39088169)) == 1);
3032*b1e83836Smrg static if (T.max > uint.max && U.max > uint.max)
3033*b1e83836Smrg {
3034*b1e83836Smrg assert(gcd(T(77160074263), U(47687519812)) == 1);
3035*b1e83836Smrg assert(gcd(T(77160074264), U(47687519812)) == 4);
3036*b1e83836Smrg }
3037*b1e83836Smrg
3038*b1e83836Smrg // Negative tests.
3039*b1e83836Smrg static if (T.min < 0)
3040*b1e83836Smrg {
3041*b1e83836Smrg assert(gcd(T(-21), U(28)) == 7);
3042*b1e83836Smrg assert(gcd(T(-3), U(4)) == 1);
3043*b1e83836Smrg }
3044*b1e83836Smrg static if (U.min < 0)
3045*b1e83836Smrg {
3046*b1e83836Smrg assert(gcd(T(1), U(-2)) == 1);
3047*b1e83836Smrg assert(gcd(T(33), U(-44)) == 11);
3048*b1e83836Smrg }
3049*b1e83836Smrg static if (T.min < 0 && U.min < 0)
3050*b1e83836Smrg {
3051*b1e83836Smrg assert(gcd(T(-5), U(-6)) == 1);
3052*b1e83836Smrg assert(gcd(T(-50), U(-60)) == 10);
3053*b1e83836Smrg }
3054*b1e83836Smrg }
3055*b1e83836Smrg }
3056*b1e83836Smrg }
3057*b1e83836Smrg
3058*b1e83836Smrg // https://issues.dlang.org/show_bug.cgi?id=21834
3059*b1e83836Smrg @safe unittest
3060*b1e83836Smrg {
3061*b1e83836Smrg assert(gcd(-120, 10U) == 10);
3062*b1e83836Smrg assert(gcd(120U, -10) == 10);
3063*b1e83836Smrg assert(gcd(int.min, 0L) == 1L + int.max);
3064*b1e83836Smrg assert(gcd(0L, int.min) == 1L + int.max);
3065*b1e83836Smrg assert(gcd(int.min, 0L + int.min) == 1L + int.max);
3066*b1e83836Smrg assert(gcd(int.min, 1L + int.max) == 1L + int.max);
3067*b1e83836Smrg assert(gcd(short.min, 1U + short.max) == 1U + short.max);
3068*b1e83836Smrg }
3069*b1e83836Smrg
3070181254a7Smrg // This overload is for non-builtin numerical types like BigInt or
3071181254a7Smrg // user-defined types.
3072181254a7Smrg /// ditto
3073*b1e83836Smrg auto gcd(T)(T a, T b)
3074181254a7Smrg if (!isIntegral!T &&
3075181254a7Smrg is(typeof(T.init % T.init)) &&
3076181254a7Smrg is(typeof(T.init == 0 || T.init > 0)))
3077181254a7Smrg {
3078*b1e83836Smrg static if (!is(T == Unqual!T))
3079*b1e83836Smrg {
3080*b1e83836Smrg return gcd!(Unqual!T)(a, b);
3081*b1e83836Smrg }
3082*b1e83836Smrg else
3083*b1e83836Smrg {
3084*b1e83836Smrg // Ensure arguments are unsigned.
3085*b1e83836Smrg a = a >= 0 ? a : -a;
3086*b1e83836Smrg b = b >= 0 ? b : -b;
3087181254a7Smrg
3088*b1e83836Smrg // Special cases.
3089*b1e83836Smrg if (a == 0)
3090*b1e83836Smrg return b;
3091*b1e83836Smrg if (b == 0)
3092*b1e83836Smrg return a;
3093*b1e83836Smrg
3094*b1e83836Smrg return gcdImpl(a, b);
3095*b1e83836Smrg }
3096*b1e83836Smrg }
3097*b1e83836Smrg
3098*b1e83836Smrg private auto gcdImpl(T)(T a, T b)
3099*b1e83836Smrg if (!isIntegral!T)
3100*b1e83836Smrg {
3101*b1e83836Smrg pragma(inline, true);
3102*b1e83836Smrg import std.algorithm.mutation : swap;
3103181254a7Smrg enum canUseBinaryGcd = is(typeof(() {
3104181254a7Smrg T t, u;
3105181254a7Smrg t <<= 1;
3106181254a7Smrg t >>= 1;
3107181254a7Smrg t -= u;
3108181254a7Smrg bool b = (t & 1) == 0;
3109181254a7Smrg swap(t, u);
3110181254a7Smrg }));
3111181254a7Smrg
3112181254a7Smrg static if (canUseBinaryGcd)
3113181254a7Smrg {
3114181254a7Smrg uint shift = 0;
3115181254a7Smrg while ((a & 1) == 0 && (b & 1) == 0)
3116181254a7Smrg {
3117181254a7Smrg a >>= 1;
3118181254a7Smrg b >>= 1;
3119181254a7Smrg shift++;
3120181254a7Smrg }
3121181254a7Smrg
3122*b1e83836Smrg if ((a & 1) == 0) swap(a, b);
3123*b1e83836Smrg
3124181254a7Smrg do
3125181254a7Smrg {
3126181254a7Smrg assert((a & 1) != 0);
3127181254a7Smrg while ((b & 1) == 0)
3128181254a7Smrg b >>= 1;
3129181254a7Smrg if (a > b)
3130181254a7Smrg swap(a, b);
3131181254a7Smrg b -= a;
3132181254a7Smrg } while (b);
3133181254a7Smrg
3134181254a7Smrg return a << shift;
3135181254a7Smrg }
3136181254a7Smrg else
3137181254a7Smrg {
3138181254a7Smrg // The only thing we have is %; fallback to Euclidean algorithm.
3139181254a7Smrg while (b != 0)
3140181254a7Smrg {
3141181254a7Smrg auto t = b;
3142181254a7Smrg b = a % b;
3143181254a7Smrg a = t;
3144181254a7Smrg }
3145181254a7Smrg return a;
3146181254a7Smrg }
3147181254a7Smrg }
3148181254a7Smrg
3149*b1e83836Smrg // https://issues.dlang.org/show_bug.cgi?id=7102
3150181254a7Smrg @system pure unittest
3151181254a7Smrg {
3152181254a7Smrg import std.bigint : BigInt;
3153181254a7Smrg assert(gcd(BigInt("71_000_000_000_000_000_000"),
3154181254a7Smrg BigInt("31_000_000_000_000_000_000")) ==
3155181254a7Smrg BigInt("1_000_000_000_000_000_000"));
3156*b1e83836Smrg
3157*b1e83836Smrg assert(gcd(BigInt(0), BigInt(1234567)) == BigInt(1234567));
3158*b1e83836Smrg assert(gcd(BigInt(1234567), BigInt(0)) == BigInt(1234567));
3159181254a7Smrg }
3160181254a7Smrg
3161181254a7Smrg @safe pure nothrow unittest
3162181254a7Smrg {
3163181254a7Smrg // A numerical type that only supports % and - (to force gcd implementation
3164181254a7Smrg // to use Euclidean algorithm).
3165181254a7Smrg struct CrippledInt
3166181254a7Smrg {
3167181254a7Smrg int impl;
3168181254a7Smrg CrippledInt opBinary(string op : "%")(CrippledInt i)
3169181254a7Smrg {
3170181254a7Smrg return CrippledInt(impl % i.impl);
3171181254a7Smrg }
3172*b1e83836Smrg CrippledInt opUnary(string op : "-")()
3173*b1e83836Smrg {
3174*b1e83836Smrg return CrippledInt(-impl);
3175*b1e83836Smrg }
opEqualsCrippledInt3176181254a7Smrg int opEquals(CrippledInt i) { return impl == i.impl; }
opEqualsCrippledInt3177181254a7Smrg int opEquals(int i) { return impl == i; }
opCmpCrippledInt3178181254a7Smrg int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; }
3179181254a7Smrg }
3180181254a7Smrg assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77));
3181*b1e83836Smrg assert(gcd(CrippledInt(-120), CrippledInt(10U)) == CrippledInt(10));
3182*b1e83836Smrg assert(gcd(CrippledInt(120U), CrippledInt(-10)) == CrippledInt(10));
3183*b1e83836Smrg }
3184*b1e83836Smrg
3185*b1e83836Smrg // https://issues.dlang.org/show_bug.cgi?id=19514
3186*b1e83836Smrg @system pure unittest
3187*b1e83836Smrg {
3188*b1e83836Smrg import std.bigint : BigInt;
3189*b1e83836Smrg assert(gcd(BigInt(2), BigInt(1)) == BigInt(1));
3190*b1e83836Smrg }
3191*b1e83836Smrg
3192*b1e83836Smrg // Issue 20924
3193*b1e83836Smrg @safe unittest
3194*b1e83836Smrg {
3195*b1e83836Smrg import std.bigint : BigInt;
3196*b1e83836Smrg const a = BigInt("123143238472389492934020");
3197*b1e83836Smrg const b = BigInt("902380489324729338420924");
3198*b1e83836Smrg assert(__traits(compiles, gcd(a, b)));
3199*b1e83836Smrg }
3200*b1e83836Smrg
3201*b1e83836Smrg // https://issues.dlang.org/show_bug.cgi?id=21834
3202*b1e83836Smrg @safe unittest
3203*b1e83836Smrg {
3204*b1e83836Smrg import std.bigint : BigInt;
3205*b1e83836Smrg assert(gcd(BigInt(-120), BigInt(10U)) == BigInt(10));
3206*b1e83836Smrg assert(gcd(BigInt(120U), BigInt(-10)) == BigInt(10));
3207*b1e83836Smrg assert(gcd(BigInt(int.min), BigInt(0L)) == BigInt(1L + int.max));
3208*b1e83836Smrg assert(gcd(BigInt(0L), BigInt(int.min)) == BigInt(1L + int.max));
3209*b1e83836Smrg assert(gcd(BigInt(int.min), BigInt(0L + int.min)) == BigInt(1L + int.max));
3210*b1e83836Smrg assert(gcd(BigInt(int.min), BigInt(1L + int.max)) == BigInt(1L + int.max));
3211*b1e83836Smrg assert(gcd(BigInt(short.min), BigInt(1U + short.max)) == BigInt(1U + short.max));
3212*b1e83836Smrg }
3213*b1e83836Smrg
3214*b1e83836Smrg
3215*b1e83836Smrg /**
3216*b1e83836Smrg Computes the least common multiple of `a` and `b`.
3217*b1e83836Smrg Arguments are the same as $(MYREF gcd).
3218*b1e83836Smrg
3219*b1e83836Smrg Returns:
3220*b1e83836Smrg The least common multiple of the given arguments.
3221*b1e83836Smrg */
3222*b1e83836Smrg typeof(Unqual!(T).init % Unqual!(U).init) lcm(T, U)(T a, U b)
3223*b1e83836Smrg if (isIntegral!T && isIntegral!U)
3224*b1e83836Smrg {
3225*b1e83836Smrg // Operate on a common type between the two arguments.
3226*b1e83836Smrg alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U));
3227*b1e83836Smrg
3228*b1e83836Smrg // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined.
3229*b1e83836Smrg static if (is(T : immutable short) || is(T : immutable byte))
3230*b1e83836Smrg UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a);
3231*b1e83836Smrg else
3232*b1e83836Smrg UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a);
3233*b1e83836Smrg
3234*b1e83836Smrg static if (is(U : immutable short) || is(U : immutable byte))
3235*b1e83836Smrg UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b);
3236*b1e83836Smrg else
3237*b1e83836Smrg UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b);
3238*b1e83836Smrg
3239*b1e83836Smrg // Special cases.
3240*b1e83836Smrg if (ax == 0)
3241*b1e83836Smrg return ax;
3242*b1e83836Smrg if (bx == 0)
3243*b1e83836Smrg return bx;
3244*b1e83836Smrg
3245*b1e83836Smrg return (ax / gcdImpl(ax, bx)) * bx;
3246*b1e83836Smrg }
3247*b1e83836Smrg
3248*b1e83836Smrg ///
3249*b1e83836Smrg @safe unittest
3250*b1e83836Smrg {
3251*b1e83836Smrg assert(lcm(1, 2) == 2);
3252*b1e83836Smrg assert(lcm(3, 4) == 12);
3253*b1e83836Smrg assert(lcm(5, 6) == 30);
3254*b1e83836Smrg }
3255*b1e83836Smrg
3256*b1e83836Smrg @safe unittest
3257*b1e83836Smrg {
3258*b1e83836Smrg import std.meta : AliasSeq;
3259*b1e83836Smrg static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3260*b1e83836Smrg const byte, const short, const int, const long,
3261*b1e83836Smrg immutable ubyte, immutable ushort, immutable uint, immutable ulong))
3262*b1e83836Smrg {
3263*b1e83836Smrg static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3264*b1e83836Smrg const ubyte, const ushort, const uint, const ulong,
3265*b1e83836Smrg immutable byte, immutable short, immutable int, immutable long))
3266*b1e83836Smrg {
3267*b1e83836Smrg assert(lcm(T(21), U(6)) == 42);
3268*b1e83836Smrg assert(lcm(T(41), U(0)) == 0);
3269*b1e83836Smrg assert(lcm(T(0), U(7)) == 0);
3270*b1e83836Smrg assert(lcm(T(0), U(0)) == 0);
3271*b1e83836Smrg assert(lcm(T(1U), U(2)) == 2);
3272*b1e83836Smrg assert(lcm(T(3), U(4U)) == 12);
3273*b1e83836Smrg assert(lcm(T(5U), U(6U)) == 30);
3274*b1e83836Smrg static if (T.min < 0)
3275*b1e83836Smrg assert(lcm(T(-42), U(21U)) == 42);
3276*b1e83836Smrg }
3277*b1e83836Smrg }
3278*b1e83836Smrg }
3279*b1e83836Smrg
3280*b1e83836Smrg /// ditto
3281*b1e83836Smrg auto lcm(T)(T a, T b)
3282*b1e83836Smrg if (!isIntegral!T &&
3283*b1e83836Smrg is(typeof(T.init % T.init)) &&
3284*b1e83836Smrg is(typeof(T.init == 0 || T.init > 0)))
3285*b1e83836Smrg {
3286*b1e83836Smrg // Ensure arguments are unsigned.
3287*b1e83836Smrg a = a >= 0 ? a : -a;
3288*b1e83836Smrg b = b >= 0 ? b : -b;
3289*b1e83836Smrg
3290*b1e83836Smrg // Special cases.
3291*b1e83836Smrg if (a == 0)
3292*b1e83836Smrg return a;
3293*b1e83836Smrg if (b == 0)
3294*b1e83836Smrg return b;
3295*b1e83836Smrg
3296*b1e83836Smrg return (a / gcdImpl(a, b)) * b;
3297*b1e83836Smrg }
3298*b1e83836Smrg
3299*b1e83836Smrg @safe unittest
3300*b1e83836Smrg {
3301*b1e83836Smrg import std.bigint : BigInt;
3302*b1e83836Smrg assert(lcm(BigInt(21), BigInt(6)) == BigInt(42));
3303*b1e83836Smrg assert(lcm(BigInt(41), BigInt(0)) == BigInt(0));
3304*b1e83836Smrg assert(lcm(BigInt(0), BigInt(7)) == BigInt(0));
3305*b1e83836Smrg assert(lcm(BigInt(0), BigInt(0)) == BigInt(0));
3306*b1e83836Smrg assert(lcm(BigInt(1U), BigInt(2)) == BigInt(2));
3307*b1e83836Smrg assert(lcm(BigInt(3), BigInt(4U)) == BigInt(12));
3308*b1e83836Smrg assert(lcm(BigInt(5U), BigInt(6U)) == BigInt(30));
3309*b1e83836Smrg assert(lcm(BigInt(-42), BigInt(21U)) == BigInt(42));
3310181254a7Smrg }
3311181254a7Smrg
3312181254a7Smrg // This is to make tweaking the speed/size vs. accuracy tradeoff easy,
3313181254a7Smrg // though floats seem accurate enough for all practical purposes, since
3314*b1e83836Smrg // they pass the "isClose(inverseFft(fft(arr)), arr)" test even for
3315181254a7Smrg // size 2 ^^ 22.
3316181254a7Smrg private alias lookup_t = float;
3317181254a7Smrg
3318181254a7Smrg /**A class for performing fast Fourier transforms of power of two sizes.
3319181254a7Smrg * This class encapsulates a large amount of state that is reusable when
3320181254a7Smrg * performing multiple FFTs of sizes smaller than or equal to that specified
3321181254a7Smrg * in the constructor. This results in substantial speedups when performing
3322181254a7Smrg * multiple FFTs with a known maximum size. However,
3323181254a7Smrg * a free function API is provided for convenience if you need to perform a
3324181254a7Smrg * one-off FFT.
3325181254a7Smrg *
3326181254a7Smrg * References:
3327181254a7Smrg * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm)
3328181254a7Smrg */
3329181254a7Smrg final class Fft
3330181254a7Smrg {
3331181254a7Smrg import core.bitop : bsf;
3332181254a7Smrg import std.algorithm.iteration : map;
3333181254a7Smrg import std.array : uninitializedArray;
3334181254a7Smrg
3335181254a7Smrg private:
3336181254a7Smrg immutable lookup_t[][] negSinLookup;
3337181254a7Smrg
enforceSize(R)3338181254a7Smrg void enforceSize(R)(R range) const
3339181254a7Smrg {
3340181254a7Smrg import std.conv : text;
3341181254a7Smrg assert(range.length <= size, text(
3342181254a7Smrg "FFT size mismatch. Expected ", size, ", got ", range.length));
3343181254a7Smrg }
3344181254a7Smrg
fftImpl(Ret,R)3345181254a7Smrg void fftImpl(Ret, R)(Stride!R range, Ret buf) const
3346181254a7Smrg in
3347181254a7Smrg {
3348181254a7Smrg assert(range.length >= 4);
3349181254a7Smrg assert(isPowerOf2(range.length));
3350181254a7Smrg }
3351*b1e83836Smrg do
3352181254a7Smrg {
3353181254a7Smrg auto recurseRange = range;
3354181254a7Smrg recurseRange.doubleSteps();
3355181254a7Smrg
3356181254a7Smrg if (buf.length > 4)
3357181254a7Smrg {
3358181254a7Smrg fftImpl(recurseRange, buf[0..$ / 2]);
3359181254a7Smrg recurseRange.popHalf();
3360181254a7Smrg fftImpl(recurseRange, buf[$ / 2..$]);
3361181254a7Smrg }
3362181254a7Smrg else
3363181254a7Smrg {
3364181254a7Smrg // Do this here instead of in another recursion to save on
3365181254a7Smrg // recursion overhead.
3366181254a7Smrg slowFourier2(recurseRange, buf[0..$ / 2]);
3367181254a7Smrg recurseRange.popHalf();
3368181254a7Smrg slowFourier2(recurseRange, buf[$ / 2..$]);
3369181254a7Smrg }
3370181254a7Smrg
3371181254a7Smrg butterfly(buf);
3372181254a7Smrg }
3373181254a7Smrg
3374181254a7Smrg // This algorithm works by performing the even and odd parts of our FFT
3375181254a7Smrg // using the "two for the price of one" method mentioned at
3376181254a7Smrg // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521
3377181254a7Smrg // by making the odd terms into the imaginary components of our new FFT,
3378181254a7Smrg // and then using symmetry to recombine them.
fftImplPureReal(Ret,R)3379181254a7Smrg void fftImplPureReal(Ret, R)(R range, Ret buf) const
3380181254a7Smrg in
3381181254a7Smrg {
3382181254a7Smrg assert(range.length >= 4);
3383181254a7Smrg assert(isPowerOf2(range.length));
3384181254a7Smrg }
3385*b1e83836Smrg do
3386181254a7Smrg {
3387181254a7Smrg alias E = ElementType!R;
3388181254a7Smrg
3389181254a7Smrg // Converts odd indices of range to the imaginary components of
3390181254a7Smrg // a range half the size. The even indices become the real components.
3391181254a7Smrg static if (isArray!R && isFloatingPoint!E)
3392181254a7Smrg {
3393181254a7Smrg // Then the memory layout of complex numbers provides a dirt
3394181254a7Smrg // cheap way to convert. This is a common case, so take advantage.
3395181254a7Smrg auto oddsImag = cast(Complex!E[]) range;
3396181254a7Smrg }
3397181254a7Smrg else
3398181254a7Smrg {
3399181254a7Smrg // General case: Use a higher order range. We can assume
3400181254a7Smrg // source.length is even because it has to be a power of 2.
3401181254a7Smrg static struct OddToImaginary
3402181254a7Smrg {
3403181254a7Smrg R source;
3404181254a7Smrg alias C = Complex!(CommonType!(E, typeof(buf[0].re)));
3405181254a7Smrg
3406181254a7Smrg @property
3407181254a7Smrg {
frontOddToImaginary3408181254a7Smrg C front()
3409181254a7Smrg {
3410181254a7Smrg return C(source[0], source[1]);
3411181254a7Smrg }
3412181254a7Smrg
backOddToImaginary3413181254a7Smrg C back()
3414181254a7Smrg {
3415181254a7Smrg immutable n = source.length;
3416181254a7Smrg return C(source[n - 2], source[n - 1]);
3417181254a7Smrg }
3418181254a7Smrg
saveOddToImaginary3419181254a7Smrg typeof(this) save()
3420181254a7Smrg {
3421181254a7Smrg return typeof(this)(source.save);
3422181254a7Smrg }
3423181254a7Smrg
emptyOddToImaginary3424181254a7Smrg bool empty()
3425181254a7Smrg {
3426181254a7Smrg return source.empty;
3427181254a7Smrg }
3428181254a7Smrg
lengthOddToImaginary3429181254a7Smrg size_t length()
3430181254a7Smrg {
3431181254a7Smrg return source.length / 2;
3432181254a7Smrg }
3433181254a7Smrg }
3434181254a7Smrg
popFrontOddToImaginary3435181254a7Smrg void popFront()
3436181254a7Smrg {
3437181254a7Smrg source.popFront();
3438181254a7Smrg source.popFront();
3439181254a7Smrg }
3440181254a7Smrg
popBackOddToImaginary3441181254a7Smrg void popBack()
3442181254a7Smrg {
3443181254a7Smrg source.popBack();
3444181254a7Smrg source.popBack();
3445181254a7Smrg }
3446181254a7Smrg
opIndexOddToImaginary3447181254a7Smrg C opIndex(size_t index)
3448181254a7Smrg {
3449181254a7Smrg return C(source[index * 2], source[index * 2 + 1]);
3450181254a7Smrg }
3451181254a7Smrg
opSliceOddToImaginary3452181254a7Smrg typeof(this) opSlice(size_t lower, size_t upper)
3453181254a7Smrg {
3454181254a7Smrg return typeof(this)(source[lower * 2 .. upper * 2]);
3455181254a7Smrg }
3456181254a7Smrg }
3457181254a7Smrg
3458181254a7Smrg auto oddsImag = OddToImaginary(range);
3459181254a7Smrg }
3460181254a7Smrg
3461181254a7Smrg fft(oddsImag, buf[0..$ / 2]);
3462181254a7Smrg auto evenFft = buf[0..$ / 2];
3463181254a7Smrg auto oddFft = buf[$ / 2..$];
3464181254a7Smrg immutable halfN = evenFft.length;
3465181254a7Smrg oddFft[0].re = buf[0].im;
3466181254a7Smrg oddFft[0].im = 0;
3467181254a7Smrg evenFft[0].im = 0;
3468181254a7Smrg // evenFft[0].re is already right b/c it's aliased with buf[0].re.
3469181254a7Smrg
3470181254a7Smrg foreach (k; 1 .. halfN / 2 + 1)
3471181254a7Smrg {
3472181254a7Smrg immutable bufk = buf[k];
3473181254a7Smrg immutable bufnk = buf[buf.length / 2 - k];
3474181254a7Smrg evenFft[k].re = 0.5 * (bufk.re + bufnk.re);
3475181254a7Smrg evenFft[halfN - k].re = evenFft[k].re;
3476181254a7Smrg evenFft[k].im = 0.5 * (bufk.im - bufnk.im);
3477181254a7Smrg evenFft[halfN - k].im = -evenFft[k].im;
3478181254a7Smrg
3479181254a7Smrg oddFft[k].re = 0.5 * (bufk.im + bufnk.im);
3480181254a7Smrg oddFft[halfN - k].re = oddFft[k].re;
3481181254a7Smrg oddFft[k].im = 0.5 * (bufnk.re - bufk.re);
3482181254a7Smrg oddFft[halfN - k].im = -oddFft[k].im;
3483181254a7Smrg }
3484181254a7Smrg
3485181254a7Smrg butterfly(buf);
3486181254a7Smrg }
3487181254a7Smrg
butterfly(R)3488181254a7Smrg void butterfly(R)(R buf) const
3489181254a7Smrg in
3490181254a7Smrg {
3491181254a7Smrg assert(isPowerOf2(buf.length));
3492181254a7Smrg }
3493*b1e83836Smrg do
3494181254a7Smrg {
3495181254a7Smrg immutable n = buf.length;
3496181254a7Smrg immutable localLookup = negSinLookup[bsf(n)];
3497181254a7Smrg assert(localLookup.length == n);
3498181254a7Smrg
3499181254a7Smrg immutable cosMask = n - 1;
3500181254a7Smrg immutable cosAdd = n / 4 * 3;
3501181254a7Smrg
negSinFromLookup(size_t index)3502181254a7Smrg lookup_t negSinFromLookup(size_t index) pure nothrow
3503181254a7Smrg {
3504181254a7Smrg return localLookup[index];
3505181254a7Smrg }
3506181254a7Smrg
cosFromLookup(size_t index)3507181254a7Smrg lookup_t cosFromLookup(size_t index) pure nothrow
3508181254a7Smrg {
3509181254a7Smrg // cos is just -sin shifted by PI * 3 / 2.
3510181254a7Smrg return localLookup[(index + cosAdd) & cosMask];
3511181254a7Smrg }
3512181254a7Smrg
3513181254a7Smrg immutable halfLen = n / 2;
3514181254a7Smrg
3515181254a7Smrg // This loop is unrolled and the two iterations are interleaved
3516181254a7Smrg // relative to the textbook FFT to increase ILP. This gives roughly 5%
3517181254a7Smrg // speedups on DMD.
3518181254a7Smrg for (size_t k = 0; k < halfLen; k += 2)
3519181254a7Smrg {
3520181254a7Smrg immutable cosTwiddle1 = cosFromLookup(k);
3521181254a7Smrg immutable sinTwiddle1 = negSinFromLookup(k);
3522181254a7Smrg immutable cosTwiddle2 = cosFromLookup(k + 1);
3523181254a7Smrg immutable sinTwiddle2 = negSinFromLookup(k + 1);
3524181254a7Smrg
3525181254a7Smrg immutable realLower1 = buf[k].re;
3526181254a7Smrg immutable imagLower1 = buf[k].im;
3527181254a7Smrg immutable realLower2 = buf[k + 1].re;
3528181254a7Smrg immutable imagLower2 = buf[k + 1].im;
3529181254a7Smrg
3530181254a7Smrg immutable upperIndex1 = k + halfLen;
3531181254a7Smrg immutable upperIndex2 = upperIndex1 + 1;
3532181254a7Smrg immutable realUpper1 = buf[upperIndex1].re;
3533181254a7Smrg immutable imagUpper1 = buf[upperIndex1].im;
3534181254a7Smrg immutable realUpper2 = buf[upperIndex2].re;
3535181254a7Smrg immutable imagUpper2 = buf[upperIndex2].im;
3536181254a7Smrg
3537181254a7Smrg immutable realAdd1 = cosTwiddle1 * realUpper1
3538181254a7Smrg - sinTwiddle1 * imagUpper1;
3539181254a7Smrg immutable imagAdd1 = sinTwiddle1 * realUpper1
3540181254a7Smrg + cosTwiddle1 * imagUpper1;
3541181254a7Smrg immutable realAdd2 = cosTwiddle2 * realUpper2
3542181254a7Smrg - sinTwiddle2 * imagUpper2;
3543181254a7Smrg immutable imagAdd2 = sinTwiddle2 * realUpper2
3544181254a7Smrg + cosTwiddle2 * imagUpper2;
3545181254a7Smrg
3546181254a7Smrg buf[k].re += realAdd1;
3547181254a7Smrg buf[k].im += imagAdd1;
3548181254a7Smrg buf[k + 1].re += realAdd2;
3549181254a7Smrg buf[k + 1].im += imagAdd2;
3550181254a7Smrg
3551181254a7Smrg buf[upperIndex1].re = realLower1 - realAdd1;
3552181254a7Smrg buf[upperIndex1].im = imagLower1 - imagAdd1;
3553181254a7Smrg buf[upperIndex2].re = realLower2 - realAdd2;
3554181254a7Smrg buf[upperIndex2].im = imagLower2 - imagAdd2;
3555181254a7Smrg }
3556181254a7Smrg }
3557181254a7Smrg
3558181254a7Smrg // This constructor is used within this module for allocating the
3559181254a7Smrg // buffer space elsewhere besides the GC heap. It's definitely **NOT**
3560181254a7Smrg // part of the public API and definitely **IS** subject to change.
3561181254a7Smrg //
3562181254a7Smrg // Also, this is unsafe because the memSpace buffer will be cast
3563181254a7Smrg // to immutable.
3564*b1e83836Smrg //
3565*b1e83836Smrg // Public b/c of https://issues.dlang.org/show_bug.cgi?id=4636.
this(lookup_t[]memSpace)3566*b1e83836Smrg public this(lookup_t[] memSpace)
3567181254a7Smrg {
3568181254a7Smrg immutable size = memSpace.length / 2;
3569181254a7Smrg
3570181254a7Smrg /* Create a lookup table of all negative sine values at a resolution of
3571181254a7Smrg * size and all smaller power of two resolutions. This may seem
3572181254a7Smrg * inefficient, but having all the lookups be next to each other in
3573181254a7Smrg * memory at every level of iteration is a huge win performance-wise.
3574181254a7Smrg */
3575181254a7Smrg if (size == 0)
3576181254a7Smrg {
3577181254a7Smrg return;
3578181254a7Smrg }
3579181254a7Smrg
3580181254a7Smrg assert(isPowerOf2(size),
3581181254a7Smrg "Can only do FFTs on ranges with a size that is a power of two.");
3582181254a7Smrg
3583181254a7Smrg auto table = new lookup_t[][bsf(size) + 1];
3584181254a7Smrg
3585181254a7Smrg table[$ - 1] = memSpace[$ - size..$];
3586181254a7Smrg memSpace = memSpace[0 .. size];
3587181254a7Smrg
3588181254a7Smrg auto lastRow = table[$ - 1];
3589181254a7Smrg lastRow[0] = 0; // -sin(0) == 0.
3590181254a7Smrg foreach (ptrdiff_t i; 1 .. size)
3591181254a7Smrg {
3592181254a7Smrg // The hard coded cases are for improved accuracy and to prevent
3593181254a7Smrg // annoying non-zeroness when stuff should be zero.
3594181254a7Smrg
3595181254a7Smrg if (i == size / 4)
3596181254a7Smrg lastRow[i] = -1; // -sin(pi / 2) == -1.
3597181254a7Smrg else if (i == size / 2)
3598181254a7Smrg lastRow[i] = 0; // -sin(pi) == 0.
3599181254a7Smrg else if (i == size * 3 / 4)
3600181254a7Smrg lastRow[i] = 1; // -sin(pi * 3 / 2) == 1
3601181254a7Smrg else
3602181254a7Smrg lastRow[i] = -sin(i * 2.0L * PI / size);
3603181254a7Smrg }
3604181254a7Smrg
3605181254a7Smrg // Fill in all the other rows with strided versions.
3606181254a7Smrg foreach (i; 1 .. table.length - 1)
3607181254a7Smrg {
3608181254a7Smrg immutable strideLength = size / (2 ^^ i);
3609181254a7Smrg auto strided = Stride!(lookup_t[])(lastRow, strideLength);
3610181254a7Smrg table[i] = memSpace[$ - strided.length..$];
3611181254a7Smrg memSpace = memSpace[0..$ - strided.length];
3612181254a7Smrg
3613181254a7Smrg size_t copyIndex;
3614181254a7Smrg foreach (elem; strided)
3615181254a7Smrg {
3616181254a7Smrg table[i][copyIndex++] = elem;
3617181254a7Smrg }
3618181254a7Smrg }
3619181254a7Smrg
3620181254a7Smrg negSinLookup = cast(immutable) table;
3621181254a7Smrg }
3622181254a7Smrg
3623181254a7Smrg public:
3624*b1e83836Smrg /**Create an `Fft` object for computing fast Fourier transforms of
3625*b1e83836Smrg * power of two sizes of `size` or smaller. `size` must be a
3626181254a7Smrg * power of two.
3627181254a7Smrg */
this(size_t size)3628181254a7Smrg this(size_t size)
3629181254a7Smrg {
3630181254a7Smrg // Allocate all twiddle factor buffers in one contiguous block so that,
3631181254a7Smrg // when one is done being used, the next one is next in cache.
3632181254a7Smrg auto memSpace = uninitializedArray!(lookup_t[])(2 * size);
3633181254a7Smrg this(memSpace);
3634181254a7Smrg }
3635181254a7Smrg
size()3636181254a7Smrg @property size_t size() const
3637181254a7Smrg {
3638181254a7Smrg return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length;
3639181254a7Smrg }
3640181254a7Smrg
3641181254a7Smrg /**Compute the Fourier transform of range using the $(BIGOH N log N)
3642*b1e83836Smrg * Cooley-Tukey Algorithm. `range` must be a random-access range with
3643*b1e83836Smrg * slicing and a length equal to `size` as provided at the construction of
3644181254a7Smrg * this object. The contents of range can be either numeric types,
3645181254a7Smrg * which will be interpreted as pure real values, or complex types with
3646*b1e83836Smrg * properties or members `.re` and `.im` that can be read.
3647181254a7Smrg *
3648181254a7Smrg * Note: Pure real FFTs are automatically detected and the relevant
3649181254a7Smrg * optimizations are performed.
3650181254a7Smrg *
3651181254a7Smrg * Returns: An array of complex numbers representing the transformed data in
3652181254a7Smrg * the frequency domain.
3653181254a7Smrg *
3654181254a7Smrg * Conventions: The exponent is negative and the factor is one,
3655181254a7Smrg * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ].
3656181254a7Smrg */
3657181254a7Smrg Complex!F[] fft(F = double, R)(R range) const
3658181254a7Smrg if (isFloatingPoint!F && isRandomAccessRange!R)
3659181254a7Smrg {
3660181254a7Smrg enforceSize(range);
3661181254a7Smrg Complex!F[] ret;
3662181254a7Smrg if (range.length == 0)
3663181254a7Smrg {
3664181254a7Smrg return ret;
3665181254a7Smrg }
3666181254a7Smrg
3667181254a7Smrg // Don't waste time initializing the memory for ret.
3668181254a7Smrg ret = uninitializedArray!(Complex!F[])(range.length);
3669181254a7Smrg
3670181254a7Smrg fft(range, ret);
3671181254a7Smrg return ret;
3672181254a7Smrg }
3673181254a7Smrg
3674181254a7Smrg /**Same as the overload, but allows for the results to be stored in a user-
3675181254a7Smrg * provided buffer. The buffer must be of the same length as range, must be
3676181254a7Smrg * a random-access range, must have slicing, and must contain elements that are
3677181254a7Smrg * complex-like. This means that they must have a .re and a .im member or
3678181254a7Smrg * property that can be both read and written and are floating point numbers.
3679181254a7Smrg */
3680181254a7Smrg void fft(Ret, R)(R range, Ret buf) const
3681181254a7Smrg if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3682181254a7Smrg {
3683181254a7Smrg assert(buf.length == range.length);
3684181254a7Smrg enforceSize(range);
3685181254a7Smrg
3686181254a7Smrg if (range.length == 0)
3687181254a7Smrg {
3688181254a7Smrg return;
3689181254a7Smrg }
3690181254a7Smrg else if (range.length == 1)
3691181254a7Smrg {
3692181254a7Smrg buf[0] = range[0];
3693181254a7Smrg return;
3694181254a7Smrg }
3695181254a7Smrg else if (range.length == 2)
3696181254a7Smrg {
3697181254a7Smrg slowFourier2(range, buf);
3698181254a7Smrg return;
3699181254a7Smrg }
3700181254a7Smrg else
3701181254a7Smrg {
3702181254a7Smrg alias E = ElementType!R;
3703181254a7Smrg static if (is(E : real))
3704181254a7Smrg {
3705181254a7Smrg return fftImplPureReal(range, buf);
3706181254a7Smrg }
3707181254a7Smrg else
3708181254a7Smrg {
3709181254a7Smrg static if (is(R : Stride!R))
3710181254a7Smrg return fftImpl(range, buf);
3711181254a7Smrg else
3712181254a7Smrg return fftImpl(Stride!R(range, 1), buf);
3713181254a7Smrg }
3714181254a7Smrg }
3715181254a7Smrg }
3716181254a7Smrg
3717181254a7Smrg /**
3718181254a7Smrg * Computes the inverse Fourier transform of a range. The range must be a
3719181254a7Smrg * random access range with slicing, have a length equal to the size
3720181254a7Smrg * provided at construction of this object, and contain elements that are
3721181254a7Smrg * either of type std.complex.Complex or have essentially
3722181254a7Smrg * the same compile-time interface.
3723181254a7Smrg *
3724181254a7Smrg * Returns: The time-domain signal.
3725181254a7Smrg *
3726181254a7Smrg * Conventions: The exponent is positive and the factor is 1/N, i.e.,
3727181254a7Smrg * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ].
3728181254a7Smrg */
3729181254a7Smrg Complex!F[] inverseFft(F = double, R)(R range) const
3730181254a7Smrg if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F)
3731181254a7Smrg {
3732181254a7Smrg enforceSize(range);
3733181254a7Smrg Complex!F[] ret;
3734181254a7Smrg if (range.length == 0)
3735181254a7Smrg {
3736181254a7Smrg return ret;
3737181254a7Smrg }
3738181254a7Smrg
3739181254a7Smrg // Don't waste time initializing the memory for ret.
3740181254a7Smrg ret = uninitializedArray!(Complex!F[])(range.length);
3741181254a7Smrg
3742181254a7Smrg inverseFft(range, ret);
3743181254a7Smrg return ret;
3744181254a7Smrg }
3745181254a7Smrg
3746181254a7Smrg /**
3747181254a7Smrg * Inverse FFT that allows a user-supplied buffer to be provided. The buffer
3748181254a7Smrg * must be a random access range with slicing, and its elements
3749181254a7Smrg * must be some complex-like type.
3750181254a7Smrg */
3751181254a7Smrg void inverseFft(Ret, R)(R range, Ret buf) const
3752181254a7Smrg if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3753181254a7Smrg {
3754181254a7Smrg enforceSize(range);
3755181254a7Smrg
3756181254a7Smrg auto swapped = map!swapRealImag(range);
3757181254a7Smrg fft(swapped, buf);
3758181254a7Smrg
3759181254a7Smrg immutable lenNeg1 = 1.0 / buf.length;
foreach(ref elem;buf)3760181254a7Smrg foreach (ref elem; buf)
3761181254a7Smrg {
3762181254a7Smrg immutable temp = elem.re * lenNeg1;
3763181254a7Smrg elem.re = elem.im * lenNeg1;
3764181254a7Smrg elem.im = temp;
3765181254a7Smrg }
3766181254a7Smrg }
3767181254a7Smrg }
3768181254a7Smrg
3769181254a7Smrg // This mixin creates an Fft object in the scope it's mixed into such that all
3770181254a7Smrg // memory owned by the object is deterministically destroyed at the end of that
3771181254a7Smrg // scope.
3772181254a7Smrg private enum string MakeLocalFft = q{
3773181254a7Smrg import core.stdc.stdlib;
3774181254a7Smrg import core.exception : onOutOfMemoryError;
3775181254a7Smrg
3776181254a7Smrg auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof))
3777181254a7Smrg [0 .. 2 * range.length];
3778181254a7Smrg if (!lookupBuf.ptr)
3779181254a7Smrg onOutOfMemoryError();
3780181254a7Smrg
3781181254a7Smrg scope(exit) free(cast(void*) lookupBuf.ptr);
3782181254a7Smrg auto fftObj = scoped!Fft(lookupBuf);
3783181254a7Smrg };
3784181254a7Smrg
3785*b1e83836Smrg /**Convenience functions that create an `Fft` object, run the FFT or inverse
3786181254a7Smrg * FFT and return the result. Useful for one-off FFTs.
3787181254a7Smrg *
3788181254a7Smrg * Note: In addition to convenience, these functions are slightly more
3789181254a7Smrg * efficient than manually creating an Fft object for a single use,
3790181254a7Smrg * as the Fft object is deterministically destroyed before these
3791181254a7Smrg * functions return.
3792181254a7Smrg */
3793181254a7Smrg Complex!F[] fft(F = double, R)(R range)
3794181254a7Smrg {
3795181254a7Smrg mixin(MakeLocalFft);
3796181254a7Smrg return fftObj.fft!(F, R)(range);
3797181254a7Smrg }
3798181254a7Smrg
3799181254a7Smrg /// ditto
fft(Ret,R)3800181254a7Smrg void fft(Ret, R)(R range, Ret buf)
3801181254a7Smrg {
3802181254a7Smrg mixin(MakeLocalFft);
3803181254a7Smrg return fftObj.fft!(Ret, R)(range, buf);
3804181254a7Smrg }
3805181254a7Smrg
3806181254a7Smrg /// ditto
3807181254a7Smrg Complex!F[] inverseFft(F = double, R)(R range)
3808181254a7Smrg {
3809181254a7Smrg mixin(MakeLocalFft);
3810181254a7Smrg return fftObj.inverseFft!(F, R)(range);
3811181254a7Smrg }
3812181254a7Smrg
3813181254a7Smrg /// ditto
inverseFft(Ret,R)3814181254a7Smrg void inverseFft(Ret, R)(R range, Ret buf)
3815181254a7Smrg {
3816181254a7Smrg mixin(MakeLocalFft);
3817181254a7Smrg return fftObj.inverseFft!(Ret, R)(range, buf);
3818181254a7Smrg }
3819181254a7Smrg
3820181254a7Smrg @system unittest
3821181254a7Smrg {
3822181254a7Smrg import std.algorithm;
3823181254a7Smrg import std.conv;
3824181254a7Smrg import std.range;
3825181254a7Smrg // Test values from R and Octave.
3826181254a7Smrg auto arr = [1,2,3,4,5,6,7,8];
3827181254a7Smrg auto fft1 = fft(arr);
3828*b1e83836Smrg assert(isClose(map!"a.re"(fft1),
3829*b1e83836Smrg [36.0, -4, -4, -4, -4, -4, -4, -4], 1e-4));
3830*b1e83836Smrg assert(isClose(map!"a.im"(fft1),
3831*b1e83836Smrg [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568], 1e-4));
3832181254a7Smrg
3833181254a7Smrg auto fft1Retro = fft(retro(arr));
3834*b1e83836Smrg assert(isClose(map!"a.re"(fft1Retro),
3835*b1e83836Smrg [36.0, 4, 4, 4, 4, 4, 4, 4], 1e-4));
3836*b1e83836Smrg assert(isClose(map!"a.im"(fft1Retro),
3837*b1e83836Smrg [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568], 1e-4));
3838181254a7Smrg
3839181254a7Smrg auto fft1Float = fft(to!(float[])(arr));
3840*b1e83836Smrg assert(isClose(map!"a.re"(fft1), map!"a.re"(fft1Float)));
3841*b1e83836Smrg assert(isClose(map!"a.im"(fft1), map!"a.im"(fft1Float)));
3842181254a7Smrg
3843181254a7Smrg alias C = Complex!float;
3844181254a7Smrg auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10),
3845181254a7Smrg C(11,12), C(13,14), C(15,16)];
3846181254a7Smrg auto fft2 = fft(arr2);
3847*b1e83836Smrg assert(isClose(map!"a.re"(fft2),
3848*b1e83836Smrg [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137], 1e-4));
3849*b1e83836Smrg assert(isClose(map!"a.im"(fft2),
3850*b1e83836Smrg [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137], 1e-4));
3851181254a7Smrg
3852181254a7Smrg auto inv1 = inverseFft(fft1);
3853*b1e83836Smrg assert(isClose(map!"a.re"(inv1), arr, 1e-6));
3854181254a7Smrg assert(reduce!max(map!"a.im"(inv1)) < 1e-10);
3855181254a7Smrg
3856181254a7Smrg auto inv2 = inverseFft(fft2);
3857*b1e83836Smrg assert(isClose(map!"a.re"(inv2), map!"a.re"(arr2)));
3858*b1e83836Smrg assert(isClose(map!"a.im"(inv2), map!"a.im"(arr2)));
3859181254a7Smrg
3860181254a7Smrg // FFTs of size 0, 1 and 2 are handled as special cases. Test them here.
3861181254a7Smrg ushort[] empty;
3862181254a7Smrg assert(fft(empty) == null);
3863181254a7Smrg assert(inverseFft(fft(empty)) == null);
3864181254a7Smrg
3865181254a7Smrg real[] oneElem = [4.5L];
3866181254a7Smrg auto oneFft = fft(oneElem);
3867181254a7Smrg assert(oneFft.length == 1);
3868181254a7Smrg assert(oneFft[0].re == 4.5L);
3869181254a7Smrg assert(oneFft[0].im == 0);
3870181254a7Smrg
3871181254a7Smrg auto oneInv = inverseFft(oneFft);
3872181254a7Smrg assert(oneInv.length == 1);
3873*b1e83836Smrg assert(isClose(oneInv[0].re, 4.5));
3874*b1e83836Smrg assert(isClose(oneInv[0].im, 0, 0.0, 1e-10));
3875181254a7Smrg
3876181254a7Smrg long[2] twoElems = [8, 4];
3877181254a7Smrg auto twoFft = fft(twoElems[]);
3878181254a7Smrg assert(twoFft.length == 2);
3879*b1e83836Smrg assert(isClose(twoFft[0].re, 12));
3880*b1e83836Smrg assert(isClose(twoFft[0].im, 0, 0.0, 1e-10));
3881*b1e83836Smrg assert(isClose(twoFft[1].re, 4));
3882*b1e83836Smrg assert(isClose(twoFft[1].im, 0, 0.0, 1e-10));
3883181254a7Smrg auto twoInv = inverseFft(twoFft);
3884*b1e83836Smrg assert(isClose(twoInv[0].re, 8));
3885*b1e83836Smrg assert(isClose(twoInv[0].im, 0, 0.0, 1e-10));
3886*b1e83836Smrg assert(isClose(twoInv[1].re, 4));
3887*b1e83836Smrg assert(isClose(twoInv[1].im, 0, 0.0, 1e-10));
3888181254a7Smrg }
3889181254a7Smrg
3890181254a7Smrg // Swaps the real and imaginary parts of a complex number. This is useful
3891181254a7Smrg // for inverse FFTs.
swapRealImag(C)3892181254a7Smrg C swapRealImag(C)(C input)
3893181254a7Smrg {
3894181254a7Smrg return C(input.im, input.re);
3895181254a7Smrg }
3896181254a7Smrg
3897*b1e83836Smrg /** This function transforms `decimal` value into a value in the factorial number
3898*b1e83836Smrg system stored in `fac`.
3899*b1e83836Smrg
3900*b1e83836Smrg A factorial number is constructed as:
3901*b1e83836Smrg $(D fac[0] * 0! + fac[1] * 1! + ... fac[20] * 20!)
3902*b1e83836Smrg
3903*b1e83836Smrg Params:
3904*b1e83836Smrg decimal = The decimal value to convert into the factorial number system.
3905*b1e83836Smrg fac = The array to store the factorial number. The array is of size 21 as
3906*b1e83836Smrg `ulong.max` requires 21 digits in the factorial number system.
3907*b1e83836Smrg Returns:
3908*b1e83836Smrg A variable storing the number of digits of the factorial number stored in
3909*b1e83836Smrg `fac`.
3910*b1e83836Smrg */
decimalToFactorial(ulong decimal,ref ubyte[21]fac)3911*b1e83836Smrg size_t decimalToFactorial(ulong decimal, ref ubyte[21] fac)
3912*b1e83836Smrg @safe pure nothrow @nogc
3913*b1e83836Smrg {
3914*b1e83836Smrg import std.algorithm.mutation : reverse;
3915*b1e83836Smrg size_t idx;
3916*b1e83836Smrg
3917*b1e83836Smrg for (ulong i = 1; decimal != 0; ++i)
3918*b1e83836Smrg {
3919*b1e83836Smrg auto temp = decimal % i;
3920*b1e83836Smrg decimal /= i;
3921*b1e83836Smrg fac[idx++] = cast(ubyte)(temp);
3922*b1e83836Smrg }
3923*b1e83836Smrg
3924*b1e83836Smrg if (idx == 0)
3925*b1e83836Smrg {
3926*b1e83836Smrg fac[idx++] = cast(ubyte) 0;
3927*b1e83836Smrg }
3928*b1e83836Smrg
3929*b1e83836Smrg reverse(fac[0 .. idx]);
3930*b1e83836Smrg
3931*b1e83836Smrg // first digit of the number in factorial will always be zero
3932*b1e83836Smrg assert(fac[idx - 1] == 0);
3933*b1e83836Smrg
3934*b1e83836Smrg return idx;
3935*b1e83836Smrg }
3936*b1e83836Smrg
3937*b1e83836Smrg ///
3938*b1e83836Smrg @safe pure @nogc unittest
3939*b1e83836Smrg {
3940*b1e83836Smrg ubyte[21] fac;
3941*b1e83836Smrg size_t idx = decimalToFactorial(2982, fac);
3942*b1e83836Smrg
3943*b1e83836Smrg assert(fac[0] == 4);
3944*b1e83836Smrg assert(fac[1] == 0);
3945*b1e83836Smrg assert(fac[2] == 4);
3946*b1e83836Smrg assert(fac[3] == 1);
3947*b1e83836Smrg assert(fac[4] == 0);
3948*b1e83836Smrg assert(fac[5] == 0);
3949*b1e83836Smrg assert(fac[6] == 0);
3950*b1e83836Smrg }
3951*b1e83836Smrg
3952*b1e83836Smrg @safe pure unittest
3953*b1e83836Smrg {
3954*b1e83836Smrg ubyte[21] fac;
3955*b1e83836Smrg size_t idx = decimalToFactorial(0UL, fac);
3956*b1e83836Smrg assert(idx == 1);
3957*b1e83836Smrg assert(fac[0] == 0);
3958*b1e83836Smrg
3959*b1e83836Smrg fac[] = 0;
3960*b1e83836Smrg idx = 0;
3961*b1e83836Smrg idx = decimalToFactorial(ulong.max, fac);
3962*b1e83836Smrg assert(idx == 21);
3963*b1e83836Smrg auto t = [7, 11, 12, 4, 3, 15, 3, 5, 3, 5, 0, 8, 3, 5, 0, 0, 0, 2, 1, 1, 0];
foreach(i,it;fac[0..21])3964*b1e83836Smrg foreach (i, it; fac[0 .. 21])
3965*b1e83836Smrg {
3966*b1e83836Smrg assert(it == t[i]);
3967*b1e83836Smrg }
3968*b1e83836Smrg
3969*b1e83836Smrg fac[] = 0;
3970*b1e83836Smrg idx = decimalToFactorial(2982, fac);
3971*b1e83836Smrg
3972*b1e83836Smrg assert(idx == 7);
3973*b1e83836Smrg t = [4, 0, 4, 1, 0, 0, 0];
foreach(i,it;fac[0..idx])3974*b1e83836Smrg foreach (i, it; fac[0 .. idx])
3975*b1e83836Smrg {
3976*b1e83836Smrg assert(it == t[i]);
3977*b1e83836Smrg }
3978*b1e83836Smrg }
3979*b1e83836Smrg
3980181254a7Smrg private:
3981181254a7Smrg // The reasons I couldn't use std.algorithm were b/c its stride length isn't
3982181254a7Smrg // modifiable on the fly and because range has grown some performance hacks
3983181254a7Smrg // for powers of 2.
Stride(R)3984181254a7Smrg struct Stride(R)
3985181254a7Smrg {
3986181254a7Smrg import core.bitop : bsf;
3987181254a7Smrg Unqual!R range;
3988181254a7Smrg size_t _nSteps;
3989181254a7Smrg size_t _length;
3990181254a7Smrg alias E = ElementType!(R);
3991181254a7Smrg
3992181254a7Smrg this(R range, size_t nStepsIn)
3993181254a7Smrg {
3994181254a7Smrg this.range = range;
3995181254a7Smrg _nSteps = nStepsIn;
3996181254a7Smrg _length = (range.length + _nSteps - 1) / nSteps;
3997181254a7Smrg }
3998181254a7Smrg
3999181254a7Smrg size_t length() const @property
4000181254a7Smrg {
4001181254a7Smrg return _length;
4002181254a7Smrg }
4003181254a7Smrg
4004181254a7Smrg typeof(this) save() @property
4005181254a7Smrg {
4006181254a7Smrg auto ret = this;
4007181254a7Smrg ret.range = ret.range.save;
4008181254a7Smrg return ret;
4009181254a7Smrg }
4010181254a7Smrg
4011181254a7Smrg E opIndex(size_t index)
4012181254a7Smrg {
4013181254a7Smrg return range[index * _nSteps];
4014181254a7Smrg }
4015181254a7Smrg
4016181254a7Smrg E front() @property
4017181254a7Smrg {
4018181254a7Smrg return range[0];
4019181254a7Smrg }
4020181254a7Smrg
4021181254a7Smrg void popFront()
4022181254a7Smrg {
4023181254a7Smrg if (range.length >= _nSteps)
4024181254a7Smrg {
4025181254a7Smrg range = range[_nSteps .. range.length];
4026181254a7Smrg _length--;
4027181254a7Smrg }
4028181254a7Smrg else
4029181254a7Smrg {
4030181254a7Smrg range = range[0 .. 0];
4031181254a7Smrg _length = 0;
4032181254a7Smrg }
4033181254a7Smrg }
4034181254a7Smrg
4035181254a7Smrg // Pops half the range's stride.
4036181254a7Smrg void popHalf()
4037181254a7Smrg {
4038181254a7Smrg range = range[_nSteps / 2 .. range.length];
4039181254a7Smrg }
4040181254a7Smrg
4041181254a7Smrg bool empty() const @property
4042181254a7Smrg {
4043181254a7Smrg return length == 0;
4044181254a7Smrg }
4045181254a7Smrg
4046181254a7Smrg size_t nSteps() const @property
4047181254a7Smrg {
4048181254a7Smrg return _nSteps;
4049181254a7Smrg }
4050181254a7Smrg
4051181254a7Smrg void doubleSteps()
4052181254a7Smrg {
4053181254a7Smrg _nSteps *= 2;
4054181254a7Smrg _length /= 2;
4055181254a7Smrg }
4056181254a7Smrg
4057181254a7Smrg size_t nSteps(size_t newVal) @property
4058181254a7Smrg {
4059181254a7Smrg _nSteps = newVal;
4060181254a7Smrg
4061181254a7Smrg // Using >> bsf(nSteps) is a few cycles faster than / nSteps.
4062181254a7Smrg _length = (range.length + _nSteps - 1) >> bsf(nSteps);
4063181254a7Smrg return newVal;
4064181254a7Smrg }
4065181254a7Smrg }
4066181254a7Smrg
4067181254a7Smrg // Hard-coded base case for FFT of size 2. This is actually a TON faster than
4068181254a7Smrg // using a generic slow DFT. This seems to be the best base case. (Size 1
4069181254a7Smrg // can be coded inline as buf[0] = range[0]).
slowFourier2(Ret,R)4070181254a7Smrg void slowFourier2(Ret, R)(R range, Ret buf)
4071181254a7Smrg {
4072181254a7Smrg assert(range.length == 2);
4073181254a7Smrg assert(buf.length == 2);
4074181254a7Smrg buf[0] = range[0] + range[1];
4075181254a7Smrg buf[1] = range[0] - range[1];
4076181254a7Smrg }
4077181254a7Smrg
4078181254a7Smrg // Hard-coded base case for FFT of size 4. Doesn't work as well as the size
4079181254a7Smrg // 2 case.
slowFourier4(Ret,R)4080181254a7Smrg void slowFourier4(Ret, R)(R range, Ret buf)
4081181254a7Smrg {
4082181254a7Smrg alias C = ElementType!Ret;
4083181254a7Smrg
4084181254a7Smrg assert(range.length == 4);
4085181254a7Smrg assert(buf.length == 4);
4086181254a7Smrg buf[0] = range[0] + range[1] + range[2] + range[3];
4087181254a7Smrg buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1);
4088181254a7Smrg buf[2] = range[0] - range[1] + range[2] - range[3];
4089181254a7Smrg buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1);
4090181254a7Smrg }
4091181254a7Smrg
4092181254a7Smrg N roundDownToPowerOf2(N)(N num)
4093181254a7Smrg if (isScalarType!N && !isFloatingPoint!N)
4094181254a7Smrg {
4095181254a7Smrg import core.bitop : bsr;
4096181254a7Smrg return num & (cast(N) 1 << bsr(num));
4097181254a7Smrg }
4098181254a7Smrg
4099181254a7Smrg @safe unittest
4100181254a7Smrg {
4101181254a7Smrg assert(roundDownToPowerOf2(7) == 4);
4102181254a7Smrg assert(roundDownToPowerOf2(4) == 4);
4103181254a7Smrg }
4104181254a7Smrg
isComplexLike(T)4105181254a7Smrg template isComplexLike(T)
4106181254a7Smrg {
4107181254a7Smrg enum bool isComplexLike = is(typeof(T.init.re)) &&
4108181254a7Smrg is(typeof(T.init.im));
4109181254a7Smrg }
4110181254a7Smrg
4111181254a7Smrg @safe unittest
4112181254a7Smrg {
4113181254a7Smrg static assert(isComplexLike!(Complex!double));
4114181254a7Smrg static assert(!isComplexLike!(uint));
4115181254a7Smrg }
4116