-
Notifications
You must be signed in to change notification settings - Fork 7
/
common.py
502 lines (445 loc) · 20.4 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
# The MIT License (MIT)
# © 2024 Chakana.tech
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software.
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import os
import io
import sys
import uuid
import time
import fcntl
import torch
import uvloop
import hashlib
import asyncio
import logging
import tempfile
import aiofiles
import numpy as np
import aiobotocore
import bittensor as bt
import botocore.config
from typing import List, Dict
from dotenv import dotenv_values
from types import SimpleNamespace
from rich.logging import RichHandler
from filelock import FileLock, Timeout
from aiobotocore.session import get_session
from rich.highlighter import NullHighlighter
# Configure loguru logger
FORMAT = "%(message)s"
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
datefmt="[%X]",
handlers=[
RichHandler(
markup=True,
rich_tracebacks=True,
highlighter=NullHighlighter(),
show_level=False,
show_time=False,
show_path=False
)
]
)
logger = logging.getLogger("rich")
logger.setLevel(logging.INFO)
def debug():
logger.setLevel(logging.DEBUG)
def trace():
logger.setLevel(logging.TRACE)
# Log helper.
def T(): return time.time()
def P( w, d ): return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63])"
# Load environment variables
env_config = {**dotenv_values(".env"), **os.environ}
AWS_ACCESS_KEY_ID = env_config.get('AWS_ACCESS_KEY_ID')
AWS_SECRET_ACCESS_KEY = env_config.get('AWS_SECRET_ACCESS_KEY')
# Configure the S3 client
client_config = botocore.config.Config(
max_pool_connections=256,
)
# Set uvloop as the event loop policy
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Define a semaphore to limit concurrent downloads (adjust as needed)
semaphore = asyncio.Semaphore(1000)
async def get_slices( filename:str, device:str ) -> Dict[str, torch.Tensor]:
# Attempt to acquire the lock with a timeout of 1 second.
lock: FileLock = FileLock(f"{filename}.lock")
with lock.acquire(timeout=5):
pass
return torch.load(
filename,
map_location=torch.device(device),
weights_only = True,
)
async def apply_slices_to_model(model: torch.nn.Module, window: int, seed: str, compression: int, key:str = 'slice') -> List[str]:
"""
Applies slices from a specific window to the given model.
Args:
model (torch.nn.Module): The PyTorch model to which the slices will be applied.
window (int): The window identifier.
seed (str): The seed used for generating indices.
compression (int): The compression factor.
Returns:
List[str]: A list of all the slice files that were applied.
"""
# First get the indices associated with the window given the model.
indices_dict = await get_indices_for_window(model, seed, compression)
# Load all the slices associated with this window.
slice_files = await load_files_for_window(window=window, key = key)
# Dictionary to keep track of the number of slices applied per parameter.
slices_per_param = {name: 0 for name, _ in model.named_parameters()}
# Dictionary to accumulate the sum of values for each parameter.
param_sums = {name: torch.zeros_like(param.data) for name, param in model.named_parameters()}
# Iterate over each slice file and compute the sum of values.
for file_i in slice_files:
# Create a file lock to ensure exclusive access to the slice file.
try:
slice_i = await get_slices(file_i, model.device)
for name, param in model.named_parameters():
if name not in indices_dict or name not in slice_i:
continue
values = slice_i[name].to(param.data.device)
param_indices = indices_dict[name].to(param.data.device)
param_sums[name].view(-1)[param_indices] += values
slices_per_param[name] += 1
del values
del slice_i
except Timeout:
# The lock could not be acquired within the timeout.
logger.error(f"Timeout occurred while trying to acquire lock on {file_i}")
continue
except Exception as e:
logger.exception(f"Error applying slice from {file_i}: {e}")
# Apply the average to the parameters.
for name, param in model.named_parameters():
if name not in slices_per_param or name not in indices_dict or slices_per_param[name] == 0:
continue
param_indices = indices_dict[name].to(param.data.device)
avg_param = param_sums[name].view(-1)[param_indices] / slices_per_param[name]
avg_param = avg_param.to(param.data.dtype)
avg_param = avg_param.to(param.data.device)
param.data.view(-1)[param_indices] = avg_param.clone()
# Return the list of the files applied.
return slice_files
async def upload_slice_for_window(bucket: str, model: torch.nn.Module, window: int, seed: str, wallet: 'bt.wallet', compression: int, key:str = 'slice'):
"""
Uploads a compressed slice of a PyTorch model to an S3 bucket.
Args:
bucket (str): Name of the S3 bucket.
model (torch.nn.Module): The PyTorch model to be sliceed and uploaded.
window (int): The window identifier.
wallet (bt.wallet): The wallet object containing the hotkey.
compression (int): The compression factor.
"""
filename = f'{key}-{window}-{wallet.hotkey.ss58_address}.pt'
logger.debug(f"Uploading slice to S3: {filename}")
model_state_dict = model.state_dict()
indices = await get_indices_for_window(model, seed, compression)
# Apply the slice to the model parameters
for name, param in model.named_parameters():
model_state_dict[name] = param.data.view(-1)[indices[name].to(model.device)].cpu()
# Create a temporary file and write the sliceed model state dictionary to it
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
torch.save(model_state_dict, temp_file)
temp_file_name = temp_file.name # Store the temporary file name
# Upload the file to S3
session = get_session()
async with session.create_client(
's3',
region_name='us-east-1',
config=client_config,
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
) as s3_client:
try:
with open(temp_file_name, 'rb') as f:
await s3_client.put_object(Bucket=bucket, Key=filename, Body=f)
# Set the object ACL to public-read
await s3_client.put_object_acl(
Bucket=bucket,
Key=filename,
ACL='public-read'
)
logger.debug(f"Successfully uploaded slice to S3: {filename}")
except Exception:
logger.exception(f"Failed to upload slice {filename} to S3")
finally:
# Clean up the temporary file
os.remove(temp_file_name)
logger.debug(f"Temporary file {temp_file_name} removed")
async def upload_master(bucket: str, model: torch.nn.Module, wallet: 'bt.wallet'):
"""
Uploads the master PyTorch model to an S3 bucket.
Args:
bucket (str): Name of the S3 bucket.
model (torch.nn.Module): The PyTorch model to be uploaded.
wallet (bt.wallet): The wallet object containing the hotkey.
"""
upload_filename = f'master-{wallet.hotkey.ss58_address}.pt'
logger.debug(f"Uploading master model to S3: {upload_filename}")
session = get_session()
async with session.create_client(
's3',
region_name='us-east-1',
config=client_config,
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
) as s3_client:
try:
# Create a temporary file and write the model state dictionary to it
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
torch.save(model.state_dict(), temp_file)
temp_file_name = temp_file.name
# Upload the file to S3
with open(temp_file_name, 'rb') as f:
await s3_client.put_object(Bucket=bucket, Key=upload_filename, Body=f)
# Set the object ACL to public-read
await s3_client.put_object_acl(
Bucket=bucket,
Key=upload_filename,
ACL='public-read'
)
logger.debug(f"Successfully uploaded master model to S3: {upload_filename}")
except Exception:
logger.exception(f"Failed to upload master model {upload_filename} to S3")
finally:
# Clean up the temporary file
os.remove(temp_file_name)
logger.debug(f"Temporary file {temp_file_name} removed")
async def get_indices_for_window(model: torch.nn.Module, seed: str, compression: int) -> Dict[str, torch.LongTensor]:
"""
Computes the indices for the given window and compression factor.
Args:
model (torch.nn.Module): The PyTorch model.
seed (str): The window seed identifier.
compression (int): The compression factor.
Returns:
Dict[str, torch.LongTensor]: A dictionary mapping parameter names to index tensors.
"""
logger.debug(f"Computing indices for window seed {seed} with compression {compression}")
result = {}
# Seed the random number generator with the seed
seed = int(hashlib.md5(str(seed).encode('utf-8')).hexdigest(), 16) % (2**32)
rng = np.random.default_rng(seed)
for name, param in model.named_parameters():
# Randomly select indices based on the compression factor
num_indices = max(1, int(param.numel() // compression))
indices = rng.choice(param.numel(), size=num_indices, replace=False)
result[name] = torch.from_numpy(indices).long().cpu()
return result
async def download_file(s3_client, bucket: str, filename: str) -> str:
"""
Downloads a file from S3, using parallel downloads for large files.
Args:
s3_client: The S3 client.
bucket (str): Name of the S3 bucket.
filename (str): The S3 object key (filename).
Returns:
str: The path to the downloaded file in the temporary directory.
"""
async with semaphore:
temp_file = os.path.join(tempfile.gettempdir(), filename)
# Check if the file exists.
if os.path.exists(temp_file):
logger.debug(f"File {temp_file} already exists, skipping download.")
return temp_file
lock_file = f"{temp_file}.lock"
lock = FileLock(lock_file)
try:
# Try to acquire both locks with a timeout
with lock.acquire(timeout=1):
# Proceed to download the file
logger.debug(f"Downloading file {filename} to {temp_file}")
head_response = await s3_client.head_object(Bucket=bucket, Key=filename)
object_size = head_response['ContentLength']
CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB
response = await s3_client.get_object(Bucket=bucket, Key=filename)
async with aiofiles.open(temp_file, 'wb') as outfile:
while True:
chunk = await response['Body'].read(CHUNK_SIZE)
if not chunk:
break
await outfile.write(chunk)
logger.debug(f"Successfully downloaded file {filename} to {temp_file}")
return temp_file
except Timeout:
logger.error(f"Timeout occurred while trying to acquire lock on {lock_file}")
return None
except Exception as e:
logger.exception(f"Failed to download file {filename} from bucket {bucket}: {e}")
return None
finally:
# The lock is automatically released when exiting the 'with' block
pass
async def handle_file(s3_client, bucket: str, filename: str, hotkey: str, window: int):
"""
Handles downloading a single file from S3.
Args:
s3_client: The S3 client.
bucket (str): Name of the S3 bucket.
filename (str): The S3 object key (filename).
hotkey (str): The hotkey identifier.
window (int): The window identifier.
Returns:
SimpleNamespace: An object containing file metadata and the path to the downloaded file.
"""
logger.debug(f"Handling file {filename} for window {window} and hotkey {hotkey}")
temp_file = await download_file(s3_client, bucket, filename)
if temp_file:
return SimpleNamespace(bucket=bucket, hotkey=hotkey, filename=filename, window=window, temp_file=temp_file)
return None
async def process_bucket(s3_client, bucket: str, windows: List[int], key:str = 'slice'):
"""
Processes an S3 bucket to download files matching the given windows.
Args:
s3_client: The S3 client.
bucket (str): Name of the S3 bucket.
windows (List[int]): A list of window identifiers.
Returns:
List[SimpleNamespace]: A list of file metadata and paths for downloaded files.
"""
logger.debug(f"Processing bucket {bucket} for window {windows}")
files = []
paginator = s3_client.get_paginator('list_objects_v2')
for window in windows:
prefix = f'{key}-{window}'
logger.debug(f"Listing objects with prefix {prefix}")
async for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
logger.trace(f"Processing page for prefix {prefix}")
if 'Contents' not in page:
logger.trace(f"No contents found for prefix {prefix}")
continue
download_tasks = []
for obj in page.get('Contents', []):
filename = obj['Key']
logger.trace(f"Processing object with key {filename}")
try:
parts = filename.split('-')
slice_window = int(parts[1])
slice_hotkey = parts[2].split('.')[0]
logger.trace(f"Parsed filename {filename} into window {slice_window} and hotkey {slice_hotkey}")
if slice_window == window:
download_tasks.append(handle_file(s3_client, bucket, filename, slice_hotkey, slice_window))
except Exception:
logger.exception(f"Error processing filename {filename}")
continue
# Download the files concurrently
results = await asyncio.gather(*download_tasks)
files.extend([res for res in results if res])
logger.trace(f"Completed processing page for prefix {prefix}")
logger.trace(f"Completed processing bucket {bucket} for windows {windows}")
return files
async def download_slices_for_buckets_and_windows(buckets: List[str], windows: List[int], key:str = 'slice') -> Dict[int, List[SimpleNamespace]]:
"""
Downloads files from multiple S3 buckets for the given windows.
Args:
buckets (List[str]): A list of S3 bucket names.
windows (List[int]): A list of window identifiers.
Returns:
Dict[int, List[SimpleNamespace]]: A dictionary mapping windows to lists of file metadata and paths.
"""
logger.debug(f"Downloading files for buckets {set(buckets)} and windows {windows}")
session = get_session()
async with session.create_client(
's3',
region_name='us-east-1',
config=client_config,
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
) as s3_client:
tasks = []
for bucket in set(buckets):
if not bucket:
continue
tasks.append(process_bucket(s3_client, bucket, windows, key))
results = await asyncio.gather(*tasks)
# Flatten the list of lists
files = [item for sublist in results for item in sublist]
# Create a dictionary with windows as keys and list of files as values
windows_dict = {}
for file in files:
window = file.window
if window not in windows_dict:
windows_dict[window] = []
windows_dict[window].append(file)
logger.debug(f"Downloaded all files grouped by windows: {windows}")
return windows_dict
async def load_files_for_window(window: int, key: str = 'slice') -> List[str]:
"""
Retrieves the paths to downloaded window files from the temporary directory.
Args:
window (int): The window identifier.
Returns:
List[str]: A list of file paths corresponding to the window.
"""
logger.debug(f"Retrieving files for window {window} from temporary directory")
temp_dir = tempfile.gettempdir()
window_files = []
for filename in os.listdir(temp_dir):
if filename.startswith(f"{key}-{window}-") and filename.endswith(".pt"):
window_files.append(os.path.join(temp_dir, filename))
logger.debug(f"Found file {filename} for window {window}")
return window_files
async def delete_files_before_window(window_max: int, key:str = 'slice'):
"""
Deletes all files on the local machine which have a window id before a specific value window_max.
Args:
window_max (int): The maximum window id. Files with window ids less than this value will be deleted.
"""
logger.debug(f"Deleting files with window id before {window_max}")
temp_dir = tempfile.gettempdir()
for filename in os.listdir(temp_dir):
if filename.startswith(f"{key}-") and ( filename.endswith(".pt") or filename.endswith(".lock") ):
try:
parts = filename.split('-')
window_id = int(parts[1])
if window_id < window_max:
if os.path.exists(filename):
os.remove(filename)
logger.debug(f"Deleted file {filename}")
except Exception as e:
logger.error(f"Error deleting file {filename}: {e}")
async def delete_files_from_bucket_before_window(bucket: str, window_max: int, key: str = 'slice'):
"""
Deletes all files in the specified S3 bucket which have a window id before a specific value window_max.
Args:
bucket (str): The name of the S3 bucket.
window_max (int): The maximum window id. Files with window ids less than this value will be deleted.
"""
logger.debug(f"Deleting files in bucket {bucket} with window id before {window_max}")
session = get_session()
async with session.create_client(
's3',
region_name='us-east-1',
config=client_config,
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
) as s3_client:
try:
response = await s3_client.list_objects_v2(Bucket=bucket)
if 'Contents' in response:
for obj in response['Contents']:
filename = obj['Key']
if filename.startswith(f"{key}-") and filename.endswith(".pt"):
try:
parts = filename.split('-')
window_id = int(parts[1])
if window_id < window_max:
await s3_client.delete_object(Bucket=bucket, Key=filename)
logger.debug(f"Deleted file {filename} from bucket {bucket}")
except Exception as e:
logger.error(f"Error deleting file {filename} from bucket {bucket}: {e}")
except Exception as e:
logger.error(f"Error listing objects in bucket {bucket}: {e}")