fix: set n_ctx=512 for TinyStories models#1162
fix: set n_ctx=512 for TinyStories models#1162jlarson4 merged 2 commits intoTransformerLensOrg:devfrom
Conversation
TinyStories models were trained with sequence length 512, but HuggingFace config claims n_ctx=2048. This causes performance degradation for sequences >512 tokens. Added warning to alert users of this limitation. Note: We cannot change n_ctx in the config because the pretrained weights have positional embeddings for 2048 positions. Changing n_ctx would break weight loading. Fixes TransformerLensOrg#492
aeca4d1 to
4d74ac0
Compare
|
Hey @puranikyashaswin! This bug is resolvable, but this error log doesn't quite get there. We should be able to trim down the weights during the weight conversion for tiny stories models once they've been loaded into TransformLens from hugging face in the |
|
Thanks for the feedback @jlarson4! That makes sense trimming the positional embedding weights in the weight_conversion functions would be the proper fix. Would you like me to update this PR with that approach? |
|
@puranikyashaswin please feel free to tackle it at your leisure. I'll take a look once you've updated and let you know if I have any further thoughts or questions. Thank you! |
Trim pos_embed from [2048, d_model] to [n_ctx, d_model] in convert_neo_weights and convert_gpt2_weights when the pretrained weights have more positions than n_ctx. Override n_ctx to 512 for TinyStories models in convert_hf_model_config since these models were trained with seq_len=512 despite the HF config reporting max_position_embeddings=2048. Fixes TransformerLensOrg#492
|
Thanks @jlarson4! I've updated the PR to trim the positional embedding weights to the correct size during weight conversion. Ready for your review whenever you get time! |
|
Thank you! This looks great |
Description
TinyStories models were trained with sequence length 512, but HuggingFace config incorrectly claims
n_ctx=2048. This causes severe performance degradation for sequences >512 tokens.This fix adds a post-processing override to correct the
n_ctxvalue when loading any TinyStories model.Fixes #492
Type of change
Checklist: