Skip to content

Commit

Permalink
Resolved pre-commit failures
Browse files Browse the repository at this point in the history
  • Loading branch information
Harikaraja committed Sep 23, 2024
1 parent 7c24956 commit 99a5092
Showing 1 changed file with 21 additions and 24 deletions.
45 changes: 21 additions & 24 deletions seed_embeddings/OpenKE/generate_embedding_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ray.train import RunConfig, CheckpointConfig
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


Expand Down Expand Up @@ -207,7 +208,7 @@ def findRep(src, dest, index_dir, src_type="json"):
"neg_rel": tune.randint(1, 30),
"bern": tune.randint(0, 2),
"opt_method": tune.choice(["SGD", "Adam"]),
#"opt_method": tune.choice(["SGD", "Adagrad", "Adam", "Adadelta"]),
# "opt_method": tune.choice(["SGD", "Adagrad", "Adam", "Adadelta"]),
"is_analogy": arg_conf.is_analogy,
"link_pred": arg_conf.link_pred,
"use_gpu": arg_conf.use_gpu,
Expand All @@ -224,7 +225,7 @@ def findRep(src, dest, index_dir, src_type="json"):
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=15,
grace_period=15,
reduction_factor=3,
metric="AnalogiesScore",
mode="max",
Expand All @@ -233,7 +234,7 @@ def findRep(src, dest, index_dir, src_type="json"):
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period= 15,
grace_period=15,
reduction_factor=3,
metric="hit1",
mode="max",
Expand All @@ -242,26 +243,26 @@ def findRep(src, dest, index_dir, src_type="json"):
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period= 15,
reduction_factor=3,
grace_period=15,
reduction_factor=3,
metric="loss",
mode="min",
)
if arg_conf.use_gpu:
train_with_resources = tune.with_resources(
train, resources={"cpu": 0, "gpu": 0.0625}
train, resources={"cpu": 0, "gpu": 0.0625}
)
else:
train_with_resources = tune.with_resources(
train, resources={"cpu": 10, "gpu": 0}
)

tuner = tune.Tuner(
train_with_resources,
param_space=search_space,
tune_config=TuneConfig(
search_alg=OptunaSearch(metric="loss",mode="min"),
max_concurrent_trials=16,
search_alg=OptunaSearch(metric="loss", mode="min"),
max_concurrent_trials=16,
scheduler=scheduler,
num_samples=512,
),
Expand All @@ -275,7 +276,7 @@ def findRep(src, dest, index_dir, src_type="json"):
),
)
results = tuner.fit()

# Write the best result to a file, best_result.txt
best_result = None
if arg_conf.is_analogy:
Expand All @@ -293,24 +294,20 @@ def findRep(src, dest, index_dir, src_type="json"):
best_result = results.get_best_result(metric="AnalogiesScore", mode="max")

elif arg_conf.link_pred:

with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
f.write(
"\n" + str(results.get_best_result(metric="hit1", mode="max"))
)

f.write("\n" + str(results.get_best_result(metric="hit1", mode="max")))

print(
"Best Config Based on Hit1 : ",
results.get_best_result(metric="hit1", mode="max"),
)
best_result = results.get_best_result(metric="hit1", mode="max")
else:

with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
f.write(
"\n" + str(results.get_best_result(metric="loss", mode="min"))
)

f.write("\n" + str(results.get_best_result(metric="loss", mode="min")))

print(
"Best Config Based on Loss : ",
results.get_best_result(metric="loss", mode="min"),
Expand Down Expand Up @@ -338,9 +335,9 @@ def findRep(src, dest, index_dir, src_type="json"):
),
)
best_checkpoint_path = best_result.checkpoint.path
print("best_checkpoint_path is: ",best_checkpoint_path)
print("best_checkpoint_path is: ", best_checkpoint_path)
file_name = os.listdir(best_checkpoint_path)[0]
print("file_name is: ",file_name)
print("file_name is: ", file_name)
if file_name.endswith(".ckpt"):
# Construct full file path
source_file = os.path.join(best_checkpoint_path, file_name)
Expand All @@ -357,7 +354,7 @@ def findRep(src, dest, index_dir, src_type="json"):
margin,
),
)
print("embeddings_path: ",embeddings_path)
print("embeddings_path: ", embeddings_path)
findRep(outfile, embeddings_path, index_dir, src_type="ckpt")
else:
print("No .ckpt file found in the source directory.")
Expand All @@ -366,4 +363,4 @@ def findRep(src, dest, index_dir, src_type="json"):
print(result)
del results

print("Training finished...")
print("Training finished...")

0 comments on commit 99a5092

Please sign in to comment.