Hello authors,
Thanks for the great work and for open-sourcing the code.
While reviewing the implementation, I noticed what appears to be a discrepancy between the paper and the codebase regarding how the continuous prediction (x_pred) and discrete decoding (s_pred) are formulated.
- In the paper, the predictions are described as direct/linear projections from the shared network output:
x_pred = net(z, t)
s_pred = x_pred @ unembed_kernel
- In the codebase, if I am understanding it correctly, the implementation branches earlier and introduces additional normalization and non-linearities (omitted
*_bias for brevity):
x_pred = linear @ RMSNorm(net(z, t))
s_pred = gelu(net(z, t) @ proj_kernel) @ unembed_kernel
Could you clarify if these were empirical design choices added to stabilize training, or please let me know if I might have missed something? Thanks for your time!
|
# Factored decoder unembedding: hidden -> text_encoder_dim -> vocab |
|
decoder_logits = None |
|
bn = self.text_encoder_dim |
|
proj_kernel = self.param('proj_kernel', DEFAULT_KERNEL_INIT, (self.hidden_size, bn)) |
|
proj_bias = self.param('proj_bias', DEFAULT_BIAS_INIT, (bn,)) |
|
unembed_kernel = self.param('unembed_kernel', DEFAULT_KERNEL_INIT, (bn, self.vocab_size)) |
|
unembed_bias = self.param('unembed_bias', DEFAULT_BIAS_INIT, (self.vocab_size,)) |
|
if decoder_step_active is not None: |
|
decoder_logits = jax.lax.cond( |
|
decoder_step_active, |
|
lambda xi: jax.nn.gelu(xi @ proj_kernel + proj_bias) @ unembed_kernel + unembed_bias, |
|
lambda xi: jnp.zeros((*xi.shape[:2], self.vocab_size), dtype=xi.dtype), |
|
x, |
|
) |
|
|
|
output = FinalLayer(self.hidden_size, patch_size, self.text_encoder_dim, name='final_layer')(x) |
|
return output, decoder_logits |
|
class FinalLayer(nn.Module): |
|
"""The final layer of ELF.""" |
|
hidden_size: int |
|
patch_size: int |
|
out_channels: int |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
x = RMSNorm(self.hidden_size, name='norm_final')(x) |
|
return nn.Dense( |
|
self.patch_size * self.patch_size * self.out_channels, use_bias=True, |
|
kernel_init=ZERO_INIT, bias_init=ZERO_INIT, name='linear', |
|
)(x) |
Hello authors,
Thanks for the great work and for open-sourcing the code.
While reviewing the implementation, I noticed what appears to be a discrepancy between the paper and the codebase regarding how the continuous prediction (
x_pred) and discrete decoding (s_pred) are formulated.x_pred = net(z, t)s_pred = x_pred @ unembed_kernel*_biasfor brevity):x_pred = linear @ RMSNorm(net(z, t))s_pred = gelu(net(z, t) @ proj_kernel) @ unembed_kernelCould you clarify if these were empirical design choices added to stabilize training, or please let me know if I might have missed something? Thanks for your time!
ELF/src/modules/model.py
Lines 141 to 157 in 1f38c80
ELF/src/modules/layers.py
Lines 204 to 216 in 1f38c80