Use of JAX-based generative models - explicit RNG management

Hi all,

First time here and aiming to use Bayesflow to speed up some existing MCMC based workflows for cognitive models.

I already have a series of models written using JAX-numpy and numpyro which can be used for likelihood and simulation. Ideally, I would simply adapt/wrap these models to use them with Bayesflow, but they would require explicit RNG management.

So I would ideally be able to define my functions as

key = jax.random.PRNGKey(3)
def model_for_simulation(rng_key = key, simulate = True):
key, key2 = jax.random.split(key)
sims = dict(a = dist.Normal(0, 1).sample(key2))

return sims, key

Or something like that. It seems however, that despite using JAX as a backend, there isn’t really a way to interface with RNG management in the ways required to use it as a modelling language?

Any ideas? perhaps I could create a wrapper that manages the key? My only concern is that I am not sure if that would affect Bayesflows assumptions about independence across batch dims etc.

Cheers :slight_smile:

Hi and welcome!

Conventions regarding RNG across packages are unfortunately not very standardized, but we are trying to improve this.

BayesFlow is very permissive in terms of simulator randomness:

  • most straightforwardly, you can always generate simulations with jax.random.PRNGKey as you mention in your message, then train with fit_offline(data=sims, ...).
  • in case you want to use fit_online, you need to pass a callable simulator object taking a batch_size/shape, returning a dict. You are free to use any specific RNG inside of that callable.

Are these applicable? If not, please could you clarify what your workflow requires in terms of RNG management?

Further:

Interested how your efforts speeding up MCMC for your models are going - feel free to reach out anytime!