Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for robust benchmarking of fms-hf-tuning #142

Open
VassilisVassiliadis opened this issue Apr 25, 2024 · 9 comments
Open

feat: support for robust benchmarking of fms-hf-tuning #142

VassilisVassiliadis opened this issue Apr 25, 2024 · 9 comments

Comments

@VassilisVassiliadis
Copy link
Contributor

Is your feature request related to a problem? Please describe.

When running benchmarks using fms-hf-tuning we need to:

  1. record exceptions such as GPU Out Of Memory, transient RuntimeError and NCCL errors, etc
  2. ensure robust capture of system metrics
  3. annotate AIM run information with additional information regarding the benchmarking in order to make the AIM data searchable via benchmarking labels e.g. benchmark experiment name, benchmark identifier

However fms-hf-tuning does not support these features hence requiring maintenance of external forks of fms-hf-tuning that provide them. Specifically

  1. fms-hf-tuning does not surface the various exceptions in a way that is easily identifiable by an automated process I.e. currently requires grepping stderr.
  2. system metrics are sent to an AIM server. However we have observed stability issues with storing and retrieving the data from AIM leading to lost data. In general for benchmarking we need a highly robust way to ensure the metrics are captured
  3. There is no way to supply external metadata for insertion to AIM for a fms-hf-tuning run

Describe the solution you'd like

Here is a simplified view of the current state of the sft_trainer.py script:

def train(
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args,
    ):
    callbacks = [SomeCallback()]

    if aim_available:
        callbacks.append(get_aimstack_callback())

    callbacks.extend(other_callbacks)

    sft_trainer.train(callbacks=callbacks)

def main():
   (
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args,
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)

   train(
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args,
    )

We propose the following

class BenchmarkingFriendlyAIMCallBack(AimCallBack):
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        super().on_train_begin(args, state, control, model, **kwargs)
        # annotate self.run with custom metadata so that we can search for this run on AIM
    
    def on_train_end(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            # Dump aim metrics to a file such that we can collect this data without actually spinning up an AIM server
            ...
        super().on_train_end(args=args, state=state, control=control, **kwargs)

def parse_args():
    return (
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args,
        benhmarking_args,
    )

def train(
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args,
        callbacks,
    ):

    callbacks.extend(other_callbacks)
    sft_trainer.train(callbacks=callbacks)
    

def main():    
    try:   
        (
            model_args,
            data_args,
            training_args,
            tune_config,
            trainer_controller_args,
            benhmarking_args,
        ) = parse_args()

        aim_metadata = read_metadata(benchmarking_args.aim_metadata_path)
     
        callbacks = [ ]

        if aim_is_available():
            callbacks.append(
                # dumps AIM data to disk plus annotates AIM run with custom metadata
                BenchmarkingFriendlyAIMCallback(
                custom_metadata=aim_metadata,
                output_file=benchmarking_args.aim_output_file,
                )
            )
        train(
            model_args,
            data_args,
            training_args,
            tune_config,
            trainer_controller_args, 
            callbacks=callbacks, # new
        )
    # Catch GPU OOM exceptions, NCCL exceptions etc, and store them in a file for easy processing
    except GPUOutOfMemoryError as exc :
        report_gpu_oom(my_custom_args.aim_output_file, exc)
        raise
    except Exception:
        report_unknown_error(my_custom_args.aim_output_file, exc)
        raise       

Pros:

  1. other people/services wishing to robustly detect GPU OOM errors and extract system metrics can use the above features

Cons:

  1. need to update the code in sft_trainer.py

As a plus, the proposed changes in train() enable us to experiment with new implementations of BenchmarkingFriendlyAIMCallback() without requiring updates to fms-hf-tuning (e.g. by using a wrapper script).

Describe alternatives you've considered

Make minimal changes to train() to assist developing third party wrapper scripts

We could slightly refactor the code such that the logic in the train() method for activating the current AIMCallback (via a call to get_aimstack_callback()) is in the main() method instead of the train() method.

This would enable us to use a wrapper script which implements our earlier proposed design of sft_trainer.py. The wrapper script would have a main() function similar to the above proposal and which would invoke tuning.train().

sft_trainer.py in fms-hf-tuning would look like this:

def parse_args():
    return (
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args,
    )

def train(
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args,
        callbacks,
    ):

    callbacks.extend(other_callbacks)
    sft_trainer.train(callbacks=callbacks)
    

def main():    
    (
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args,
        benhmarking_args,
    ) = parse_args()

    # Move the insertion of the AIM Callback from train() to main()
    callbacks = [ ]

    if aim_available:
        callbacks.append(get_aimstack_callback())
    
    train(
        model_args,
        data_args,
        training_args,
        tune_config,
        trainer_controller_args, 
        callbacks=callbacks, # new
    )

Pros:

  1. changes to fms-hf-tuning are a handful of lines of code

Cons:

  1. Other services/people wishing to robustly collect metrics and/or exceptions will have to implement these features themselves (e.g. using a wrapper script)
@michael-johnston
Copy link

@Ssukriti Addressing the above will make it significantly easier for the benchmarking effort to use the latest fms_hf_tuning changes, as well as for others to execute benchmarking runs, so please let us know how we can help - we can contribute a PR with our proposed changes for example.

@ashokponkumar
Copy link
Collaborator

@dushyantbehl can you please work with @VassilisVassiliadis to understand the issues with AIM and if the changes requested are already addressed with your pending PRs?

@VassilisVassiliadis
Copy link
Contributor Author

These are the changes we're proposing:

  1. catch GPU OOM exceptions and other exceptions and store them in a file
  2. tweak the AIM code a bit:
    1. move the auto-insertion of the AIMCallback from the train() method to the main() method
    2. update the AIMCallback class so that it can optionally dump the metrics that it collects to a file
    3. update the AIMCallback class so that it also tags the AIM Experiment with metadata

If these seem reasonable to you, I can either contribute bullet 2 directly to the PR that @dushyantbehl is working on or open a new PR after @dushyantbehl 's PR is merged into main. We can then discuss how to handle bullet point 1.

@dushyantbehl
Copy link
Contributor

Thanks and Great @VassilisVassiliadis. This is on my plate already.
The patch has 1 and 3 already as working examples and for 2 I already have a code on my system for another thread which I can push to the repo too. I will take these bullets no worries!

Let me know if we need to connect on this for any other details.

@dushyantbehl
Copy link
Contributor

Please follow PR #89 for bullet 1 and 3. We have separate PR for 2 .

@dushyantbehl
Copy link
Contributor

@VassilisVassiliadis #89 is merged.

@VassilisVassiliadis
Copy link
Contributor Author

Thank you @dushyantbehl I'll take a look.

Since you're already working on 2.2 we still need to decide how we're going to handle 1 (i.e. reporting errors). I noticed that @kellyaa has already started working on something related to this in #149

@Ssukriti would it make sense to design a unified way of reporting errors that all scripts (sft_trainer.py, accelerate_launch.py etc) can use ?

@kellyaa
Copy link
Member

kellyaa commented May 10, 2024

The importance of #149 is to ensure that whatever error messages are thrown:

  • A concise summary of that error it is written to /dev/termination-log
  • A well-defined exit code is thrown. For "user fault" errors, it must be less than 128. For unknown or system-level causes, they are >128 <256.

This is so that when tuning jobs are being executed via Kubernetes CRs, the pod can accurately reflect the end state of the job without the user having to scrape the logs.

I'm good with whatever unified mechanisms you want to create, as long as we are able to accomplish the above.

@VassilisVassiliadis
Copy link
Contributor Author

I think this is now done. I just found a small bug and opened a PR for it here: #199

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants