Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
simplifying vllm config
Browse files Browse the repository at this point in the history
  • Loading branch information
veekaybee committed Jan 30, 2024
1 parent 0c048fe commit f5668a1
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 52 deletions.
2 changes: 0 additions & 2 deletions src/flamingo/integrations/vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
# ruff: noqa: I001
from flamingo.integrations.vllm.model_config import *
from flamingo.integrations.vllm.path_config import *

15 changes: 2 additions & 13 deletions src/flamingo/integrations/vllm/model_config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
from pydantic import validator

from flamingo.integrations.wandb import WandbArtifactConfig
from flamingo.types import BaseFlamingoConfig, TorchDtypeString
from flamingo.integrations.vllm import LocalServerConfig
from flamingo.types import BaseFlamingoConfig


class InferenceServerConfig(BaseFlamingoConfig):
"""Inference Server URL endpoint path"""

load_from: LocalServerConfig | WandbArtifactConfig

trust_remote_code: bool = False
torch_dtype: TorchDtypeString | None = None

_validate_load_from_string = validator("load_from", pre=True, allow_reuse=True)(
convert_string_to_repo_config
)
base_url: str
25 changes: 0 additions & 25 deletions src/flamingo/integrations/vllm/path_config.py

This file was deleted.

12 changes: 1 addition & 11 deletions src/flamingo/jobs/lm_harness/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from pydantic import Field, conlist, validator

Check failure on line 3 in src/flamingo/jobs/lm_harness/config.py

View workflow job for this annotation

GitHub Actions / pytest_ruff

Ruff (F401)

src/flamingo/jobs/lm_harness/config.py:3:38: F401 `pydantic.validator` imported but unused

from flamingo.integrations.vllm import LocalServerConfig
from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig
from flamingo.integrations.vllm import InferenceServerConfig
from flamingo.integrations.wandb import WandbRunConfig
from flamingo.types import BaseFlamingoConfig

Expand All @@ -25,7 +25,6 @@ class LMHarnessEvaluatorConfig(BaseFlamingoConfig):
limit: int | float | None = None



class LMHarnessJobConfig(BaseFlamingoConfig):
"""Configuration to run an lm-evaluation-harness evaluation job."""

Expand All @@ -34,12 +33,3 @@ class LMHarnessJobConfig(BaseFlamingoConfig):
quantization: QuantizationConfig | None = None
tracking: WandbRunConfig | None = None
ray: LMHarnessRayConfig = Field(default_factory=LMHarnessRayConfig)

@validator("model", pre=True, always=True)
def validate_model_arg(cls, x):
"""Allow for passing a path string as the model argument."""
if "v1/completions" in x:
return InferenceServerConfig(load_from=x)
else:
return AutoModelConfig(load_from=x)
return x
3 changes: 2 additions & 1 deletion src/flamingo/jobs/lm_harness/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def log_evaluation_artifact(run_name: str, results: dict[str, dict[str, Any]]) -

def load_harness_model(config: LMHarnessJobConfig) -> HFLM | OpenaiCompletionsLM:
# Helper method to return lm-harness model wrapper
def loader(model: str | None, tokenizer: str, peft: str | None):
def _loader(model: str | None , tokenizer: str, peft: str | None):

"""Load model directly from HF if HF path, otherwise from an inference server URL"""

if isinstance(config.model) == AutoModelConfig:
Expand Down

0 comments on commit f5668a1

Please sign in to comment.