-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/dev' into feature/retrain_for_ne…
…w_users
- Loading branch information
Showing
93 changed files
with
449,108 additions
and
637 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
557 changes: 557 additions & 0 deletions
557
AI/Notebooks/.ipynb_checkpoints/Koja-CR-BERT ---- v2-checkpoint.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
506 changes: 506 additions & 0 deletions
506
AI/Notebooks/.ipynb_checkpoints/Koja-CR-BERT ---- v3-checkpoint.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
173 changes: 173 additions & 0 deletions
173
AI/Notebooks/.ipynb_checkpoints/Untitled-checkpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "e423f1c8-2d49-4d7b-95a7-c17fe21c5ad4", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']\n", | ||
"- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n", | ||
"- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).\n", | ||
"All the weights of TFBertModel were initialized from the PyTorch model.\n", | ||
"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.\n", | ||
"[nltk_data] Downloading package punkt to\n", | ||
"[nltk_data] C:\\Users\\vanzy\\AppData\\Roaming\\nltk_data...\n", | ||
"[nltk_data] Package punkt is already up-to-date!\n", | ||
"[nltk_data] Downloading package stopwords to\n", | ||
"[nltk_data] C:\\Users\\vanzy\\AppData\\Roaming\\nltk_data...\n", | ||
"[nltk_data] Package stopwords is already up-to-date!\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# preprocessing.py\n", | ||
"\n", | ||
"from transformers import BertTokenizer, TFBertModel\n", | ||
"from joblib import load\n", | ||
"import nltk\n", | ||
"from nltk.corpus import stopwords\n", | ||
"from nltk.tokenize import word_tokenize\n", | ||
"from nltk.stem import PorterStemmer\n", | ||
"import tensorflow as tf\n", | ||
"\n", | ||
"# Initialize the BERT tokenizer and model\n", | ||
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", | ||
"model = TFBertModel.from_pretrained('bert-base-uncased')\n", | ||
"\n", | ||
"nltk.download('punkt')\n", | ||
"nltk.download('stopwords')\n", | ||
"\n", | ||
"def bert_feature_extraction(text):\n", | ||
" with tf.device('/GPU:0'):\n", | ||
" inputs = tokenizer.encode_plus(text, return_tensors='tf', add_special_tokens=True, max_length=50, truncation=True, padding='max_length')\n", | ||
" outputs = model(inputs)\n", | ||
" return outputs.last_hidden_state[:,0,:].numpy()\n", | ||
"\n", | ||
"def preprocess_text(text):\n", | ||
" if not isinstance(text, str):\n", | ||
" text = str(text)\n", | ||
"\n", | ||
" tokens = word_tokenize(text)\n", | ||
" \n", | ||
" stop_words = set(stopwords.words('english'))\n", | ||
" tokens = [word for word in tokens if word.lower() not in stop_words]\n", | ||
" \n", | ||
" ps = PorterStemmer()\n", | ||
" tokens = [ps.stem(word) for word in tokens]\n", | ||
" \n", | ||
" return ' '.join(tokens)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "4e94ef96-9295-493f-9bac-3f5228b9039d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"classifier = load('KOJA-CR.joblib')\n", | ||
"def classify_sentence(sentence):\n", | ||
" # Preprocess the sentence\n", | ||
" preprocessed_sentence = preprocess_text(sentence)\n", | ||
" \n", | ||
" # Extract features using BERT\n", | ||
" features = bert_feature_extraction(preprocessed_sentence)\n", | ||
" \n", | ||
" # Reshape the features to match the input shape\n", | ||
" features = features.reshape(1, -1)\n", | ||
" \n", | ||
" # Make the prediction\n", | ||
" return classifier.predict(features)[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"id": "28fcdfdd-df37-41f6-8287-94de015db56c", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Nature\n", | ||
"Kids\n", | ||
"Nature\n", | ||
"Nature\n", | ||
"Arts\n", | ||
"Education\n", | ||
"Kids\n", | ||
"Arts\n", | ||
"Nature\n", | ||
"Education\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"result = classify_sentence(\"Meeting with the marketing team\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Annual company picnic\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Project deadline\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Client presentation\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Team building workshop\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Quarterly business review\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Product launch event\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Board of directors meeting\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Employee training session\")\n", | ||
"print(result)\n", | ||
"\n", | ||
"result = classify_sentence(\"Year-end party\")\n", | ||
"print(result)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d315791c-8d0e-4335-a865-4846a0da05ee", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.