Some questions about the split network

  1. Can split network use hierarchical network? This sounds weird, but it’s a very real use case of mine. I have some data of the shape (1, subjects, time, observations). The observations come from two different sources, thus they require two different types of networks; but they both have similar hierarchies. So would it be possible to split data into (1, subjects, time, observations_1) and (1, subjects, time, observations_2)?

  2. Can split network use different network types for different split? From the documentation, this looks like a no.

Hi Chris, it may be easier to achieve what you want using a custom summary network which combines existing components. What are the two different network types that you would use? I can provide a sample code to build upon.

Thanks. Let’s say the data size is (batch, subject, L+D) where L is orderless and D is sequential. I read from another post that any arbitrary summary network is supported, and I think it might indeed be easier to just write a custom model.

Here is a self-contained example of such a network:

import tensorflow as tf
import bayesflow as bf

class DualSplitNetwork(tf.keras.Model):
    
    def __init__(self, network1, network2, split_idx, **kwargs):
        """
        Initializes a composite summary network which will split input
        data batches across the last dimension according to split_idx.
        """
        super().__init__(**kwargs)
        
        self.network1 = network1
        self.network2 = network2
        self.split_idx = split_idx
        
    def call(self, x, **kwargs):
        out1 = self.network1(x[..., :self.split_idx], **kwargs)
        out2 = self.network2(x[..., self.split_idx:], **kwargs)
        out = tf.concat((out1, out2), axis=-1)
        return out

We can now use it as follows:

# Test as-if input
x = tf.random.normal((32, 111, 5))

# Create custom network, will split into portions 0:1 / 2:5
summary_net = DualSplitNetwork(
    network1=bf.networks.SetTransformer(input_dim=2, summary_dim=4),
    network2=bf.networks.TimeSeriesTransformer(input_dim=3, summary_dim=6),
    split_idx=2
)

# Forward pass
summary_out = summary_net(x)
assert summary_out.shape == (32, 10)

The possibilities for playing around with such a network are endless. Hope it can get you started.

1 Like

Just clarifying, should the input dim be 1 in this case since 2 or 3 are the number of sequences here?

The input dims for the individual transformers are 2 and 3, respectively, because after the split, the input shapes for my random input tensor are:

(batch_size, num_observations, 2) # sequence 1, goes into network1
(batch_size, num_observations, 3) # sequence 2, goes into network2

This is because the transformers need to know the size of the last axis beforehand.

I see. That makes sense. So if the input is instead of shape (batch, subject, L+D, num_feats), and I’d like to make the model hierarchical, does the following look about right?

import tensorflow as tf
import bayesflow as bf
from bayesflow.summary_networks import HierarchicalNetwork

class DualSplitNetwork(tf.keras.Model):
    
    def __init__(self, network1, network2, split_idx, **kwargs):
        """
        Initializes a composite summary network which will split input
        data batches across the last dimension according to split_idx.
        """
        super().__init__(**kwargs)
        
        self.network1 = network1
        self.network2 = network2
        self.split_idx = split_idx
        
    def call(self, x, **kwargs):
        out1 = self.network1(x[..., :self.split_idx, :], **kwargs)
        out2 = self.network2(x[..., self.split_idx:, :], **kwargs)
        out = tf.concat((out1, out2), axis=-1)
        return out

# Test as-if input
x = tf.random.normal((32, 111, 5, 8))

# Create custom network, will split into portions 0:1 / 2:5
summary_net = bf.networks.HierarchicalNetwork([
    DualSplitNetwork(
        network1=bf.networks.SetTransformer(input_dim=num_feats, summary_dim=4),
        network2=bf.networks.TimeSeriesTransformer(input_dim=num_feats, summary_dim=6),
        split_idx=2
    ),
    bf.networks.SetTransformer(input_dim=10, summary_dim=15)
])

# Forward pass
summary_out = summary_net(x)
assert summary_out.shape == (32, 15)

Yes, that looks right, but are you sure you want to split along the third axis and not the last? This code does something slightly different than what you described in the initial question regarding data of shape (1, subjects, time, observations).

The transformer subnetworks for the DualSplitNetwork may need to be wrapped within tf.keras.layers.TimeDistributed in order to work with 4D inputs.

Sorry, I think my original description was a bit misleading. The third dimension is really just overloading two different modalities of data together. The first L are sequential (subject responses across time steps), and the last D are questionnaire answers that shouldn’t have an order. They don’t actually have the same num_feats, but for convenience I will just pad the questionnaire answers.

I also just realized that I can pack all the data into a flat vector of size (batch, everything) and reshape it in the call method of the model. This is more space efficient and no space is wasted. So I can do that too.

Yes, slicing the data and re-shaping it into the summary network would also work. Care need to be taken with this approach is the data dimensions vary.

@marvinschmitt This really falls into the data fusion category.