Question about ordered model output and context variables

Hi all,

Thanks in advance for any help. This may be long. My set up is quite similar to the one defined here Fitting a value-based decision model with trial-level context variables .

The model I am using predicts fixation paths for each trial. The model has three free parameters, yet the output also depends on two context variables, namely attributeValues (6 value array per trial) which vary across trials and attributeWeights a vector of 3 which is subject specific but the same across trials.My ultimate goal is to estimate the three free parameters per subject (50 subjects), from 420 trials per trial.

My meta_fn picks a random number of trials.

def determine_n_trials(batch_size):
# N: number of observation in a dataset
n_trials = np.random.randint(120, 450)
return dict(n_trials = n_trials)

and I have another function which samples attributeValues and attributeWeights (determine_attVal(n_trials)).

Together they make my simulator
simulator = bf.make_simulator([determine_attVal, prior, modelFunc], meta_fn=determine_n_trials)

Using out = simulator.sample(5), returns the following…

n_trials: no .shape attribute (type = <class ‘int’>)
attValues: shape = (5, 292, 6)
attWeights: shape = (5, 3)
sn: shape = (5, 1)
threshInc: shape = (5, 1)
searchSense: shape = (5, 1)
choice: shape = (5, 292)
RT: shape = (5, 292)
allFix: shape = (5, 292, 100)

sn, threshInc and searchSense are the free parameters, while choice, RT and ‘allFix’ are the output of the model. For now, I am just using allFix as a summary_variable. The third dimension of allFix is the number of fixations. The model can produce a maximum of 100 fixations per trial.

From the discussion in the other post, I thought that attWeights, should be an inference condition and attValues should be part of summary_variables.

adapter = (
bf.adapters.Adapter()
.broadcast(“n_trials”, to=“RT”)
.convert_dtype(“float64”, “float32”)
.concatenate([“sn”,“threshInc”,“searchSense”], into=“inference_variables”)
.concatenate([“attWeights”,“n_trials”], into=“inference_conditions”)
.concatenate([“allFix”,“attValues”],into=“summary_variables”)
)

However, since ‘allFix’ contains fixation paths, I am using a TimeSeriesNetwork as the summary_network to ensure the order does not get lost. Not sure if this is the right choice. The order of attWeights and attValues is also important as they enter the model.

Online fitting runs (loss is huge but still decreasing). My main question is about the adapter: is this the right way of setting things up in my case and is a TimeSeriesNetwork appropriate here? Also, if I want to also use RT and choice as summary_variables (along with allFix), is this still an appropriate netwrok?

Thanks so much, I apologise in advance for the essay.

Hi,
the data structure in allFix is quite challenging, because it contains both a set for the participants and a time series for the fixation paths. So here the implied shape would be (5, 292, 100, 1). There are two ways to deal with this.

First, you can ignore that the fixations are a time series and just treat them as variables of a set. This is what your adapter currently does, but then you need to use a SetTransformer, and not a TimeSeriesTransformer. This is because the time series transformer expects input of the shape (batch_size, num_timepoints, num_variables), so here the participants would be treated as the time points, which is not what you want. The set transformer expects (batch_size, num_members, num_variables), so this would work but is not optimized for the time series data. However, this is doable with the default architectures that come with BayesFlow.

The second is more difficult, but would reflect the data structure better. You would first need to process all time series with an identical TimeSeriesTransformer to obtain a first summary with time_series_summary_dim variables, which would give you an output of shape (5, 292, time_series_summary_dim). Then, you could further process this with a SetTransformer or a DeepSet to take the exchangeability into account.

Unfortunately, as far as I can tell we don’t have a ready-to-use architecture for this setting. I think it is not so uncommon to have this kind of two-level structure, so maybe offering an interface for that (or extending the set networks to enable it) would be good. Tagging @KLDivergence for comment.

Regarding attValues (and optionally choice and RT), it would probably make sense to concatenate them to the summary outputs of the TimeSeriesTransformer, giving you the shape (5, 292, time_series_summary_dim + 6 + 1 + 1) as input for the set network.

Note that depending on your background, this will be challenging to implement, and I’m not sure how efficient this will be during training. If you need help with this, maybe we can ask around if someone from the developer team has capacity to support you.

