diff --git a/.gitignore b/.gitignore index 83fd0216..1c65f5f5 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,6 @@ cython_debug/ # Poetry poetry.lock + +# Ignore requirements since we only use for local builds +requirements.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..3199668d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: + - repo: https://github.com/Yelp/detect-secrets + rev: v1.2.0 + hooks: + - id: detect-secrets + exclude: tests/integration/.+ + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-merge-conflict + - id: trailing-whitespace + - id: end-of-file-fixer + - id: requirements-txt-fixer + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.7 + hooks: + - id: ruff + args: [--exit-non-zero-on-fix] + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.7 + hooks: + - id: ruff-format diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3d3c3d37..6d55960a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,7 +13,7 @@ Ruff will pick up the configuration defined in the `pyproject.toml` file automat `flamingo` is intended to be installed as a pip requirement in the runtime environment of a Ray job. However, it is often desirable to test local branches on Ray before publishing a new version of the library. -This is possible submitting a Ray job with a runtime environment that points to your +This is possible submitting a Ray job with a runtime environment that points to your development branch of the `flamingo` repo. To do so, follow the steps: @@ -24,7 +24,7 @@ To do so, follow the steps: poetry export --without-hashes --with finetuning,evaluation -o requirements.txt ``` - The following command will create a `requirements.txt` file in the repository + The following command will create a `requirements.txt` file in the repository that contains the dependencies for the `finetuning` and `evaluation` job groups: 2. When submitting a job to cluster, specify in the Ray runtime environment the following: @@ -42,4 +42,3 @@ To do so, follow the steps: but does not install its entrypoint in the environment path. An example of this workflow can be found in the `examples/dev_workflow.ipynb` notebook. - diff --git a/README.md b/README.md index d2e6c189..1e8b4b5c 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ This will install an editable version of the package along with all of its depen Poetry should recognize your active virtual environment during installation If you have an active Conda environment, Poetry should recognize it during installation -and install the package dependencies there. +and install the package dependencies there. This hasn't been explicitly tested with other virtual python environments, but will likely work. Alternatively, you can use poetry's own environment by running @@ -44,7 +44,7 @@ poetry install where `python3.10` is your python interpreter. The `pyproject.toml` file defines dependency groups for the logical job types in the package. -Individual dependency groups can be installed by running +Individual dependency groups can be installed by running `poetry install --with ,` or `poetry install --only `. See the [contributing](CONTRIBUTING.md) guide for more information on development workflows. @@ -52,7 +52,7 @@ See the [contributing](CONTRIBUTING.md) guide for more information on developmen ### Usage `flamingo` exposes a simple CLI with a few commands, one for each Ray job type. -Jobs are expected to take as input a YAML configuration file +Jobs are expected to take as input a YAML configuration file that contains all necessary parameters/settings for the work. See the `examples/configs` folder for examples of the configuration structure. diff --git a/examples/configs/lm_harness.yaml b/examples/configs/lm_harness_hf_config.yaml similarity index 100% rename from examples/configs/lm_harness.yaml rename to examples/configs/lm_harness_hf_config.yaml diff --git a/examples/configs/lm_harness_inference_server_config.yaml b/examples/configs/lm_harness_inference_server_config.yaml new file mode 100644 index 00000000..b14ce9ae --- /dev/null +++ b/examples/configs/lm_harness_inference_server_config.yaml @@ -0,0 +1,19 @@ +# Model to evaluate, specified as a W&B artifact +model: + base_url: "1.2.3.4:8000/v1/completions" + tokenizer: "mistralai/Mistral-7B-v0.1" + +# Settings specific to lm_harness.evaluate +evaluator: + tasks: ["gsm8k"] + num_fewshot: 5 + limit: 10 + +tracking: + name: "mistral-finetune" + project: "mistral-finetune" + entity: "mozilla-ai" + +ray: + num_cpus: 1 + timeout: 3600 diff --git a/pyproject.toml b/pyproject.toml index 1daea82d..77e4666a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ ruff = "0.1.7" pytest = "7.4.3" pytest-cov = "4.1.0" jupyter = "1.0.0" +pre-commit = "3.6.0" [tool.poetry.group.finetuning.dependencies] datasets = "2.16.1" @@ -38,7 +39,7 @@ trl = "0.7.10" bitsandbytes = "0.42.0" [tool.poetry.group.evaluation.dependencies] -lm-eval = "0.4.0" +lm-eval = "0.4.1" einops = "0.7.0" [tool.poetry.scripts] diff --git a/src/flamingo/integrations/vllm/__init__.py b/src/flamingo/integrations/vllm/__init__.py new file mode 100644 index 00000000..5e648365 --- /dev/null +++ b/src/flamingo/integrations/vllm/__init__.py @@ -0,0 +1 @@ +from flamingo.integrations.vllm.model_config import * diff --git a/src/flamingo/integrations/vllm/model_config.py b/src/flamingo/integrations/vllm/model_config.py new file mode 100644 index 00000000..3288be04 --- /dev/null +++ b/src/flamingo/integrations/vllm/model_config.py @@ -0,0 +1,8 @@ +from flamingo.types import BaseFlamingoConfig + + +class InferenceServerConfig(BaseFlamingoConfig): + """Inference Server URL endpoint path""" + + base_url: str + tokenizer: str diff --git a/src/flamingo/jobs/lm_harness/config.py b/src/flamingo/jobs/lm_harness/config.py index f5d33b86..5fbfe897 100644 --- a/src/flamingo/jobs/lm_harness/config.py +++ b/src/flamingo/jobs/lm_harness/config.py @@ -1,8 +1,9 @@ import datetime -from pydantic import Field, conlist, validator +from pydantic import Field, conlist from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig +from flamingo.integrations.vllm import InferenceServerConfig from flamingo.integrations.wandb import WandbRunConfig from flamingo.types import BaseFlamingoConfig @@ -27,15 +28,8 @@ class LMHarnessEvaluatorConfig(BaseFlamingoConfig): class LMHarnessJobConfig(BaseFlamingoConfig): """Configuration to run an lm-evaluation-harness evaluation job.""" - model: AutoModelConfig + model: AutoModelConfig | InferenceServerConfig evaluator: LMHarnessEvaluatorConfig quantization: QuantizationConfig | None = None tracking: WandbRunConfig | None = None ray: LMHarnessRayConfig = Field(default_factory=LMHarnessRayConfig) - - @validator("model", pre=True, always=True) - def validate_model_arg(cls, x): - """Allow for passing just a path string as the model argument.""" - if isinstance(x, str): - return AutoModelConfig(load_from=x) - return x diff --git a/src/flamingo/jobs/lm_harness/entrypoint.py b/src/flamingo/jobs/lm_harness/entrypoint.py index ec2efadf..184c79ef 100644 --- a/src/flamingo/jobs/lm_harness/entrypoint.py +++ b/src/flamingo/jobs/lm_harness/entrypoint.py @@ -4,9 +4,11 @@ import ray import wandb from lm_eval.models.huggingface import HFLM +from lm_eval.models.openai_completions import OpenaiCompletionsLM from peft import PeftConfig -from flamingo.integrations.huggingface import resolve_loadable_path +from flamingo.integrations.huggingface import AutoModelConfig, resolve_loadable_path +from flamingo.integrations.vllm import InferenceServerConfig from flamingo.integrations.wandb import ( ArtifactType, WandbResumeMode, @@ -30,36 +32,46 @@ def log_evaluation_artifact(run_name: str, results: dict[str, dict[str, Any]]) - return wandb.log_artifact(artifact) -def load_harness_model(config: LMHarnessJobConfig) -> HFLM: - # Helper method to return lm-harness model wrapper - def loader(pretrained: str, tokenizer: str, peft: str | None): +def load_harness_model(config: LMHarnessJobConfig) -> HFLM | OpenaiCompletionsLM: + if isinstance(config.model, AutoModelConfig): + # We don't know if the checkpoint is adapter weights or merged model weights + # Try to load as an adapter and fall back to the checkpoint containing the full model + path, revision = resolve_loadable_path(config.model.load_from) + try: + peft_config = PeftConfig.from_pretrained(path, revision=revision) + peft_path = path + pretrained_model_path = peft_config.base_model_name_or_path + except ValueError as e: + print( + f"Unable to load model as adapter: {e}. " + "This is expected if the checkpoint does not contain adapter weights." + ) + peft_path = None + pretrained_model_path = path + + # Return the lm-harness version of a HuggingFace LLM quantization_kwargs = config.quantization.dict() if config.quantization else {} return HFLM( - pretrained=pretrained, - tokenizer=tokenizer, - peft=peft, + pretrained=pretrained_model_path, + tokenizer=pretrained_model_path, + peft=peft_path, + revision=revision, device="cuda" if config.ray.num_gpus > 0 else None, trust_remote_code=config.model.trust_remote_code, dtype=config.model.torch_dtype if config.model.torch_dtype else "auto", **quantization_kwargs, ) - # We don't know if the checkpoint is adapter weights or merged model weights - # Try to load as an adapter and fall back to the checkpoint containing the full model - load_path, revision = resolve_loadable_path(config.model.load_from) - try: - peft_config = PeftConfig.from_pretrained(load_path, revision=revision) - return loader( - pretrained=peft_config.base_model_name_or_path, - tokenizer=peft_config.base_model_name_or_path, - peft=load_path, - ) - except ValueError as e: - print( - f"Unable to load model as adapter: {e}. " - "This is expected if the checkpoint does not contain adapter weights." + elif isinstance(config.model, InferenceServerConfig): + # Return the lm-harness version of a model endpoint + return OpenaiCompletionsLM( + model="vllm-model", + tokenizer=config.model.tokenizer, + base_url=config.model.base_url, ) - return loader(pretrained=load_path, tokenizer=load_path, peft=None) + + else: + raise ValueError(f"Unexpected model config type: {type(config.model)}") def load_and_evaluate(config: LMHarnessJobConfig) -> dict[str, Any]: diff --git a/tests/conftest.py b/tests/conftest.py index 89edd4f1..0cac579d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,5 @@ """ -Tests for the Flamingo. - -This file is used to provide fixtures for the test session that are accessible to all submodules. +This file is used to provide fixtures for the test session accessible to all Flamingo submodules. """ from pathlib import Path diff --git a/tests/resources/README.md b/tests/resources/README.md index cd6305c8..2b970012 100644 --- a/tests/resources/README.md +++ b/tests/resources/README.md @@ -4,5 +4,5 @@ Collection of resources to load/parse during tests. These resources should be kept as small as possible to minimize the git repo size. -When applicable, helper scripts for re-generating the resources +When applicable, helper scripts for re-generating the resources can be added to the appropriate subfolders. diff --git a/tests/resources/datasets/xyz.hf/dataset_info.json b/tests/resources/datasets/xyz.hf/dataset_info.json index 872c12a3..63102527 100644 --- a/tests/resources/datasets/xyz.hf/dataset_info.json +++ b/tests/resources/datasets/xyz.hf/dataset_info.json @@ -17,4 +17,4 @@ }, "homepage": "", "license": "" -} \ No newline at end of file +} diff --git a/tests/resources/datasets/xyz.hf/state.json b/tests/resources/datasets/xyz.hf/state.json index 93fe3d1d..5bfaa1c6 100644 --- a/tests/resources/datasets/xyz.hf/state.json +++ b/tests/resources/datasets/xyz.hf/state.json @@ -10,4 +10,4 @@ "_format_type": null, "_output_all_columns": false, "_split": null -} \ No newline at end of file +} diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 934e5f1c..b4ac2b07 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,6 +8,7 @@ QuantizationConfig, TextDatasetConfig, ) +from flamingo.integrations.vllm import InferenceServerConfig from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig @@ -22,6 +23,13 @@ def model_config_with_artifact(): return AutoModelConfig(load_from=artifact, trust_remote_code=True) +@pytest.fixture +def inference_server_config(): + return InferenceServerConfig( + base_url="1.2.3.4:8000/v1/completions", tokenizer="mistralai/Mistral-7B-v0.1" + ) + + @pytest.fixture def tokenizer_config_with_repo_id(): return AutoTokenizerConfig(load_from="mistral-ai/mistral-7", trust_remote_code=True) diff --git a/tests/unit/jobs/test_lm_harness_config.py b/tests/unit/jobs/test_lm_harness_config.py index a2a07bdc..b1a65483 100644 --- a/tests/unit/jobs/test_lm_harness_config.py +++ b/tests/unit/jobs/test_lm_harness_config.py @@ -1,7 +1,5 @@ import pytest -from pydantic import ValidationError -from flamingo.integrations.huggingface import HuggingFaceRepoConfig from flamingo.jobs.lm_harness import ( LMHarnessEvaluatorConfig, LMHarnessJobConfig, @@ -28,47 +26,56 @@ def lm_harness_ray_config(): @pytest.fixture def lm_harness_job_config( + request, model_config_with_artifact, + inference_server_config, quantization_config, wandb_run_config, lm_harness_evaluator_config, lm_harness_ray_config, ): - return LMHarnessJobConfig( - model=model_config_with_artifact, - evaluator=lm_harness_evaluator_config, - ray=lm_harness_ray_config, - tracking=wandb_run_config, - quantization=quantization_config, - ) - - + if request.param == "model_config_with_artifact": + return LMHarnessJobConfig( + model=model_config_with_artifact, + evaluator=lm_harness_evaluator_config, + ray=lm_harness_ray_config, + tracking=wandb_run_config, + quantization=quantization_config, + ) + elif request.param == "inference_server_config": + return LMHarnessJobConfig( + model=inference_server_config, + evaluator=lm_harness_evaluator_config, + ray=lm_harness_ray_config, + tracking=wandb_run_config, + quantization=quantization_config, + ) + + +@pytest.mark.parametrize( + "lm_harness_job_config", + ["model_config_with_artifact", "inference_server_config"], + indirect=True, +) def test_serde_round_trip(lm_harness_job_config): assert LMHarnessJobConfig.parse_raw(lm_harness_job_config.json()) == lm_harness_job_config +@pytest.mark.parametrize( + "lm_harness_job_config", + ["model_config_with_artifact", "inference_server_config"], + indirect=True, +) def test_parse_yaml_file(lm_harness_job_config): with lm_harness_job_config.to_tempfile() as config_path: assert lm_harness_job_config == LMHarnessJobConfig.from_yaml_file(config_path) -def test_load_example_config(examples_dir): +@pytest.mark.parametrize( + "file_suffix", ["lm_harness_hf_config.yaml", "lm_harness_inference_server_config.yaml"] +) +def test_load_example_config(examples_dir, file_suffix): """Load the example configs to make sure they stay up to date.""" - config_file = examples_dir / "configs" / "lm_harness.yaml" + config_file = examples_dir / "configs" / file_suffix config = LMHarnessJobConfig.from_yaml_file(config_file) assert LMHarnessJobConfig.parse_raw(config.json()) == config - - -def test_model_validation(lm_harness_evaluator_config): - model_repo = HuggingFaceRepoConfig(repo_id="model_repo") - allowed_config = LMHarnessJobConfig( - model=model_repo.repo_id, - evaluator=lm_harness_evaluator_config, - ) - assert allowed_config.model.load_from == model_repo - - with pytest.raises(ValidationError): - LMHarnessJobConfig(model="invalid...hf..repo", evaluator=lm_harness_evaluator_config) - - with pytest.raises(ValidationError): - LMHarnessJobConfig(model=12345, evaluator=lm_harness_evaluator_config)