Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
fix driver imports:
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 18, 2024
1 parent 6b3248e commit 67e18a3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/flamingo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def run():
@run.command("simple")
@click.option("--config", type=str)
def run_simple(config: str) -> None:
from flamingo.jobs.drivers import run_simple
from flamingo.jobs.drivers.simple import run_simple

config = SimpleJobConfig.from_yaml_file(config)
run_simple(config)
Expand All @@ -25,7 +25,7 @@ def run_simple(config: str) -> None:
@run.command("finetuning")
@click.option("--config", type=str)
def run_finetuning(config: str) -> None:
from flamingo.jobs.drivers import run_finetuning
from flamingo.jobs.drivers.finetuning import run_finetuning

config = FinetuningJobConfig.from_yaml_file(config)
run_finetuning(config)
Expand All @@ -35,15 +35,15 @@ def run_finetuning(config: str) -> None:
@click.option("--config", type=str)
@click.option("--dataset", type=str)
def run_ludwig(config: str, dataset: str) -> None:
from flamingo.jobs.drivers import run_ludwig
from flamingo.jobs.drivers.ludwig import run_ludwig

run_ludwig(config, dataset)


@run.command("lm-harness")
@click.option("--config", type=str)
def run_lm_harness(config: str) -> None:
from flamingo.jobs.drivers import run_lm_harness
from flamingo.jobs.drivers.lm_harness import run_lm_harness

config = LMHarnessJobConfig.from_yaml_file(config)
run_lm_harness(config)
Expand Down
6 changes: 0 additions & 6 deletions src/flamingo/jobs/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
from .finetuning import run_finetuning
from .lm_harness import run_lm_harness
from .ludwig import run_ludwig
from .simple import run_simple

__all__ = ["run_finetuning", "run_lm_harness", "run_ludwig", "run_simple"]
2 changes: 1 addition & 1 deletion src/flamingo/jobs/drivers/lm_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def load_harness_model(config: LMHarnessJobConfig, loader: WandbArtifactLoader)
)


@ray.remote(num_cpus=)
@ray.remote
def evaluation_task(config: LMHarnessJobConfig, model_to_load: str) -> None:
print("Initializing lm-harness tasks...")
lm_eval.tasks.initialize_tasks()
Expand Down

0 comments on commit 67e18a3

Please sign in to comment.