Hi all,
Sorry for the many questions I have had lately, thanks for your time and effort, it is much appreciated!
I have been getting trouble when trying to load a trained approximator when I use a custom summary network. My setup is as follows:
I define my custom summary network MLP using a custom keras layer SensorMessagePassing in a separate script:
import keras
import tensorflow as tf
from keras import layers
import bayesflow as bf
import os
if "KERAS_BACKEND" not in os.environ:
os.environ["KERAS_BACKEND"] = "jax"
import seaborn as sns
from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Tensor
from bayesflow.networks.summary_network import SummaryNetwork
class SensorMessagePassing(keras.layers.Layer):
def __init__(self, connectivity_dict, output_dim, activation="relu", **kwargs):
super().__init__(**kwargs)
self.connectivity_dict = connectivity_dict
self.output_dim = output_dim
self.activation = keras.activations.get(activation)
def build(self, input_shape):
feature_dim = input_shape[-1]
self.W_msg = self.add_weight(
shape=(feature_dim, self.output_dim),
initializer="he_normal",
trainable=True,
name="W_msg"
)
self.W_self = self.add_weight(
shape=(feature_dim, self.output_dim),
initializer="he_normal",
trainable=True,
name="W_self"
)
self.b = self.add_weight(
shape=(self.output_dim,),
initializer="zeros",
trainable=True,
name="b"
)
def call(self, x):
# x: (batch, sensors, features)
x = tf.convert_to_tensor(x)
# Reorder input for now
node_indices = list(range(4, 26))
cable_indices = list(range(4))
x = tf.concat([
tf.gather(x, node_indices, axis=1),
tf.gather(x, cable_indices, axis=1)
], axis=1)
batch_size = tf.shape(x)[0]
num_sensors = x.shape[1]
feature_dim = x.shape[2]
updated_sensors = []
for i in range(num_sensors): #
neighbors = self.connectivity_dict.get(i, [])
if neighbors:
neighbor_feats = tf.gather(x, neighbors, axis=1) # (batch, num_neighbors, features)
neighbor_msgs = tf.einsum("bnf,fd->bnd", neighbor_feats, self.W_msg) # (batch, num_neighbors, output_dim)
agg_msg = tf.reduce_mean(neighbor_msgs, axis=1) # (batch, output_dim)
else:
agg_msg = tf.zeros((batch_size, self.output_dim), dtype=x.dtype)
self_msg = tf.matmul(x[:, i], self.W_self) # (batch, output_dim)
updated = self.activation(self_msg + agg_msg + self.b) # (batch, output_dim)
updated_sensors.append(updated)
return tf.stack(updated_sensors, axis=1) # (batch, sensors, output_dim)
@serializable(package="bayesflow.networks")
class MPL(SummaryNetwork):
def __init__(
self,
summary_dim: int = 10,
activation: str = "relu",
kernel_initializer: str = "he_normal",
padding: str = "same",
neighbors_dict: dict = None,
**kwargs,
):
super().__init__(**kwargs)
self.summary_dim = summary_dim
self.activation = activation
self.kernel_initializer = kernel_initializer
self.padding = padding
if neighbors_dict is None:
raise ValueError("neighbors_dict must be provided to define sensor connectivity.")
self.neighbors_dict = neighbors_dict
self.model = keras.models.Sequential(name="SummaryNetwork")
# GNN layer
gnn_layer = SensorMessagePassing(
connectivity_dict=neighbors_dict,
output_dim=50,
activation='relu',
name="GNNLayer"
)
self.model.add(gnn_layer)
# Flatten layer to convert the output to a 2D tensor
flatten_layer = keras.layers.Flatten(name="FlattenLayer")
self.model.add(flatten_layer)
# Dense layer to compress into a fixed-size summary vector
dense_layer = keras.layers.Dense(self.summary_dim, activation="relu", name="DenseSummary")
self.model.add(dense_layer)
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
"""Forward pass of the CNN summary network."""
return self.model(x)
def show_summary(self):
"""Show the summary of the model."""
self.model.summary()
I then use this class in my base working notebook as follows:
inference_network = bf.networks.CouplingFlow(depth=12)
summary_network = MPL(summary_dim=10,
activation='relu',
kernel_initializer='he_normal',
padding='same',
neighbors_dict=neighbors_dict)
summary_network.build((1, 26, 100))
adapter = (
bf.Adapter()
.to_array()
.convert_dtype("float64", "float32")
.standardize("E_bar", axis=0)
.standardize("E_cable", axis=0)
.standardize("damages", axis=0)
.standardize("pred_vector", axis=(0, 2))
.concatenate(["E_bar", "E_cable", "damages"], into="inference_variables")
.concatenate(["pred_vector"], into="summary_variables")
)
workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
summary_network=summary_network,
inference_network=inference_network
)
filepath = os.path.join(model_folder_path, 'model.keras')
# Check if the model folder already exists
if not os.path.exists(model_folder_path):
# If not, create a new model
os.makedirs(model_folder_path)
history = workflow.fit_offline(training_data, epochs=epochs, num_batches=num_batches, batch_size=batch_size, validation_data=validation_data, callbacks=[StepHistory()])
f = bf.diagnostics.plots.loss(history)
save_loc = os.path.join(model_folder_path, f'{n_training_samples}_sims_{epochs}_epochs_loss_plot.png')
plt.savefig(save_loc)
workflow.approximator.save(filepath=filepath)
else:
# If the model folder already exists, load the existing model
workflow = keras.saving.load_model(filepath)
loss_plot_path = os.path.join(model_folder_path, f'{n_training_samples}_sims_{epochs}_epochs_loss_plot.png')
plt.imshow(plt.imread(loss_plot_path))
plt.axis('off')
plt.show()
post_draws = workflow.sample(conditions=ground_truth_observations_dict, num_samples=1000)
When I train a model from scratch it works perfectly and yields nice posterior results. However, when I load the same model after training and saving, the posterior results for the same model become much worse. Furthermore, when I run the same script, but use a built in summary network from the BayesFlow package, the problem does not occur. I have also built a custom summary network inspired by the SequenceNetwork() from the stable-legacy version. When using this, I also have no issues.
import tensorflow as tf
import keras
from keras.saving import register_keras_serializable as serializable
from bayesflow.networks.summary_network import SummaryNetwork
from summary_networks.MultiConv1D import MultiConv1D
@serializable(package="bayesflow.networks")
class SequenceNetwork(SummaryNetwork):
def __init__(
self,
summary_dim: int = 10,
num_conv_layers: int = 2,
lstm_units: int = 128,
bidirectional: bool = False,
conv_settings: list = None,
**kwargs,
):
super().__init__(**kwargs)
if conv_settings is None:
conv_settings = {
"layer_args": {"activation": "relu", "filters": 32, "strides": 1, "padding": "causal"},
"min_kernel_size": 1, "max_kernel_size": 3,}
self.net = keras.Sequential(MultiConv1D(conv_settings) for _ in range(num_conv_layers))
self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(lstm_units)) if bidirectional else keras.layers.LSTM(lstm_units, return_sequences=False)
self.out_layer = keras.layers.Dense(summary_dim, activation="linear")
self.summary_dim = summary_dim
def call(self, x, **kwargs):
x = self.net(x)
x = self.lstm(x)
x = self.out_layer(x)
return x
Any advice on what to change in my summary network or how I save/load models?