Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edge behavior in jax.scipy.special.betainc #21900

Open
mdhaber opened this issue Jun 15, 2024 · 6 comments
Open

Edge behavior in jax.scipy.special.betainc #21900

mdhaber opened this issue Jun 15, 2024 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@mdhaber
Copy link

mdhaber commented Jun 15, 2024

Description

jax.scipy.special.betainc seems to have trouble with very small values of the parameter a, at least for certain values of b and x.

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()
image

I know that it is difficult to guarantee accuracy to machine precision for all possible combinations of input : ) Just thought I'd point out this problem spot since it came up in SciPy testing (scipy/scipy#20963).

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='e901fac133dc', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')

@rajasekharporeddy
Copy link
Contributor

rajasekharporeddy commented Jun 17, 2024

Hi @mdhaber

JAX typically uses single-precision floating-point numbers for calculations, while SciPY defaults to double precision. This difference in precision can lead to slightly different results, especially when working with very small numbers. If the double precision is enabled in JAX, then JAX yields the results that are consistent with SciPy even with very small numbers:

import jax

jax.config.update('jax_enable_x64', True)

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

image

Please find the gist for reference.

Thank you.

@mdhaber
Copy link
Author

mdhaber commented Jun 17, 2024

Thanks! Although this actually came up in the context of 32-bit calculations. The definitions should have been:

a = np.logspace(-40, -1, 300, dtype=np.float32)
b = np.float32(1.)
x = np.float32(0.25)

and the plot looks the same. SciPy's outputs are float32, so I assume that's being preserved internally, although perhaps it is converting back and forth.
In any case, I know the trouble area is toward the small end of normal numbers and extends into the subnormals, so I understand if it's not a priority. Feel free to close!


To zoom in:

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a0 = np.finfo(np.float32).smallest_normal
b = np.float32(1.)
x = np.float32(0.25)
factor = np.float32(10)
a = np.logspace(np.log10(a0), np.log10(a0*factor), 300, dtype=np.float32)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

image

@rajasekharporeddy
Copy link
Contributor

rajasekharporeddy commented Jun 24, 2024

Hi @mdhaber

IIUC, according to scipy/scipy/#8495 (Comment), SciPy do all the internal (low level c) calculations in float64 even if the input is float32 or other. But JAX do it in float32 itself. That might be causing this difference.

Thank you.

@pearu
Copy link
Collaborator

pearu commented Jun 24, 2024

Whatever are the reasons for scipy to use float64 internally (one practical reason could be that there are no float32 implementation available for scipy, for instance), evaluating functions using float32 correctly requires the usage of an algorithm that can properly handle overflows, underflows, or cancellations. Using higher precision is a typical cheap trick to avoid paying attention to these fp errors in implementations of the function algorithms to keep algorithms simple.
So, I wonder what is the location of jax.scipy.special.betainc implementation which may provide explanations for the behavior observed in this issue.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 24, 2024

JAX's implementation is here, and mentions that it's based on http://dlmf.nist.gov/8.17.E23:

jax/jax/_src/lax/special.py

Lines 182 to 190 in 2b728d5

def regularized_incomplete_beta_impl(a, b, x, *, dtype):
shape = a.shape
def nth_partial_betainc_numerator(iteration, a, b, x):
"""
The partial numerator for the incomplete beta function is given
here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
case: the partial numerator for the first iteration is one.
"""

@mdhaber
Copy link
Author

mdhaber commented Jun 24, 2024

SciPy do all the internal (low level c) calculations in float64 even if the input is float32 or other.

But that comment is about scipy.ndimage.affine_transform, not scipy.special.betainc.

I confirmed with @steppi that SciPy now uses Boost's ibeta for betainc, and the types seem to be preserved in the calculation.

Here is where betainc is defined in terms of ibeta.
https://github.com/scipy/scipy/blob/e36e728081475466d2faae65e1dfecfa2314c857/scipy/special/functions.json#L118-L123

Here is where ibeta is used for float and double instantiations of the function.
https://github.com/scipy/scipy/blob/e36e728081475466d2faae65e1dfecfa2314c857/scipy/special/boost_special_functions.h#L106-L116

and Boost's ibeta is templated:
https://beta.boost.org/doc/libs/1_68_0/libs/math/doc/html/math_toolkit/sf_beta/ibeta_function.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants