Skip to content

Commit

Permalink
add comments to llms api and classes
Browse files Browse the repository at this point in the history
  • Loading branch information
lyliyu committed Nov 13, 2023
1 parent cff8eda commit 0ed788b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
10 changes: 9 additions & 1 deletion framework/feature_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,15 @@ def append_catalog(self, df: DataFrame, groupBy_cols, catalog_cls, feature_names
return self.append_features(df, groupBy_cols, [fs], withTrendsForFeatures, granularityEnum)

def assemble_llm_feature(self, spark: SparkSession, srcDirectory: str, llmFeature: LLMFeature, partitionNum: int):

"""
Creates a dataframe which contains only one column named as llmFeature.name.
The method will distribute the files under srcDirectory to the partitions determined by the partitionNum.
Each file will be parsed and chunked using the reader and splitter in the llmFeature object.
:param spark: a spark session instance
:param srcDirectory: the directory containing documents to parse
:llmFeature: the LLM feature instance
:partitionNum: the number of partitions the src documents will be distributed onto.
"""
all_files = self.helpers.list_files_recursively(srcDirectory)
src_rdd = spark.sparkContext.parallelize(all_files, partitionNum)

Expand Down
58 changes: 49 additions & 9 deletions framework/feature_factory/llm_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@


class LLMTool(ABC):

"""Generic interface for LLMs tools.
apply and create methods need to be implemented in the children classes.
create method creates resources for the tool and apply method makes inference using the resources.
If the resources are not created before calling apply(), create() will be invoked in the beginning of the apply().
Having a separate create() will make it more efficient to initalize/create all required resouces only once per partition.
"""
def __init__(self) -> None:
self._initialized = False

Expand All @@ -37,7 +42,8 @@ def create(self):


class DocReader(LLMTool):

""" Generic class for doc reader.
"""
def create(self):
...

Expand All @@ -46,7 +52,8 @@ def apply(self, filename: str) -> Union[str, List[Document]]:


class DocSplitter(LLMTool):

""" Generic class for doc splitter.
"""
def __init__(self) -> None:
super().__init__()

Expand Down Expand Up @@ -114,7 +121,9 @@ def apply(self, docs: Union[str, List[Document]]) -> List[str]:


class LlamaIndexDocReader(DocReader):

"""A wrapper class for SimpleDirectoryReader of LlamaIndex.
For more details, refer to https://gpt-index.readthedocs.io/en/latest/examples/data_connectors/simple_directory_reader.html
"""
def __init__(self) -> None:
super().__init__()

Expand All @@ -124,6 +133,9 @@ def apply(self, filename: str) -> List[Document]:


class UnstructuredDocReader(DocReader):
"""
A doc reader class using Unstructured API. Only allowed categories will be included in the final parsed text.
"""

def __init__(self, allowedCategories: Tuple[str]=('NarrativeText', 'ListItem')) -> None:
super().__init__()
Expand All @@ -143,7 +155,9 @@ def apply(self, filename: str) -> str:


class LLMDef(LLMTool):

""" A generic class to define LLM instance e.g. using HuggingFace APIs.
An example can be found at notebooks/feature_factory_llms.py
"""
def __init__(self) -> None:
self._instance = None

Expand All @@ -153,7 +167,11 @@ def get_instance(self):


class LlamaIndexDocSplitter(DocSplitter):

"""A class to split documents using LlamaIndex SimpleNodeParser.
TokenTextSplitter and TitleExtractor are used to generate text chunks and metadata for each chunk.
`chunk_size`, `chunk_overlap` are the super parameters to tweak for better response from LLMs.
`llm` is the LLM instance used for metadata extraction. If not provided, the splitter will generate text chunks only.
"""
def __init__(self, chunk_size:int=1024, chunk_overlap:int=64, llm:LLMDef=None) -> None:
super().__init__()
self.chunk_size = chunk_size
Expand Down Expand Up @@ -191,7 +209,10 @@ def apply(self, docs: List[Document]):


class LangChainRecursiveCharacterTextSplitter(DocSplitter):

""" A splitter class to utilize Langchain RecursiveCharacterTextSplitter to generate text chunks.
If `pretrained_model_path` is provided, the `chunk_size` and `chunk_overlap` will be measured in tokens.
If `pretrained_model_path` is not provided, the `chunk_size` and `chunk_overlap` will be measured in characters.
"""
def __init__(self, chunk_size=1024, chunk_overlap=64, pretrained_model_path: str=None) -> None:
super().__init__()
self.chunk_size = chunk_size
Expand Down Expand Up @@ -219,7 +240,9 @@ def apply(self, docs):


class TokenizerTextSpliter(DocSplitter):

""" A text splitter which uses LLM defined by `pretrained_tokenizer_path` to encode the input text.
The splitting will be applied to the tokens instead of characters.
"""
def __init__(self, chunk_size=1024, chunk_overlap=64, pretrained_tokenizer_path: str=None) -> None:
super().__init__()
self.chunk_size = chunk_size
Expand Down Expand Up @@ -248,7 +271,23 @@ def apply(self, text: Union[str, List[Document]]) -> List[str]:


class LLMFeature(LLMTool):
""" A container class to hold all required reader and splitter instances.
The name is the column name for text chunks in the generated spark dataframe.
If the name is not provided, it will take the variable name in the LLM catalog as the name.
e.g.
class TestCatalog(LLMCatalogBase):
# define a reader for the documents
doc_reader = LlamaIndexDocReader()
# define a text splitter
doc_splitter = LangChainRecursiveCharacterTextSplitter()
# define a LLM feature, the name is the column name in the result dataframe
chunk_col_name = LLMFeature(reader=doc_reader, splitter=doc_splitter)
The name of output dataframe will be `chunk_col_name`.
"""
def __init__(self, reader: DocReader, splitter: DocSplitter, name: str = "chunks") -> None:
super().__init__()
self.name = name
Expand All @@ -267,7 +306,8 @@ def apply(self, filename: str):


class LLMUtils:

""" Util class to define generic split and process methods invoked from spark.
"""
@classmethod
def split_docs(cls, fileName: str, llmFeat: LLMFeature):
print(fileName)
Expand Down

0 comments on commit 0ed788b

Please sign in to comment.