From c0a7b4296e5dacdbe1d30e3514d7ada970c51195 Mon Sep 17 00:00:00 2001 From: TAERANG <60762935+taerangi@users.noreply.github.com> Date: Wed, 20 Oct 2021 14:19:40 +0900 Subject: [PATCH] Update main.py --- baseline/main.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/baseline/main.py b/baseline/main.py index 25270d8..b9bf629 100644 --- a/baseline/main.py +++ b/baseline/main.py @@ -14,6 +14,9 @@ from tensorflow.keras.models import Model from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint import pickle +import warnings +warnings.filterwarnings(action='ignore') + from tensorflow.keras.layers import Layer from tensorflow.keras import backend as K @@ -241,11 +244,6 @@ def make_model_input(dataset, is_valid=False, is_test = False): return encoder_input, decoder_input, decoder_output -def decode_sequence(model, input_sequences): - - decoded_sequnences = model.predict(input_sequences, batch_size=100) - - return decoded_sequnences def seq2text(input_seq): temp='' @@ -275,6 +273,12 @@ def train(model) : validation_data = ([encoder_input_val, decoder_input_val], decoder_target_val), batch_size = 100, callbacks=[es], epochs = 1) +def decode_idx2word(sequence, tar_index_to_word): + sentence = [tar_index_to_word[idx] for idx in sequence] + sentence = ' '.join(sentence) + return sentence + + def bind_model(model, parser): # 학습한 모델을 저장하는 함수입니다. def save(dir_name, *parser): @@ -301,9 +305,12 @@ def load(dir_name, *parser): def infer(test_path, **kwparser): + tar_index_to_word = dict_for_infer['tar_index_to_word'] src_tokenizer = dict_for_infer['src_tokenizer'] tar_tokenizer = dict_for_infer['tar_tokenizer'] + print('tar_index_to_word: \n', tar_index_to_word) + preprocessor = Preprocess() test_json_path = os.path.join(test_path, 'test_data', '*') @@ -314,7 +321,7 @@ def infer(test_path, **kwparser): test_json_list = preprocessor.make_dataset_list(test_path_list) test_data = preprocessor.make_set_as_df(test_json_list) - + print(f'test_data:\n{test_data["dialogue"]}') encoder_input_test, decoder_input_test = preprocessor.make_model_input(test_data, is_test= True) @@ -330,15 +337,19 @@ def infer(test_path, **kwparser): for i in range(0, total_data, batch): if i == 0: summary = model.predict([encoder_input_test[i:i+batch], decoder_input_test[i:i + batch]]).argmax(2) + summary = [decode_idx2word(batch_idx, tar_index_to_word) for batch_idx in summary] else: - predict = model.predict([encoder_input_test[i:i+batch], decoder_input_test[i:i+batch]]).argmax(2) - summary = np.concatenate([summary,predict]) - + predict = model.predict([encoder_input_test[i:i+batch], decoder_input_test[i:i + batch]]).argmax(2) + predict = [decode_idx2word(batch_idx, tar_index_to_word) for batch_idx in predict] + summary.extend(predict) + + test_id = test_data['dialogueID'] # DONOTCHANGE: They are reserved for nsml # 리턴 결과는 [(id, summary)] 의 형태로 보내야만 리더보드에 올릴 수 있습니다. # ex)[(' efe21026-0715-5ca4-99fe-46d0ecfba147', '철수는 밥을 먹었다.'), ...] + return list(zip(test_id, summary)) # DONOTCHANGE: They are reserved for nsml @@ -346,6 +357,7 @@ def infer(test_path, **kwparser): nsml.bind(save=save, load=load, infer=infer) + if __name__ == '__main__': parser = argparse.ArgumentParser(description='nia_test') parser.add_argument('--mode', type=str, default='train') @@ -413,6 +425,7 @@ def infer(test_path, **kwparser): src_index_to_word = src_tokenizer.index_word tar_word_to_index = tar_tokenizer.word_index tar_index_to_word = tar_tokenizer.index_word + tar_index_to_word[0] = 'unk' dict_for_infer = { 'tar_tokenizer' : tar_tokenizer, @@ -422,7 +435,9 @@ def infer(test_path, **kwparser): 'tar_vocab' : tar_vocab, 'hidden_size' : hidden_size, 'text_max_len' : text_max_len, - 'summary_max_len' : summary_max_len + 'summary_max_len' : summary_max_len, + + 'tar_index_to_word' : tar_index_to_word } for epoch in range(args.epochs): @@ -431,4 +446,3 @@ def infer(test_path, **kwparser): # DONOTCHANGE (You can decide how often you want to save the model) nsml.save(epoch) -