You mention “validation metrics” and “held-out validation set”. In the online learning setting, why would one want to use a hold-out validation set? I assume this set would also be sampled from the same data-generating process that is used to drive the “training loop”, correct? Couldn’t I just monitor the training loss and stop “when it isn’t decreasing by a meaningful amount anymore”?
Correct! As you said, the training loss curve is a valid criterion for judging convergence in online training.
However, there are two reasons why a held-out validation set is still useful.
- First, networks behave differently in training mode than in validation/inference mode.
In BayesFlow (and Keras more generally), call methods of networks can have a keyword argument, training, which defaults to False. In the training loop (model.fit()), this flag is automatically set correctly to process training and validation samples differently.
- For example, the default
CouplingFlow employs mild dropout regularization during training to improve robustness (See Keras docs on dropout here: Dropout layer ). This means that during training, each weight has a random chance of being temporarily disabled (set to zero) at a certain rate. The default rate for CouplingFlow is 0.05, other inference and summary networks have similar defaults.
However, after training, we do not want to lose any information, so “dropout” is disabled.
- L1/L2 regularization only influence the training objective, not how the model behaves during inference.
The penalty term (\lambda_1 |w|_1 + \lambda_2 |w|_2^2) is added to the training loss; your validation loss/metric is computed without that penalty. So if you only watch training loss, you can’t tell whether you’re over- or under-regularizing.
- Second, when you have added a custom metric to some layer or network, it will be computed on the validation set only.
You only mention the normalising flow (network) when talking about the loss (log-likelihood). What about the summary network, doesn’t it also contribute to the loss-function via additional terms, or is it “part of” the log-likelihood loss of the normalising flow?
The summary net by default does not have its own separate loss. However, additional summary losses can be used for specific goals, like incentivizing a Gaussian distribution of the summaries during training for out-of-distribution/misspecification detection [1].
Finally I’d be interested to understand how to best handle the situation (in the linear regression example) when the signal to noise ratio is low and the number of simulation samples is “large“ …
I will respond later to this concrete case. Not having tried it myself, I’d guess that the gradient computation dominates the other reasons for runtime increase.
EDIT: The part of the computation in the summary network needs to be done for all (100 or 1000) set elements. Playing with the hyperparameters in the summary network should allow you to speed it up considerably. Replacing SetTransformer → DeepSet on a consumer laptop speeds it up 10x:
SetTransformer (default hyperparameters) takes 20 seconds per batch;
DeepSet (default hyperparameters) takes 2 seconds per batch for the N=1000 case.
Further, when you reduce some size related hyperparameter, for example by default you are using mlp_widths_equivariant=(64,64), which is overkill for this simple problem. Going with
DeepSet(summary_dim=10, mlp_widths_equivariant=(16,)) instead improves training speed to 1 second per batch, etc.
Let me know if that clears up everything and feel free to open a new question if you are progressing through the tutorials or moving to some application you are considering!
[1] Schmitt, M., Bürkner, P.-C., Köthe, U., & Radev, S. T. (2022). Detecting Model Misspecification in Amortized Bayesian Inference with Neural Networks (No. arXiv:2112.08866). arXiv. [2112.08866] Detecting Model Misspecification in Amortized Bayesian Inference with Neural Networks