Recurse Center - Batch 3 - Cycle 2025012-20250114 - Decoding


This cycle I finished up coding the speech transformer from end-to-end :)

The code for the speech transformer looks something like this now:

My encoder looks like this:

class Encoder(nn.Module):

    Input: [batch time_steps freq_bins]
    Output: [batch seq_length d_model]

    This class processes a batch of spectrograms by:

    - Expanding the input tensor by one dimension for channel
    - Passing our tensor through a Conv2d block (with ReLU)
    - Passing our tensor through another Conv2d block (with ReLU)
    - Reshaping out tensor so it has three dimensions instead of four
    - Passing through a linear layer so we get an output shape of d_model for the last dimension
    - Passing our tensor through a posiitonal encoder
    - Passing out query_input through n encoder blocks
    - Passing our tensor through a layer norm
    - Outputting our encoded output to be used by the decoder (for cross attention)

    def __init__(self, cfg: Config):
        self.cfg = cfg

        # encoder components
        self.repeat = Repeat(self.cfg)
        self.conv2d_block_one = Conv2DBlock(self.cfg, self.cfg.n_channels, self.cfg.n_out_channels, self.cfg.conv2d_kernel_size, self.cfg.conv2d_stride, self.cfg.conv2d_padding)
        self.conv2d_block_two = Conv2DBlock(self.cfg, self.cfg.n_out_channels, self.cfg.n_out_channels, self.cfg.conv2d_kernel_size, self.cfg.conv2d_stride, self.cfg.conv2d_padding)
        self.reshape = Reshape(self.cfg, "b c ts fb -> b ts (c fb)")
        self.linear = Linear(self.cfg)
        self.positional_encoder = PositionalEncoder(self.cfg)
        self.encoder_blocks = nn.Sequential(
            *[EncoderBlock() for _ in range(self.cfg.n_encoder_layers)]
        self.layer_norm = LayerNorm(self.cfg)

        # encoder as an nn.Sequential
        self.sequential = nn.Sequential(

    def forward(self, x: Float[t.Tensor, "batch time_steps freq_bins"]) -> Float[t.Tensor, "batch seq_length d_model"] :  # type: ignore
        # take in our input spectrograms, encode, and generate encoded outputs to be used by the decoder for cross-attention
        assert x.ndim == 3, f"Expected 3 dimensions, got {x.ndim}"
        assert x.shape[2] == self.cfg.n_freq_bins
        return self.sequential(x)

And my decoder looks like this:

class Decoder(nn.Module):
    def __init__(self, cfg):
        self.cfg = cfg
        self.character_embedding = CharacterEmbedding(self.cfg)
        self.decoder_positional_encoder = PositionalEncoder(self.cfg)
        self.decoder_blocks = nn.ModuleList(
            [DecoderBlock(self.cfg) for _ in range(self.cfg.n_decoder_layers)]
        self.layer_norm = LayerNorm(self.cfg)
        self.linear = nn.Linear(self.cfg.d_model, self.cfg.vocab_size)
        self.soft_max = nn.Softmax(dim=-1)  # Apply softmax along the vocabulary dimension

    def forward(self, x,
            key_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None, # type: ignore
            value_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None # type: ignore
        # take in our input tokens, encode, and generate character embeddings
        embeddings = self.character_embedding(x)

        encoded_sequence = self.decoder_positional_encoder(embeddings)

        assert encoded_sequence.ndim == 3
        assert encoded_sequence.shape[1] <= self.cfg.max_seq_length
        assert encoded_sequence.shape[2] == self.cfg.d_model

        # pass positional encoding into decoder blocks

        x = encoded_sequence
        for block in self.decoder_blocks:
            x = block(x, key_input, value_input)

        x = self.layer_norm(x)

        x = self.linear(x)

        probabilities = self.soft_max(x)

        return probabilities

More on the full SpeechTransformer and its classes can be found here:


Day 1

Today I started refactoring the encoder into its own class. Once that was done, I did some refactoring of the decoder as I wrapped up both sides of the speech transformer diagram.

I learned a bit about static type checkers in IDEs.

I developed a deeper understanding of what the signifigance of our sequence dimension, and how this relates to vocab_size and d_model.

In implementing cross-attention in the speech transformer's decoder, we wnat to use the same Attention class to do both self-attention (with just query_input) or cross-attention (with query_input, as well as key_input, and key_input that come from the output in our encoder.

In order to accomplish this, we use a lambda function that lets us pass optional parameters that end up "closing over" and capturing the extra parameters, if used.

If key_input and value_input are empty, we default to useing query_input for calcuating Q, K, and V in our attention function. If we pass in key_input and value_input from our encoder's output, they will be used in their respective functions in the attention function. This way we can support self-attention and cross-attention in the same class/function.

Here's what that looks like in code:

In our decoder:

class DecoderBlock(TransformerBlock):
    def __init__(self, cfg):
        self.cfg = cfg
        self.layer_norm_one = nn.LayerNorm(self.cfg.d_model)
        self.attention_masked = Attention(self.cfg, apply_mask=True)
        self.layer_norm_two = nn.LayerNorm(self.cfg.d_model)
        self.attention_unmasked = Attention(self.cfg, apply_mask=False)
        self.layer_norm_three = nn.LayerNorm(self.cfg.d_model)
        self.feed_forward_network = FFN(self.cfg)

    def forward(
            x: Float[t.Tensor, "batch posn d_model"], # type: ignore
            key_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None, # type: ignore
            value_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None # type: ignore
            ) -> Float[t.Tensor, "batch posn d_model"]: # type: ignore

        x = self.add_to_residual_stream(x, self.layer_norm_one, self.attention_masked)
        x = self.add_to_residual_stream(x, self.layer_norm_two, lambda norm_x: self.attention_unmasked(norm_x, key_input, value_input)) # this layer needs to use encoder outputs as its inputs for keys and values, and use queries from previous sub-block outputs
        x = self.add_to_residual_stream(x, self.layer_norm_three, self.feed_forward_network)

        return x

And in our Attention class:

    def forward(
            query_input: Float[t.Tensor, "batch posn d_model"], # type: ignore
            key_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None, # type: ignore
            value_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None # type: ignore
            ) -> Float[t.Tensor, "batch posn d_model"]: # type: ignore
        # linear map

        Q = einops.einsum(query_input,
                          "b s e, n e h -> b s n h") + self.b_Q

        K = einops.einsum(query_input if key_input is None else key_input,
                          "b s e, n e h -> b s n h") + self.b_K

        V = einops.einsum(query_input if value_input is None else value_input,
                          "b s e, n e h -> b s n h") + self.b_V

        # ...

And as I wrapped up the decoder, I got a better understanding of how our output probabilities ends up as a vocabulary size, predicting what character should come next. This was a bit a-ha moment for me!

Day 2

Today I reactored and tested the full encoder-decoder speech transformer, as well as some of the older tests where I modified the implementation since yesterday.

def test_speech_transformer(model, cfg):

    cv = CharacterVocabulary(cfg)

    batches = 2
    time_steps = 100
    freq_bins = model.cfg.n_freq_bins

    text_input = "listening is a practice of freedom"
    encoded_text_input = cv.encode(text_input)

    assert len(encoded_text_input) == model.cfg.max_seq_length

    encoded_text_tensor = t.tensor(encoded_text_input, dtype=t.long)

    encoded_text_tensor = einops.repeat(encoded_text_tensor, "seq_length -> b seq_length", b=batches).to(cfg.device)

    spectrograms = t.randn(batches, time_steps, freq_bins).to(cfg.device)

    output = model(spectrograms, encoded_text_tensor)

    assert output.shape == (batches, model.cfg.max_seq_length, model.cfg.vocab_size)

Day 3:

Today was mostly spent writing this post and re-factoring some code in my tests, using pytest fixtures.

Things for next cycle

Next cycle will be training, so I'm looking forward to setting up a data processing pipeline and doing my first training run with this model! I also have a few things that are less of a priority, like move some of the asserts in the classes to tests.