Hi all,
First time here and aiming to use Bayesflow to speed up some existing MCMC based workflows for cognitive models.
I already have a series of models written using JAX-numpy and numpyro which can be used for likelihood and simulation. Ideally, I would simply adapt/wrap these models to use them with Bayesflow, but they would require explicit RNG management.
So I would ideally be able to define my functions as
key = jax.random.PRNGKey(3)
def model_for_simulation(rng_key = key, simulate = True):
key, key2 = jax.random.split(key)
sims = dict(a = dist.Normal(0, 1).sample(key2))
return sims, key
Or something like that. It seems however, that despite using JAX as a backend, there isn’t really a way to interface with RNG management in the ways required to use it as a modelling language?
Any ideas? perhaps I could create a wrapper that manages the key? My only concern is that I am not sure if that would affect Bayesflows assumptions about independence across batch dims etc.
Cheers ![]()