Skip to content

Commit

Permalink
Fixing minor bugs in vocab gen
Browse files Browse the repository at this point in the history
  • Loading branch information
svkeerthy committed Oct 5, 2024
1 parent 0615c8e commit a379155
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 84 deletions.
4 changes: 2 additions & 2 deletions seed_embeddings/OpenKE/config/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
save_steps=None,
checkpoint_dir=None,
index_dir=None,
out_path=None,
analogy_file="analogies.txt",
):

self.work_threads = 8
Expand All @@ -54,7 +54,7 @@ def __init__(
# self.out_path = out_path

self.entity_names = self.load_entity_names(index_dir)
self.analogies = analogy.AnalogyScorer(analogy_file="/lfs1/usrscratch/staff/nvk1tb/IR2Vec/seed_embeddings/OpenKE/analogies.txt")
self.analogies = analogy.AnalogyScorer(analogy_file=analogy_file)

def load_entity_names(self, index_dir):
with open(os.path.join(index_dir, "entity2id.txt")) as fEntity:
Expand Down
206 changes: 124 additions & 82 deletions seed_embeddings/OpenKE/generate_embedding_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"


def test_files(index_dir):
entities = os.path.join(index_dir, "entity2id.txt")
relations = os.path.join(index_dir, "relation2id.txt")
Expand All @@ -41,67 +40,56 @@ def test_files(index_dir):
if not os.path.exists(train):
raise Exception("train2id.txt not found")


# TODO :: alpha, lmda, bern, opt_method
def train(arg_conf):

try:
test_files(arg_conf["index_dir"])
print("Files are OK")
except:
print(arg_conf)
print("Error in files")
raise Exception("Error in files")

def train(config, args=None):
# dataloader for training

train_dataloader = TrainDataLoader(
in_path=arg_conf["index_dir"],
nbatches=arg_conf["nbatches"],
in_path=args.index_dir,
batch_size=config["batch_size"],
threads=4,
sampling_mode="normal",
bern_flag=arg_conf["bern"],
# bern_flag=config["bern"],
filter_flag=1,
neg_ent=arg_conf["neg_ent"],
neg_rel=arg_conf["neg_rel"],
# neg_ent=config["neg_ent"],
# neg_rel=config["neg_rel"],
)
# dataloader for test (link prediction)
if arg_conf["link_pred"]:
test_dataloader = TestDataLoader(arg_conf["index_dir"], "link")
if args.link_pred:
test_dataloader = TestDataLoader(args.index_dir, "link")
else:
test_dataloader = None

transe = TransE(
ent_tot=train_dataloader.get_ent_tot(),
rel_tot=train_dataloader.get_rel_tot(),
dim=arg_conf["dim"],
dim=args.dim,
p_norm=1,
norm_flag=True,
)
# define the loss function
model = NegativeSampling(
model=transe,
loss=MarginLoss(margin=arg_conf["margin"]),
loss=MarginLoss(margin=config["margin"]),
batch_size=train_dataloader.get_batch_size(),
)
# train the model
trainer = Trainer(
model=model,
data_loader=train_dataloader,
train_times=arg_conf["epoch"],
alpha=arg_conf["alpha"],
index_dir=arg_conf["index_dir"],
use_gpu=arg_conf["use_gpu"],
train_times=args.epoch,
alpha=config["alpha"],
index_dir=args.index_dir,
use_gpu=args.use_gpu,
analogy_file=args.analogy_file,
)
trainer.run(
link_prediction=arg_conf["link_pred"],
link_prediction=args.link_pred,
test_dataloader=test_dataloader,
model=transe,
is_analogy=arg_conf["is_analogy"],
is_analogy=args.is_analogy,
)


def findRep(src, dest, index_dir, src_type="json"):
def findRep(src, index_dir, src_type="json"):
rep = None
if src_type == "json":
with open(src) as fSource:
Expand All @@ -116,22 +104,52 @@ def findRep(src, dest, index_dir, src_type="json"):
with open(os.path.join(index_dir, "entity2id.txt")) as fEntity:
content = fEntity.read()

with open(dest, "w") as fDest:
entities = content.split("\n")
toTxt = ""
entities = content.split("\n")
toTxt = ""
for i in range(1, int(entities[0])):
toTxt += entities[i].split("\t")[0] + ":" + str(rep[i - 1]) + ",\n"
toTxt += (
entities[int(entities[0])].split("\t")[0]
+ ":"
+ str(rep[int(entities[0]) - 1])
)
return toTxt

for i in range(1, int(entities[0])):
toTxt += entities[i].split("\t")[0] + ":" + str(rep[i - 1]) + ",\n"
toTxt += (
entities[int(entities[0])].split("\t")[0]
+ ":"
+ str(rep[int(entities[0]) - 1])
)
fDest.write(toTxt)
def reformat_embeddings(data):
result = []
current_line = ""
is_inside_brackets = False

for line in data.splitlines():
line = line.strip()

if '[' in line:
is_inside_brackets = True

opening_bracket_index = line.index('[')
current_line += line[:opening_bracket_index + 1] + line[opening_bracket_index + 1:].lstrip()

numbers_part = current_line.split('[')[1].strip()
if numbers_part:
numbers = numbers_part.split()
current_line = current_line.split('[')[0] + '[' + ', '.join(numbers)

elif ']' in line:

numbers = line.split(']')[0].split()
current_line += ', '.join(numbers) + ']'
is_inside_brackets = False
result.append(current_line)
current_line = ""

elif is_inside_brackets:
numbers = line.split()
current_line += ', ' + ', '.join(numbers)

return '\n'.join(result)

if __name__ == "__main__":

if __name__ == "__main__":
ray.init()
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -145,7 +163,6 @@ def findRep(src, dest, index_dir, src_type="json"):
parser.add_argument(
"--epoch", dest="epoch", help="Epochs", required=False, type=int, default=1000
)

parser.add_argument(
"--is_analogy",
dest="is_analogy",
Expand All @@ -170,10 +187,9 @@ def findRep(src, dest, index_dir, src_type="json"):
type=int,
default=300,
)

parser.add_argument(
"--nbatches",
dest="nbatches",
"--batch_size",
dest="batch_size",
help="Number of batches",
required=False,
type=int,
Expand All @@ -195,27 +211,46 @@ def findRep(src, dest, index_dir, src_type="json"):
type=bool,
default=False,
)
parser.add_argument(
"--storage_path",
dest="storage_path",
help="Path to store the ray results",
required=False,
type=str,
default=os.path.join(os.path.expanduser("~"), "ray_results"),
)
parser.add_argument(
"--analogy_file",
dest="analogy_file",
help="Path to the analogy file",
required=False,
type=str,
default="./analogies.txt",
)

arg_conf = parser.parse_args()
arg_conf.index_dir = arg_conf.index_dir + "/"

try:
test_files(arg_conf.index_dir)
print("Files are OK")
except Exception as e:
print("Exception: ", e)
print("Error in files")
raise Exception("Error in files")

search_space = {
"epoch": arg_conf.epoch,
"dim": arg_conf.dim,
"index_dir": arg_conf.index_dir,
"nbatches": tune.sample_from(lambda spec: 2 ** np.random.randint(8, 12)),
"margin": tune.quniform(3, 6, 0.5),
"batch_size": tune.sample_from(lambda spec: 2 ** np.random.randint(8, 12)),
"margin": tune.quniform(0.2, 6, 0.2),
"alpha": tune.loguniform(1e-4, 1e-1),
"neg_ent": tune.randint(1, 30),
"neg_rel": tune.randint(1, 30),
"bern": tune.randint(0, 2),
"opt_method": tune.choice(["SGD", "Adam"]),
# "opt_method": tune.choice(["SGD", "Adagrad", "Adam", "Adadelta"]),
"is_analogy": arg_conf.is_analogy,
"link_pred": arg_conf.link_pred,
"use_gpu": arg_conf.use_gpu,
# "neg_ent": tune.randint(1, 30),
# "neg_rel": tune.randint(1, 30),
# "bern": tune.randint(0, 2),
"opt_method": "Adam", #tune.choice(["SGD", "Adam"]),
}

try:
test_files(search_space["index_dir"])
test_files(arg_conf.index_dir)
print("Files are OK")
except:
print("Error in files")
Expand All @@ -234,7 +269,7 @@ def findRep(src, dest, index_dir, src_type="json"):
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=15,
grace_period=10,
reduction_factor=3,
metric=metric,
mode=mode,
Expand All @@ -243,11 +278,13 @@ def findRep(src, dest, index_dir, src_type="json"):

if arg_conf.use_gpu:
train_with_resources = tune.with_resources(
train, resources={"cpu": 8, "gpu": 0.15}
tune.with_parameters(train, args = arg_conf),
resources={"cpu": 8, "gpu": 0.15}
)
else:
train_with_resources = tune.with_resources(
train, resources={"cpu": 10, "gpu": 0}
tune.with_parameters(train, args = arg_conf),
resources={"cpu": 10, "gpu": 0}
)

tuner = tune.Tuner(
Expand All @@ -257,10 +294,10 @@ def findRep(src, dest, index_dir, src_type="json"):
search_alg=optuna,
max_concurrent_trials=12,
scheduler=scheduler,
num_samples=512,
num_samples=128,
),
run_config=RunConfig(
storage_path="/lfs1/usrscratch/staff/nvk1tb/ray_results/",
storage_path=arg_conf.storage_path,
checkpoint_config=CheckpointConfig(
num_to_keep=1,
# *Best* checkpoints are determined by these params:
Expand All @@ -273,7 +310,7 @@ def findRep(src, dest, index_dir, src_type="json"):

# Write the best result to a file, best_result.txt
fin_res = results.get_best_result(metric=metric, mode=mode)
with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
with open(os.path.join(arg_conf.index_dir, "best_result.txt"), "a") as f:
f.write(
"\n" + str(fin_res)
)
Expand All @@ -296,21 +333,22 @@ def findRep(src, dest, index_dir, src_type="json"):

# Get the best configuration
best_config = fin_res.config

print("best_config: ", best_config)
# Extract the values for constructing the file name
epoch = best_config["epoch"]
dim = best_config["dim"]
nbatches = best_config["nbatches"]
dim = arg_conf.dim
batch_size = best_config["batch_size"]
margin = best_config["margin"]
index_dir = best_config["index_dir"]
alpha = best_config["alpha"]
index_dir = arg_conf.index_dir

# Construct the output file name using the best hyperparameters
outfile = os.path.join(
index_dir,
"seedEmbedding_{}E_{}D_{}batches_{}margin.ckpt".format(
epoch,
"seedEmbedding_{}_{}Dim_{}Alpha_{}batchsize_{}margin.ckpt".format(
metric,
dim,
nbatches,
alpha,
batch_size,
margin,
),
)
Expand All @@ -327,20 +365,24 @@ def findRep(src, dest, index_dir, src_type="json"):

embeddings_path = os.path.join(
index_dir,
"embeddings/seedEmbedding_{}E_{}D_{}batches{}margin.txt".format(
epoch,
"seedEmbedding_{}_{}Dim_{}Alpha_{}batchsize_{}margin.ckpt".format(
metric,
dim,
nbatches,
alpha,
batch_size,
margin,
),
)

data = findRep(outfile, index_dir, src_type="ckpt")
formatted_data = reformat_embeddings(data)

# Write the embeddings to outfile
embeddings_path = embeddings_path.replace(".ckpt", ".txt")
print("embeddings_path: ", embeddings_path)
findRep(outfile, embeddings_path, index_dir, src_type="ckpt")
with open(embeddings_path, "w") as f:
f.write(formatted_data)
else:
print("No .ckpt file found in the source directory.")

for result in results:
print(result)
del results

print("Training finished...")

0 comments on commit a379155

Please sign in to comment.