What does `epoch` mean when using an online data generating process?

Disclaimer: I am just getting started with learning BayesFlow and I am following the Linear regression example here.

The example uses BasicWorkflow.fit_online with epochs=50, batch_size=64, and num_batches_per_epoch=200. Previously I have only encountered “epoch” when fitting a model on a fixed dataset, typically sampling batches without replacement until the whole dataset is “exhausted” and a single epoch is completed.

In that sense epoch does not seem to be a meaningful concept when fitting using an online data-generating process. Is there any specific meaning for epoch for fit_online or is it just there to match the signature for methods fit_offline and fit_disk?

In other words could I achieve an identical model fit by using epochs=1, num_batches_per_epoch=50 * 200 when I am using an online data generating process ?

Also the loss for that example becomes smaller than zero for later epochs. So the loss is not the KL-divergence but some related function without the non-negativity guarantee?

As I am generating new samples on each batch, I do not have to worry about overfitting - at least that’s my current understanding of “online training” . How many epochs would it take for that example above until the loss does not significantly decay anymore, i.e. “the model has converged”. What are practical approaches to determine when to stop training a BayeFlow model?

Thanks a lot for the great documentation and examples. Looking forward to your reply.

Hi and welcome Benjamin!

n that sense epoch does not seem to be a meaningful concept when fitting using an online data-generating process. Is there any specific meaning for epoch for fit_online or is it just there to match the signature for methods fit_offline and fit_disk?

Great question, and you’re totally right. The concept of “epoch” in online learning isn’t strictly required, as opposed to offline training where it counts when the dataset has been exhausted once. However, epochs can still act as a natural unit of measurement in online training, for example when it comes to computing validation metrics (e.g., negative-log-likelihood on a held-out validation set).

Also the loss for that example becomes smaller than zero for later epochs. So the loss is not the KL-divergence but some related function without the non-negativity guarantee?

The loss for a normalizing flow is the negative log-likelihood which can be negative. That loss function is only derived from the KL divergence, but in this derivation some terms can be dropped for optimization objectives because they don’t depend on the neural network weights. So no worries, it’s totally normal that the normalizing flow loss can be negative :+1:

What are practical approaches to determine when to stop training a BayeFlow model?

You can use early stopping criteria (manual or with a keras callback – e.g., after each epoch haha), for example one that measures when the validation loss isn’t decreasing by a meaningful amount anymore. @KLDivergence iirc you implemented something like this a while ago, maybe you could chime in here?

Cheers!

Marvin

Thanks @marvinschmitt for your quick reply. Your explanations are very helpful.

You mention “validation metrics” and “held-out validation set”. In the online learning setting, why would one want to use a hold-out validation set? I assume this set would also be sampled from the same data-generating process that is used to drive the “training loop”, correct? Couldn’t I just monitor the training loss and stop “when it isn’t decreasing by a meaningful amount anymore”?

You only mention the normalising flow (network) when talking about the loss (log-likelihood). What about the summary network, doesn’t it also contribute to the loss-function via additional terms, or is it “part of” the log-likelihood loss of the normalising flow?

Finally I’d be interested to understand how to best handle the situation (in the linear regression example) when the signal to noise ratio is low and the number of simulation samples is “large’“, e.g. my intercept and slope are on the order of 0.1 but sigma is on the order of one, scale=0.1 and shape=10 in the gamma prior, and N in [100, 1000] (or more). If I use a batch-size of 64 then I observe quite a large memory footprint when running the example (also individual batches take a “long” time). I can always reduce the batch-size, e.g. from 64 to 8, to counteract the effect of a larger “N”, but this has obvious problems, if N becomes even larger then I cannot counteract that effect and also my (stochastic) gradients will become more noisy the smaller the batch-size I am using. Intuitively I would expect a larger N to somewhat counteract the noise in the gradients, but there is still variation induced by the prior, my working “model” of it is that samples from a single batch “average out the variation induced by the prior”, but I am not sure if that makes sense at all?

