Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long Output After Finetuning #47

Open
TonyJiang17 opened this issue Oct 18, 2024 · 52 comments
Open

Long Output After Finetuning #47

TonyJiang17 opened this issue Oct 18, 2024 · 52 comments

Comments

@TonyJiang17
Copy link

Have anyone ever ran into the issue where after finetuning the output doesn't know when to end, only ends until max new token is reached? Does it has to do with the tokenizer is not adding an eos token to the end?

I am specifically finetuning llava-next-video...

@zjysteven
Copy link
Owner

Yeah not having eos token probably is the cause. Let me update it real soon. Currently I'm just using the chat template from huggingface which does not apply eos token probably because it's designed for inference.

@zjysteven
Copy link
Owner

@TonyJiang17 Would you try again and let me know if it helps on llava-next-video?

@TonyJiang17
Copy link
Author

TonyJiang17 commented Oct 18, 2024

I already made the following change locally and am training something now. Does this look correct to you? I'll let you know if there's any changes. I made this change in the Loader file for llava next video

"processor = LlavaNextVideoProcessor.from_pretrained(self.model_hf_path, add_eos_token=True)"

@zjysteven
Copy link
Owner

One more change is needed, which requires a monkey patch of huggingface's apply_chat_template. They hard-coded add_special_tokens=False, which still won't add bos and eos tokens.

@TonyJiang17
Copy link
Author

Oh shoot, nothing we could do locally to change it? unless i guess manually add them?

@zjysteven
Copy link
Owner

I've pushed a fix here c95ea1a. It's not that many changes, and I have confirmed from the outputs of the collator that it now includes bos and eos. If you just pull and train again that would be great.

@TonyJiang17
Copy link
Author

thanks! will try and let you know, may not be able to give an update until tmr tho

@zjysteven
Copy link
Owner

zjysteven commented Oct 18, 2024

No worries. I have a local file for checking the output of collator, in case it helps.

import json
import os
from tqdm import tqdm

import torch
torch.set_printoptions(profile="full", linewidth=240)
from torch.utils.data import DataLoader
from transformers import AutoProcessor, AutoTokenizer

from datasets import LazySupervisedDataset
from collators import COLLATORS
from loaders import LOADERS
from supported_models import MODEL_HF_PATH

model_id = "llava-next-video-7b"
model_family_id = "llava-next-video"

dataset = LazySupervisedDataset(
    data_path='./example_data/video.json',
    image_folder='./example_data/images',
    video_folder='./example_data/videos',
    model_family_id=model_family_id,
)

_, tokenizer, processor, config = LOADERS[model_family_id](
    model_hf_path=MODEL_HF_PATH[model_id],
    model_local_path=MODEL_HF_PATH[model_id],
    compute_dtype=torch.float16,
).load(load_model=False)
tokenizer.model_max_length = 256
collator = COLLATORS[model_family_id](
    config=config,
    processor=processor,
    tokenizer=tokenizer
)

dataloader = DataLoader(dataset, batch_size=2, collate_fn=collator)

batch = next(iter(dataloader))
print(batch["input_ids"])
print()
print(batch["labels"])
print()
print(tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False))
print(tokenizer.decode(
    batch["labels"][1][torch.where(batch["labels"][1] != -100)[0]]
))

@zjysteven
Copy link
Owner

Oh wait. Seems like the eos token is not included in labels. One sec.

@zjysteven
Copy link
Owner

Now it works. Again please pull the latest code.

whycantfindaname pushed a commit to whycantfindaname/lmms-finetune that referenced this issue Oct 20, 2024
whycantfindaname pushed a commit to whycantfindaname/lmms-finetune that referenced this issue Oct 20, 2024
whycantfindaname pushed a commit to whycantfindaname/lmms-finetune that referenced this issue Oct 20, 2024
@TonyJiang17
Copy link
Author

It seems to have made things beter after the fix on Friday. I will try test again tomorrow. @zjysteven

@jackyangcv
Copy link

jackyangcv commented Oct 21, 2024

Now it works. Again please pull the latest code.

更新了新版本的代码后,用之前的推理代码发现模型的输出是空的,请问推理部分的代码需要修改吗?@zjysteven

@zjysteven
Copy link
Owner

推理代码具体是什么样子?模型是什么?这个codebase理论上是和推理部分解耦的,应该不会有影响

@jackyangcv
Copy link

jackyangcv commented Oct 21, 2024

推理代码具体是什么样子?模型是什么?这个codebase理论上是和推理部分解耦的,应该不会有影响

是按照这个里面推理的:https://github.com/zjysteven/lmms-finetune/blob/main/docs/inference.md,模型是llava_next_video_7B
未fix前的代码推理没问题,更新fix的代码后模型输出就变成空字符串了 @zjysteven

@zjysteven
Copy link
Owner

还是挺奇怪的,推理是没有用到任何lmms–finetune内部的代码的。你有重新训练吗还是

@jackyangcv
Copy link

jackyangcv commented Oct 21, 2024

还是挺奇怪的,推理是没有用到任何lmms–finetune内部的代码的。你有重新训练吗还是

有重新训练,请问你有尝试重新训练后用原有推理代码测试吗?是否正常呢 @zjysteven

@TonyJiang17
Copy link
Author

可以检查一下inference的时候processor有没有加eos token的,inference应该是不用的。

@zjysteven
Copy link
Owner

@jackyangcv 我跑了一遍example_video.sh,然后用如下inference代码,输出没有问题

import av
import torch
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
import numpy as np

# model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
model_id = "../checkpoints/llava-next-video-7b_lora-True_qlora-False"

model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to(0)

processor = LlavaNextVideoProcessor.from_pretrained(
    "llava-hf/LLaVA-NeXT-Video-7B-hf"
)

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


