Model Misspecification Diffusion model for conflict tasks

Dear all,

we are currently trying to apply Bayesflow to a reaction time and error rate based cognitive model (Diffusion model for conflict tasks, Ulrich et al., 2015). Despite very good parameter recoveries (r > .85), we experience a few difficulties concerning the fitting routine and model misspecification.

Model misspecification
In line with the model misspecification workflow on the bayesflow.org site, we plotted the summary statistics derived from the summary network and calculated the maximum mean discrepancy. We used several prior distributions (beta, uniform and normal) within approximately the same range. The Observed data comprises three flanker data sets from different studies (A1, B1, C2). We split the observed data sets randomly into batches of 200 observations and applied the summary network to each of these batches. The number of simulated batches as well as the number of observations per batch match approximately the resulting number of observed batches. The results (see attached figure) are somewhat unsatisfactory, since the model seems to produce far more extreme summary statistics than the observed data. However, the data space looks quite good. At the moment we are concerned with the following questions:
Is the procedure of batching the observed data legit? How should the observed data structure correspond with the simulated data in terms of numbers of batches and observations per batch?
How exactly should we interpret these plots? Can we derive any direction in which we could adapt the priors, prior ranges or model specification that is more promising than just randomly trying different things?
In the “Detecting Model Misspecification” workflow, subsection “Hypothesis test” it says: “It is important that the number of simulated data sets to estimate the sampling distribution of the summary under the null hypothesis matches the number of observed data sets.” The code is as follows:
observed_data = trainer.configurator(trainer.generative_model(10))
MMD_sampling_distribution, MMD_observed = trainer.mmd_hypothesis_test(observed_data, num_reference_simulations=1000, num_null_samples=500, bootstrap=False)
_ = bf.diagnostics.plot_mmd_hypothesis_test(MMD_sampling_distribution, MMD_observed)
So does “match” mean that the ratio between number of observed batches and reference simulations has to be 1:100 or does this refer to a another number of simulations?

Training phase
We are currently running the online training in 100 epochs with 1000 iterations each. Each iteration takes about 5-10 seconds (with a batch size of 16 and 200 to 1000 observations in each batch), resulting in a total time of about 280 hours. Here is our code:


summary_net = bf.networks.SetTransformer(input_dim = 4,
                     summary_dim = 32,
                     name = "dmc_summary")

inference_net = bf.networks.InvertibleNetwork( num_params = len(prior.param_names),
                        num_coupling_layers = 12,
                        coupling_settings = {
                          "dropout_prob": 0.1,
                          'bins': 64},
                        name = "dmc_inference")

amortizer = bf.amortizers.AmortizedPosterior(inference_net,
                       summary_net,
                       name = 'dmc_amortizer’,
                       summary_loss_fun = "MMD")

trainer = bf.trainers.Trainer(
  generative_model = model,
  amortizer = amortizer,
  configurator = configurator,
  checkpoint_path = model_dir,
  memory = True)

h = trainer.train_online(epochs = 100, iterations_per_epoch = 1000,
             batch_size = 16,
             save_checkpoint = True)

Do you have any suggestions how to speed up the training phase? We are planning to try offline training, but aren’t sure if that is significantly more efficient or might deteriorate the results.

Any thoughts or suggestions are highly appreciated, thank you in advance!

Best,

Simon Schaefer

2 Likes

Hi Simon,

welcome to the BayesFlow Forums! Thanks for reaching out and for providing detailed context along with your question.

Scatterplots

Before I can thoroughly answer your questions about the scatterplots, I need a bit of general clarification:

  • What exactly is a “batch” and an “observation” for you?
  • The scatterplots seem odd: Since you have 32 learned summary statistics (summary_dim=32), the scatterplot for each cell should be 32-dimensional (i.e., a large pair plot). Can you share the exact code that you use to create one of the scatterplot cells?

MMD hypothesis test

The MMD values follow a sampling distribution under the null hypothesis that your model is well-specified. We first need to sample from that sampling distribution by simulating data from the generative model and computing the MMD to other data which is simulated from the model as well (-> by definition well-defined). Then, we plug in the observed data and see where in this sampling distribution the MMD of observed vs. model-simulated lies.

A detail to watch out for: If we compute the MMD of 10vs1000, it’s gonna be different from 50vs1000. So if we have 10 observed data sets, you should also compute the sampling distribution of 10vs1000. Speaking in Python syntax, 10 is observed_data.shape[0] and 1000 is num_reference_simulations.

If I understand you correctly, you want to use 200 observed data sets and have those stored as observed_data. That means you’re good to go and BayesFlow will automatically compute draws from the sampling distribution of 200vs1000.

Speed-Up

Some questions and thoughts:

  • What does the loss history look like? Ideally also with a validation loss so we could spot overfitting. See 1. Quickstart: Amortized Posterior Estimation — BayesFlow: Amortized Bayesian Inference
  • 100k update steps (100 epochs * 1000 iterations) seems like a lot. Have you tried automated early stopping? (relates to the loss history inspection)
  • What happens if you increase the learning rate? (again, need to see the effect on the loss history)
  • What is your bottleneck? If it’s the simulation process of the diffusion model, offline training will help. If it’s the NN training, online training is equally fast.
  • Have you tried using a DeepSet as summary net? Transformers have O(N^2) complexity which doesn’t scale so well
  • 12 coupling layers in the inference net look like a lot. maybe less will do the trick.

Hope that helps a bit, happy to discuss more details with more info.

Cheers,
Marvin

3 Likes

Hey Simon,

Welcome to the forum from me as well!

To add to Marvin’s comment, I would be more worried if the observed data summaries lie outside the scope of the simulator, rather than within (which is what you roughly observe for the beta priors). The more serios issue threatening the approximation fidelity of the networks is the former case of out-of-distribution (OOD). The latter case merely points out that the simulator’s outputs are overdispersed relative to what you observe in the experiment, but it doesn’t provide any specific hints on how to modify the simulator. That relies on some understanding of the generative model itself. Long story told short, I think that as long as the observations lie roughly within the scope of the simulator, you need to worry less about the validity of the estimates and more about the fact that you may be simulating too many implausible data sets.

Regarding online vs. offline learning:

  • When in a development phase, I would always go for offline training with a smallish pre-simulated set (e.g., 10-20k sims), because not simulating during training would speed-up inference a lot.
  • You don’t need dropout for online learning and the setting num_bins only has an effect if you are using coupling_design='spline'. I wouldn’t change the default num_bins, unless I had strong reasons to.
  • I would go with the following invertible net and stick to the SetTransformer as a summary net:
inference_net = bf.networks.InvertibleNetwork( 
    num_params=len(prior.param_names),
    num_coupling_layers=6,
    coupling_design="spline",
    name = "dmc_inference"
)

and proceed with offline training for as long as the networks don’t overfit (200-300 epochs?). You can determine if the networks overfit by providing a small subset of simulations to the train_offline method via the validation_sims keyword argument. The largest efficiency gains would come from doing offline learning and it will allow you to iterate quickly over different neural net configurations, if needed.

Hope that helps!

Cheers,
Stefan

Edit: I would avoid using bounded priors (e.g., uniforms / beta) for parameters which have no natural bounds.

3 Likes

Hey everyone,

Thank you for your response, you already helped me a lot!

Scatter plots
The scatter plots in my first post depict only the first two dimensions of the summary statistics. I agree that this is not really helpful if the actual summary space includes 32 dimensions. I updated the plot (see attachment). If I understood your paper (Schmitt et al., 2022) correctly, there is a tradeoff between parameter recovery and model misspecification detection that depends on the number of summary dimensions. Is there any rule of thumb or relation to the number of parameters (we estimate five parameters) that can help to find a suitable number of dimensions that is not too sensitive to MMS? I read about using at least as many dimensions as parameters estimated, but I wonder if five dimensions are sufficient or If there is any sweet spot between 5 and 32 dimensions.

The loss history (attached) shows that a loss of 0 is reached after about 40k iterations, so I think early stopping is a good idea in this case. I assume that our loss history suggests that our model is likely to be overfitted. Could this also contribute to a significant model misspecification?

Offline training

Thank you for your advice regarding the training phase. I adapted my script and tried the trainer.train_offline()-function:

simulations_dict = model(100)

h = trainer.train_offline(simulations_dict = simulations_dict,
                          epochs = 10,
                          batch_size = 10,
                          validation_sims=200,
                          save_checkpoint=True,
                          early_stopping=True)

but received this error message:

