Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add truncated llama style model init via reset parameters() (#54) #530

Closed
wants to merge 1 commit into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Aug 19, 2024

Stack from ghstack (oldest at bottom):

  • (to be filled)

This PR adds the following:
1 - via reset parameters, a full layerwise init for the llama models
under /llama. This uses the total model depth as part of the init via:
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

2 - The final output ffn (head) is init with sqrt of the dim of the
model itself and a slightly wider cutoff factor of 3.

3 - tangential change - updates run_llama_train.sh with updated MODEL
and MODEL_CONF params to allow for direct model control via the sh
script. (there was a MODEL already but it was incorrectly using that in
place of MODEL_CONF...though we should update this as it's not
intuitive).

4 - made the debugmodel default to 2 layers as an improved debug check.

5 - added a 1B and 40B for additional testing configs. I can't currently
run 70B on my H100 due to OOM, but can run 40B.

Testing:
Verified proper init and training with 7B, 13B and ~40B:

Screenshot 2024-02-11 at 10 39 12 PM

This PR adds the following:
1 - via reset parameters, a full layerwise init for the llama models
under /llama. This uses the total model depth as part of the init via:
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

2 - The final output ffn (head) is init with sqrt of the dim of the
model itself and a slightly wider cutoff factor of 3.

3 - tangential change - updates run_llama_train.sh with updated MODEL
and MODEL_CONF params to allow for direct model control via the sh
script. (there was a MODEL already but it was incorrectly using that in
place of MODEL_CONF...though we should update this as it's not
intuitive).

4 - made the debugmodel default to 2 layers as an improved debug check.

5 - added a 1B and 40B for additional testing configs. I can't currently
run 70B on my H100 due to OOM, but can run 40B.

Testing:
Verified proper init and training with 7B, 13B and ~40B:

<img width="1085" alt="Screenshot 2024-02-11 at 10 39 12 PM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/049037ed-63a4-4ab0-bebc-f297857aab72">

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 19, 2024
@awgu awgu closed this Aug 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants