From 76bb6a48fe548b5f94941424b9e4e155dbf97c31 Mon Sep 17 00:00:00 2001 From: Basil Shikin Date: Sat, 23 Sep 2023 14:44:30 -0700 Subject: [PATCH] Fixes to HuggingfaceLLM Input Bugs 1. Using a [INST] ... [/INST] prompt template for HuggingfaceLLM 2. Making sure there is only one instance of the LLM --- configs/godot-huggingface.yaml | 6 +++++- unifree/llms/huggingface_llm.py | 25 +++++++++++++++---------- unifree/project_migration_strategies.py | 2 ++ unifree/utils.py | 21 ++++++++++++++++++++- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/configs/godot-huggingface.yaml b/configs/godot-huggingface.yaml index c48f48b..3eeb71f 100644 --- a/configs/godot-huggingface.yaml +++ b/configs/godot-huggingface.yaml @@ -2,7 +2,7 @@ prompts: system: | You are professional Unity engineer who is migrating a large project from Unity platform to Godot 4.1. Migrate code to GDScript, which you are an expert in. Follow the following rules: 1. Output code only, put explanations as comments. - 2. Output code must be surrounded with ``` (three ticks) + 2. Wrap your code with ``` (three ticks) 3. Do not skip any logic. 4. Preserve all comments without changing. 5. If migration is impossible leave "TODO [Migrate]" comment. @@ -31,6 +31,10 @@ llm: context_length: 4096 model_type: llama gpu_layers: 50 + prompt_template: | + [INST]: + ${PROMPT} + [/INST] source: ignore_locations: # These locations will not be included in migration diff --git a/unifree/llms/huggingface_llm.py b/unifree/llms/huggingface_llm.py index 62fb0f9..02e11cd 100644 --- a/unifree/llms/huggingface_llm.py +++ b/unifree/llms/huggingface_llm.py @@ -2,6 +2,7 @@ from typing import Optional, List, Dict from unifree import LLM, QueryHistoryItem, log +from unifree.utils import get_or_create_global_instance # Copyright (c) Unifree @@ -31,17 +32,17 @@ def __init__(self, config: Dict) -> None: def query(self, user: str, system: Optional[str] = None, history: Optional[List[QueryHistoryItem]] = None) -> str: prompt = '' if system: - prompt += "> user: Remember these rules: \n" + system + "\n\n" - prompt += "> assistant: Certainly, I will remember and follow these rules. \n\n" + prompt += self._to_user_prompt(f"Remember these rules:\n{system}\n") + prompt += "\nCertainly, I will remember and follow these rules.\n" if history: - history_str = [f"> {h.role}: {h.content}" for h in history] - history_str = "\n\n".join(history_str) + for item in history: + if item.role == "user": + prompt += self._to_user_prompt(f"\n{item.content}\n") + else: + prompt += f"\n{item.content}\n" - prompt += history_str + "\n\n" - - prompt += "> user: " + user + "\n\n" - prompt += "> assistant: " + prompt += self._to_user_prompt(f"\n{user}\n") log.debug(f"\n==== LLM REQUEST ====\n{prompt}\n") @@ -57,15 +58,19 @@ def initialize(self) -> None: llm_config = self.config["config"] checkpoint = llm_config["checkpoint"] - self._model = AutoModelForCausalLM.from_pretrained( + self._model = get_or_create_global_instance(checkpoint, lambda: AutoModelForCausalLM.from_pretrained( checkpoint, model_type=llm_config["model_type"], gpu_layers=llm_config["gpu_layers"], context_length=llm_config["context_length"], - ) + )) def fits_in_one_prompt(self, token_count: int) -> bool: return token_count < self.config["config"]["context_length"] def count_tokens(self, source_text: str) -> int: return len(self._model.tokenize(source_text)) + + def _to_user_prompt(self, user: str) -> str: + prompt_template = self.config["config"]["prompt_template"] + return prompt_template.replace("${PROMPT}", user) diff --git a/unifree/project_migration_strategies.py b/unifree/project_migration_strategies.py index b22e149..02b5fb4 100644 --- a/unifree/project_migration_strategies.py +++ b/unifree/project_migration_strategies.py @@ -4,6 +4,7 @@ # This code is licensed under MIT license (see LICENSE.txt for details) import os.path +import traceback from abc import ABC from typing import List, Union, Dict, Optional, Iterable @@ -120,6 +121,7 @@ def _map_file_path_to_migration(self, file_path: str) -> Union[MigrationStrategy return self._create_migration_strategy(strategy_name, spec) return None except Exception as e: + traceback.print_exc() return f"'{file_path}' failed to create strategy: {e}" def _create_migration_strategy(self, strategy_name: str, spec: FileMigrationSpec) -> MigrationStrategy: diff --git a/unifree/utils.py b/unifree/utils.py index 2df68d2..d6e8eb8 100644 --- a/unifree/utils.py +++ b/unifree/utils.py @@ -6,8 +6,9 @@ import importlib import os import re +import threading from collections import defaultdict -from typing import Type, Dict, Any +from typing import Type, Dict, Any, TypeVar, Callable import yaml @@ -76,3 +77,21 @@ def to_default_dict(d): def _return_none(): return None + + +InstanceType = TypeVar('InstanceType') + +_global_instances: Dict[str, Any] = {} +_global_instances_lock: threading.Lock = threading.Lock() + + +def get_or_create_global_instance(name: str, new_instance_creator: Callable[[], InstanceType]) -> InstanceType: + global _global_instances + global _global_instances_lock + + if name not in _global_instances: + with _global_instances_lock: + if name not in _global_instances: + _global_instances[name] = new_instance_creator() + + return _global_instances[name]