Segmentation fault with jax

I tried running the “Two Moons: Tackling Bimodal Posteriors” with jax + gpu but it failed when calling fit_offline and raise a Segmentation Fault.

I then reimported bayesflow like this :

os.environ[“JAX_PLATFORMS”] = “cpu”
os.environ[“CUDA_VISIBLE_DEVICES”] = “-1”
os.environ[“KERAS_BACKEND”] = “jax”
import bayesflow as bf

To force CPU usage and everything works well. I then tried to see if my jax installation was working on gpu :

import jax

gpu_device = jax.devices(‘gpu’)[0]
cpu_device = jax.devices(‘cpu’)[0]

def my_function(x):
return x.sum()

x = jax.numpy.arange(10)

x_gpu = jax.jit(my_function, device=gpu_device)(x)
print(x_gpu.device)
#gpu:0

x_cpu = jax.jit(my_function, device=cpu_device)(x)
print(x_cpu.device)
#TFRT_CPU_0

And everything works well.

Did anybody experienced the same problem ? How can I better debug the seg fault ?

Thanks for reaching out. I have not encountered this error yet. As we do not do anything especially low-level in Jax, the issue probably lies somewhere downstream (i.e., in Jax or Python). Do the other backends (TensorFlow and PyTorch) work with the GPU for you?
Regarding debugging, you can take a look at this Jax issue which also dealt with a segmentation fault, it gives some instructions on debugging. There it turned out to be a problem with Python itself. Could you try a different Python version (e.g. via conda) and see if the problems persist across versions?

I switched to python 3.10 and everything works fine. Thanks for pointing out the Jax issue. It definitely came from my configuration, even though, jax alone code worked well (for example the code from the Jax issue would run on GPU with no problem).

I don’t have time to investigate this bug and find the source. So for the moment I will stay with Python 3.10 !

Thanks for replying !

Best

1 Like