-
Notifications
You must be signed in to change notification settings - Fork 0
/
Utility.py
149 lines (129 loc) · 5.88 KB
/
Utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from torch.utils.data import Dataset
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
import pandas as pd
class REDataset_entities(Dataset):
def __init__(self, dataframe, tokenizer, max_seq_length):
self.data = dataframe
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data.iloc[idx]['sentence']
entities = self.data.iloc[idx]['entities']
relation = self.data.iloc[idx]['relation_id']
entities_text = " [SEP] ".join(entities)
text = text + " [SEP] " + entities_text
inputs = self.tokenizer(text, padding='max_length', max_length=self.max_seq_length, return_tensors='pt', truncation=True)
input_ids = inputs['input_ids'].squeeze()
attention_mask = inputs['attention_mask'].squeeze()
return input_ids, attention_mask, relation
class REDataset(Dataset):
def __init__(self, dataframe, tokenizer, max_seq_length):
self.data = dataframe
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data.iloc[idx]['sentence']
#entities = self.data.iloc[idx]['entities']
relation = self.data.iloc[idx]['relation_id']
#entities_text = " [SEP] ".join(entities)
#text = text + " [SEP] " + entities_text
inputs = self.tokenizer(text, padding='max_length', max_length=self.max_seq_length, return_tensors='pt', truncation=True)
input_ids = inputs['input_ids'].squeeze()
attention_mask = inputs['attention_mask'].squeeze()
return input_ids, attention_mask, relation
class REModelWithAttention(nn.Module):
def __init__(self, tokenizer, num_classes):
super(REModelWithAttention, self).__init__()
self.tokenizer = tokenizer
self.bert = AutoModel.from_pretrained("bert-base-uncased") # You can choose a different transformer model
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.fc(pooled_output)
return logits
def train(model, train_loader, valid_loader, criterion, optimizer,device, patience=5, num_epochs=20):
best_loss = float('inf')
current_patience = 0
val_loss = []
trn_loss = []
for epoch in range(num_epochs):
# Training
model.train()
total_loss = 0.0
for input_ids, attention_mask, targets in train_loader:
input_ids, attention_mask, targets = input_ids.to(device), attention_mask.to(device), targets.to(device)
optimizer.zero_grad()
logits = model(input_ids, attention_mask)
loss = criterion(logits, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_train_loss = total_loss / len(train_loader)
#Evaluation
model.eval()
with torch.no_grad():
total_loss = 0.0
for input_ids, attention_mask, targets in valid_loader:
input_ids, attention_mask, targets = input_ids.to(device), attention_mask.to(device), targets.to(device)
logits = model(input_ids, attention_mask)
loss = criterion(logits, targets)
total_loss += loss.item()
average_loss = total_loss / len(valid_loader)
print(f'Epoch [{epoch+1}/{num_epochs}] - Training Loss: {avg_train_loss:.4f}- Validation Loss: {average_loss:.4f}')
val_loss.append(average_loss)
trn_loss.append(avg_train_loss)
# Check for early stopping
if average_loss < best_loss:
best_loss = average_loss
current_patience = 0
else:
current_patience += 1
if current_patience >= patience:
print(f'Early stopping after {epoch+1} epochs without improvement.')
break
return trn_loss, val_loss
def evaluate_model(model, data_loader,device):
model.eval()
total_loss = 0.0
correct_predictions = 0
total_samples = 0
all_predictions = []
all_targets = []
with torch.no_grad():
for input_ids, attention_mask, targets in data_loader:
input_ids, attention_mask, targets = input_ids.to(device), attention_mask.to(device), targets.to(device)
outputs = model(input_ids, attention_mask)
loss = criterion(outputs, targets)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == targets).sum().item()
total_samples += targets.size(0)
all_predictions.extend(predicted.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
average_loss = total_loss / len(data_loader)
accuracy = accuracy_score(all_targets, all_predictions)
precision = precision_score(all_targets, all_predictions, average='macro')
recall = recall_score(all_targets, all_predictions, average='macro')
f1 = f1_score(all_targets, all_predictions, average='macro')
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
def sampleData(df,group_by):
balanced_data = []
for relation, group in df.groupby(group_by):
# Sample sentences for each relation up to the minimum sentence count
sampled_group = group.sample(21)
balanced_data.append(sampled_group)
# Combine the sampled data into a new DataFrame
new_df = pd.concat(balanced_data, ignore_index=True)
return new_df