How to load a trained approximator when using a custom summary network?

Hi all,

Sorry for the many questions I have had lately, thanks for your time and effort, it is much appreciated!

I have been getting trouble when trying to load a trained approximator when I use a custom summary network. My setup is as follows:

I define my custom summary network MLP using a custom keras layer SensorMessagePassing in a separate script:

import keras
import tensorflow as tf
from keras import layers
import bayesflow as bf
import os
if "KERAS_BACKEND" not in os.environ:
    os.environ["KERAS_BACKEND"] = "jax"
import seaborn as sns
from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Tensor
from bayesflow.networks.summary_network import SummaryNetwork


class SensorMessagePassing(keras.layers.Layer):
    def __init__(self, connectivity_dict, output_dim, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.connectivity_dict = connectivity_dict  
        self.output_dim = output_dim
        self.activation = keras.activations.get(activation)

    def build(self, input_shape):
        feature_dim = input_shape[-1]
        
        self.W_msg = self.add_weight(
            shape=(feature_dim, self.output_dim),
            initializer="he_normal",
            trainable=True,
            name="W_msg"
        )
        self.W_self = self.add_weight(
            shape=(feature_dim, self.output_dim),
            initializer="he_normal",
            trainable=True,
            name="W_self"
        )
        self.b = self.add_weight(
            shape=(self.output_dim,),
            initializer="zeros",
            trainable=True,
            name="b"
        )

    def call(self, x):
    # x: (batch, sensors, features)
        x = tf.convert_to_tensor(x)
        
        # Reorder input for now
        node_indices = list(range(4, 26))  
        cable_indices = list(range(4))  
        x = tf.concat([
            tf.gather(x, node_indices, axis=1),
            tf.gather(x, cable_indices, axis=1)
        ], axis=1)
        
        batch_size = tf.shape(x)[0]
        num_sensors = x.shape[1] 
        feature_dim = x.shape[2]

        updated_sensors = []

        for i in range(num_sensors):  # 
            neighbors = self.connectivity_dict.get(i, []) 

            if neighbors:
                neighbor_feats = tf.gather(x, neighbors, axis=1)  # (batch, num_neighbors, features)
                neighbor_msgs = tf.einsum("bnf,fd->bnd", neighbor_feats, self.W_msg)  # (batch, num_neighbors, output_dim)
                agg_msg = tf.reduce_mean(neighbor_msgs, axis=1)  # (batch, output_dim)
            else:
                agg_msg = tf.zeros((batch_size, self.output_dim), dtype=x.dtype)

            self_msg = tf.matmul(x[:, i], self.W_self)  # (batch, output_dim)
            updated = self.activation(self_msg + agg_msg + self.b)  # (batch, output_dim)
            updated_sensors.append(updated)

        return tf.stack(updated_sensors, axis=1)  # (batch, sensors, output_dim)
    

@serializable(package="bayesflow.networks")
class MPL(SummaryNetwork):

    def __init__(
        self,
        summary_dim: int = 10,
        activation: str = "relu",
        kernel_initializer: str = "he_normal",
        padding: str = "same",
        neighbors_dict: dict = None,
        **kwargs,
    ):
        
        super().__init__(**kwargs)
        self.summary_dim = summary_dim
        self.activation = activation
        self.kernel_initializer = kernel_initializer
        self.padding = padding
        if neighbors_dict is None:
            raise ValueError("neighbors_dict must be provided to define sensor connectivity.")
        self.neighbors_dict = neighbors_dict

        self.model = keras.models.Sequential(name="SummaryNetwork")
        
        # GNN layer
        gnn_layer = SensorMessagePassing(
            connectivity_dict=neighbors_dict,
            output_dim=50,
            activation='relu',
            name="GNNLayer"
        )
        self.model.add(gnn_layer)

        # Flatten layer to convert the output to a 2D tensor
        flatten_layer = keras.layers.Flatten(name="FlattenLayer")
        self.model.add(flatten_layer)

        # Dense layer to compress into a fixed-size summary vector
        dense_layer = keras.layers.Dense(self.summary_dim, activation="relu", name="DenseSummary")
        self.model.add(dense_layer)

    def build(self, input_shape):
        super().build(input_shape)
        self.call(keras.ops.zeros(input_shape))

    def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
        """Forward pass of the CNN summary network."""   
        return self.model(x)
    
    def show_summary(self):
        """Show the summary of the model."""
        self.model.summary()

I then use this class in my base working notebook as follows:


inference_network = bf.networks.CouplingFlow(depth=12)

summary_network = MPL(summary_dim=10,
                      activation='relu',
                      kernel_initializer='he_normal',
                      padding='same',
                      neighbors_dict=neighbors_dict)
summary_network.build((1, 26, 100))

adapter = (
    bf.Adapter()
    .to_array()
    .convert_dtype("float64", "float32")

    .standardize("E_bar", axis=0)
    .standardize("E_cable", axis=0)
    .standardize("damages", axis=0)
    .standardize("pred_vector", axis=(0, 2))
    .concatenate(["E_bar", "E_cable", "damages"], into="inference_variables")
    .concatenate(["pred_vector"], into="summary_variables")
)

workflow = bf.BasicWorkflow(
        simulator=simulator,
        adapter=adapter,
        summary_network=summary_network,
        inference_network=inference_network
    )

filepath = os.path.join(model_folder_path, 'model.keras')

# Check if the model folder already exists
if not os.path.exists(model_folder_path):
    # If not, create a new model
    os.makedirs(model_folder_path)
    history = workflow.fit_offline(training_data, epochs=epochs, num_batches=num_batches, batch_size=batch_size, validation_data=validation_data, callbacks=[StepHistory()])
    f = bf.diagnostics.plots.loss(history)
    save_loc = os.path.join(model_folder_path, f'{n_training_samples}_sims_{epochs}_epochs_loss_plot.png')
    plt.savefig(save_loc)
    workflow.approximator.save(filepath=filepath)
    
else:
    # If the model folder already exists, load the existing model
    workflow = keras.saving.load_model(filepath)
    loss_plot_path = os.path.join(model_folder_path, f'{n_training_samples}_sims_{epochs}_epochs_loss_plot.png')
    plt.imshow(plt.imread(loss_plot_path))
    plt.axis('off')
    plt.show()


post_draws = workflow.sample(conditions=ground_truth_observations_dict, num_samples=1000)

When I train a model from scratch it works perfectly and yields nice posterior results. However, when I load the same model after training and saving, the posterior results for the same model become much worse. Furthermore, when I run the same script, but use a built in summary network from the BayesFlow package, the problem does not occur. I have also built a custom summary network inspired by the SequenceNetwork() from the stable-legacy version. When using this, I also have no issues.

import tensorflow as tf
import keras
from keras.saving import register_keras_serializable as serializable
from bayesflow.networks.summary_network import SummaryNetwork
from summary_networks.MultiConv1D import MultiConv1D

@serializable(package="bayesflow.networks")
class SequenceNetwork(SummaryNetwork):
    
    def __init__(
        self,
        summary_dim: int = 10,
        num_conv_layers: int = 2,
        lstm_units: int = 128,
        bidirectional: bool = False,
        conv_settings: list = None,
        **kwargs,
    ):
        
        super().__init__(**kwargs)
        
        if conv_settings is None:
            conv_settings = {
                "layer_args": {"activation": "relu", "filters": 32, "strides": 1, "padding": "causal"}, 
                "min_kernel_size": 1, "max_kernel_size": 3,}

        self.net = keras.Sequential(MultiConv1D(conv_settings) for _ in range(num_conv_layers))

        self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(lstm_units)) if bidirectional else keras.layers.LSTM(lstm_units, return_sequences=False)
        self.out_layer = keras.layers.Dense(summary_dim, activation="linear")
        self.summary_dim = summary_dim

    def call(self, x, **kwargs):
        x = self.net(x)
        x = self.lstm(x)
        x = self.out_layer(x)
        return x

Any advice on what to change in my summary network or how I save/load models?

Hi,
I have no time to check the complete code, but I would suggest two avenues for debugging:

  1. Check for all networks that the values passed to __init__ and build when the model is loaded (you can use print statements or a debugger for this) are what you expect. If they are not, please reach out and we can help you to figure out the details
  2. Check if/which weights of the model are correctly stored, by printing them before saving and after loading. This would help to locate where the issue arises. I would first look at the self.model of MLP, as I’m not sure if Keras’ automatic tracking works for models inside models (I have not checked yet, it might just work).

You can also take a look at this section of the developer docs, which contains additional links regarding serialization that might help you.