Skip to content

Commit

Permalink
Impl extractor subclassing llmtool; provide extractors as input to th…
Browse files Browse the repository at this point in the history
…e doc splitter
  • Loading branch information
lyliyu committed Dec 2, 2023
1 parent b079aa2 commit 6a1272c
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 23 deletions.
52 changes: 36 additions & 16 deletions framework/feature_factory/llm_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from llama_index.node_parser import SimpleNodeParser
from llama_index.node_parser.extractors import (
MetadataExtractor,
TitleExtractor
TitleExtractor,
MetadataFeatureExtractor
)
from llama_index.text_splitter import TokenTextSplitter
from llama_index.schema import MetadataMode, Document as Document
Expand All @@ -24,6 +25,7 @@ class LLMTool(ABC):
"""
def __init__(self) -> None:
self._initialized = False
self.instance = None

def _require_init(self) -> bool:
if self._initialized:
Expand All @@ -40,6 +42,9 @@ def apply(self):
def create(self):
...

def get_instance(self):
return self._instance


class DocReader(LLMTool):
""" Generic class for doc reader.
Expand Down Expand Up @@ -165,16 +170,31 @@ def apply(self, filename: str) -> str:



class LLMDef(LLMTool):
""" A generic class to define LLM instance e.g. using HuggingFace APIs.
An example can be found at notebooks/feature_factory_llms.py
"""
def __init__(self) -> None:
self._instance = None
# class LLMDef(LLMTool):
# """ A generic class to define LLM instance e.g. using HuggingFace APIs.
# An example can be found at notebooks/feature_factory_llms.py
# """
# def __init__(self) -> None:
# self._instance = None

def get_instance(self):
return self._instance
# def get_instance(self):
# return self._instance

class LlamaIndexTitleExtractor(LLMTool):

def __init__(self, llm_def, nodes) -> None:
super().__init__()
self.llm_def = llm_def
self.nodes = nodes

def create(self):
if super()._require_init():
self.llm_def.create()
self._instance = TitleExtractor(nodes=self.nodes, llm=self.llm_def.get_instance())

def apply(self):
self.create()



class LlamaIndexDocSplitter(DocSplitter):
Expand All @@ -183,23 +203,23 @@ class LlamaIndexDocSplitter(DocSplitter):
`chunk_size`, `chunk_overlap` are the super parameters to tweak for better response from LLMs.
`llm` is the LLM instance used for metadata extraction. If not provided, the splitter will generate text chunks only.
"""
def __init__(self, chunk_size:int=1024, chunk_overlap:int=64, llm:LLMDef=None) -> None:
def __init__(self, chunk_size:int=1024, chunk_overlap:int=64, extractors:List[LLMTool]=None) -> None:
super().__init__()
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.llm = llm
self.extractors = extractors

def create(self):
if super()._require_init():
text_splitter = TokenTextSplitter(
separator=" ", chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
if self.llm:
self.llm.create()
if self.extractors:
for extractor in self.extractors:
extractor.create()
extractor_instances = [e.get_instance() for e in self.extractors]
metadata_extractor = MetadataExtractor(
extractors=[
TitleExtractor(nodes=5, llm=self.llm.get_instance())
],
extractors=extractor_instances,
in_place=False,
)
else:
Expand Down
35 changes: 28 additions & 7 deletions notebooks/feature_factory_llms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Databricks notebook source
# MAGIC %pip install llama-index pypdf
# MAGIC %pip install llama-index==0.8.61 pypdf

# COMMAND ----------

Expand All @@ -11,7 +11,11 @@

# COMMAND ----------

from framework.feature_factory.llm_tools import LLMFeature, LlamaIndexDocReader, LlamaIndexDocSplitter, LLMDef
# MAGIC %pip list

# COMMAND ----------

from framework.feature_factory.llm_tools import LLMFeature, LlamaIndexDocReader, LlamaIndexDocSplitter, LLMTool, LangChainRecursiveCharacterTextSplitter, LlamaIndexTitleExtractor
from framework.feature_factory import Feature_Factory
import torch
from llama_index.llms import HuggingFaceLLM
Expand All @@ -22,7 +26,7 @@

# COMMAND ----------

class MPT7b(LLMDef):
class MPT7b(LLMTool):
def create(self):
torch.cuda.empty_cache()
generate_params = {
Expand All @@ -35,7 +39,7 @@ def create(self):
"pad_token_id": 0
}

self._instance = HuggingFaceLLM(
llm = HuggingFaceLLM(
max_new_tokens=256,
generate_kwargs=generate_params,
# system_prompt=system_prompt,
Expand All @@ -46,21 +50,33 @@ def create(self):
tokenizer_kwargs={"max_length": 1024},
model_kwargs={"torch_dtype": torch.float16, "trust_remote_code": True}
)
return None
self._instance = llm
return llm

def apply(self):
...

# COMMAND ----------

title_extractor = LlamaIndexTitleExtractor(nodes=5, llm_def = MPT7b())

# COMMAND ----------

doc_splitter = LlamaIndexDocSplitter(
chunk_size = 1024,
chunk_overlap = 32,
llm = MPT7b()
extractors = [title_extractor]
)

# COMMAND ----------

# doc_splitter = LangChainRecursiveCharacterTextSplitter(
# chunk_size = 1024,
# chunk_overlap = 32
# )

# COMMAND ----------

llm_feature = LLMFeature (
name = "chunks",
reader = LlamaIndexDocReader(),
Expand All @@ -73,12 +89,17 @@ def apply(self):

# COMMAND ----------

df = ff.assemble_llm_feature(spark, srcDirectory= "/dbfs/tmp/li_yu/va_llms/pdf", llmFeature=llm_feature, partitionNum=partition_num)
df = ff.assemble_llm_feature(spark, srcDirectory= "your source document directory", llmFeature=llm_feature, partitionNum=partition_num)

# COMMAND ----------

display(df)

# COMMAND ----------

df.write.mode("overwrite").saveAsTable("<catalog>.<schema>.<table>")

# COMMAND ----------



31 changes: 31 additions & 0 deletions test/test_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from framework.feature_factory.catalog import LLMCatalogBase
from enum import IntEnum
from framework.feature_factory.llm_tools import *
from llama_index.llms import HuggingFaceLLM
import torch


class TestLLMTools(unittest.TestCase):
Expand All @@ -19,6 +21,35 @@ def test_llamaindex_reader(self):
doc_reader.create()
docs = doc_reader.apply("test/data/sample.pdf")
assert len(docs) == 2

def test_metadata_extractor(self):
class MPT7b(LLMTool):
def create(self):
generate_params = {
"temperature": 1.0,
"top_p": 1.0,
"top_k": 50,
"use_cache": True,
"do_sample": True,
"eos_token_id": 0,
"pad_token_id": 0
}

self._instance = HuggingFaceLLM(
max_new_tokens=256,
generate_kwargs=generate_params,
tokenizer_name="mosaicml/mpt-7b-instruct",
model_name="mosaicml/mpt-7b-instruct",
device_map="auto",
tokenizer_kwargs={"max_length": 1024},
model_kwargs={"torch_dtype": torch.float16, "trust_remote_code": True}
)
return None
def apply(self):
...

title_extractor = LlamaIndexTitleExtractor(nodes=5, llm_def=MPT7b())
assert title_extractor.nodes == 5 and isinstance(title_extractor.llm_def, MPT7b)

def test_llamaindex_splitter(self):
doc_reader = LlamaIndexDocReader()
Expand Down

0 comments on commit 6a1272c

Please sign in to comment.