SetTransformer Dimensions

Hello,

Thank you for the amazing library! It is a wonderful tool. I am trying to implement a Hierarchical model. After passing everything through the configurator, the stimulated data has shape (num_batches, num_groups, num_obs, input_dim); now when I SetTransformer via

summary_net = bf.summary_networks.HierarchicalNetwork([
    bf.networks.SetTransformer(input_dim=input_dim, summary_dim=32, name="local_summary"),
    bf.networks.SetTransformer(input_dim=input_dim, summary_dim=64, name="global_summary")
])

I get error

ConfigurationError: Could not carry out computations of generative_model ->configurator -> amortizer -> loss! Error trace:
 Exception encountered when calling layer 'multi_head_attention_30' (type MultiHeadAttention).

{{function_node __wrapped__Einsum_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Expected input 1 to have rank 5 but got: 4 [Op:Einsum] name: 

Call arguments received by layer 'multi_head_attention_30' (type MultiHeadAttention):
  • query=tf.Tensor(shape=(2, 32, 4), dtype=float32)
  • value=tf.Tensor(shape=(2, 75, 155, 3), dtype=float32)
  • key=tf.Tensor(shape=(2, 75, 155, 3), dtype=float32)
  • attention_mask=None
  • return_attention_scores=False
  • training=None
  • use_causal_mask=False

However, when I use DeepSet via


summary_net = bf.summary_networks.HierarchicalNetwork([
    bf.networks.DeepSet(summary_dim=32, name="local_summary"), 
    bf.networks.DeepSet(summary_dim=64, name="global_summary")
])

The model fitting begins. I had a couple of questions: (1) if my data is permutation invariant, is there some decision criteria to use SetTransformer over DeepSet? and (2) how do I fix the setTransformer issue? I’m not sure how to interpret the output.

Again, thanks for an amazing library!

Hi Wei, welcome to the forum!

The error occurs because the set transformer does not natively operate on 4D inputs (as is the case for hierarchical simulations).

One workaround is to specify the architecture like so:

summary_net = bf.summary_networks.HierarchicalNetwork([
    tf.keras.layers.TimeDistributed(
        bf.networks.SetTransformer(input_dim=input_dim, summary_dim=32, name="local_summary")),
    bf.networks.SetTransformer(input_dim=32, summary_dim=64, name="global_summary")
])

Which will broadcast the set transformer’s operations across the group dimension. However, I believe @Daniel encountered some issues with that and had a workaround.

The DeepSet is the simpler architecture and tends to underperform relative to the set transformer, especially when learning interrelationships between the data inputs points is important for posterior estimation. I suggest starting with the deep set and only switching to the set transformer if performance needs improvement.

1 Like