Enabling multi‐GPU (2× GPUs) training in BayesFlow

Hi BayesFlow team,

I’m running BayesFlow v1’s train_online on a single GPU but would like to leverage two GPUs (for both simulation-heavy and neural network phases).

Does the Trainer offer built in data, or model parallel multi GPU support?

If not, what’s the recommended way to wrap the simulation loop and amortizer training to run across multiple GPUs?

Any example code for multi-GPU BayesFlow training would be greatly appreciated.

Hi Noura, I am not sure why the post was stuck and needed extra steps, but I fixed it now. Sorry for the delay. Unfortunately, the Trainer in v1 does not have out of the box support for multi-GPU training. Multi-GPU training is something we are currently working on for v2, so I am linking @LarsKue.

1 Like

v2 should already support both model and data parallel to some degree, using JAX. I will add an example once I find the time.

1 Like

Thank you, Lars. I’m wondering whether v2 allows GPU-accelerated simulators written in JAX, and if so, whether there are any available examples I can follow?

Hi Lars,

I am also interested to know if Bayesflow supports custom-made GPU-accelerated samplers written in JAX / TENSROFLOW / KERAS / etc, and if there is any example on the management of JAX keys for pseudo randomness.

Thanks,

Dan

Hi Noura and Dan,

we currently encourage using CPU-side simulators, since this frees up the GPU to do training. The best way to currently implement a GPU-accelerated simulator is to run the simulations on the GPU using your deep learning backend of choice, and then move the tensors to numpy before returning them to bayesflow in the following fashion:

def gpu_simulator():
    x = torch.randn((100,), device="cuda")
    return {"x": x.cpu().numpy()}

I understand that this is suboptimal, as it creates a memory move bottleneck in the simulator. If this is a significant issue for you, we can look into better support for GPU-accelerated simulators.

Hi Lars,

Thank you for your reply. Can you please help me with an example where you create a bayesflow simulator using prior and likelihood functions which have batched outputs? This is because the simulation is quite computationally heavy. I suspect I need the Lambda Simulator class, but it’s not clear to me how I should wrap the two functions and how to pass them to bayesflow. Thanks!

Hi Dan, for heavy simulators, the recommended approach would be to pre-simulate a large number of cases and train with the .fit_disk method of a BasicWorkflow.