1 /**
2 * Transcendental bonus functions.
3 *
4 * Copyright: Copyright Guillaumr Piolat 2016-2020.
5 *            Copyright (C) 2007  Julien Pommier
6 * License:   $(LINK2 http://www.boost.org/LICENSE_1_0.txt, Boost License 1.0)
7 */
8 module inteli.math;
9 
10 /* Copyright (C) 2007  Julien Pommier
11 
12   This software is provided 'as-is', without any express or implied
13   warranty.  In no event will the authors be held liable for any damages
14   arising from the use of this software.
15 
16   Permission is granted to anyone to use this software for any purpose,
17   including commercial applications, and to alter it and redistribute it
18   freely, subject to the following restrictions:
19 
20   1. The origin of this software must not be misrepresented; you must not
21      claim that you wrote the original software. If you use this software
22      in a product, an acknowledgment in the product documentation would be
23      appreciated but is not required.
24   2. Altered source versions must be plainly marked as such, and must not be
25      misrepresented as being the original software.
26   3. This notice may not be removed or altered from any source distribution.
27 
28   (this is the zlib license)
29 */
30 import inteli.emmintrin;
31 import inteli.internals;
32 
33 nothrow @nogc:
34 
35 /// Natural `log` computed for a single 32-bit float.
36 /// This is an approximation, valid up to approximately -119dB of accuracy, on the range -inf..50
37 /// IMPORTANT: NaN, zero, or infinity input not supported properly. x must be > 0 and finite.
38 // #BONUS
39 float _mm_log_ss(float v) pure @safe
40 {
41     __m128 r = _mm_log_ps(_mm_set1_ps(v));
42     return r.array[0];
43 }
44 
45 /// Natural logarithm computed for 4 simultaneous float.
46 /// This is an approximation, valid up to approximately -119dB of accuracy, on the range -inf..50
47 /// IMPORTANT: NaN, zero, or infinity input not supported properly. x must be > 0 and finite.
48 // #BONUS
49 __m128 _mm_log_ps(__m128 x) pure @safe
50 {
51     static immutable __m128i _psi_inv_mant_mask = [~0x7f800000, ~0x7f800000, ~0x7f800000, ~0x7f800000];
52     static immutable __m128 _ps_cephes_SQRTHF = [0.707106781186547524, 0.707106781186547524, 0.707106781186547524, 0.707106781186547524];
53     static immutable __m128 _ps_cephes_log_p0 = [7.0376836292E-2, 7.0376836292E-2, 7.0376836292E-2, 7.0376836292E-2];
54     static immutable __m128 _ps_cephes_log_p1 = [- 1.1514610310E-1, - 1.1514610310E-1, - 1.1514610310E-1, - 1.1514610310E-1];
55     static immutable __m128 _ps_cephes_log_p2 = [1.1676998740E-1, 1.1676998740E-1, 1.1676998740E-1, 1.1676998740E-1];
56     static immutable __m128 _ps_cephes_log_p3 = [- 1.2420140846E-1, - 1.2420140846E-1, - 1.2420140846E-1, - 1.2420140846E-1];
57     static immutable __m128 _ps_cephes_log_p4 = [+ 1.4249322787E-1, + 1.4249322787E-1, + 1.4249322787E-1, + 1.4249322787E-1];
58     static immutable __m128 _ps_cephes_log_p5 = [- 1.6668057665E-1, - 1.6668057665E-1, - 1.6668057665E-1, - 1.6668057665E-1];
59     static immutable __m128 _ps_cephes_log_p6 = [+ 2.0000714765E-1, + 2.0000714765E-1, + 2.0000714765E-1, + 2.0000714765E-1];
60     static immutable __m128 _ps_cephes_log_p7 = [- 2.4999993993E-1, - 2.4999993993E-1, - 2.4999993993E-1, - 2.4999993993E-1];
61     static immutable __m128 _ps_cephes_log_p8 = [+ 3.3333331174E-1, + 3.3333331174E-1, + 3.3333331174E-1, + 3.3333331174E-1];
62     static immutable __m128 _ps_cephes_log_q1 = [-2.12194440e-4, -2.12194440e-4, -2.12194440e-4, -2.12194440e-4];
63     static immutable __m128 _ps_cephes_log_q2 = [0.693359375, 0.693359375, 0.693359375, 0.693359375];
64 
65     /* the smallest non denormalized float number */
66     static immutable __m128i _psi_min_norm_pos  = [0x00800000,   0x00800000,   0x00800000, 0x00800000];
67 
68     __m128i emm0;
69     __m128 one = _ps_1;
70     __m128 invalid_mask = _mm_cmple_ps(x, _mm_setzero_ps());
71     x = _mm_max_ps(x, cast(__m128)_psi_min_norm_pos);  /* cut off denormalized stuff */
72     emm0 = _mm_srli_epi32(cast(__m128i)x, 23);
73 
74     /* keep only the fractional part */
75     x = _mm_and_ps(x, cast(__m128)_psi_inv_mant_mask);
76     x = _mm_or_ps(x, _ps_0p5);
77 
78     emm0 = _mm_sub_epi32(emm0, _pi32_0x7f);
79     __m128 e = _mm_cvtepi32_ps(emm0);
80     e += one;
81     __m128 mask = _mm_cmplt_ps(x, _ps_cephes_SQRTHF);
82     __m128 tmp = _mm_and_ps(x, mask);
83     x -= one;
84     e -= _mm_and_ps(one, mask);
85     x += tmp;
86     __m128 z = x * x;
87     __m128 y = _ps_cephes_log_p0;
88     y = _mm_mul_ps(y, x);
89     y = _mm_add_ps(y, _ps_cephes_log_p1);
90     y = _mm_mul_ps(y, x);
91     y = _mm_add_ps(y, _ps_cephes_log_p2);
92     y = _mm_mul_ps(y, x);
93     y = _mm_add_ps(y, _ps_cephes_log_p3);
94     y = _mm_mul_ps(y, x);
95     y = _mm_add_ps(y, _ps_cephes_log_p4);
96     y = _mm_mul_ps(y, x);
97     y = _mm_add_ps(y, _ps_cephes_log_p5);
98     y = _mm_mul_ps(y, x);
99     y = _mm_add_ps(y, _ps_cephes_log_p6);
100     y = _mm_mul_ps(y, x);
101     y = _mm_add_ps(y, _ps_cephes_log_p7);
102     y = _mm_mul_ps(y, x);
103     y = _mm_add_ps(y, _ps_cephes_log_p8);
104     y = _mm_mul_ps(y, x);
105     y = _mm_mul_ps(y, z);
106     tmp = _mm_mul_ps(e, _ps_cephes_log_q1);
107     y = _mm_add_ps(y, tmp);
108     tmp = _mm_mul_ps(z, _ps_0p5);
109     y = _mm_sub_ps(y, tmp);
110     tmp = _mm_mul_ps(e, _ps_cephes_log_q2);
111     x = _mm_add_ps(x, y);
112     x = _mm_add_ps(x, tmp);
113     x = _mm_or_ps(x, invalid_mask); // negative arg will be NAN
114     return x;
115 }
116 
117 /// Natural `exp` computed for a single float.
118 /// This is an approximation, valid up to approximately -109dB of accuracy
119 /// IMPORTANT: NaN input not supported.
120 // #BONUS
121 float _mm_exp_ss(float v) pure @safe
122 {
123     __m128 r = _mm_exp_ps(_mm_set1_ps(v));
124     return r.array[0];
125 }
126 
127 /// Natural `exp` computed for 4 simultaneous float in `x`.
128 /// This is an approximation, valid up to approximately -109dB of accuracy
129 /// IMPORTANT: NaN input not supported.
130 // #BONUS
131 __m128 _mm_exp_ps(__m128 x) pure @safe
132 {
133     static immutable __m128 _ps_exp_hi         = [88.3762626647949f, 88.3762626647949f, 88.3762626647949f, 88.3762626647949f];
134     static immutable __m128 _ps_exp_lo         = [-88.3762626647949f, -88.3762626647949f, -88.3762626647949f, -88.3762626647949f];
135     static immutable __m128 _ps_cephes_LOG2EF  = [1.44269504088896341, 1.44269504088896341, 1.44269504088896341, 1.44269504088896341];
136     static immutable __m128 _ps_cephes_exp_C1  = [0.693359375, 0.693359375, 0.693359375, 0.693359375];
137     static immutable __m128 _ps_cephes_exp_C2  = [-2.12194440e-4, -2.12194440e-4, -2.12194440e-4, -2.12194440e-4];
138     static immutable __m128 _ps_cephes_exp_p0  = [1.9875691500E-4, 1.9875691500E-4, 1.9875691500E-4, 1.9875691500E-4];
139     static immutable __m128 _ps_cephes_exp_p1  = [1.3981999507E-3, 1.3981999507E-3, 1.3981999507E-3, 1.3981999507E-3];
140     static immutable __m128 _ps_cephes_exp_p2  = [8.3334519073E-3, 8.3334519073E-3, 8.3334519073E-3, 8.3334519073E-3];
141     static immutable __m128 _ps_cephes_exp_p3  = [4.1665795894E-2, 4.1665795894E-2, 4.1665795894E-2, 4.1665795894E-2];
142     static immutable __m128 _ps_cephes_exp_p4  = [1.6666665459E-1, 1.6666665459E-1, 1.6666665459E-1, 1.6666665459E-1];
143     static immutable __m128 _ps_cephes_exp_p5  = [5.0000001201E-1, 5.0000001201E-1, 5.0000001201E-1, 5.0000001201E-1];
144 
145     __m128 tmp = _mm_setzero_ps(), fx;
146     __m128i emm0;
147     __m128 one = _ps_1;
148 
149     x = _mm_min_ps(x, _ps_exp_hi);
150     x = _mm_max_ps(x, _ps_exp_lo);
151 
152     /* express exp(x) as exp(g + n*log(2)) */
153     fx = _mm_mul_ps(x, _ps_cephes_LOG2EF);
154     fx = _mm_add_ps(fx, _ps_0p5);
155 
156     /* how to perform a floorf with SSE: just below */
157     emm0 = _mm_cvttps_epi32(fx);
158     tmp  = _mm_cvtepi32_ps(emm0);
159 
160     /* if greater, substract 1 */
161     __m128 mask = _mm_cmpgt_ps(tmp, fx);
162     mask = _mm_and_ps(mask, one);
163     fx = tmp - mask;
164 
165     tmp = _mm_mul_ps(fx, _ps_cephes_exp_C1);
166     __m128 z = _mm_mul_ps(fx, _ps_cephes_exp_C2);
167     x -= tmp;
168     x -= z;
169 
170     z = x * x;
171 
172     __m128 y = _ps_cephes_exp_p0;
173     y = _mm_mul_ps(y, x);
174     y = _mm_add_ps(y, _ps_cephes_exp_p1);
175     y = _mm_mul_ps(y, x);
176     y = _mm_add_ps(y, _ps_cephes_exp_p2);
177     y = _mm_mul_ps(y, x);
178     y = _mm_add_ps(y, _ps_cephes_exp_p3);
179     y = _mm_mul_ps(y, x);
180     y = _mm_add_ps(y, _ps_cephes_exp_p4);
181     y = _mm_mul_ps(y, x);
182     y = _mm_add_ps(y, _ps_cephes_exp_p5);
183     y = _mm_mul_ps(y, z);
184     y = _mm_add_ps(y, x);
185     y += one;
186 
187     /* build 2^n */
188     emm0 = _mm_cvttps_epi32(fx);
189 
190     emm0 = _mm_add_epi32(emm0, _pi32_0x7f);
191     emm0 = _mm_slli_epi32(emm0, 23);
192     __m128 pow2n = cast(__m128)emm0;
193     y *= pow2n;
194     return y;
195 }
196 
197 /// Computes `base^exponent` for a single 32-bit float.
198 /// This is an approximation, valid up to approximately -100dB of accuracy
199 /// IMPORTANT: NaN, zero, or infinity input not supported properly. x must be > 0 and finite.
200 // #BONUS
201 float _mm_pow_ss(float base, float exponent) pure @safe
202 {
203     __m128 r = _mm_pow_ps(_mm_set1_ps(base), _mm_set1_ps(exponent));
204     return r.array[0];
205 }
206 
207 /// Computes `base^exponent`, for 4 floats at once.
208 /// This is an approximation, valid up to approximately -100dB of accuracy
209 /// IMPORTANT: NaN, zero, or infinity input not supported properly. x must be > 0 and finite.
210 // #BONUS
211 __m128 _mm_pow_ps(__m128 base, __m128 exponents) pure @safe
212 {
213     return _mm_exp_ps(exponents * _mm_log_ps(base));
214 }
215 
216 /// Computes `base^exponent`, for 4 floats at once.
217 /// This is an approximation, valid up to approximately -100dB of accuracy
218 /// IMPORTANT: NaN, zero, or infinity input not supported properly. x must be > 0 and finite.
219 // #BONUS
220 __m128 _mm_pow_ps(__m128 base, float exponent) pure @safe
221 {
222     return _mm_exp_ps(_mm_set1_ps(exponent) * _mm_log_ps(base));
223 }
224 
225 unittest
226 {
227     import std.math;
228 
229     bool approxEquals(double groundTruth, double approx, double epsilon) pure @trusted @nogc nothrow
230     {
231         if (!isFinite(groundTruth))
232             return true; // no need to approximate where this is NaN or infinite
233 
234         if (groundTruth == 0) // the approximaton should produce zero too if needed
235         {
236             return approx == 0;
237         }
238 
239         if (approx == 0)
240         {
241             // If the approximation produces zero, the error should be below 140 dB
242             return ( abs(groundTruth) < 1e-7 );
243         }
244 
245         if ( ( abs(groundTruth / approx) - 1 ) >= epsilon)
246         {
247             import core.stdc.stdio;
248             debug printf("approxEquals (%g, %g, %g) failed\n", groundTruth, approx, epsilon);
249             debug printf("ratio is %f\n", abs(groundTruth / approx) - 1);
250         }
251 
252         return ( abs(groundTruth / approx) - 1 ) < epsilon;
253     }
254 
255     // test _mm_log_ps
256     for (double mantissa = 0.1; mantissa < 1.0; mantissa += 0.05)
257     {
258         foreach (exponent; -23..23)
259         {
260             double x = mantissa * 2.0 ^^ exponent;
261             double phobosValue = log(x);
262             __m128 v = _mm_log_ps(_mm_set1_ps(x));
263             foreach(i; 0..4)
264                 assert(approxEquals(phobosValue, v.array[i], 1.1e-6));
265         }
266     }
267 
268     // test _mm_exp_ps    
269     for (double mantissa = -1.0; mantissa < 1.0; mantissa += 0.1)
270     {
271         foreach (exponent; -23..23)
272         {
273             double x = mantissa * 2.0 ^^ exponent;
274 
275             // don't test too high numbers because they saturate FP precision pretty fast
276             if (x > 50) continue;
277 
278             double phobosValue = exp(x);
279             __m128 v = _mm_exp_ps(_mm_set1_ps(x));
280             foreach(i; 0..4)
281             {
282                 if (!approxEquals(phobosValue, v.array[i], 3.4e-6))
283                 {
284                     import core.stdc.stdio;
285                     printf("x = %f   truth = %f vs estimate = %fn", x, phobosValue, v.array[i]);
286                     assert(false);
287                 }
288             }
289         }
290     }
291 
292     // test than exp(-inf) is 0
293     {
294         __m128 R = _mm_exp_ps(_mm_set1_ps(-float.infinity));
295         float[4] correct = [0.0f, 0.0f, 0.0f, 0.0f];
296         assert(R.array == correct);
297     }
298 
299     // test log baheviour with NaN and infinities
300     // the only guarantee for now is that _mm_log_ps(negative) yield a NaN
301     {
302         __m128 R = _mm_log_ps(_mm_setr_ps(+0.0f, -0.0f, -1.0f, float.nan));
303       // DOESN'T PASS
304       //  assert(isInfinity(R[0]) && R[0] < 0); // log(+0.0f) = -infinity
305       // DOESN'T PASS
306       //  assert(isInfinity(R[1]) && R[1] < 0); // log(-0.0f) = -infinity
307         assert(isNaN(R.array[2])); // log(negative number) = NaN
308 
309         // DOESN'T PASS
310         //assert(isNaN(R[3])); // log(NaN) = NaN
311     }
312 
313 
314     // test _mm_pow_ps
315     for (double mantissa = -1.0; mantissa < 1.0; mantissa += 0.1)
316     {
317         foreach (exponent; -8..4)
318         {
319             double powExponent = mantissa * 2.0 ^^ exponent;
320 
321             for (double mantissa2 = 0.1; mantissa2 < 1.0; mantissa2 += 0.1)
322             {
323                 foreach (exponent2; -4..4)
324                 {
325                     double powBase = mantissa2 * 2.0 ^^ exponent2;
326                     double phobosValue = pow(powBase, powExponent);
327                     float fPhobos = phobosValue;
328                     if (!isFinite(fPhobos)) continue;
329                      __m128 v = _mm_pow_ps(_mm_set1_ps(powBase), _mm_set1_ps(powExponent));
330 
331                     foreach(i; 0..4)
332                     {
333                         if (!approxEquals(phobosValue, v.array[i], 1e-5))
334                         {
335                             printf("%g ^^ %g\n", powBase, powExponent);
336                             assert(false);
337                         }
338                     }
339                 }
340             }
341         }
342     }
343 }
344 
345 private:
346 
347 static immutable __m128 _ps_1   = [1.0f, 1.0f, 1.0f, 1.0f];
348 static immutable __m128 _ps_0p5 = [0.5f, 0.5f, 0.5f, 0.5f];
349 static immutable __m128i _pi32_0x7f = [0x7f, 0x7f, 0x7f, 0x7f];