-
Notifications
You must be signed in to change notification settings - Fork 0
/
CreateDocuments.py
154 lines (121 loc) · 5.49 KB
/
CreateDocuments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import numpy as np
import pandas as pd
from pickle import Pickler, Unpickler
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
RAW_TRAIN_DATA_PATH = 'rag-dataset-12000/data/train-00000-of-00001-9df3a936e1f63191.parquet'
RAW_TEST_DATA_PATH = 'rag-dataset-12000/data/test-00000-of-00001-af2a9f454ad1b8a3.parquet'
def get_original_contexts(dataset_path):
"""
Loads the dataset and returns the 'context' column as a numpy array.
The context column contains the context on which the query/responses are based.
Args:
dataset_path: The path to the dataset.
Returns:
An array containing the 'context' column.
"""
df = pd.read_parquet(dataset_path)
#print(df.shape)
#df = df.iloc[[7937, 7952, 7771]] # For testing purposes
return df['context'].array
def split_into_chunks(dataset_path, chunk_size, chunk_overlap):
"""
Splits the contexts into documents. It uses the RecursiveCharacterTextSplitter to split the contexts into chunks.
The RecursiveCharacterTextSplitter strategy maximizes the length of each document.
See: https://python.langchain.com/docs/modules/data_connection/document_transformers/recursive_text_splitter/
Args:
dataset_path (string): The path to the dataset.
chunk_size (int): The size of each chunk.
chunk_overlap (int): The overlap between chunks.
Returns:
A list of documents, where each document is a list of chunks.
"""
original_contexts = get_original_contexts(dataset_path)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,)
chunks = text_splitter.create_documents(original_contexts)
return chunks
def save_chunks(chunks, dir, dataset_name, chunk_size, chunk_overlap):
"""
Saves the documents to a pickle file. The filename is based on the dataset, chunk_size, and chunk_overlap.
Args:
chunks (list): The list of chunks.
dir (string): The directory where the documents will be saved.
dataset_name (string): The name of the dataset (train or test).
chunk_size (int): The size of each chunk.
chunk_overlap (int): The overlap between chunks.
Returns:
The path to the saved file.
"""
if not os.path.exists(dir):
os.makedirs(dir)
file_path = f'{dir}/{dataset_name}_size_{chunk_size}_overlap_{chunk_overlap}.pkl'
with open(file_path, 'wb') as file:
Pickler(file).dump(chunks)
file.flush()
print('saved chunks to:', file_path)
return file_path
def load_chunks(file_path):
"""
Loads the chunks from a pickle file.
Args:
file_path (string): The name of the file.
Returns:
The list of chunks.
"""
with open(file_path, 'rb') as file:
return Unpickler(file).load()
def create_chunks(chunk_size=1000, chunk_overlap=200, dataset_path=RAW_TRAIN_DATA_PATH):
"""
Creates documents from the dataset and saves them to a pickle file.
Args:
chunk_size (int): The size of each chunk.
chunk_overlap (int): The overlap between chunks.
dataset_path (string): The path to the dataset.
"""
documents = split_into_chunks(dataset_path, chunk_size, chunk_overlap)
dataset_name = 'train' if 'train' in dataset_path else 'test'
file_path = save_chunks(documents, 'chunks', dataset_name, chunk_size, chunk_overlap)
return file_path
def create_embeddings(chunks, model_name):
"""
Creates embeddings for the chunks using the HuggingFaceEmbeddings from LangChain.
Args:
chunks (list): The list of documents.
model_name (string): The name of the model to use for the embeddings.
Returns:
A list of embeddings (list of lists).
"""
embedding_model = HuggingFaceEmbeddings(model_name=model_name, show_progress=True)
doc_func = lambda x: x.page_content
docs = list(map(doc_func, chunks))
doc_embeddings = embedding_model.embed_documents(docs)
return doc_embeddings
def create_embeddings_from_sentence_transformer(chunks, model_name='sentence-transformers/all-MiniLM-L6-v2'):
"""
Creates embeddings for the chunks using the SentenceTransformer from HuggingFace (no LangChain).
Args:
chunks (list): The list of chunks.
model_name (string): The name of the model to use for the embeddings.
Returns:
A list of embeddings (tensor).
"""
model = SentenceTransformer(model_name)
embeddings = model.encode([doc.page_content for doc in chunks], convert_to_tensor=True)
return embeddings
if __name__ == '__main__':
chunk_size = 1000
chunk_overlap = 100
for dataset_path in [RAW_TRAIN_DATA_PATH]:
dataset_name = 'train' if 'train' in dataset_path else 'test'
chunks = split_into_chunks(dataset_path, chunk_size, chunk_overlap)
save_chunks(chunks=chunks, dir='chunks', dataset_name=dataset_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# save_chunks(chunks, 'documents', dataset_name, chunk_size, chunk_overlap)
# embeddings = create_embeddings(chunks, 'all-MiniLM-L6-v2')
# Alternatively, use generate directly from HF (returns tensors)
# embeddings2 = create_embeddings_from_sentence_transformer(documents, 'sentence-transformers/all-MiniLM-L6-v2')