Hi Noura, I am not sure why the post was stuck and needed extra steps, but I fixed it now. Sorry for the delay. Unfortunately, the Trainer in v1 does not have out of the box support for multi-GPU training. Multi-GPU training is something we are currently working on for v2, so I am linking @LarsKue.
Thank you, Lars. I’m wondering whether v2 allows GPU-accelerated simulators written in JAX, and if so, whether there are any available examples I can follow?
I am also interested to know if Bayesflow supports custom-made GPU-accelerated samplers written in JAX / TENSROFLOW / KERAS / etc, and if there is any example on the management of JAX keys for pseudo randomness.
we currently encourage using CPU-side simulators, since this frees up the GPU to do training. The best way to currently implement a GPU-accelerated simulator is to run the simulations on the GPU using your deep learning backend of choice, and then move the tensors to numpy before returning them to bayesflow in the following fashion:
def gpu_simulator():
x = torch.randn((100,), device="cuda")
return {"x": x.cpu().numpy()}
I understand that this is suboptimal, as it creates a memory move bottleneck in the simulator. If this is a significant issue for you, we can look into better support for GPU-accelerated simulators.
Thank you for your reply. Can you please help me with an example where you create a bayesflow simulator using prior and likelihood functions which have batched outputs? This is because the simulation is quite computationally heavy. I suspect I need the Lambda Simulator class, but it’s not clear to me how I should wrap the two functions and how to pass them to bayesflow. Thanks!
Hi Dan, for heavy simulators, the recommended approach would be to pre-simulate a large number of cases and train with the .fit_disk method of a BasicWorkflow.