I hope this helps, if you have questions, let us know.

Hi valentin,

Thanks so much for your fast reply, it is very helpful :slight_smile: . I will try the first option but would be very interested in Option 2.

I have very little experience using neural networks in this way, but would love to learn more. I think I understand the idea of Option 2 and can try to give it a shot. But of course, any professional help would be very much appreciated!

So just to summarise, for Option 2, I would essentially be creating a custom summary_network that consists of

  1. TimeSeriesTransfomer –> reduces dimensionality of fixation paths alone
  2. Concatenate with RT, choice and attValues
  3. plug the final representation into a SetTransformer

Would the adapter then remain the same? (RT, choice, allFix and attValues are all summary_variables) but then I treat them differently once they get passed to the summary network?

Thank you again :slight_smile:

I think it would be easiest to use a setup similar to this tutorial on multimodal data, using the .group method of the adapter to pass the separate inputs to the network. For an example on how to process them, take a look at the FusionNetwork implementation. I think you would have to perform the following steps:

Initialization (in __init__)

  • create a time series network/transformer
  • create a set transformer

Data Processing in the call function

  • as written above, allFix implies the shape (5, 292, 100, 1). Unfortunately, the time series networks can only work with three dimensions, so you first have to reshape to (5 * 292, 100, 1). You have to use the keras.ops module instead of numpy, so the backpropagation still works
  • pass the reshaped data through the time series network, giving you an output of shape (5 * 292, time_series_summary_dim)
  • reshape the output to (5, 292, time_series_summary_dim)
  • run the reshaped output through the set transformer, giving you shape (5, set_transformer_summary_dim)
  • return that

Other functions

For the build, calculate the shapes that each network will receive for the given inputs_shape, same for the compute_output_shape, which is just (batch_size, set_transformer_summary_dim).

I think for developing this, it would be best if you take a simple example that follows the same basic structure as your application and share it here (maybe similar to the notebook I linked above). For us to help you with this, it would be best to share a fully reproducible example, i.e., just a single code block that we can execute and play around with.
If you get stuck, don’t hesitate to reach out early, as the error messages are sometimes not super meaningful and can require some experience to make sense of.

As a final note, I would recommend using the TensorFlow backend for this for now, as it is a bit more flexible in some situations when varying shapes are used.

Hi again, thanks again for the detailed reply. I am also working on a simple example which I will post here, but I tried implementing a custom summary network by following your steps and looking at the FusionNetwork. I have this… not sure if it is correct but stepping through it seems ok…

class CustomNetwork(SummaryNetwork):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # Time series for fixation paths
        self.time_series_net = bf.networks.TimeSeriesNetwork()

        # Set Transformer for the second part
        self.set_transformer = bf.networks.SetTransformer()
    

    def call(self, input: Mapping[str, Tensor]):
        # Get shape of allFix summary_variable
        batch_size, trials, maxFix = ops.shape(input["allFix"])
        feat_dim = 1
        # Reshaping...
        x = ops.reshape(input["allFix"], (batch_size * trials, maxFix, feat_dim))

        # Pass to time series
        x = self.time_series_net(x)  # should return (batch*trials, time_series_summary_dim)

        # Reshape back to (batch, trials, time_series_summary_dim)
        time_series_summary_dim = ops.shape(x)[-1]
        x = ops.reshape(x, (batch_size, trials, time_series_summary_dim))
        
        # Concatenate choice and RT 
        input["RT"] = ops.expand_dims(input["RT"], axis = -1) # had to add a third dimension to concatenate
        input["choice"] = ops.expand_dims(input["choice"], axis = -1) # had to add a third dimension to concatenate
        x = ops.concatenate([input["attValues"],input["RT"],input["choice"],x],axis=2) # now returning size (batch, trials, time_series_dim+6+1+1)

        # Give to set transformer 
        #finalShape = ops.shape(x)
        #print(f"Size of allFix after timeseries: {finalShape}")
        
        out = self.set_transformer(x) # shape is (batch, trials, time_summary_dim + 6 (attVals) + 1 (RT) + 1 (choice))

        return out

    def build(self, input_shape: Mapping[str, Shape]):
        if self.built:
                return
        batch, trials, maxFix = input_shape["allFix"]
        feat_dim = 1 # implied shape is (batch, trials, maxFix, 1)

          # Build TimeSeries
        if not self.time_series_net.built:
            self.time_series_net.build((None, maxFix, feat_dim))
        ts_input_dim = self.time_series_net.compute_output_shape((None, maxFix, feat_dim))[-1]

      # Build set transformer (including extra features if any)
        set_input_dim = ts_input_dim + 6 + 1 + 1 # values + RT + choice
        if not self.set_transformer.built:
            self.set_transformer.build((None, trials, set_input_dim))

        self.built = True
        

    def compute_output_shape(self, input_shape: Mapping[str, Shape]):
        batch, trials, maxFix = input_shape["allFix"]
        feat_dim = 1 # implied extra dimension
    # For Time Series 
        ts_dim = self.time_series_net.compute_output_shape((None, maxFix, feat_dim))[-1]

    # For SetTransformer
        set_input_dim = ts_dim + 6 + 1 + 1 

    # Pass to SetTransformer
        return self.set_transformer.compute_output_shape((batch, trials, set_input_dim))