Traceback (most recent call last):
  File "/Applications/DataSpell.app/Contents/plugins/python-ce/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 1, in <module>
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/bayesflow/trainers.py", line 551, in train_offline
    data_set = SimulationDataset(simulations_dict, batch_size)
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/bayesflow/helper_classes.py", line 67, in __init__
    self.data = tf.data.Dataset.from_tensor_slices(tuple(slices)).shuffle(buffer_size).batch(batch_size)
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 825, in from_tensor_slices
    return from_tensor_slices_op._from_tensor_slices(tensors, name)
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/data/ops/from_tensor_slices_op.py", line 25, in _from_tensor_slices
    return _TensorSliceDataset(tensors, name=name)
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/data/ops/from_tensor_slices_op.py", line 38, in __init__
    self._structure = nest.map_structure(
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/data/util/nest.py", line 122, in map_structure
    return nest_util.map_structure(
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/util/nest_util.py", line 1068, in map_structure
    return _tf_data_map_structure(func, *structure, **kwargs)
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/util/nest_util.py", line 1135, in _tf_data_map_structure
    return _tf_data_pack_sequence_as(structure[0], [func(*x) for x in entries])
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/util/nest_util.py", line 1135, in <listcomp>
    return _tf_data_pack_sequence_as(structure[0], [func(*x) for x in entries])
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/data/ops/from_tensor_slices_op.py", line 39, in <lambda>
    lambda component_spec: component_spec._unbatch(), batched_spec)  # pylint: disable=protected-access
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.10/site-packages/tensorflow/python/framework/tensor.py", line 1199, in _unbatch
    raise ValueError("Unbatching a tensor is only supported for rank >= 1")
ValueError: Unbatching a tensor is only supported for rank >= 1

Our simulation_data has a similar shape as the one in this bayesflow-script, capturing the numbers of observations in the non batchable context and an np.array of congruency conditions in the batchable context. I tested the functions mentioned in the traceback and figured out that this function:

tf.data.Dataset.from_tensor_slices(tuple(slices)).shuffle(buffer_size).batch(batch_size)

cannot handle our non batchable context (<class 'int’>) because it expects a shape corresponding to the other elements of slices. The problem seems to be caused by the non batchable context, since changing the non batchable context from an integer to an array with the shape (number of simulations, 1) works fine. If I do so, I have to change also the configurator, so the ‘direct conditions’ output of the configurator matches the shape of the non batchable context. Unfortunately, this leads to different ranks in the tensors and an error message when executing the bf.trainers.Trainer():


Traceback (most recent call last):
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.11/site-packages/bayesflow/trainers.py", line 1314, in _check_consistency
    _ = self.amortizer.compute_loss(self.configurator(self.generative_model(_n_sim)))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.11/site-packages/bayesflow/amortizers.py", line 209, in compute_loss
    net_out, sum_out = self(input_dict, return_summary=True, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.11/site-packages/bayesflow/amortizers.py", line 174, in call
    summary_out, full_cond = self._compute_summary_condition(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.11/site-packages/bayesflow/amortizers.py", line 410, in _compute_summary_condition
    full_cond = tf.concat([sum_condition, direct_conditions], axis=-1)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling layer 'dmc_amortizer_model21beta2wide_test_transf_offline' (type AmortizedPosterior).
{{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Dimension 0 in both shapes must be equal: shape[0] = [2,32] vs. shape[1] = [699,1] [Op:ConcatV2] name: concat
Call arguments received by layer 'dmc_amortizer_model21beta2wide_test_transf_offline' (type AmortizedPosterior):
  • input_dict={'summary_conditions': 'tf.Tensor(shape=(2, 699, 4), dtype=float32)', 'direct_conditions': 'tf.Tensor(shape=(699, 1), dtype=float32)', 'parameters': 'tf.Tensor(shape=(2, 5), dtype=float32)'}
  • return_summary=True
  • kwargs={'training': 'None'}
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/Applications/DataSpell.app/Contents/plugins/python-ce/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
           ^^^^^^
  File "<input>", line 1, in <module>
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.11/site-packages/bayesflow/trainers.py", line 220, in __init__
    self._check_consistency()
  File "/Users/simonschaefer/anaconda3/envs/bf/lib/python3.11/site-packages/bayesflow/trainers.py", line 1317, in _check_consistency
    raise ConfigurationError(
bayesflow.exceptions.ConfigurationError: Could not carry out computations of generative_model ->configurator -> amortizer -> loss! Error trace:
 Exception encountered when calling layer 'dmc_amortizer_model21beta2wide_test_transf_offline' (type AmortizedPosterior).
{{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Dimension 0 in both shapes must be equal: shape[0] = [2,32] vs. shape[1] = [699,1] [Op:ConcatV2] name: concat
Call arguments received by layer 'dmc_amortizer_model21beta2wide_test_transf_offline' (type AmortizedPosterior):
  • input_dict={'summary_conditions': 'tf.Tensor(shape=(2, 699, 4), dtype=float32)', 'direct_conditions': 'tf.Tensor(shape=(699, 1), dtype=float32)', 'parameters': 'tf.Tensor(shape=(2, 5), dtype=float32)'}
  • return_summary=True
  • kwargs={'training': 'None'}

Is there an easy way to adapt the simulation data (the non batchable context in particular) so that it works with train_offline()?

Thank you in advance, I appreciate any thoughts and remarks :slight_smile:

1 Like

Hey Simon, is it possible to save (e.g., pickle) and share a small number of simulations from your simulator so I can investigate the problem more thoroughly?

Regarding the number of summary statistics, we typically use a heuristic:

number_of_summary_statistics ~= 2 * number_of_parameters

With fewer summary statistics, there is less chance that the summary vector picks up irrelevant aspects of the data which may be impacted by contaminants. You can also apply any compression method (e.g., PCA) to the summary space and check if fewer statistics are necessary to explain most of the variance (see, e.g., Figure 10 in [2112.08866] Detecting Model Misspecification in Amortized Bayesian Inference with Neural Networks).

The loss history does indeed show convergence. When doing offline learning, it is also a good idea to provide a small validation set to the function via the validation_sims keyword argument for tracking the loss on simulations not used for optimization.

1 Like

Thanks for the additional details! I agree with Stefan‘s advice re validation data to judge overfitting.

The loss history (attached) shows that a loss of 0 is reached after about 40k iterations, so I think early stopping is a good idea in this case.

The maximum likelihood loss in NPE is based on minimizing the negative log-posterior, so negative values are absolutely fine.

1 Like