-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Edge behavior in jax.scipy.special.betainc
#21900
Comments
Hi @mdhaber JAX typically uses single-precision floating-point numbers for calculations, while SciPY defaults to double precision. This difference in precision can lead to slightly different results, especially when working with very small numbers. If the double precision is enabled in JAX, then JAX yields the results that are consistent with SciPy even with very small numbers: import jax
jax.config.update('jax_enable_x64', True)
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax
a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend() Please find the gist for reference. Thank you. |
Thanks! Although this actually came up in the context of 32-bit calculations. The definitions should have been: a = np.logspace(-40, -1, 300, dtype=np.float32)
b = np.float32(1.)
x = np.float32(0.25) and the plot looks the same. SciPy's outputs are To zoom in: import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax
a0 = np.finfo(np.float32).smallest_normal
b = np.float32(1.)
x = np.float32(0.25)
factor = np.float32(10)
a = np.logspace(np.log10(a0), np.log10(a0*factor), 300, dtype=np.float32)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend() |
Hi @mdhaber IIUC, according to scipy/scipy/#8495 (Comment), SciPy do all the internal (low level c) calculations in Thank you. |
Whatever are the reasons for scipy to use float64 internally (one practical reason could be that there are no float32 implementation available for scipy, for instance), evaluating functions using float32 correctly requires the usage of an algorithm that can properly handle overflows, underflows, or cancellations. Using higher precision is a typical cheap trick to avoid paying attention to these fp errors in implementations of the function algorithms to keep algorithms simple. |
JAX's implementation is here, and mentions that it's based on http://dlmf.nist.gov/8.17.E23: Lines 182 to 190 in 2b728d5
|
But that comment is about I confirmed with @steppi that SciPy now uses Boost's Here is where Here is where and Boost's |
Description
jax.scipy.special.betainc
seems to have trouble with very small values of the parametera
, at least for certain values ofb
andx
.I know that it is difficult to guarantee accuracy to machine precision for all possible combinations of input : ) Just thought I'd point out this problem spot since it came up in SciPy testing (scipy/scipy#20963).
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='e901fac133dc', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: