Recurse Center - Batch 3 - Cycle 2025012-20250114 - Decoding
Decoding
This cycle I finished up coding the speech transformer from end-to-end :)
The code for the speech transformer looks something like this now:
My encoder looks like this:
class Encoder(nn.Module):
"""Encoder
Input: [batch time_steps freq_bins]
Output: [batch seq_length d_model]
This class processes a batch of spectrograms by:
- Expanding the input tensor by one dimension for channel
- Passing our tensor through a Conv2d block (with ReLU)
- Passing our tensor through another Conv2d block (with ReLU)
- Reshaping out tensor so it has three dimensions instead of four
- Passing through a linear layer so we get an output shape of d_model for the last dimension
- Passing our tensor through a posiitonal encoder
- Passing out query_input through n encoder blocks
- Passing our tensor through a layer norm
- Outputting our encoded output to be used by the decoder (for cross attention)
"""
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
# encoder components
self.repeat = Repeat(self.cfg)
self.conv2d_block_one = Conv2DBlock(self.cfg, self.cfg.n_channels, self.cfg.n_out_channels, self.cfg.conv2d_kernel_size, self.cfg.conv2d_stride, self.cfg.conv2d_padding)
self.conv2d_block_two = Conv2DBlock(self.cfg, self.cfg.n_out_channels, self.cfg.n_out_channels, self.cfg.conv2d_kernel_size, self.cfg.conv2d_stride, self.cfg.conv2d_padding)
self.reshape = Reshape(self.cfg, "b c ts fb -> b ts (c fb)")
self.linear = Linear(self.cfg)
self.positional_encoder = PositionalEncoder(self.cfg)
self.encoder_blocks = nn.Sequential(
*[EncoderBlock() for _ in range(self.cfg.n_encoder_layers)]
)
self.layer_norm = LayerNorm(self.cfg)
# encoder as an nn.Sequential
self.sequential = nn.Sequential(
self.repeat,
self.conv2d_block_one,
self.conv2d_block_two,
self.reshape,
self.linear,
self.positional_encoder,
self.encoder_blocks,
self.layer_norm
)
def forward(self, x: Float[t.Tensor, "batch time_steps freq_bins"]) -> Float[t.Tensor, "batch seq_length d_model"] : # type: ignore
# take in our input spectrograms, encode, and generate encoded outputs to be used by the decoder for cross-attention
assert x.ndim == 3, f"Expected 3 dimensions, got {x.ndim}"
assert x.shape[2] == self.cfg.n_freq_bins
return self.sequential(x)
And my decoder looks like this:
class Decoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.character_embedding = CharacterEmbedding(self.cfg)
self.decoder_positional_encoder = PositionalEncoder(self.cfg)
self.decoder_blocks = nn.ModuleList(
[DecoderBlock(self.cfg) for _ in range(self.cfg.n_decoder_layers)]
)
self.layer_norm = LayerNorm(self.cfg)
self.linear = nn.Linear(self.cfg.d_model, self.cfg.vocab_size)
self.soft_max = nn.Softmax(dim=-1) # Apply softmax along the vocabulary dimension
def forward(self, x,
key_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None, # type: ignore
value_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None # type: ignore
):
# take in our input tokens, encode, and generate character embeddings
embeddings = self.character_embedding(x)
encoded_sequence = self.decoder_positional_encoder(embeddings)
assert encoded_sequence.ndim == 3
assert encoded_sequence.shape[1] <= self.cfg.max_seq_length
assert encoded_sequence.shape[2] == self.cfg.d_model
# pass positional encoding into decoder blocks
x = encoded_sequence
for block in self.decoder_blocks:
x = block(x, key_input, value_input)
x = self.layer_norm(x)
x = self.linear(x)
probabilities = self.soft_max(x)
return probabilities
More on the full SpeechTransformer and its classes can be found here:
Day 1
Today I started refactoring the encoder into its own class. Once that was done, I did some refactoring of the decoder as I wrapped up both sides of the speech transformer diagram.
I learned a bit about static type checkers in IDEs.
I developed a deeper understanding of what the signifigance of our sequence dimension, and how this relates to vocab_size and d_model.
In implementing cross-attention in the speech transformer's decoder, we wnat to use the same Attention class to do both self-attention (with just query_input) or cross-attention (with query_input, as well as key_input, and key_input that come from the output in our encoder.
In order to accomplish this, we use a lambda function that lets us pass optional parameters that end up "closing over" and capturing the extra parameters, if used.
If key_input and value_input are empty, we default to useing query_input for calcuating Q, K, and V in our attention function. If we pass in key_input and value_input from our encoder's output, they will be used in their respective functions in the attention function. This way we can support self-attention and cross-attention in the same class/function.
Here's what that looks like in code:
In our decoder:
class DecoderBlock(TransformerBlock):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.layer_norm_one = nn.LayerNorm(self.cfg.d_model)
self.attention_masked = Attention(self.cfg, apply_mask=True)
self.layer_norm_two = nn.LayerNorm(self.cfg.d_model)
self.attention_unmasked = Attention(self.cfg, apply_mask=False)
self.layer_norm_three = nn.LayerNorm(self.cfg.d_model)
self.feed_forward_network = FFN(self.cfg)
def forward(
self,
x: Float[t.Tensor, "batch posn d_model"], # type: ignore
key_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None, # type: ignore
value_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None # type: ignore
) -> Float[t.Tensor, "batch posn d_model"]: # type: ignore
x = self.add_to_residual_stream(x, self.layer_norm_one, self.attention_masked)
x = self.add_to_residual_stream(x, self.layer_norm_two, lambda norm_x: self.attention_unmasked(norm_x, key_input, value_input)) # this layer needs to use encoder outputs as its inputs for keys and values, and use queries from previous sub-block outputs
x = self.add_to_residual_stream(x, self.layer_norm_three, self.feed_forward_network)
return x
And in our Attention class:
def forward(
self,
query_input: Float[t.Tensor, "batch posn d_model"], # type: ignore
key_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None, # type: ignore
value_input: Optional[Float[t.Tensor, "batch posn d_model"]] = None # type: ignore
) -> Float[t.Tensor, "batch posn d_model"]: # type: ignore
# linear map
Q = einops.einsum(query_input,
self.W_Q,
"b s e, n e h -> b s n h") + self.b_Q
K = einops.einsum(query_input if key_input is None else key_input,
self.W_K,
"b s e, n e h -> b s n h") + self.b_K
V = einops.einsum(query_input if value_input is None else value_input,
self.W_V,
"b s e, n e h -> b s n h") + self.b_V
# ...
And as I wrapped up the decoder, I got a better understanding of how our output probabilities ends up as a vocabulary size, predicting what character should come next. This was a bit a-ha moment for me!
Day 2
Today I reactored and tested the full encoder-decoder speech transformer, as well as some of the older tests where I modified the implementation since yesterday.
def test_speech_transformer(model, cfg):
cv = CharacterVocabulary(cfg)
batches = 2
time_steps = 100
freq_bins = model.cfg.n_freq_bins
text_input = "listening is a practice of freedom"
encoded_text_input = cv.encode(text_input)
assert len(encoded_text_input) == model.cfg.max_seq_length
encoded_text_tensor = t.tensor(encoded_text_input, dtype=t.long)
encoded_text_tensor = einops.repeat(encoded_text_tensor, "seq_length -> b seq_length", b=batches).to(cfg.device)
spectrograms = t.randn(batches, time_steps, freq_bins).to(cfg.device)
output = model(spectrograms, encoded_text_tensor)
assert output.shape == (batches, model.cfg.max_seq_length, model.cfg.vocab_size)
Day 3:
Today was mostly spent writing this post and re-factoring some code in my tests, using pytest fixtures.
Things for next cycle
Next cycle will be training, so I'm looking forward to setting up a data processing pipeline and doing my first training run with this model! I also have a few things that are less of a priority, like move some of the asserts in the classes to tests.