Skip to content

Commit

Permalink
Add sparse embed workflow (#291)
Browse files Browse the repository at this point in the history
* Add sparse embed func

* Add sparse embed func

* Add sparse embed func

* Add sparse embed func

* Fix BGEM3SparseEmbeddingFunction model_dir

* Fix BGEM3SparseEmbeddingFunction model_dir

* Fix BGEM3SparseEmbeddingFunction model_dir

* Fix ModelScopeDownloader

* Fix ModelScopeDownloader
  • Loading branch information
wwxxzz authored Dec 2, 2024
1 parent 18eac82 commit 9775485
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type = "SimpleDirectoryReader"
[rag.embedding]
source = "DashScope"
embed_batch_size = 10
enable_sparse = false

[rag.index]
persist_path = "localdata/storage"
Expand Down
15 changes: 14 additions & 1 deletion src/pai_rag/integrations/embeddings/pai/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pai_rag.integrations.embeddings.clip.cnclip_embedding import CnClipEmbedding
import os
from loguru import logger
from pai_rag.utils.download_models import ModelScopeDownloader


def create_embedding(embed_config: PaiBaseEmbeddingConfig):
Expand All @@ -36,8 +37,20 @@ def create_embedding(embed_config: PaiBaseEmbeddingConfig):
)
elif isinstance(embed_config, HuggingFaceEmbeddingConfig):
pai_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository")
pai_model_name = os.path.join(pai_model_dir, embed_config.model)
if not os.path.exists(pai_model_name):
logger.info(
f"Embedding model {embed_config.model} not found in {pai_model_dir}, try download it."
)
download_models = ModelScopeDownloader(
fetch_config=True, download_directory_path=pai_model_dir
)
download_models.load_models(model=embed_config.model)
logger.info(
f"Embedding model {embed_config.model} downloaded to {pai_model_name}."
)
embed_model = HuggingFaceEmbedding(
model_name=os.path.join(pai_model_dir, embed_config.model),
model_name=pai_model_name,
embed_batch_size=embed_config.embed_batch_size,
trust_remote_code=True,
callback_manager=Settings.callback_manager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class PaiBaseEmbeddingConfig(BaseModel):
source: SupportedEmbedType
model: str
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE
enable_sparse: bool = False

class Config:
frozen = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@


class BGEM3SparseEmbeddingFunction:
def __init__(self) -> None:
def __init__(self, model_name_or_path: str = None) -> None:
try:
from FlagEmbedding import BGEM3FlagModel

self.model = BGEM3FlagModel(
model_name_or_path=os.path.join(DEFAULT_MODEL_DIR, MODEL_NAME),
model_name_or_path=os.path.join(
model_name_or_path or DEFAULT_MODEL_DIR, MODEL_NAME
),
use_fp16=False,
)
except Exception:
Expand Down
14 changes: 9 additions & 5 deletions src/pai_rag/tools/data_process/embed_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
from loguru import logger
import ray
import time
from ray.data.datasource.filename_provider import _DefaultFilenameProvider
from pai_rag.tools.data_process.tasks.embed_node import embed_node_task
from pai_rag.tools.data_process.utils.ray_init import init_ray_env, get_num_workers
Expand All @@ -17,22 +18,25 @@ def main(args):
concurrency=num_workers,
)
logger.info("Embedding nodes completed.")
logger.info(f"Write to {args.output_dir}")
timestamp = time.strftime("%Y%m%d-%H%M%S")
ds = ds.repartition(1)
logger.info(f"Write to {args.output_path}")
ds.write_json(
args.output_path,
filename_provider=_DefaultFilenameProvider(file_format="jsonl"),
args.output_dir,
filename_provider=_DefaultFilenameProvider(
dataset_uuid=timestamp, file_format="jsonl"
),
force_ascii=False,
)
logger.info("Write completed.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--working_dir", type=str, default=None)
parser.add_argument("--config_file", type=str, default=None)
parser.add_argument("--data_path", type=str, default=None)
parser.add_argument("--output_path", type=str, default=None)
parser.add_argument("--working_dir", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)
args = parser.parse_args()

print(f"Init: args: {args}")
Expand Down
14 changes: 10 additions & 4 deletions src/pai_rag/tools/data_process/parse_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import ray
import argparse
import json
import time
from loguru import logger
from pai_rag.integrations.readers.pai.pai_data_reader import get_input_files
from pai_rag.tools.data_process.utils.ray_init import init_ray_env
Expand Down Expand Up @@ -32,16 +34,20 @@ def main(args):
]
results = ray.get(run_tasks)
logger.info("Master node completed processing files.")
write_to_file.remote(results, args.output_path)
logger.info(f"Results written to {args.output_path} asynchronously.")
os.makedirs(args.output_dir, exist_ok=True)
timestamp = time.strftime("%Y%m%d-%H%M%S")
save_file = os.path.join(args.output_dir, f"{timestamp}.jsonl")
write_to_file.remote(results, save_file)
logger.info(f"Results written to {save_file} asynchronously.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--working_dir", type=str, default=None)
parser.add_argument("--config_file", type=str, default=None)
parser.add_argument("--data_path", type=str, default=None)
parser.add_argument("--output_path", type=str, default=None)
parser.add_argument("--working_dir", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)

args = parser.parse_args()

logger.info(f"Init: args: {args}")
Expand Down
15 changes: 10 additions & 5 deletions src/pai_rag/tools/data_process/split_workflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import ray
import time
from loguru import logger
from ray.data.datasource.filename_provider import _DefaultFilenameProvider
from pai_rag.tools.data_process.tasks.split_node import split_node_task
Expand All @@ -17,21 +18,25 @@ def main(args):
concurrency=num_workers,
)
logger.info("Splitting nodes completed.")
logger.info(f"Write to {args.output_dir}")
timestamp = time.strftime("%Y%m%d-%H%M%S")
ds = ds.repartition(1)
logger.info(f"Write to {args.output_path}")
ds.write_json(
args.output_path,
filename_provider=_DefaultFilenameProvider(file_format="jsonl"),
args.output_dir,
filename_provider=_DefaultFilenameProvider(
dataset_uuid=timestamp, file_format="jsonl"
),
force_ascii=False,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--working_dir", type=str, default=None)
parser.add_argument("--config_file", type=str, default=None)
parser.add_argument("--data_path", type=str, default=None)
parser.add_argument("--output_path", type=str, default=None)
parser.add_argument("--working_dir", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)

args = parser.parse_args()

logger.info(f"Init: args: {args}")
Expand Down
16 changes: 14 additions & 2 deletions src/pai_rag/tools/data_process/tasks/embed_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,29 @@
dict_to_text_node,
)
from pai_rag.utils.download_models import ModelScopeDownloader
from pai_rag.integrations.index.pai.utils.sparse_embed_function import (
BGEM3SparseEmbeddingFunction,
)

RAY_ENV_MODEL_DIR = "/PAI-RAG/pai_rag_model_repository"
os.environ["PAI_RAG_MODEL_DIR"] = RAY_ENV_MODEL_DIR


def embed_node_task(node, config_file):
config = RagConfigManager.from_file(config_file).get_value()
ModelScopeDownloader(download_directory_path=RAY_ENV_MODEL_DIR).load_rag_models()
download_models = ModelScopeDownloader(
fetch_config=True, download_directory_path=RAY_ENV_MODEL_DIR
)

embed_model = resolve(cls=PaiEmbedding, embed_config=config.embedding)
format_node = dict_to_text_node(node)
embed_nodes = embed_model([format_node])
nodes_dict = text_node_to_dict(embed_nodes[0])
sparse_embedding = None
if config.embedding.enable_sparse:
download_models.load_models(model="bge-m3")
sparse_embed_model = BGEM3SparseEmbeddingFunction(
model_name_or_path=RAY_ENV_MODEL_DIR
)
sparse_embedding = sparse_embed_model.encode_documents([embed_nodes[0].text])[0]
nodes_dict = text_node_to_dict(embed_nodes[0], sparse_embedding)
return nodes_dict
10 changes: 10 additions & 0 deletions src/pai_rag/tools/data_process/tasks/load_and_parse_doc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import os
from pai_rag.integrations.readers.pai.pai_data_reader import PaiDataReader
from pai_rag.utils.oss_client import OssClient
from pai_rag.core.rag_module import resolve
from pai_rag.core.rag_config_manager import RagConfigManager
from pai_rag.tools.data_process.utils.format_document import document_to_dict
from pai_rag.utils.download_models import ModelScopeDownloader

RAY_ENV_MODEL_DIR = "/PAI-RAG/pai_rag_model_repository"
os.environ["PAI_RAG_MODEL_DIR"] = RAY_ENV_MODEL_DIR


def load_and_parse_doc_task(config_file, input_file):
config = RagConfigManager.from_file(config_file).get_value()
download_models = ModelScopeDownloader(
fetch_config=True, download_directory_path=RAY_ENV_MODEL_DIR
)
download_models.load_mineru_config()

data_reader_config = config.data_reader
oss_store = None
if config.oss_store.bucket:
Expand Down
3 changes: 2 additions & 1 deletion src/pai_rag/tools/data_process/utils/format_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from llama_index.core.schema import TextNode


def text_node_to_dict(node):
def text_node_to_dict(node, sparse_embedding=None):
return {
"id": node.id_,
"embedding": node.embedding,
"sparse_embedding": sparse_embedding,
"metadata": {
k: str(v) if isinstance(v, datetime) else v
for k, v in node.metadata.items()
Expand Down
6 changes: 4 additions & 2 deletions src/pai_rag/utils/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def load_model(self, model):
else:
raise ValueError(f"{model} is not a valid model name.")
temp_model_dir = snapshot_download(model_id, cache_dir=temp_dir)

logger.info(
f"Downloaded model {model} to {temp_model_dir} and move to {model_path}."
)
shutil.move(temp_model_dir, model_path)
end_time = time.time()
duration = end_time - start_time
Expand Down Expand Up @@ -91,7 +93,7 @@ def load_mineru_config(self):
"Copy magic-pdf.template.json to ~/magic-pdf.json and modify models-dir to model path."
)

def load_models(self, model):
def load_models(self, model=None):
if model is None:
models = [model for model in self.model_info["basic_models"].keys()] + [
model for model in self.model_info["extra_models"].keys()
Expand Down

0 comments on commit 9775485

Please sign in to comment.