-
Notifications
You must be signed in to change notification settings - Fork 77
/
splitter.py
117 lines (96 loc) · 3.95 KB
/
splitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, Language
import code_snippets as code_snippets
import tiktoken
# Streamlit UI
st.title("Text Splitter Playground")
st.info("""Split a text into chunks using a **Text Splitter**. Parameters include:
- `chunk_size`: Max size of the resulting chunks (in either characters or tokens, as selected)
- `chunk_overlap`: Overlap between the resulting chunks (in either characters or tokens, as selected)
- `length_function`: How to measure lengths of chunks, examples are included for either characters or tokens
- The type of the text splitter, this largely controls the separators used to split on
""")
col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
with col1:
chunk_size = st.number_input(min_value=1, label="Chunk Size", value=1000)
with col2:
# Setting the max value of chunk_overlap based on chunk_size
chunk_overlap = st.number_input(
min_value=1,
max_value=chunk_size - 1,
label="Chunk Overlap",
value=int(chunk_size * 0.2),
)
# Display a warning if chunk_overlap is not less than chunk_size
if chunk_overlap >= chunk_size:
st.warning("Chunk Overlap should be less than Chunk Length!")
with col3:
length_function = st.selectbox(
"Length Function", ["Characters", "Tokens"]
)
splitter_choices = ["RecursiveCharacter", "Character"] + [str(v) for v in Language]
with col4:
splitter_choice = st.selectbox(
"Select a Text Splitter", splitter_choices
)
if length_function == "Characters":
length_function = len
length_function_str = code_snippets.CHARACTER_LENGTH
elif length_function == "Tokens":
enc = tiktoken.get_encoding("cl100k_base")
def length_function(text: str) -> int:
return len(enc.encode(text))
length_function_str = code_snippets.TOKEN_LENGTH
else:
raise ValueError
if splitter_choice == "Character":
import_text = code_snippets.CHARACTER.format(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=length_function_str
)
elif splitter_choice == "RecursiveCharacter":
import_text = code_snippets.RECURSIVE_CHARACTER.format(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=length_function_str
)
elif "Language." in splitter_choice:
import_text = code_snippets.LANGUAGE.format(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
language=splitter_choice,
length_function=length_function_str
)
else:
raise ValueError
st.info(import_text)
# Box for pasting text
doc = st.text_area("Paste your text here:")
# Split text button
if st.button("Split Text"):
# Choose splitter
if splitter_choice == "Character":
splitter = CharacterTextSplitter(separator = "\n\n",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=length_function)
elif splitter_choice == "RecursiveCharacter":
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=length_function)
elif "Language." in splitter_choice:
language = splitter_choice.split(".")[1].lower()
splitter = RecursiveCharacterTextSplitter.from_language(language,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=length_function)
else:
raise ValueError
# Split the text
splits = splitter.split_text(doc)
# Display the splits
for idx, split in enumerate(splits, start=1):
st.text_area(
f"Split {idx}", split, height=200
)