Skip to content

Commit

Permalink
enable gpu usage (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
christinaexyou authored May 14, 2024
1 parent 1b8159d commit 8516461
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions src/trustyai/language/detoxify/tmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
tokenizer=None,
max_length=150,
model_type: str = "causal_lm",
device=None,
):
if expert_weights is None:
expert_weights = [-0.5, 0.5]
Expand Down Expand Up @@ -94,6 +95,13 @@ def __init__(
)
self.content_feature = "comment_text"

if isinstance(device, str):
self.device = torch.device(device)
else:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)

def load_models(self, experts: list, expert_weights: list = None):
"""Load expert models."""
if expert_weights is not None:
Expand All @@ -102,7 +110,9 @@ def load_models(self, experts: list, expert_weights: list = None):
for expert in experts:
if isinstance(expert, str):
expert = BartForConditionalGeneration.from_pretrained(
expert, forced_bos_token_id=self.tokenizer.bos_token_id
expert,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
expert_models.append(expert)
self.experts = expert_models
Expand Down Expand Up @@ -200,15 +210,21 @@ def train_models(

if model_type is None:
gminus = BartForConditionalGeneration.from_pretrained(
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
base_model,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
elif model_type == "causal_lm":
gminus = AutoModelForCausalLM.from_pretrained(
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
base_model,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
elif model_type == "seq2seq_lm":
gminus = AutoModelForSeq2SeqLM.from_pretrained(
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
base_model,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
else:
raise Exception(f"unsupported model type {model_type}")
Expand Down Expand Up @@ -254,15 +270,21 @@ def train_models(

if model_type is None:
gplus = BartForConditionalGeneration.from_pretrained(
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
base_model,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
elif model_type == "causal_lm":
gplus = AutoModelForCausalLM.from_pretrained(
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
base_model,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
elif model_type == "seq2seq_lm":
gplus = AutoModelForSeq2SeqLM.from_pretrained(
base_model, forced_bos_token_id=self.tokenizer.bos_token_id
base_model,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
else:
raise Exception(f"unsupported model type {model_type}")
Expand Down Expand Up @@ -380,6 +402,7 @@ def rephrase(
model=expert,
tokenizer=self.tokenizer,
top_k=self.tokenizer.vocab_size,
device=self.device,
)
)
for idx in range(len(masked_sentence_tokens)):
Expand Down Expand Up @@ -477,9 +500,10 @@ def compute_mask_logits(
self, model, sequence, verbose: bool = False, mask: bool = True
):
"""Compute mask logits."""
model.to(self.device)
if verbose:
print(f"input sequence: {sequence}")
subseq_ids = self.tokenizer(sequence, return_tensors="pt")
subseq_ids = self.tokenizer(sequence, return_tensors="pt").to(self.device)
if verbose:
raw_outputs = model.generate(**subseq_ids)
print(sequence)
Expand All @@ -502,9 +526,12 @@ def compute_mask_logits_multiple(
self, model, sequences, verbose: bool = False, mask: bool = True
):
"""Compute mask logits multiple."""
model.to(self.device)
if verbose:
print(f"input sequences: {sequences}")
subseq_ids = self.tokenizer(sequences, return_tensors="pt", padding=True)
subseq_ids = self.tokenizer(
sequences, return_tensors="pt", padding=True
).to(self.device)
if verbose:
raw_outputs = model.generate(**subseq_ids)
print(sequences)
Expand Down Expand Up @@ -554,6 +581,7 @@ def score(
model=model,
tokenizer=self.tokenizer,
top_k=10,
device=self.device,
)
for masked_sentence in masked_sentences:
# approximated probabilities for top_k tokens
Expand All @@ -567,7 +595,9 @@ def score(
js_distances = []
for distr_pair in distr_pairs:
js_distance = jensenshannon(
distr_pair[0], distr_pair[1], axis=1
distr_pair[0].cpu().clone().numpy(),
distr_pair[1].cpu().clone().numpy(),
axis=1,
)
if normalize:
js_distance = js_distance / np.average(js_distance)
Expand Down Expand Up @@ -653,7 +683,10 @@ def reflect(
chat_tokenizer.chat_template = chat_template

converse_pipeline = pipeline(
"conversational", model=chat_model, tokenizer=chat_tokenizer
"conversational",
model=chat_model,
tokenizer=chat_tokenizer,
device=self.device,
)

for text_id in range(len(texts)):
Expand Down Expand Up @@ -729,6 +762,7 @@ def reflect(
conversation_output = converse_pipeline(
formatted_messages,
pad_token_id=converse_pipeline.tokenizer.eos_token_id,
device=self.device,
)
if verbose:
print(f"chat conversation:\n{conversation_output}")
Expand Down

0 comments on commit 8516461

Please sign in to comment.