Skip to content

Commit

Permalink
better backoff default + chat template in hf-local
Browse files Browse the repository at this point in the history
  • Loading branch information
jogonba2 committed May 2, 2024
1 parent b6025d9 commit caa4c30
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
8 changes: 8 additions & 0 deletions text_machina/src/models/hf_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def generate_completion(self, prompt: str, generation_config: Dict) -> str:
This method is not used, since generations are done
with batches using `generate_completions`.
"""

if self.model_config.api_type == CompletionType.CHAT:
prompt = self.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)

tokenized = self.tokenizer(
prompt, truncation=True, padding=True, return_tensors="pt"
)
Expand Down
4 changes: 1 addition & 3 deletions text_machina/src/models/hf_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def __init__(self, model_config: ModelConfig):
retry_adapter = HTTPAdapter(
max_retries=Retry(
total=getattr(self.model_config, "max_retries", 5),
backoff_factor=getattr(
self.model_config, "backoff_factor", 2
),
backoff_factor=getattr(self.model_config, "backoff_factor", 2),
status_forcelist=[
code for code in requests.status_codes._codes if code != 200
],
Expand Down
2 changes: 1 addition & 1 deletion text_machina/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_MAJOR = "0"
_MINOR = "2"
_REVISION = "10"
_REVISION = "11"

VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _REVISION)
Expand Down

0 comments on commit caa4c30

Please sign in to comment.