I am not sure if the increase in rumtime is due to the simulation overhead, i.e. repeatedly calling the numpy sampling functions for different parameters sampled from the prior, due to any overhead in the adapter that transforms the simulation data, or simply because computing gradient gets more expensive since every “sample” contains more values to aggregate.

As mentioned, I am just getting started and I am lacking quite a lot of context and literature exposure, so I apologies if these questions are already clarified elsewhere. If so, I’d be glad to look at relevant references/literature.

Thanks again for your help @marvinschmitt , I appreciate your taking time to answer questions here.

You mention “validation metrics” and “held-out validation set”. In the online learning setting, why would one want to use a hold-out validation set? I assume this set would also be sampled from the same data-generating process that is used to drive the “training loop”, correct? Couldn’t I just monitor the training loss and stop “when it isn’t decreasing by a meaningful amount anymore”?

Correct! As you said, the training loss curve is a valid criterion for judging convergence in online training.
However, there are two reasons why a held-out validation set is still useful.

  1. First, networks behave differently in training mode than in validation/inference mode.
    In BayesFlow (and Keras more generally), call methods of networks can have a keyword argument, training, which defaults to False. In the training loop (model.fit()), this flag is automatically set correctly to process training and validation samples differently.
  • For example, the default CouplingFlow employs mild dropout regularization during training to improve robustness (See Keras docs on dropout here: Dropout layer ). This means that during training, each weight has a random chance of being temporarily disabled (set to zero) at a certain rate. The default rate for CouplingFlow is 0.05, other inference and summary networks have similar defaults.
    However, after training, we do not want to lose any information, so “dropout” is disabled.
  • L1/L2 regularization only influence the training objective, not how the model behaves during inference.
    The penalty term (\lambda_1 |w|_1 + \lambda_2 |w|_2^2) is added to the training loss; your validation loss/metric is computed without that penalty. So if you only watch training loss, you can’t tell whether you’re over- or under-regularizing.
  1. Second, when you have added a custom metric to some layer or network, it will be computed on the validation set only.

You only mention the normalising flow (network) when talking about the loss (log-likelihood). What about the summary network, doesn’t it also contribute to the loss-function via additional terms, or is it “part of” the log-likelihood loss of the normalising flow?

The summary net by default does not have its own separate loss. However, additional summary losses can be used for specific goals, like incentivizing a Gaussian distribution of the summaries during training for out-of-distribution/misspecification detection [1].

Finally I’d be interested to understand how to best handle the situation (in the linear regression example) when the signal to noise ratio is low and the number of simulation samples is “large“ …

I will respond later to this concrete case. Not having tried it myself, I’d guess that the gradient computation dominates the other reasons for runtime increase.


EDIT: The part of the computation in the summary network needs to be done for all (100 or 1000) set elements. Playing with the hyperparameters in the summary network should allow you to speed it up considerably. Replacing SetTransformerDeepSet on a consumer laptop speeds it up 10x:

  • SetTransformer (default hyperparameters) takes 20 seconds per batch;
  • DeepSet (default hyperparameters) takes 2 seconds per batch for the N=1000 case.

Further, when you reduce some size related hyperparameter, for example by default you are using mlp_widths_equivariant=(64,64), which is overkill for this simple problem. Going with
DeepSet(summary_dim=10, mlp_widths_equivariant=(16,)) instead improves training speed to 1 second per batch, etc.

Let me know if that clears up everything and feel free to open a new question if you are progressing through the tutorials or moving to some application you are considering!


[1] Schmitt, M., Bürkner, P.-C., Köthe, U., & Radev, S. T. (2022). Detecting Model Misspecification in Amortized Bayesian Inference with Neural Networks (No. arXiv:2112.08866). arXiv. [2112.08866] Detecting Model Misspecification in Amortized Bayesian Inference with Neural Networks

2 Likes