forked from fofr/cog-stickers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
weights_manifest.py
91 lines (76 loc) · 2.97 KB
/
weights_manifest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import subprocess
import time
import os
import json
from helpers.ComfyUI_BRIA_AI_RMBG import ComfyUI_BRIA_AI_RMBG
UPDATED_WEIGHTS_MANIFEST_URL = f"https://weights.replicate.delivery/default/comfy-ui/weights.json?cache_bypass={int(time.time())}"
UPDATED_WEIGHTS_MANIFEST_PATH = "updated_weights.json"
WEIGHTS_MANIFEST_PATH = "weights.json"
BASE_URL = "https://weights.replicate.delivery/default/comfy-ui"
BASE_PATH = "ComfyUI/models"
class WeightsManifest:
def __init__(self):
self.weights_manifest = self._load_weights_manifest()
self.weights_map = self._initialize_weights_map()
def _load_weights_manifest(self):
self._download_updated_weights_manifest()
return self._merge_manifests()
def _download_updated_weights_manifest(self):
if not os.path.exists(UPDATED_WEIGHTS_MANIFEST_PATH):
print(
f"Downloading updated weights manifest from {UPDATED_WEIGHTS_MANIFEST_URL}"
)
start = time.time()
subprocess.check_call(
[
"pget",
"--log-level",
"warn",
"-f",
UPDATED_WEIGHTS_MANIFEST_URL,
UPDATED_WEIGHTS_MANIFEST_PATH,
],
close_fds=False,
)
print(
f"Downloading {UPDATED_WEIGHTS_MANIFEST_URL} took: {(time.time() - start):.2f}s"
)
else:
print("Updated weights manifest file already exists")
def _merge_manifests(self):
if os.path.exists(WEIGHTS_MANIFEST_PATH):
with open(WEIGHTS_MANIFEST_PATH, "r") as f:
original_manifest = json.load(f)
else:
original_manifest = {}
with open(UPDATED_WEIGHTS_MANIFEST_PATH, "r") as f:
updated_manifest = json.load(f)
for key in updated_manifest:
if key in original_manifest:
for item in updated_manifest[key]:
if item not in original_manifest[key]:
print(f"Adding {item} to {key}")
original_manifest[key].append(item)
else:
original_manifest[key] = updated_manifest[key]
return original_manifest
def _generate_weights_map(self, keys, dest):
return {
key: {
"url": f"{BASE_URL}/{dest}/{key}.tar",
"dest": f"{BASE_PATH}/{dest}",
}
for key in keys
}
def _initialize_weights_map(self):
weights_map = {}
for key in self.weights_manifest.keys():
if key.isupper():
weights_map.update(
self._generate_weights_map(self.weights_manifest[key], key.lower())
)
weights_map.update(ComfyUI_BRIA_AI_RMBG.weights_map(BASE_URL))
print("Allowed weights:")
for weight in weights_map.keys():
print(weight)
return weights_map