Hi all,
Is there any way to use offline training with non-batchable context?
Pre-simulated data seems to have a tensor-like shape (n_simulations, n_observations, n_dim), so the n_observation parameter should be already fixed in different simulations.
presimulated_data = generative_model(1000)
losses = trainer.train_offline(simulations_dict = presimulated_data,
epochs=num_epochs,
batch_size=32,
iterations_per_epoch=1000,
validation_sims=val_sims)
Cheers,
Amin
Hi Amin,
yes, you can vary the number of observations during offline training.
I suggest you simulate your training data with the maximal number of observation. Then you can pass a configurator to your trainer that randomly removes some observations in a given batch. This could look as follow:
def configure_input(forward_dict):
out_dict = {}
data = forward_dict["sim_data"]
max_obs = data.shape[1]
min_obs = ...
# random num obs
num_obs = np.random.integers(low=min_obs, high=max_obs + 1)
vec_num_obs = np.full((data.shape[0], 1), vec_num_obs)
# randomly select obs
idx = np.random.choice(
np.arange(max_obs + 1), num_obs, replace=False
)
# only keep randomly selected obs
out_dict["summary_conditions"] = data[:, idx, :].astype(np.float32)
# use non-batchable context as direct conditions
out_dict["direct_conditions"] = np.sqrt(vec_num_obs).astype(np.float32)
return out_dict
I hope this works out for you.
Best,
Lukas
1 Like
Hi Lukas,
Sounds great! Thank you for your response.
Cheers,
Amin