Thanks for reaching out. I have not encountered this error yet. As we do not do anything especially low-level in Jax, the issue probably lies somewhere downstream (i.e., in Jax or Python). Do the other backends (TensorFlow and PyTorch) work with the GPU for you?
Regarding debugging, you can take a look at this Jax issue which also dealt with a segmentation fault, it gives some instructions on debugging. There it turned out to be a problem with Python itself. Could you try a different Python version (e.g. via conda) and see if the problems persist across versions?
I switched to python 3.10 and everything works fine. Thanks for pointing out the Jax issue. It definitely came from my configuration, even though, jax alone code worked well (for example the code from the Jax issue would run on GPU with no problem).
I don’t have time to investigate this bug and find the source. So for the moment I will stay with Python 3.10 !