Skip to content

Commit

Permalink
Fixing minor bugs in seedemb training
Browse files Browse the repository at this point in the history
  • Loading branch information
svkeerthy committed Oct 2, 2024
1 parent 6bd3ddc commit 0615c8e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 73 deletions.
45 changes: 31 additions & 14 deletions seed_embeddings/OpenKE/config/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
# self.out_path = out_path

self.entity_names = self.load_entity_names(index_dir)
self.analogies = analogy.AnalogyScorer(analogy_file="analogies.txt")
self.analogies = analogy.AnalogyScorer(analogy_file="/lfs1/usrscratch/staff/nvk1tb/IR2Vec/seed_embeddings/OpenKE/analogies.txt")

def load_entity_names(self, index_dir):
with open(os.path.join(index_dir, "entity2id.txt")) as fEntity:
Expand Down Expand Up @@ -94,7 +94,7 @@ def getEntityDict(self, ent_embeddings):
"""
entity_dict = {}

for i, entity_name in enumerate(self.entity_dict):
for i, entity_name in enumerate(self.entity_names):
entity_dict[entity_name] = ent_embeddings[i].tolist()

return entity_dict
Expand Down Expand Up @@ -139,7 +139,7 @@ def run(
weight_decay=self.weight_decay,
)
print("Finish initializing...")

best_metric_val = 0.0
training_range = tqdm(range(self.train_times))
for epoch in training_range:
res = 0.0
Expand All @@ -148,6 +148,7 @@ def run(
res += loss
training_range.set_description("Epoch %d | loss: %f" % (epoch, res))
checkpoint = None
save_ckpt = False
if ray and epoch % freq == 0:
metrics = {"loss": res}
# Link Prediction
Expand All @@ -170,27 +171,43 @@ def run(
"hit1": hit1,
}
)
if best_metric_val <= hit1:
best_metric_val = hit1
save_ckpt = True
print("Link Prediction Scores Completed")

if is_analogy:
elif is_analogy:
# self.model => Negative Sampling object
# self.mode.model => Transe model

ent_embeddings = self.model.model.ent_embeddings.weight.data.numpy()
ent_embeddings = self.model.model.ent_embeddings.weight.data.cpu().numpy()
entity_dict = self.getEntityDict(ent_embeddings)
analogy_score = self.analogies.get_analogy_score(entity_dict)
metrics.update({"AnalogiesScore": analogy_score})
print("Analogy Score Completed")

print("Analogy Score completed")

del entity_dict

if best_metric_val <= analogy_score:
best_metric_val = analogy_score
save_ckpt = True

else: # loss
if best_metric_val >= res:
best_metric_val = res
save_ckpt = True

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
# Save the checkpoint...
self.model.save_checkpoint(
os.path.join(
temp_checkpoint_dir,
"checkpoint" + "-" + str(epoch) + ".ckpt",
# Save the checkpoint...
checkpoint = None
if save_ckpt:
self.model.save_checkpoint(
os.path.join(
temp_checkpoint_dir,
"checkpoint" + "-" + str(epoch) + ".ckpt",
)
)
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

train.report(metrics, checkpoint=checkpoint)

Expand Down
95 changes: 36 additions & 59 deletions seed_embeddings/OpenKE/generate_embedding_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"


def test_files(index_dir):
Expand Down Expand Up @@ -199,7 +199,7 @@ def findRep(src, dest, index_dir, src_type="json"):

search_space = {
"epoch": arg_conf.epoch,
"dim": tune.sample_from(lambda spec: 100 * np.random.randint(1, 6)),
"dim": arg_conf.dim,
"index_dir": arg_conf.index_dir,
"nbatches": tune.sample_from(lambda spec: 2 ** np.random.randint(8, 12)),
"margin": tune.quniform(3, 6, 0.5),
Expand All @@ -222,38 +222,28 @@ def findRep(src, dest, index_dir, src_type="json"):
raise Exception("Error in files")

if arg_conf.is_analogy:
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=15,
reduction_factor=3,
metric="AnalogiesScore",
mode="max",
)
optuna = OptunaSearch(metric="AnalogiesScore", mode="max")
metric = "AnalogiesScore"
mode = "max"
elif arg_conf.link_pred:
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=15,
reduction_factor=3,
metric="hit1",
mode="max",
)
optuna = OptunaSearch(metric="hit1", mode="max")
metric = "hit1"
mode = "max"
else:
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=15,
reduction_factor=3,
metric="loss",
mode="min",
)
optuna = OptunaSearch(metric="loss", mode="min")
metric = "loss"
mode = "min"

scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=15,
reduction_factor=3,
metric=metric,
mode=mode,
)
optuna = OptunaSearch(metric=metric, mode=mode)

if arg_conf.use_gpu:
train_with_resources = tune.with_resources(
train, resources={"cpu": 0, "gpu": 0.0625}
train, resources={"cpu": 8, "gpu": 0.15}
)
else:
train_with_resources = tune.with_resources(
Expand All @@ -265,60 +255,47 @@ def findRep(src, dest, index_dir, src_type="json"):
param_space=search_space,
tune_config=TuneConfig(
search_alg=optuna,
max_concurrent_trials=16,
max_concurrent_trials=12,
scheduler=scheduler,
num_samples=512,
),
run_config=RunConfig(
storage_path="/lfs1/usrscratch/staff/nvk1tb/ray_results/",
checkpoint_config=CheckpointConfig(
num_to_keep=1,
# *Best* checkpoints are determined by these params:
checkpoint_score_attribute="loss",
checkpoint_score_order="min",
checkpoint_score_attribute=metric,
checkpoint_score_order=mode,
)
),
)
results = tuner.fit()

# Write the best result to a file, best_result.txt
best_result = None
if arg_conf.is_analogy:

with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
f.write(
"\n" + str(results.get_best_result(metric="AnalogiesScore", mode="max"))
)
fin_res = results.get_best_result(metric=metric, mode=mode)
with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
f.write(
"\n" + str(fin_res)
)

if arg_conf.is_analogy:
print(
"Best Config Based on Analogy Score : ",
results.get_best_result(metric="AnalogiesScore", mode="max"),
fin_res,
)

best_result = results.get_best_result(metric="AnalogiesScore", mode="max")

elif arg_conf.link_pred:

with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
f.write("\n" + str(results.get_best_result(metric="hit1", mode="max")))

print(
"Best Config Based on Hit1 : ",
results.get_best_result(metric="hit1", mode="max"),
fin_res,
)
best_result = results.get_best_result(metric="hit1", mode="max")
else:

with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
f.write("\n" + str(results.get_best_result(metric="loss", mode="min")))

print(
"Best Config Based on Loss : ",
results.get_best_result(metric="loss", mode="min"),
fin_res,
)
best_result = results.get_best_result(metric="loss", mode="min")


# Get the best configuration
best_config = best_result.config
best_config = fin_res.config

# Extract the values for constructing the file name
epoch = best_config["epoch"]
Expand All @@ -337,7 +314,7 @@ def findRep(src, dest, index_dir, src_type="json"):
margin,
),
)
best_checkpoint_path = best_result.checkpoint.path
best_checkpoint_path = fin_res.checkpoint.path
print("best_checkpoint_path is: ", best_checkpoint_path)
file_name = os.listdir(best_checkpoint_path)[0]
print("file_name is: ", file_name)
Expand Down

0 comments on commit 0615c8e

Please sign in to comment.