diff --git a/seed_embeddings/OpenKE/generate_embeddings.sh b/seed_embeddings/OpenKE/generate_embeddings.sh deleted file mode 100644 index eeb5a224..00000000 --- a/seed_embeddings/OpenKE/generate_embeddings.sh +++ /dev/null @@ -1,21 +0,0 @@ -#! /bin/bash - -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=1.5 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=2.0 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=2.5 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=3.0 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=3.5 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=4.0 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=4.5 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=5.0 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=300 --margin=5.0 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=300 --margin=5.5 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=300 --margin=6.0 - -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=5.5 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=6.0 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=250 --margin=6.5 - -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=275 --margin=6.0 -# python3 generate_embedding.py --index_dir=../preprocessed/ --nbatches=225 --margin=6.0 - diff --git a/seed_embeddings/OpenKE/preprocess.py b/seed_embeddings/OpenKE/preprocess.py index 1c2c1ac5..9d1d2964 100644 --- a/seed_embeddings/OpenKE/preprocess.py +++ b/seed_embeddings/OpenKE/preprocess.py @@ -57,7 +57,7 @@ def getRelationDict(config): return relationDict -def createTrain2ID(ed, rd, config): +def createTrain2ID(entityDict, relationDict, config): ip = open(str(config.tripletFile), "r") content = ip.read() sentences = content.split("\n") @@ -71,23 +71,42 @@ def createTrain2ID(ed, rd, config): l = len(s) if s[0] != "": if opc != "": - if s[0] not in ed: + if s[0] not in entityDict: print(sentence) print(s) print(l) print(str(sentences.index(sentence))) print(s[0] + " not found in ed") - if "Next" not in rd: - print("Next not found in rd") - toWrite += ed[opc] + "\t" + ed[s[0]] + "\t" + rd["Next"] + "\n" + if "Next" not in relationDict: + print("Next not found in relationDict") + toWrite += ( + entityDict[opc] + + "\t" + + entityDict[s[0]] + + "\t" + + relationDict["Next"] + + "\n" + ) nol += 1 opc = s[0] - toWrite += ed[opc] + "\t" + ed[s[1]] + "\t" + rd["Type"] + "\n" + toWrite += ( + entityDict[opc] + + "\t" + + entityDict[s[1]] + + "\t" + + relationDict["Type"] + + "\n" + ) nol += 1 i = 0 for arg in range(2, l): toWrite += ( - ed[opc] + "\t" + ed[s[arg]] + "\t" + rd["Arg" + str(i)] + "\n" + entityDict[opc] + + "\t" + + entityDict[s[arg]] + + "\t" + + relationDict["Arg" + str(i)] + + "\n" ) nol += 1 i += 1