Population Level Parameter Recovery Problem

Dear BayesFlow Team,

thanks for developing such an interesting tool, and I really enjoy working with it!

I was trying to train a network for a hierarchical two-choice drift diffusion model. It has 5 parameters in the individual level (two drift rates (b0 & b1), boundary separation (\alpha), response bias (\beta), and non-decision time (\tau)), and 10 parameters in the population level (the mean \mu and standard deviation \sigma of the five distributions where the individual parameters were drawn from).

The network settings are:

summary_net = bf.summary_networks.HierarchicalNetwork([
    #from the bottom level to the top level

local_inference_net = bf.networks.InvertibleNetwork(
    coupling_settings={"dense_args": dict(kernel_regularizer=None), "dropout": False},

hyper_inference_net = bf.networks.InvertibleNetwork(
    coupling_settings={"dense_args": dict(kernel_regularizer=None), "dropout": False},

The training was fine, and the individual level parameters are nicely recovered after this much of training:

history = trainer.train_online(epochs=300, iterations_per_epoch=300, batch_size=10)

However, the population level parameter recovery is problematic for one parameter (in this case \sigma_{b0}):

From my experience of retraining the model several times, the problematic recovery only appears in the estimation of population-level standard deviation, but not on a certain standard deviation. For example, It can happen on \sigma_{b0}, \sigma_{b1} or \sigma_\tau, and the plot here is only the output of one neural network model.

I am quite confused about this problem and I am seeking for help on this. Please let me know if anything is unclear here.

Thanks a lot!


Some additional information is:

I trained these networks on GPU, and when I tried to investigate the latent space z with:

f = bf.diagnostics.plot_latent_space_2d(z_samples[1])

the z samples (of the population level I guess) has 32 dimensions.
But from the tutorial I read that the dimensionality of the latent space is equal to the num_params in the inference network. So that was a bit confusing as well.


I would start by dropping the DeepSet in favor of a SetTransformer summary network.