My adapter is now

adapter = (
    bf.adapters.Adapter()
    .broadcast("n_trials", to="RT") # without this, we get an error from convert_dtype? about n_trials being int.
    #.sqrt("n_trials")
    .convert_dtype("float64", "float32")
    .concatenate(["sn","threshInc","searchSense"], into="inference_variables")
    .concatenate(["attWeights","n_trials"], into="inference_conditions")     
    .group(
       ["RT","choice","allFix",'attValues'], into="summary_variables")
)

Then, in my workflow, summary_network is just CustomNetwork.
Everything seems to have the right shape if I understood correctly (even if it’s wrong, it was fun to try putting it together :slight_smile: )
What do you think?

Edit: unfortunately online training crashed with an out of memory error. Not sure if it is because I have done something wrong, but I could reduce the length of the fixation paths and n_trials if that would help.

Thanks again!

Nice, I skimmed the code and it looks good to me. The out of memory error seems plausible, as batchsize * trials can become large quite quickly. So as you suggested it might be good to start with low batch size, n_trials and length of the fixation path to see if it works at all. If it does, you can increase them to see what is possible with the memory available in your setup.

Hi thanks for checking it. I had a cython wrapper in my simulation because model can be quite slow, taking it out seemed to fix the problem (not sure), but it runs now.

Something else however: I accidentally closed down the program after training for two days! I can successfully load the model because I have a checkpoint but, is it possible to continue training once it has been interrupted?

You can continue training by setting the learning rate of the optimizer to (approximately) match the learning rate you had when you stopped training (e.g., by computing the cosine decayed lr at step t using the keras scheduler).

Nice that it runs now.
I had already nearly written this when Stefan responded, so here is a somewhat more detailed response. You can resume training, but you have to do a bit of manual work to do so.
First, you have to define the learning rate or the learning rate schedule.
We usually use CosineDecay, which means that in the BasicWorkflow (as specified here), the learning rate will already have decayed until training cutoff. If you are close to the end of training, I think the easiest would be eyeballing the learning rate of the cutoff and then continuing either with constant learning rate or with a new learning rate decay.

Then you can create an optimizer, we usually use Adam for online training and AdamW for offline training. Here are the two options with our defaults:

# online
optimizer = keras.optimizers.Adam(learning_rate, clipnorm=1.5)
# offline
optimizer = keras.optimizers.AdamW(learning_rate, weight_decay=5e-3, clipnorm=1.5)

Finally, you run approximator.compile(optimizer=optimizer), and then approximator.fit with the appropriate number of epochs and number of batches per epoch.

1 Like

Brilliant, thank you both very much (sorry, am pretty new to all this :smiley: )

1 Like

You are welcome! Just another minor clarification: if you are training using a Workflow object, there is no need to compile the approximator. Simply pass the optimizer class and the lr to the workflow like this:

my_workflow = bf.BasicWorkflow(
... # my other arguments,
optimzier=keras.optimizers.Adam(learning_rate, clipnorm=1.5) # example
)

Just a slightly off-topic comment on losing connection to the remote host:

If you use tmux, you can just detach and attach back to a session on the remote host. The run will continue regardless of whether you’re connected or not :slight_smile:

Cheers!

1 Like