Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vetaerang authored Oct 20, 2021
1 parent 58175ea commit c0a7b42
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions baseline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import pickle
import warnings
warnings.filterwarnings(action='ignore')


from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K
Expand Down Expand Up @@ -241,11 +244,6 @@ def make_model_input(dataset, is_valid=False, is_test = False):

return encoder_input, decoder_input, decoder_output

def decode_sequence(model, input_sequences):

decoded_sequnences = model.predict(input_sequences, batch_size=100)

return decoded_sequnences

def seq2text(input_seq):
temp=''
Expand Down Expand Up @@ -275,6 +273,12 @@ def train(model) :
validation_data = ([encoder_input_val, decoder_input_val], decoder_target_val),
batch_size = 100, callbacks=[es], epochs = 1)

def decode_idx2word(sequence, tar_index_to_word):
sentence = [tar_index_to_word[idx] for idx in sequence]
sentence = ' '.join(sentence)
return sentence


def bind_model(model, parser):
# 학습한 모델을 저장하는 함수입니다.
def save(dir_name, *parser):
Expand All @@ -301,9 +305,12 @@ def load(dir_name, *parser):

def infer(test_path, **kwparser):

tar_index_to_word = dict_for_infer['tar_index_to_word']
src_tokenizer = dict_for_infer['src_tokenizer']
tar_tokenizer = dict_for_infer['tar_tokenizer']

print('tar_index_to_word: \n', tar_index_to_word)

preprocessor = Preprocess()

test_json_path = os.path.join(test_path, 'test_data', '*')
Expand All @@ -314,7 +321,7 @@ def infer(test_path, **kwparser):

test_json_list = preprocessor.make_dataset_list(test_path_list)
test_data = preprocessor.make_set_as_df(test_json_list)

print(f'test_data:\n{test_data["dialogue"]}')
encoder_input_test, decoder_input_test = preprocessor.make_model_input(test_data, is_test= True)

Expand All @@ -330,22 +337,27 @@ def infer(test_path, **kwparser):
for i in range(0, total_data, batch):
if i == 0:
summary = model.predict([encoder_input_test[i:i+batch], decoder_input_test[i:i + batch]]).argmax(2)
summary = [decode_idx2word(batch_idx, tar_index_to_word) for batch_idx in summary]
else:
predict = model.predict([encoder_input_test[i:i+batch], decoder_input_test[i:i+batch]]).argmax(2)
summary = np.concatenate([summary,predict])

predict = model.predict([encoder_input_test[i:i+batch], decoder_input_test[i:i + batch]]).argmax(2)
predict = [decode_idx2word(batch_idx, tar_index_to_word) for batch_idx in predict]
summary.extend(predict)


test_id = test_data['dialogueID']

# DONOTCHANGE: They are reserved for nsml
# 리턴 결과는 [(id, summary)] 의 형태로 보내야만 리더보드에 올릴 수 있습니다.
# ex)[(' efe21026-0715-5ca4-99fe-46d0ecfba147', '철수는 밥을 먹었다.'), ...]

return list(zip(test_id, summary))

# DONOTCHANGE: They are reserved for nsml
# nsml에서 지정한 함수에 접근할 수 있도록 하는 함수입니다.
nsml.bind(save=save, load=load, infer=infer)



if __name__ == '__main__':
parser = argparse.ArgumentParser(description='nia_test')
parser.add_argument('--mode', type=str, default='train')
Expand Down Expand Up @@ -413,6 +425,7 @@ def infer(test_path, **kwparser):
src_index_to_word = src_tokenizer.index_word
tar_word_to_index = tar_tokenizer.word_index
tar_index_to_word = tar_tokenizer.index_word
tar_index_to_word[0] = 'unk'

dict_for_infer = {
'tar_tokenizer' : tar_tokenizer,
Expand All @@ -422,7 +435,9 @@ def infer(test_path, **kwparser):
'tar_vocab' : tar_vocab,
'hidden_size' : hidden_size,
'text_max_len' : text_max_len,
'summary_max_len' : summary_max_len
'summary_max_len' : summary_max_len,

'tar_index_to_word' : tar_index_to_word
}

for epoch in range(args.epochs):
Expand All @@ -431,4 +446,3 @@ def infer(test_path, **kwparser):

# DONOTCHANGE (You can decide how often you want to save the model)
nsml.save(epoch)

0 comments on commit c0a7b42

Please sign in to comment.