Skip to content

Discrepancy between paper and codebase regarding prediction heads #6

@cs-giung

Description

@cs-giung

Hello authors,

Thanks for the great work and for open-sourcing the code.

While reviewing the implementation, I noticed what appears to be a discrepancy between the paper and the codebase regarding how the continuous prediction (x_pred) and discrete decoding (s_pred) are formulated.

  • In the paper, the predictions are described as direct/linear projections from the shared network output:
    • x_pred = net(z, t)
    • s_pred = x_pred @ unembed_kernel
  • In the codebase, if I am understanding it correctly, the implementation branches earlier and introduces additional normalization and non-linearities (omitted *_bias for brevity):
    • x_pred = linear @ RMSNorm(net(z, t))
    • s_pred = gelu(net(z, t) @ proj_kernel) @ unembed_kernel

Could you clarify if these were empirical design choices added to stabilize training, or please let me know if I might have missed something? Thanks for your time!

ELF/src/modules/model.py

Lines 141 to 157 in 1f38c80

# Factored decoder unembedding: hidden -> text_encoder_dim -> vocab
decoder_logits = None
bn = self.text_encoder_dim
proj_kernel = self.param('proj_kernel', DEFAULT_KERNEL_INIT, (self.hidden_size, bn))
proj_bias = self.param('proj_bias', DEFAULT_BIAS_INIT, (bn,))
unembed_kernel = self.param('unembed_kernel', DEFAULT_KERNEL_INIT, (bn, self.vocab_size))
unembed_bias = self.param('unembed_bias', DEFAULT_BIAS_INIT, (self.vocab_size,))
if decoder_step_active is not None:
decoder_logits = jax.lax.cond(
decoder_step_active,
lambda xi: jax.nn.gelu(xi @ proj_kernel + proj_bias) @ unembed_kernel + unembed_bias,
lambda xi: jnp.zeros((*xi.shape[:2], self.vocab_size), dtype=xi.dtype),
x,
)
output = FinalLayer(self.hidden_size, patch_size, self.text_encoder_dim, name='final_layer')(x)
return output, decoder_logits

ELF/src/modules/layers.py

Lines 204 to 216 in 1f38c80

class FinalLayer(nn.Module):
"""The final layer of ELF."""
hidden_size: int
patch_size: int
out_channels: int
@nn.compact
def __call__(self, x):
x = RMSNorm(self.hidden_size, name='norm_final')(x)
return nn.Dense(
self.patch_size * self.patch_size * self.out_channels, use_bias=True,
kernel_init=ZERO_INIT, bias_init=ZERO_INIT, name='linear',
)(x)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions