-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable llava on torchchat #1183
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1183
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 937e7ed with merge base 2cf4016 (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder to test that Flamingo and 3.1 still work as expected
Also reminder that to test convert_hf_checkpoint you need to delete your download/conversion and rerun
@@ -21,9 +24,176 @@ | |||
|
|||
from torchchat.model import ModelArgs | |||
|
|||
def remap_llava_checkpoint(llava_ckpt): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this written inhouse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not pretty following your question.
This function is consumed by convert_llava_checkpoint
to get remapped checkpoint.
I made this as an individual function to simply the logic
@@ -21,9 +24,176 @@ | |||
|
|||
from torchchat.model import ModelArgs | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | |
Llava Conversion Code | |
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code comment blocks to help us move things around later
|
||
tokenizer_path = model_dir / "tokenizer.model" | ||
shutil.copy(tokenizer_files[0], tokenizer_path) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | |
Text-Only Conversion Code | |
""" |
if batch and self.model.config.model_type == ModelType.Llava: | ||
context_len, next_token = next_token | ||
else: | ||
context_len, next_token = T, next_token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
context_len, next_token = T, next_token | |
context_len = T |
encoded = batch["tokens"] | ||
elif self.model.config.model_type == ModelType.Llava: | ||
#TODO: double check the tokenizer. | ||
def find_subtensor(tensor, target): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typehints
"""Applies Rotary Position Embedding to the query and key tensors. | ||
|
||
Args: | ||
q (`torch.Tensor`): The query tensor. | ||
k (`torch.Tensor`): The key tensor. | ||
cos (`torch.Tensor`): The cosine part of the rotary embedding. | ||
sin (`torch.Tensor`): The sine part of the rotary embedding. | ||
unsqueeze_dim (`int`, *optional*, defaults to 1): | ||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | ||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | ||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | ||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | ||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | ||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | ||
Returns: | ||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outdated comment?
@@ -0,0 +1,80 @@ | |||
import torch | |||
import torchvision as tv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lintrunner ordering
padding with median RGB value to make a square, scaling, and normalizing. | ||
|
||
Args: | ||
img_address (str): Address of the local image file will be forwarded to the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Autogen'd comment?
@@ -919,6 +937,58 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |||
return x_out2.type_as(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can move apply_rotary_emb
so that it is sequentially after hf_apply_rotary_emb
?
Mainly for keeping concepts together
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to keep the current structure, with all HF rotary embedding functions grouped together and all previous embedding functions in a separate section.
encoded = batch["tokens"] | ||
assert len(images) == 1, "Only one image prompt is supported for now" | ||
|
||
#TODO: updated encoded variable for multi-modality models to include image tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this to me?
This PR enable llava1.5 on torchchat, which is the first multi-modality model on torchchat.
How to play?
You can use
--prompt
as the flag for text input, and--image-prompt
as image input.e.g.
It can also handle input without image input: