First time here and aiming to use Bayesflow to speed up some existing MCMC based workflows for cognitive models.
I already have a series of models written using JAX-numpy and numpyro which can be used for likelihood and simulation. Ideally, I would simply adapt/wrap these models to use them with Bayesflow, but they would require explicit RNG management.
So I would ideally be able to define my functions as
Or something like that. It seems however, that despite using JAX as a backend, there isnât really a way to interface with RNG management in the ways required to use it as a modelling language?
Any ideas? perhaps I could create a wrapper that manages the key? My only concern is that I am not sure if that would affect Bayesflows assumptions about independence across batch dims etc.
Conventions regarding RNG across packages are unfortunately not very standardized, but we are trying to improve this.
BayesFlow is very permissive in terms of simulator randomness:
most straightforwardly, you can always generate simulations with jax.random.PRNGKey as you mention in your message, then train with fit_offline(data=sims, ...).
in case you want to use fit_online, you need to pass a callable simulator object taking a batch_size/shape, returning a dict. You are free to use any specific RNG inside of that callable.
Are these applicable? If not, please could you clarify what your workflow requires in terms of RNG management?
Further:
The next BayesFlow release will expose a seed argument to all network sample methods which can be either a fixed seed or a backend-agnostic keras seed generator. It is already on the dev branch.
Hey, letting you know that I was able to work around this issue by only jax.jit compiling functions AFTER splitting keys in pure python, then passing a vector of keys to a jit compiled function. Probably a bit slower than the absolute best case scenario but with minimal rewriting.
I also suspect treating a key as a parameter in the param dictionary may work, then .dropping it with an adapter so it isnât learned.