Offline training with different observations

Hi all,

Is there any way to use offline training with non-batchable context?

Pre-simulated data seems to have a tensor-like shape (n_simulations, n_observations, n_dim), so the n_observation parameter should be already fixed in different simulations.

presimulated_data = generative_model(1000)

losses = trainer.train_offline(simulations_dict = presimulated_data,
epochs=num_epochs,
batch_size=32,
iterations_per_epoch=1000,
validation_sims=val_sims)

Cheers,
Amin

Hi Amin,

yes, you can vary the number of observations during offline training.

I suggest you simulate your training data with the maximal number of observation. Then you can pass a configurator to your trainer that randomly removes some observations in a given batch. This could look as follow:

def configure_input(forward_dict):
    out_dict = {}
    data = forward_dict["sim_data"]
    max_obs = data.shape[1]
    min_obs = ...
    # random num obs
    num_obs = np.random.integers(low=min_obs, high=max_obs + 1)
    vec_num_obs = np.full((data.shape[0], 1), vec_num_obs)
    # randomly select obs
    idx = np.random.choice(
        np.arange(max_obs + 1), num_obs, replace=False
    )
    # only keep randomly selected obs
    out_dict["summary_conditions"] = data[:, idx, :].astype(np.float32)
    # use non-batchable context as direct conditions
    out_dict["direct_conditions"] = np.sqrt(vec_num_obs).astype(np.float32)
    return out_dict

I hope this works out for you.

Best,
Lukas

1 Like

Hi Lukas,

Sounds great! Thank you for your response.

Cheers,
Amin