Skip to content

Commit

Permalink
Merge pull request #55 from ProjectUnifree/bug/huggingface_llm_input
Browse files Browse the repository at this point in the history
Fixes to HuggingfaceLLM Input Bugs
  • Loading branch information
bshikin authored Sep 23, 2023
2 parents 5bfddb0 + 76bb6a4 commit 13c421b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
6 changes: 5 additions & 1 deletion configs/godot-huggingface.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ prompts:
system: |
You are professional Unity engineer who is migrating a large project from Unity platform to Godot 4.1. Migrate code to GDScript, which you are an expert in. Follow the following rules:
1. Output code only, put explanations as comments.
2. Output code must be surrounded with ``` (three ticks)
2. Wrap your code with ``` (three ticks)
3. Do not skip any logic.
4. Preserve all comments without changing.
5. If migration is impossible leave "TODO [Migrate]" comment.
Expand Down Expand Up @@ -31,6 +31,10 @@ llm:
context_length: 4096
model_type: llama
gpu_layers: 50
prompt_template: |
[INST]:
${PROMPT}
[/INST]
source:
ignore_locations: # These locations will not be included in migration
Expand Down
25 changes: 15 additions & 10 deletions unifree/llms/huggingface_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, List, Dict

from unifree import LLM, QueryHistoryItem, log
from unifree.utils import get_or_create_global_instance


# Copyright (c) Unifree
Expand Down Expand Up @@ -31,17 +32,17 @@ def __init__(self, config: Dict) -> None:
def query(self, user: str, system: Optional[str] = None, history: Optional[List[QueryHistoryItem]] = None) -> str:
prompt = ''
if system:
prompt += "> user: Remember these rules: \n" + system + "\n\n"
prompt += "> assistant: Certainly, I will remember and follow these rules. \n\n"
prompt += self._to_user_prompt(f"Remember these rules:\n{system}\n")
prompt += "\nCertainly, I will remember and follow these rules.\n"

if history:
history_str = [f"> {h.role}: {h.content}" for h in history]
history_str = "\n\n".join(history_str)
for item in history:
if item.role == "user":
prompt += self._to_user_prompt(f"\n{item.content}\n")
else:
prompt += f"\n{item.content}\n"

prompt += history_str + "\n\n"

prompt += "> user: " + user + "\n\n"
prompt += "> assistant: "
prompt += self._to_user_prompt(f"\n{user}\n")

log.debug(f"\n==== LLM REQUEST ====\n{prompt}\n")

Expand All @@ -57,15 +58,19 @@ def initialize(self) -> None:
llm_config = self.config["config"]
checkpoint = llm_config["checkpoint"]

self._model = AutoModelForCausalLM.from_pretrained(
self._model = get_or_create_global_instance(checkpoint, lambda: AutoModelForCausalLM.from_pretrained(
checkpoint,
model_type=llm_config["model_type"],
gpu_layers=llm_config["gpu_layers"],
context_length=llm_config["context_length"],
)
))

def fits_in_one_prompt(self, token_count: int) -> bool:
return token_count < self.config["config"]["context_length"]

def count_tokens(self, source_text: str) -> int:
return len(self._model.tokenize(source_text))

def _to_user_prompt(self, user: str) -> str:
prompt_template = self.config["config"]["prompt_template"]
return prompt_template.replace("${PROMPT}", user)
2 changes: 2 additions & 0 deletions unifree/project_migration_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This code is licensed under MIT license (see LICENSE.txt for details)

import os.path
import traceback
from abc import ABC
from typing import List, Union, Dict, Optional, Iterable

Expand Down Expand Up @@ -120,6 +121,7 @@ def _map_file_path_to_migration(self, file_path: str) -> Union[MigrationStrategy
return self._create_migration_strategy(strategy_name, spec)
return None
except Exception as e:
traceback.print_exc()
return f"'{file_path}' failed to create strategy: {e}"

def _create_migration_strategy(self, strategy_name: str, spec: FileMigrationSpec) -> MigrationStrategy:
Expand Down
21 changes: 20 additions & 1 deletion unifree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import importlib
import os
import re
import threading
from collections import defaultdict
from typing import Type, Dict, Any
from typing import Type, Dict, Any, TypeVar, Callable

import yaml

Expand Down Expand Up @@ -76,3 +77,21 @@ def to_default_dict(d):

def _return_none():
return None


InstanceType = TypeVar('InstanceType')

_global_instances: Dict[str, Any] = {}
_global_instances_lock: threading.Lock = threading.Lock()


def get_or_create_global_instance(name: str, new_instance_creator: Callable[[], InstanceType]) -> InstanceType:
global _global_instances
global _global_instances_lock

if name not in _global_instances:
with _global_instances_lock:
if name not in _global_instances:
_global_instances[name] = new_instance_creator()

return _global_instances[name]

0 comments on commit 13c421b

Please sign in to comment.