# define a chat histiry and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image", "video") 
conversation = [
    {

        "role": "user",
        "content": [
            {"type": "text", "text": "Please provide a detailed description of the video."},
            {"type": "video"},
        ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

video_path = "../example_data/videos/ego4d/1e85d8b5-5ca8-4bbf-be51-21741ac8a694.mp4"
container = av.open(video_path)

# sample uniformly 8 frames from the video, can sample more for longer videos
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)
inputs_video = processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(model.device)

output = model.generate(**inputs_video, max_new_tokens=512, do_sample=True)
print(processor.decode(output[0], skip_special_tokens=True))

输出是

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.03s/it]
Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
USER:
Please provide a detailed description of the video. ASSISTANT: The video showcases a person seated in a room, engaged in a creative activity of sewing or altering something on a pattern. The person appears to be well-equipped with various sewing and crafting materials and tools like a sewing machine, fabrics, threads, buttons, and possibly a pattern guide. The room is cluttered with other items adding to the sense of a dedicated workspace and creative atmosphere, including a laptop screen displaying patterns, a phone on a pink stand, and a pattern book resting on a table. The person's focus is evident as they work with sharp concentration and precision, carefully examining the stitches and cutting pieces of material. The camera captures the action from an aerial view, allowing us to appreciate the entire setup and the person's workspace.

可以检查一下inference的时候processor有没有加eos token的,inference应该是不用的。

inference时候的processor用的都是未经修改的hf的processor,和正常inference一个llava-next-video是完全一样的

@TonyJiang17
Copy link
Author

yea it's just when i tested inference, i accidentally had
"processor = LlavaNextVideoProcessor.from_pretrained(self.model_hf_path, add_eos_token=True)"

This caused some issue but after i removed it, it seems to work fine

@jackyangcv
Copy link

yea it's just when i tested inference, i accidentally had "processor = LlavaNextVideoProcessor.from_pretrained(self.model_hf_path, add_eos_token=True)"

This caused some issue but after i removed it, it seems to work fine

请问你的意思是推理的时候使用processor = LlavaNextVideoProcessor.from_pretrained(self.model_hf_path)吗?

@zjysteven
Copy link
Owner

@jackyangcv 我跑了一遍example_video.sh,然后用如下inference代码,输出没有问题

import av
import torch
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
import numpy as np

# model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
model_id = "../checkpoints/llava-next-video-7b_lora-True_qlora-False"

model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to(0)

processor = LlavaNextVideoProcessor.from_pretrained(
    "llava-hf/LLaVA-NeXT-Video-7B-hf"
)

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


# define a chat histiry and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image", "video") 
conversation = [
    {

        "role": "user",
        "content": [
            {"type": "text", "text": "Please provide a detailed description of the video."},
            {"type": "video"},
        ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

video_path = "../example_data/videos/ego4d/1e85d8b5-5ca8-4bbf-be51-21741ac8a694.mp4"
container = av.open(video_path)

# sample uniformly 8 frames from the video, can sample more for longer videos
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)
inputs_video = processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(model.device)

output = model.generate(**inputs_video, max_new_tokens=512, do_sample=True)
print(processor.decode(output[0], skip_special_tokens=True))

输出是

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.03s/it]
Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
USER:
Please provide a detailed description of the video. ASSISTANT: The video showcases a person seated in a room, engaged in a creative activity of sewing or altering something on a pattern. The person appears to be well-equipped with various sewing and crafting materials and tools like a sewing machine, fabrics, threads, buttons, and possibly a pattern guide. The room is cluttered with other items adding to the sense of a dedicated workspace and creative atmosphere, including a laptop screen displaying patterns, a phone on a pink stand, and a pattern book resting on a table. The person's focus is evident as they work with sharp concentration and precision, carefully examining the stitches and cutting pieces of material. The camera captures the action from an aerial view, allowing us to appreciate the entire setup and the person's workspace.

可以检查一下inference的时候processor有没有加eos token的,inference应该是不用的。

inference时候的processor用的都是未经修改的hf的processor,和正常inference一个llava-next-video是完全一样的

@jackyangcv 看这里

@TonyJiang17
Copy link
Author

yea it's just when i tested inference, i accidentally had "processor = LlavaNextVideoProcessor.from_pretrained(self.model_hf_path, add_eos_token=True)"
This caused some issue but after i removed it, it seems to work fine

请问你的意思是推理的时候使用processor = LlavaNextVideoProcessor.from_pretrained(self.model_hf_path)吗?

对的 inference不要加add_eos_token=True

@jackyangcv
Copy link

已解决,感谢 @zjysteven @TonyJiang17

@TonyJiang17
Copy link
Author

@zjysteven the response repeating and never ending issues still exists tho mitigated than before.

The problem shows up the most when I set the mask questions to True in the training argument. @zjysteven do you have any idea on why that's the case. Also, just curious, when would you suggest we mask the question when not to? I assume it's more for continuous pretraining purpose? thanks

@TonyJiang17
Copy link
Author

when you are doing instruction tuning, one should include the question in the training, no? I only assume we can mask question when we are doing pretraining when the question doesn't really matter?

@zjysteven
Copy link
Owner

zjysteven commented Oct 24, 2024

when you are doing instruction tuning, one should include the question in the training, no? I only assume we can mask question when we are doing pretraining when the question doesn't really matter?

Well my understanding is that essentially what you want from the model is to do conditional generation conditioned on the question. This does not necessarily require the model to be able to predict the tokens in the questions. For example, given the input "USER: Which team is Messi playing for right now? ASSISTANT: ", we want the model to generate "Inter Miami CF". Whether it can generate "is Messi playing..." given "USER: Which team " is much less important. Again you can check https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/train/train.py#L453-L492 and https://github.com/QwenLM/Qwen-VL/blob/aa00ed04091eea5fcdd32985e7915f1c53e7d599/finetune.py#L155-L156. These two are official implementations of LLaVA and Qwen.

For the repeating issue, it seems somewhat common in practice, for example see https://www.reddit.com/r/LocalLLaMA/comments/1ap8mxh/what_causes_llms_to_fall_into_repetitions_while/. The post mentions potential causes and fixes.

Again what is the sampling strategy you are using here? Also, does increasing the temperature (putting more randomness) or having repetition_penalty helps here https://huggingface.co/transformers/v2.11.0/main_classes/model.html#transformers.PreTrainedModel.generate?

@TonyJiang17
Copy link
Author

TonyJiang17 commented Oct 24, 2024

thanks gotcha this is helpful. Yea i am using greedy sampling (will try random sampling tmr) repetition penalty helps but it seems to hurt performance too ...

but adding eos seems to have helped! tho i was curious why it wasn't included before in other projects? do you have any idea why?

@zjysteven
Copy link
Owner

i was curious why it wasn't included before in other projects

I believe they must have included it (may not be obvious), otherwise the trained model won't work lol? The reason why here we have to do the inclusion very explicitly is because we are using the chat template functionalities of transformers, which are inherently designed/biased for inference rather than training. Therefore we have to do many tweaks to make it align with training-time behaviors (and including eos token is one of them).

@TonyJiang17
Copy link
Author

@zjysteven i thought ill just ask a unrelated follow up question here. Now when I training, I keep getting this warning message shown below. Do you know if it means we could now decide our own patch size? I thought it was fixed given how the model was originally trained? Also does this affect Llava-next-video for video processing or is it only for image?

Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. Please add patch_size and vision_feature_select_strategy to the model's processing config or set directly with processor.patch_size = {{patch_size}} and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.


Also do you think training on single images of basic soccer concepts like corner dick, direct kick etc first could improve the model's ability in recognizing those activities when processing a video clip?

Thanks in advance!

@zjysteven
Copy link
Owner

For the warning, it’s because while the transformers updated it’s preprocessing to count vision tokens in model_max_length now, the config hasn’t been updated (see #43). I don’t think the patch size can be adjusted. The warning doesn’t really affect much though (like discussed in #43, it’s just a matter of whether you count vision tokens in model_max_length).

For the second question, yes I think that makes sense and might be helpful. Otherwise I imagine it would be difficult for the model to recognize among video frames if it is not familiar with those concepts in the first place (but tbh I don’t have much experience on this so this is just a very general guess).

@TonyJiang17
Copy link
Author

Hi @zjysteven sorry for another probably dumb question, but would love some clarification from you.

You know on huggingface there are two versions of llava-next models, one set hosted by huggingface under llava-hf/ and another set under lmm-labs. I understand the models in llava-hf, which are the ones this package finetune upon, are incorporated into transformers library and leverages many of the library existing functions. Aside from that what are the main difference between the models between lmms-labs and llava-hf?

How hard would it be to directly finetune model hosted by lmms-labs? as I realize there is a lag between when lmms-labs releases a model and llava-hf supports it. Can your package be altered in someway by myself to directly finetune lmms-labs' model?

Would really appreciate an answer. Thanks.

@zjysteven
Copy link
Owner

zjysteven commented Oct 29, 2024

This is a good question.

Aside from that what are the main difference between the models between lmms-labs and llava-hf?

In general I think there aren't any other major/significant differences. From what I see, llava-hf maintainers try to replicate the model as much as possible such that the inference should be exactly the same. There could be again some caveats where the training-time behavior does not exactly follow the original implementation from lmms-lab, e.g., see this comment from a huggingface staff who implements basically all the llava models in transformers #43 (comment)

How hard would it be to directly finetune model hosted by lmms-labs? Can your package be altered in someway by myself to directly finetune lmms-labs' model?

There wouldn't be any easy tweak to use this codebase for lmms-labs' models. The reason is that you would need the very specific implementation (model class definition) in the code of lmms-labs to use and finetune their models (you can see how this would be a bit cumbersome if you want to try multiple models; for example to finetune llava-1.5 and llava-onevision you have to actually clone their two separate repos... which is exactly the reason why I chose llava-hf's unified implementations). That said, the good thing is that lmms-lab have released finetuning scripts for many of their models. I would suggest adapting their script for your need, which wouldn't be difficult because after all we are all using transformers' trainer so things should be very similar.

@TonyJiang17
Copy link
Author

TonyJiang17 commented Oct 29, 2024

also @zjysteven, is it possible to finetune llava-next-video with a training set where it contains a mix of image and video samples (eg. some training sample is image only while some is video)?


I tried it, it seems to work if I just run it on one GPU but once I do distributed training it doesn't seems to work... would appreciate any help. thanks!

@zjysteven
Copy link
Owner

What's the error message?

@TonyJiang17
Copy link
Author

TonyJiang17 commented Oct 29, 2024

It rolled over in terminal now... But it's actually a time out error from distributed computing.

Actually never mind even for single gpu, i ran into the following error:

[rank0]: RuntimeError: still have inflight params [{'id': 0, 'status': 'AVAILABLE', 'numel': 602112, 'ds_numel': 602112, 'shape': (1024, 3, 14, 14), 'ds_shape': (1024, 3, 14, 14), 'requires_grad': False, 'grad_shape': None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape': torch.Size([602112])}, {'id': 1, 'status': 'AVAILABLE', 'numel': 590848, 'ds_numel': 590848, 'shape': (577, 1024), 'ds_shape': (577, 1024), 'requires_grad': False, 'grad_shape': None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape': torch.Size([590848])}, {'id': 7, 'status': 'AVAILABLE', 'numel': 1048576, 'ds_numel': 1048576, 'shape': (1024, 1024), 'ds_shape': (1024, 1024), 'requires_grad': False, 'grad_shape': None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape': torch.Size([1048576])}, {'id': 11, 'status': 'AVAILABLE', 'numel': 1048576, 'ds_numel': 1048576, 'shape': (1024, 1024), 'ds_shape': (1024, 1024), 'requires_grad': False, 'grad_shape': None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape': torch.Size([1048576])}, {'id': 15, 'status': 'AVAILABLE', 'numel': 4194304, 'ds_numel': 4194304, 'shape': (4096, 1024), 'ds_shape': (4096, 1024), 'requires_grad': False, 'grad_shape': None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape': torch.Size([4194304])}, {'id': 9, 'status': 'AVAILABLE', 'numel': 1048576, 'ds_numel': 1048576, 'shape': (1024, 1024), 'ds_shape': (1024, 1024), 'requires_grad': False, 'grad_shape': None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape': torch.Size([1048576])}, {'id': 5, 'status': 'AVAILABLE', 'numel': 1048576, 'ds_numel': 1048576, 'shape': (1024, 1024), 'ds_shape': (1024, 1024), 'requires_grad': False, 'grad_shape': None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape': torch.Size([1048576])}]


I wonder if you could replicate the issue by mixing up the example image train json and video train json together into one json...

@zjysteven
Copy link
Owner

zjysteven commented Oct 29, 2024

Yes I can reproduce. Will look into it asap

@TonyJiang17 I somehow have a very unstable connection to GPU server today. Will have to wait until tomorrow. The error is very strange because to my knowledge "inflight params" occur when part of the model trainable parameters do not contribute to the computation of loss, which shouldn't happen with llava-next-video.

@TonyJiang17
Copy link
Author

TonyJiang17 commented Oct 29, 2024

No worries @zjysteven, just when you get a chance to, it would be great. Really appreciate you looking into it already. Thanks.

Just theoretically it should work right? Each training sample regardless of whether is image or video we should be padding to the same dimension?

@zjysteven
Copy link
Owner

zjysteven commented Oct 31, 2024

@TonyJiang17 Yes. I still cannot identify the cause after investigation. However, it seems that changing the deepspeed stage from zero3 to zero2 works. Can you try and let me know if it works for your case?

@TonyJiang17
Copy link
Author

thanks @zjysteven changing to zero2 seems to work for me as well. A little weird...

Btw, I am not too familiar with the difference between zero3 and zero2, but I believe they are different memory optimization schemes? would using zero2 instead of zero3 have any effect on the model capability? thanks

@zjysteven
Copy link
Owner

Yes it is only about the memory optimization and should not affect anything else. I wouldn't worry about using zero2 for training.

@TonyJiang17
Copy link
Author

Hi @zjysteven quick question, the current way I am doing multi-stage training is to create a separate LoRA each time and merging it into the original model and then start a new LoRA.

Curious, is there way to keep using the same LoRA matrix, and we just pause and continue training? Is there such thing as 'resume_from_checkpoint' than we could use together with your trainer? Thanks as always

@zjysteven
Copy link
Owner

I'm not exactly sure (don't have experience with this specific case) but I think it should be possible. I recommend checking the documentation of transformers trainer for this.

That said, my intuition is that continuing the training with the same LoRA matrix across different stages might lead to some forgetting given the relatively small capacity of LoRA.

@TonyJiang17
Copy link
Author

I see, thanks @zjysteven, that's a good point.

I think for different training stages like for example pretraining and then instruction-tuning, I would want to use different LoRAs. However, the reason I am asking to use the same LoRA is that I want to ablation test on different training sets for each stage of training (say pretraining). So, for example for pretraining, I thought instead of training every time from scratch for each training dataset size, since the small dataset is just a subset of the larger dataset, I can iteratively train building on top of the previously trained LoRA. So for example, I train using a dataset of 10k samples first, and for the 20k dataset ablation test, I can just continue training the LoRA that has already gone through the initial 10k samples ...Assuming I set the LoRA size to be able to handle the larger dataset size.

I guess another benefit is that some of the training dataset samples are synthetically generated, so I could train a subset of the dataset while the rest of the dataset is still being generated.

Do you see any issues or risk with this approach? Thanks!

@zjysteven
Copy link
Owner

Yeah I think it makes sense!

@TonyJiang17
Copy link
Author

if I set a LoRA matrix size that's supposedly for 20k samples and use it also for 10k, will it hurt the 10k result version's result or maybe it won't have too much of an impact?

@zjysteven
Copy link
Owner

I would guess "won't have too much of an impact"

@TonyJiang17
Copy link
Author

hey @zjysteven sorry for a quick question. Im planning to test the 32k version of llava-next-video since it uses a different llm backbone. would our update on adding the eos token be compatiable with this? I believe this llm backbone is mistral. Thanks

@zjysteven
Copy link
Owner

Noted. I'm on a super tight schedule though so wouldn't be able to update it until the week of thanksgiving.

@TonyJiang17
Copy link
Author

Hi @zjysteven understood. I really want to test it quickly. Maybe I can make some changes in my local code first? Do you mind sharing what I need to do, maybe it would be similar to the changes you made before? So I assume the previous code change won't work with that version of llava-next-video? I see it's one of the supported models too.

@zjysteven
Copy link
Owner

Oh I forgot it's already supported. Then I think it should work as well. I suggest using the following code to see if eos is included.

No worries. I have a local file for checking the output of collator, in case it helps.

import json
import os
from tqdm import tqdm

import torch
torch.set_printoptions(profile="full", linewidth=240)
from torch.utils.data import DataLoader
from transformers import AutoProcessor, AutoTokenizer

from datasets import LazySupervisedDataset
from collators import COLLATORS
from loaders import LOADERS
from supported_models import MODEL_HF_PATH

model_id = "llava-next-video-7b"
model_family_id = "llava-next-video"

dataset = LazySupervisedDataset(
    data_path='./example_data/video.json',
    image_folder='./example_data/images',
    video_folder='./example_data/videos',
    model_family_id=model_family_id,
)

_, tokenizer, processor, config = LOADERS[model_family_id](
    model_hf_path=MODEL_HF_PATH[model_id],
    model_local_path=MODEL_HF_PATH[model_id],
    compute_dtype=torch.float16,
).load(load_model=False)
tokenizer.model_max_length = 256
collator = COLLATORS[model_family_id](
    config=config,
    processor=processor,
    tokenizer=tokenizer
)

dataloader = DataLoader(dataset, batch_size=2, collate_fn=collator)

batch = next(iter(dataloader))
print(batch["input_ids"])
print()
print(batch["labels"])
print()
print(tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False))
print(tokenizer.decode(
    batch["labels"][1][torch.where(batch["labels"][1] != -100)[0]]
))

danielwusg pushed a commit to sunfanyunn/lmms-finetune that referenced this issue Nov 18, 2024
danielwusg pushed a commit to sunfanyunn/lmms-finetune that referenced this issue Nov 18, 2024
danielwusg pushed a commit to sunfanyunn/lmms-finetune that referenced this issue Nov 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants