-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
372 lines (301 loc) · 12.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os
import os.path as op
import pandas as pd
import glob
import random
import matplotlib.pyplot as plt
import seaborn as sns
import re
import json
import gzip
from transformers import pipeline
from googleapiclient.discovery import build
API_KEY = pd.read_json(op.join(".", "config.json"))["api_key"][
0
] # local file w personal API key
# ______________________________________________________________________________________________________________________
# Data extraction
# ______________________________________________________________________________________________________________________
def filter_jsonl(input_path, category, batch_size, save_path, verbose=False):
"""Unzips input jsonl data then extracts rows with given category and saves them in batches
Args:
input (str): path to yt_metadata_en.jsonl.gz (incl)
category (str): from the options in channel metadata
batch_size (int): number of videos per batch
save_path (str): path to folder where you want the batch dataframes to be saved
verbose (bool, optional): print info. Defaults to False.
"""
batch_index = -1 # so we can start with index 0
line_counter = 0
category_counter = 0
renew_list = True # bc issue: 0 % anythig = 0
with gzip.open(input_path, "rt", encoding="utf-8") as f:
for line in f:
entry = json.loads(line)
line_counter += 1
# create new batch list
if category_counter % batch_size == 0 and renew_list:
renew_list = False
filtered_data = []
batch_index += 1
if verbose:
print(
f"======== Batch {batch_index} - started at {line_counter} ========"
)
if entry.get("categories") == category:
category_counter += 1
filtered_data.append(entry)
if verbose:
if category_counter != 0 and category_counter % 100000 == 0:
print(
f"Filtered {category_counter} {category} videos out of {line_counter} so far"
)
if len(filtered_data) == batch_size: # save
df_filtered = pd.DataFrame(filtered_data)
df_filtered.to_csv(
os.path.join(save_path, f"{category}_videos_{batch_index}.csv")
)
renew_list = True
if verbose:
print(
f"We filtered a total of {category_counter} videos in the {category} category!"
)
df_filtered = pd.DataFrame(filtered_data)
df_filtered.to_csv(
os.path.join(save_path, f"{category}_videos_{batch_index}.csv")
)
print(
f"We filtered a total of {category_counter} videos in the {category} category!"
)
# ______________________________________________________________________________________________________________________
# BART classification functions
# ______________________________________________________________________________________________________________________
def load_metadata_videos(file_path):
"""
Load the metadata of the videos from the file_path
"""
return pd.read_csv(file_path).drop(columns="Unnamed: 0").dropna()
def bart_classification(text, candidate_labels, multi_label=True, plot=False, title=""):
"""
Perform zero-shot classification using BART model
Parameters:
- text: the text to classify
- candidate_labels: the list of labels to classify the text into
- multi_label: whether to allow multiple labels or not
- plot: whether to plot the scores or not
- title: the title of the plot
Returns:
- a list of labels
"""
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
result = classifier(text, candidate_labels, multi_label=multi_label)
scores, labels = result["scores"], result["labels"]
sorted_pairs = sorted(zip(scores, labels), key=lambda x: x[0], reverse=True)
scores, labels = zip(*sorted_pairs)
max_score = scores[0]
threshold = max_score * 0.9
top_count = len([i for i, score in enumerate(scores) if score >= threshold])
if plot:
plot_scores_BART(scores, labels, top_count, title)
if max_score < 0.3:
return ["misc"]
elif top_count == 1:
return [labels[0]]
elif top_count == 2 and multi_label:
return [labels[0], labels[1]]
elif top_count == 3 and multi_label:
return [labels[0], labels[1], labels[2]]
else:
return ["uncertain"]
def plot_scores_BART(scores, labels, top_count, title):
"""
Plot the scores of the labels of the BART classification
Parameters:
- scores: the scores of the labels
- labels: the labels
- top_count: the number of top labels
- title: the title of the plot
"""
# Define colors: green for the top scores, grey for others
colors = ["green" if i < top_count else "grey" for i in range(len(labels))]
# Map x-axis labels to integers from 1 to len(labels)
x_positions = range(1, len(labels) + 1)
# Create a figure with two subplots
fig, (ax_main, ax_legend) = plt.subplots(
1, 2, gridspec_kw={"width_ratios": [4, 1]}, figsize=(12, 4)
)
# Plot the main bar chart on the first subplot
bars = ax_main.bar(x_positions, scores, color=colors)
# Add score labels above each bar
for bar, score in zip(bars, scores):
height = bar.get_height()
ax_main.text(
bar.get_x() + bar.get_width() / 2,
height + 0.01, # Slightly above the bar
f"{score:.2f}",
ha="center",
va="bottom",
fontsize=10,
)
# Customize the main subplot
ax_main.set_title(f"Probability of Each Label for the Video:\n{title}", fontsize=14)
ax_main.set_xlabel("Label Numbers", fontsize=12)
ax_main.set_ylabel("Probability", fontsize=12)
ax_main.set_ylim(0, max(scores) + 0.1) # Add some space on top for labels
ax_main.set_xticks(x_positions) # Use integers on x-axis
ax_main.grid(axis="y", linestyle="--", alpha=0.7)
# Set up the legend subplot with small points at (0,0) for each label
for i, label in enumerate(labels):
ax_legend.plot(
0, 0, "o", color="white", label=f"{i + 1}: {label}"
) # White dot as a placeholder
# Hide the legend subplot axes and only show the legend
ax_legend.legend(loc="center", fontsize=9)
ax_legend.axis("off")
# Display the plot with tight layout
plt.tight_layout()
plt.show()
# ______________________________________________________________________________________________________________________
# Functions to extract countries from the education channels - YouTube API
# ______________________________________________________________________________________________________________________
def extract_channels_edu(path_edu, N_BATCHES, verbose=False):
channels = []
for i in range(N_BATCHES):
if verbose:
print(f"Processing file : path_edu_{i}", end="")
edu = pd.read_csv(path_edu.format(i), index_col=0)
ch = list(pd.unique(edu["channel_id"]))
if verbose:
print(f" --> Found {len(ch)} channels")
channels.extend(ch)
channels = list(set(channels)) # take unique of the junction
if verbose:
print("Total number of unique channels :", len(channels))
return channels
def agglomerate_countries(x, val_counts, filter=10):
if type(x) == str and val_counts[x] < filter:
return "Other"
elif type(x) == str and x == "deleted": # assign deleted to 'unknown'
return "?"
elif type(x) == float: # assign NaN to 'unknown'
return "?"
else:
return x
def youtube_country_scraper(channel_ids, verbose=False):
# Disable OAuthlib's HTTPS verification when running locally. *DO NOT* leave this option enabled in production.
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
youtube = build("youtube", "v3", developerKey=API_KEY)
ids_string = ",".join(channel_ids)
request = youtube.channels().list(part="snippet", id=ids_string)
items = request.execute()
countries = {ch: "Redo" for ch in channel_ids}
if "items" in items: # for when you redo with single channels
for item in items.get("items", []):
if "snippet" in item:
id = item.get("id")
country = item.get("snippet").get("country")
if (
id in channel_ids
): # else the channel now has a different id and need to be redone
countries[id] = country
else:
countries[id] = None
else:
countries[list(countries)[0]] = (
"deleted" # channel info is not available anymore
)
if verbose:
print(items)
print(countries)
return countries
# ______________________________________________________________________________________________________________________
# Functions FRED
# ______________________________________________________________________________________________________________________
def remove_nan(df: pd.DataFrame):
df = df.dropna()
return df
def clean_non_ascii(text):
return re.sub(r"[^\x00-\x7F]+", " ", text)
def replace_non_ascii_in_dataframe(df, columns=["title", "tags", "description"]):
for column in columns:
df.loc[:, column] = df[column].apply(
lambda x: clean_non_ascii(x) if isinstance(x, str) else x
)
return df
def remove_rows_with_empty_strings(df, columns=["title", "tags", "description"]):
# Filter out rows where any specified column has an empty string
df_filtered = df[
~df[columns].apply(lambda row: any(cell == "" for cell in row), axis=1)
]
return df_filtered
def clean_data(df: pd.DataFrame):
out = remove_nan(df)
out = replace_non_ascii_in_dataframe(out)
out = remove_rows_with_empty_strings(out)
return out
def random_sample_from_csv_files(directory_path, total_sample_size):
# Find all CSV files that start with "Education_videos_"
file_paths = glob.glob(f"{directory_path}/Education_videos_*.csv")
# Initialize an empty list for the reservoir
reservoir = []
row_index = 0 # Track the index of the current row across all files
# Process each file one by one
for file_path in file_paths:
# Read the file in chunks to manage memory usage
df = pd.read_csv(file_path)
# Iterate over each row in the current file
for _, row in df.iterrows():
if len(reservoir) < total_sample_size:
# If reservoir is not full, add the row directly
reservoir.append(row)
else:
# If reservoir is full, replace an element with decreasing probability
replace_index = random.randint(0, row_index)
if replace_index < total_sample_size:
reservoir[replace_index] = row
# Increment the global row index
row_index += 1
# Convert the reservoir (list of rows) back to a DataFrame
sampled_df = pd.DataFrame(reservoir)
return sampled_df
def create_channels_cat(df_cat: pd.DataFrame):
# Count the number of categories for each channel ids
channels_cat = (
df_cat.groupby("channel_id")["broad_category"]
.value_counts()
.reset_index(name="count")
)
# Assign a weight to each category for each channel ids (count/total count)
channels_cat["weights"] = channels_cat.groupby("channel_id")["count"].transform(
lambda x: x / x.sum()
)
# Aggregate the categories and weights to lists
result = (
channels_cat.groupby("channel_id")
.agg(categories=("broad_category", list), weights=("weights", list))
.reset_index()
)
return result
def create_channel_cat_single(df_cat: pd.DataFrame):
# Assign channels to the category with the highest number of videos
channel_cat = df_cat.groupby(["channel_id", "broad_category"]).size()
channel_cat = (
channel_cat.groupby("channel_id")
.idxmax()
.apply(lambda x: x[1])
.reset_index(name="dominant_category")
)
return channel_cat
def category_filter(df: pd.DataFrame, category: str):
# Return the dataframe filtered with the choosen category, and with the weight of the corresponding category
filtered_df = df[df["categories"].apply(lambda x: category in x)]
filtered_df["category_weight"] = filtered_df.apply(
lambda row: row["weights"][row["categories"].index(category)], axis=1
)
return filtered_df
def add_datetime_info(df, column):
df[column] = pd.to_datetime(df[column])
df["month"] = df[column].dt.month
df["year"] = df[column].dt.year
df["day"] = df[column].dt.day
return df