-
Notifications
You must be signed in to change notification settings - Fork 0
/
multi-gpu-llama2.py
264 lines (227 loc) · 10.1 KB
/
multi-gpu-llama2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import torch
import torch.distributed as dist
from transformers import LlamaTokenizer, LlamaForCausalLM
from datasets import load_dataset
import os
import json
from tqdm import tqdm
import argparse
def setup_distributed():
"""
Initialize the distributed environment.
"""
try:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist_backend = 'nccl'
# Initialize the process group
dist.init_process_group(backend=dist_backend, rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
print(f"Rank {rank}: Distributed process group initialized on device {device}.")
return rank, world_size, local_rank, device
except KeyError as e:
print(f"Rank {rank if 'rank' in locals() else 'Unknown'}: Environment variable {e} not set.")
raise
except Exception as e:
print(f"Rank {rank if 'rank' in locals() else 'Unknown'}: Failed to initialize distributed environment - {str(e)}")
raise
def cleanup_distributed():
"""
Clean up the distributed environment.
"""
try:
dist.destroy_process_group()
print("Distributed process group destroyed.")
except Exception as e:
print(f"Error during distributed cleanup: {str(e)}")
def shard_dataset(dataset, rank, world_size):
"""
Shard the dataset for each process.
"""
try:
total_samples = len(dataset)
per_gpu = total_samples // world_size
start = rank * per_gpu
# Ensure the last GPU gets any remaining samples
end = start + per_gpu if rank != world_size - 1 else total_samples
dataset_shard = dataset.select(range(start, end))
print(f"Rank {rank}: Sharded samples from index {start} to {end-1}. Shard size: {len(dataset_shard)}")
return dataset_shard
except Exception as e:
print(f"Rank {rank}: Failed to shard dataset - {str(e)}")
raise
def batch_loader(dataset, batch_size):
"""
Generator to yield batches from the dataset.
Each batch is a list of dictionaries representing individual samples.
"""
for i in range(0, len(dataset), batch_size):
# Select the batch range from the dataset
batch = dataset[i : i + batch_size]
# Convert the batch to a list of dictionaries directly from the dataset
# Assuming 'selected_caption' is a field in the dataset
batch_dicts = [{key: batch[key][j] for key in batch.keys()} for j in range(len(batch['selected_caption']))]
indices = list(range(i, min(i + batch_size, len(dataset))))
yield batch_dicts, indices
def inference(rank, device, model_name, dataset_shard, batch_size, output_dir):
"""
Inference function for each process.
"""
try:
# Load model and tokenizer
model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model.to(device)
model.eval()
print(f"Rank {rank}: Model loaded on {device}")
tokenizer = LlamaTokenizer.from_pretrained(model_name)
print(f"Rank {rank}: Tokenizer loaded")
# Shard the dataset
shard_size = len(dataset_shard)
print(f"Rank {rank}: Dataset sharded, size: {shard_size}")
# Sanity Check: Print the first sample
if shard_size > 0:
first_sample = dataset_shard[0]
print(f"Rank {rank}: First sample keys: {first_sample.keys()}")
print(f"Rank {rank}: First sample: {first_sample}")
results = []
# Calculate the number of batches
batch_count = (shard_size + batch_size - 1) // batch_size
for batch_idx, (batch, indices) in enumerate(tqdm(batch_loader(dataset_shard, batch_size),
total=batch_count,
desc=f"Rank {rank} Processing")):
print(f"Rank {rank}: Processing batch {batch_idx + 1}/{batch_count}")
for sample, idx in zip(batch, indices):
try:
selected_caption = sample['selected_caption']
except KeyError:
print(f"Rank {rank}: 'selected_caption' not found in sample {idx}. Skipping.")
continue
input_text = f"The following caption seems weird: '{selected_caption}'. Explain why it feels unusual."
inputs = tokenizer(input_text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(inputs['input_ids'], max_new_tokens=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Create a unique_id by combining rank and original id to ensure uniqueness
unique_id = f"{rank}_{sample['image_id']}"
results.append({
"unique_id": unique_id,
"input_prompt": input_text,
"caption": selected_caption,
"generated_response": generated_text
})
# Save results to a JSON file specific to this rank as a list
output_file = os.path.join(output_dir, f"results_rank_{rank}.json")
with open(output_file, 'w') as f:
json.dump(results, f, indent=4)
print(f"Rank {rank}: Results written to {output_file}. Number of entries: {len(results)}")
except Exception as e:
print(f"Rank {rank}: Error during inference - {str(e)}")
raise
def main(args):
"""
Main function to perform distributed inference.
"""
# Initialize distributed environment
try:
rank, world_size, local_rank, device = setup_distributed()
except Exception as e:
print(f"Rank {rank if 'rank' in locals() else 'Unknown'}: Initialization failed - {str(e)}")
return
# Each rank loads the dataset independently
dataset_name = 'nlphuji/whoops'
try:
dataset = load_dataset(dataset_name)['test']
print(f"Rank {rank}: Dataset loaded, size: {len(dataset)}")
except Exception as e:
print(f"Rank {rank}: Failed to load dataset - {str(e)}")
cleanup_distributed()
return
# Shard the dataset
try:
dataset_shard = shard_dataset(dataset, rank, world_size)
print(f"Rank {rank}: Dataset sharded, size: {len(dataset_shard)}")
except Exception as e:
print(f"Rank {rank}: Failed to shard dataset - {str(e)}")
cleanup_distributed()
return
# Define model and batch size
# model_name = "meta-llama/Llama-2-7b-hf"
model_name = args.model_name
# batch_size = 2
batch_size = args.batch_size
# Create output directory
# output_dir = 'results'
output_dir = args.output_dir
try:
if rank == 0:
os.makedirs(output_dir, exist_ok=True)
print(f"Rank {rank}: Created output directory '{output_dir}'.")
except Exception as e:
print(f"Rank {rank}: Failed to create output directory - {str(e)}")
cleanup_distributed()
return
# Ensure all ranks wait until directory is created
try:
dist.barrier()
print(f"Rank {rank}: Passed the output directory barrier.")
except Exception as e:
print(f"Rank {rank}: Barrier synchronization failed - {str(e)}")
cleanup_distributed()
return
# Perform inference
try:
inference(rank, device, model_name, dataset_shard, batch_size, output_dir)
except Exception as e:
print(f"Rank {rank}: Inference failed - {str(e)}")
cleanup_distributed()
return
# Ensure all processes have finished inference before aggregating
try:
dist.barrier()
print(f"Rank {rank}: Passed the inference barrier.")
except Exception as e:
print(f"Rank {rank}: Barrier synchronization failed - {str(e)}")
cleanup_distributed()
return
# Only rank 0 will aggregate results
if rank == 0:
combined_results = []
for r in range(world_size):
output_file = os.path.join(output_dir, f"results_rank_{r}.json")
if os.path.exists(output_file):
try:
with open(output_file, 'r') as f:
rank_results = json.load(f)
combined_results.extend(rank_results)
print(f"Rank {rank}: Loaded results from Rank {r}. Number of entries: {len(rank_results)}")
except Exception as e:
print(f"Rank {rank}: Failed to load results from Rank {r} - {str(e)}")
else:
print(f"Rank {rank}: Result file {output_file} not found.")
# Write the combined results to a single JSON file
# final_output_file = 'results_combined.json'
final_output_file = args.final_output_file
try:
with open(final_output_file, 'w') as f:
json.dump(combined_results, f, indent=4)
print(f"Rank {rank}: Combined results written to {final_output_file}. Total entries: {len(combined_results)}")
except Exception as e:
print(f"Rank {rank}: Failed to write combined results - {str(e)}")
# Cleanup distributed environment after all operations
cleanup_distributed()
if __name__ == "__main__":
# Optional: Enable NCCL debugging for troubleshooting
os.environ['NCCL_DEBUG'] = 'INFO'
os.environ['NCCL_DEBUG_SUBSYS'] = 'ALL'
os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker'
# Take arguments from command line
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="results", help="Output directory for results")
parser.add_argument("--batch_size", type=int, default=2, help="Batch size for inference")
parser.add_argument("--final_output_file", type=str, default="results_combined.json", help="Final output file for combined results")
# model_name = "meta-llama/Llama-2-7b-hf"
parser.add_argument("--model_name", type=str, default="meta-llama/Llama-2-7b-hf", help="Model name for inference")
args = parser.parse_args()
main(args)