From c4db6be8d98684c0637b0d6e738a0039b2e0a725 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 3 Sep 2023 16:18:46 +1200 Subject: [PATCH 01/48] Added log-in request class --- src/login_request.py | 58 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 src/login_request.py diff --git a/src/login_request.py b/src/login_request.py new file mode 100644 index 0000000..e5ecd86 --- /dev/null +++ b/src/login_request.py @@ -0,0 +1,58 @@ +""" +Login request module. +Defines the LoginRequest class which is used to encode and decode login request packets. +""" +import logging + +from .message_type import MessageType +from .record import Record + + +class LoginRequest(Record): + """ + The LoginRequest class is used to encode and decode login request packets. + """ + def __init__(self, user_name: str): + """ + Create a login request packet + """ + self.user_name = user_name + self.record = bytearray(4) + + def to_bytes(self) -> bytes: + """ + Encode the login request packet into a byte array + """ + logging.info("Creating log-in request as %s", self.user_name) + + self.record[0] = Record.MAGIC_NUMBER >> 8 + self.record[1] = Record.MAGIC_NUMBER & 0xFF + self.record[2] = MessageType.LOGIN.value + self.record[3] = len(self.user_name.encode()) + self.record.extend(self.user_name.encode()) + + return bytes(self.record) + + @classmethod + def from_record(cls, record: bytes) -> "LoginRequest": + """ + Create a login request packet from a byte array + """ + magic_number = record[0] << 8 | record[1] + if magic_number != Record.MAGIC_NUMBER: + raise ValueError("Received message request with incorrect magic number") + + if record[2] != MessageType.LOGIN.value: + raise ValueError("Received log-in request with invalid ID") + + user_name_length = record[3] + user_name = record[4:4 + user_name_length].decode() + + return cls(user_name) + + def decode(self) -> tuple: + """ + Decode the individual fields of the login request packet + :return: A tuple containing the username + """ + return self.user_name, From c79c109a8238c65f15f59493cab88bd742e3a2a5 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 3 Sep 2023 16:33:07 +1200 Subject: [PATCH 02/48] Added incomplete log-in response class --- src/login_response.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/login_response.py diff --git a/src/login_response.py b/src/login_response.py new file mode 100644 index 0000000..c8a01f5 --- /dev/null +++ b/src/login_response.py @@ -0,0 +1,38 @@ +""" +Login response module +Defines class for encoding and decoding login response packets. +""" +from .message_type import MessageType +from .record import Record + + +class LoginResponse(Record): + """ + The LoginResponse class is used to encode and decode login response packets. + """ + def __init__(self, is_registered: bool): + """ + Create a new login response packet. + """ + self.is_registered = is_registered + self.record = bytearray(3) + + def to_bytes(self) -> bytes: + """ + Encode a login response packet into a byte array. + """ + self.record[0] = Record.MAGIC_NUMBER >> 8 + self.record[1] = Record.MAGIC_NUMBER & 0xFF + self.record[2] = MessageType.LOGIN.value + self.record[3] = self.is_registered + # self.record.extend(self.encryption_key.product.to_bytes(4, "big")) + # self.record.extend(self.encryption_key.exponent.to_bytes(4, "big")) + + return bytes(self.record) + + @classmethod + def from_record(cls, record: bytes) -> "LoginResponse": + pass + + def decode(self) -> tuple: + pass From 0bf1f9bedfbc73453c6ca6b5b4af18338f676227 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 3 Sep 2023 16:33:18 +1200 Subject: [PATCH 03/48] Added LOGIN message type --- src/message_type.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/message_type.py b/src/message_type.py index 5668ad0..39bdc8c 100644 --- a/src/message_type.py +++ b/src/message_type.py @@ -14,6 +14,7 @@ class MessageType(Enum): READ = 1 CREATE = 2 RESPONSE = 3 + LOGIN = 4 @staticmethod def from_str(string: str) -> "MessageType": From b649ef6a78ca1f2f5fc02d0d4e0267c1abaec868 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 3 Sep 2023 16:33:52 +1200 Subject: [PATCH 04/48] Starting to change client to send login request --- client.py | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/client.py b/client.py index 4fc0281..a737238 100644 --- a/client.py +++ b/client.py @@ -13,6 +13,8 @@ from src.command_line_application import CommandLineApplication from src.message_response import MessageResponse from src.message_request import MessageRequest +from src.login_response import LoginResponse +from src.login_request import LoginRequest from src.message_type import MessageType from src.port_number import PortNumber @@ -28,17 +30,14 @@ def __init__(self, arguments: list[str]): """ super().__init__(OrderedDict(host_name=self.parse_hostname, port_number=PortNumber, - user_name=self.parse_username, - message_type=MessageType.from_str)) + user_name=self.parse_username)) # pylint thinks that self.parse_arguments is only capable of returning an empty list # pylint: disable=unbalanced-tuple-unpacking - self.host_name, self.port_number, self.user_name, self.message_type =\ - self.parse_arguments(arguments) + self.host_name, self.port_number, self.user_name = self.parse_arguments(arguments) - logging.info("Client for %s port %s created by %s to send %s request", - self.host_name, self.port_number, self.user_name, - self.message_type.name.lower()) + logging.info("Client for %s port %s created by %s", + self.host_name, self.port_number, self.user_name) self.receiver_name = "" self.message = "" @@ -78,6 +77,32 @@ def parse_username(user_name: str) -> str: return user_name + def send_login_request(self) -> LoginResponse: + """ + Sends a login request record to the server + """ + request = LoginRequest(self.user_name) + record = request.to_bytes() + try: + with socket.socket() as connection_socket: + connection_socket.settimeout(1) + connection_socket.connect((self.host_name, self.port_number)) + connection_socket.send(record) + response = connection_socket.recv(4096) + response = LoginResponse.from_record(response) + except ConnectionRefusedError as error: + logging.error(error) + print("Connection refused, likely due to invalid port number") + raise SystemExit from error + except socket.timeout as error: + logging.error(error) + print("Connection timed out, likely due to invalid host name") + raise SystemExit from error + + logging.info("Received login response from server") + + return response + def send_message_request(self, request: MessageRequest) -> Optional[MessageResponse]: """ Sends a message request record to the server From c3c311268550c1645a0f5597307039d3e1d7cd21 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 3 Sep 2023 16:36:14 +1200 Subject: [PATCH 05/48] Starting to change server to work with login request --- server.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/server.py b/server.py index 7876586..dedeba6 100644 --- a/server.py +++ b/server.py @@ -12,6 +12,8 @@ from src.command_line_application import CommandLineApplication from src.message_response import MessageResponse from src.message_request import MessageRequest +from src.login_response import LoginResponse +from src.login_request import LoginRequest from src.message_type import MessageType from src.port_number import PortNumber @@ -34,6 +36,7 @@ def __init__(self, arguments: list[str]): self.hostname = "192.168.68.75" self.messages: dict[str, list[tuple[str, bytes]]] = {} + self.users: dict[str, str] = {} def run(self): try: @@ -53,6 +56,14 @@ def run(self): print("Error binding socket on provided port") raise SystemExit from error + def login_user(self, login_request: LoginRequest) -> LoginResponse: + """ + Check log in request to ensure that the user is registered. + Otherwise, register the user. + """ + user_name, = login_request.decode() + return LoginResponse(user_name in self.users) + def run_server(self, welcoming_socket: socket.socket): """ Runs the server side of the program From 01e69e7ec15a8454a51b15dccf3f7d810d7882c9 Mon Sep 17 00:00:00 2001 From: Harrison Parkes <66008198+hazzery@users.noreply.github.com> Date: Sun, 3 Sep 2023 06:22:52 +0000 Subject: [PATCH 06/48] Added AffineCipher as submodule --- .gitmodules | 3 +++ libs/AffineCipher | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 libs/AffineCipher diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..feeb567 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "libs/AffineCipher"] + path = libs/AffineCipher + url = https://github.com/hazzery/AffineCipher diff --git a/libs/AffineCipher b/libs/AffineCipher new file mode 160000 index 0000000..e855ab9 --- /dev/null +++ b/libs/AffineCipher @@ -0,0 +1 @@ +Subproject commit e855ab94adbde54841194ce17da4db7314acfdfa From b4a7b66ff406f24cd9776817450b24c7adfff66d Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 3 Sep 2023 18:27:28 +1200 Subject: [PATCH 07/48] Small fix in server.py --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index dedeba6..675cdbc 100644 --- a/server.py +++ b/server.py @@ -34,7 +34,7 @@ def __init__(self, arguments: list[str]): # pylint: disable=unbalanced-tuple-unpacking self.port_number, = self.parse_arguments(arguments) - self.hostname = "192.168.68.75" + self.hostname = "localhost" self.messages: dict[str, list[tuple[str, bytes]]] = {} self.users: dict[str, str] = {} From 8fe8ab29a6115cba721dff1790c79702e0bb5594 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Wed, 6 Sep 2023 14:27:11 +1200 Subject: [PATCH 08/48] no longer using AffineCipher as a git submodule, now as installable package --- .gitmodules | 3 --- libs/AffineCipher | 1 - 2 files changed, 4 deletions(-) delete mode 100644 .gitmodules delete mode 160000 libs/AffineCipher diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index feeb567..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "libs/AffineCipher"] - path = libs/AffineCipher - url = https://github.com/hazzery/AffineCipher diff --git a/libs/AffineCipher b/libs/AffineCipher deleted file mode 160000 index e855ab9..0000000 --- a/libs/AffineCipher +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e855ab94adbde54841194ce17da4db7314acfdfa From d71d1025ddd4b35a7d5a3dd0727b6a64a911d3bf Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Thu, 7 Sep 2023 10:36:21 +1200 Subject: [PATCH 09/48] extracted checking of magic number and message type into Record class --- src/login_request.py | 7 ++----- src/message_request.py | 9 ++------- src/message_response.py | 10 ++-------- src/record.py | 16 ++++++++++++++++ 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/login_request.py b/src/login_request.py index e5ecd86..2e44271 100644 --- a/src/login_request.py +++ b/src/login_request.py @@ -38,11 +38,8 @@ def from_record(cls, record: bytes) -> "LoginRequest": """ Create a login request packet from a byte array """ - magic_number = record[0] << 8 | record[1] - if magic_number != Record.MAGIC_NUMBER: - raise ValueError("Received message request with incorrect magic number") - - if record[2] != MessageType.LOGIN.value: + message_type = Record.validate_record(record) + if message_type != MessageType.LOGIN: raise ValueError("Received log-in request with invalid ID") user_name_length = record[3] diff --git a/src/message_request.py b/src/message_request.py index 5a45062..f016f48 100644 --- a/src/message_request.py +++ b/src/message_request.py @@ -69,13 +69,8 @@ def from_record(cls, record: bytes) -> "MessageRequest": Decodes a message request packet :param record: An array of bytes containing the message request """ - magic_number = record[0] << 8 | record[1] - if magic_number != Record.MAGIC_NUMBER: - raise ValueError("Received message request with incorrect magic number") - - if 1 <= record[2] <= 2: - message_type = MessageType(record[2]) - else: + message_type = Record.validate_record(record) + if message_type not in (MessageType.READ, MessageType.CREATE): raise ValueError("Received message request with invalid ID") user_name_length = record[3] diff --git a/src/message_response.py b/src/message_response.py index 8aff7d6..cfac386 100644 --- a/src/message_response.py +++ b/src/message_response.py @@ -55,14 +55,8 @@ def from_record(cls, record: bytes) -> "MessageResponse": Decodes a message response packet :param record: The packet to be decoded """ - magic_number = record[0] << 8 | record[1] - if magic_number != Record.MAGIC_NUMBER: - raise ValueError("Invalid magic number when decoding message response") - - try: - message_type = MessageType(record[2]) - except ValueError as error: - raise ValueError("Invalid message type when decoding message response") from error + message_type = Record.validate_record(record) + if message_type != MessageType.RESPONSE: raise ValueError(f"Message type {message_type} found when decoding message response, " f"expected RESPONSE") diff --git a/src/record.py b/src/record.py index 6bd5934..5b2e196 100644 --- a/src/record.py +++ b/src/record.py @@ -4,6 +4,8 @@ import abc +from message_type import MessageType + class Record(metaclass=abc.ABCMeta): """ @@ -26,6 +28,20 @@ def to_bytes(self) -> bytes: """ raise NotImplementedError + @staticmethod + def validate_record(record) -> MessageType: + """ + Checks the magic number is correct + """ + magic_number = record[0] << 8 | record[1] + if magic_number != Record.MAGIC_NUMBER: + raise ValueError("Invalid magic number when decoding message response") + + try: + return MessageType(record[2]) + except ValueError as error: + raise ValueError("Invalid message type when decoding message response") from error + @classmethod @abc.abstractmethod def from_record(cls, record: bytes) -> "Record": From baa9b1543eb836d9dd4fc34e881ee3b12e733b8f Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Fri, 8 Sep 2023 21:45:37 +1200 Subject: [PATCH 10/48] server now has a public and private key for encrypting packets --- server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server.py b/server.py index 675cdbc..35095ce 100644 --- a/server.py +++ b/server.py @@ -9,6 +9,8 @@ import sys import os +from message_cipher.rsa_system import RSA + from src.command_line_application import CommandLineApplication from src.message_response import MessageResponse from src.message_request import MessageRequest @@ -33,6 +35,7 @@ def __init__(self, arguments: list[str]): # pylint thinks that self.parse_arguments is only capable of returning an empty list # pylint: disable=unbalanced-tuple-unpacking self.port_number, = self.parse_arguments(arguments) + self.public_key, self.private_key = RSA() self.hostname = "localhost" self.messages: dict[str, list[tuple[str, bytes]]] = {} From ec7ccaf0fd3aec47cbde8c6e1be084b341c6243f Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Fri, 8 Sep 2023 21:47:00 +1200 Subject: [PATCH 11/48] login response now gives clients the public key to encrypt packets --- server.py | 2 +- src/login_response.py | 32 ++++++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/server.py b/server.py index 35095ce..f4059d6 100644 --- a/server.py +++ b/server.py @@ -65,7 +65,7 @@ def login_user(self, login_request: LoginRequest) -> LoginResponse: Otherwise, register the user. """ user_name, = login_request.decode() - return LoginResponse(user_name in self.users) + return LoginResponse(user_name in self.users, self.public_key) def run_server(self, welcoming_socket: socket.socket): """ diff --git a/src/login_response.py b/src/login_response.py index c8a01f5..437f2e0 100644 --- a/src/login_response.py +++ b/src/login_response.py @@ -2,6 +2,8 @@ Login response module Defines class for encoding and decoding login response packets. """ +from message_cipher.rsa_encrypter import RsaEncrypter + from .message_type import MessageType from .record import Record @@ -10,10 +12,11 @@ class LoginResponse(Record): """ The LoginResponse class is used to encode and decode login response packets. """ - def __init__(self, is_registered: bool): + def __init__(self, is_registered: bool, encryption_key: RsaEncrypter): """ Create a new login response packet. """ + self.encryption_key = encryption_key self.is_registered = is_registered self.record = bytearray(3) @@ -25,14 +28,31 @@ def to_bytes(self) -> bytes: self.record[1] = Record.MAGIC_NUMBER & 0xFF self.record[2] = MessageType.LOGIN.value self.record[3] = self.is_registered - # self.record.extend(self.encryption_key.product.to_bytes(4, "big")) - # self.record.extend(self.encryption_key.exponent.to_bytes(4, "big")) + self.record.extend(self.encryption_key.product.to_bytes(4, "big")) + self.record.extend(self.encryption_key.exponent.to_bytes(4, "big")) return bytes(self.record) @classmethod def from_record(cls, record: bytes) -> "LoginResponse": - pass + """ + Decode a login response packet from a byte array. + """ + message_type = Record.validate_record(record) + if message_type != MessageType.LOGIN: + raise ValueError(f"Message type {message_type} found when decoding login response, " + f"expected LOGIN") + + is_registered = bool(record[3]) + product = int.from_bytes(record[4:8], "big") + exponent = int.from_bytes(record[8:12], "big") - def decode(self) -> tuple: - pass + encryption_key = RsaEncrypter(product, exponent) + return cls(is_registered, encryption_key) + + def decode(self) -> tuple[bool, RsaEncrypter]: + """ + Decodes the login response packet + :return: A tuple containing the boolean value of is_registered and the encryption key + """ + return self.is_registered, self.encryption_key From 6eb4bfc788dfdf5f094cc16a185169b95a08e964 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 28 Nov 2023 17:40:54 +1300 Subject: [PATCH 12/48] rewrote login_response.py for new Packet structure --- src/login_response.py | 66 ++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/src/login_response.py b/src/login_response.py index 437f2e0..30013c7 100644 --- a/src/login_response.py +++ b/src/login_response.py @@ -2,57 +2,65 @@ Login response module Defines class for encoding and decoding login response packets. """ +import logging +import struct + from message_cipher.rsa_encrypter import RsaEncrypter -from .message_type import MessageType -from .record import Record +from src.message_type import MessageType +from src.packets.packet import Packet -class LoginResponse(Record): +class LoginResponse(Packet, struct_format="!HB?QQ"): """ The LoginResponse class is used to encode and decode login response packets. """ + def __init__(self, is_registered: bool, encryption_key: RsaEncrypter): """ Create a new login response packet. """ self.encryption_key = encryption_key self.is_registered = is_registered - self.record = bytearray(3) + self.packet = bytes() def to_bytes(self) -> bytes: """ Encode a login response packet into a byte array. """ - self.record[0] = Record.MAGIC_NUMBER >> 8 - self.record[1] = Record.MAGIC_NUMBER & 0xFF - self.record[2] = MessageType.LOGIN.value - self.record[3] = self.is_registered - self.record.extend(self.encryption_key.product.to_bytes(4, "big")) - self.record.extend(self.encryption_key.exponent.to_bytes(4, "big")) + logging.info("Creating login response") + + self.packet = struct.pack( + self.struct_format, + Packet.MAGIC_NUMBER, + MessageType.LOGIN.value, + self.is_registered, + self.encryption_key.product, + self.encryption_key.exponent, + ) - return bytes(self.record) + return self.packet @classmethod - def from_record(cls, record: bytes) -> "LoginResponse": - """ - Decode a login response packet from a byte array. - """ - message_type = Record.validate_record(record) - if message_type != MessageType.LOGIN: - raise ValueError(f"Message type {message_type} found when decoding login response, " - f"expected LOGIN") + def decode_packet(cls, packet: bytes) -> tuple[bool, RsaEncrypter]: + header_fields, payload = cls.split_packet(packet) + magic_number, message_type, is_registered, product, exponent = header_fields + + if magic_number != Packet.MAGIC_NUMBER: + raise ValueError("Invalid magic number when decoding message response") - is_registered = bool(record[3]) - product = int.from_bytes(record[4:8], "big") - exponent = int.from_bytes(record[8:12], "big") + try: + message_type = MessageType(message_type) + except ValueError as error: + raise ValueError( + "Invalid message type when decoding message response" + ) from error + if message_type != MessageType.LOGIN: + raise ValueError( + f"Message type {message_type} found when decoding message response, " + "expected LOGIN" + ) encryption_key = RsaEncrypter(product, exponent) - return cls(is_registered, encryption_key) - def decode(self) -> tuple[bool, RsaEncrypter]: - """ - Decodes the login response packet - :return: A tuple containing the boolean value of is_registered and the encryption key - """ - return self.is_registered, self.encryption_key + return is_registered, encryption_key From f3056f2b02ce26c10b34c90d2e68795c49b51e03 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 28 Nov 2023 17:49:47 +1300 Subject: [PATCH 13/48] rewrote login_request.py for new Packet structure --- src/login_request.py | 62 +++++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/src/login_request.py b/src/login_request.py index 2e44271..6a6d722 100644 --- a/src/login_request.py +++ b/src/login_request.py @@ -2,22 +2,25 @@ Login request module. Defines the LoginRequest class which is used to encode and decode login request packets. """ +from typing import Any import logging +import struct -from .message_type import MessageType -from .record import Record +from src.message_type import MessageType +from src.packets.packet import Packet -class LoginRequest(Record): +class LoginRequest(Packet, struct_format="!HBB"): """ The LoginRequest class is used to encode and decode login request packets. """ + def __init__(self, user_name: str): """ Create a login request packet """ self.user_name = user_name - self.record = bytearray(4) + self.packet = bytes() def to_bytes(self) -> bytes: """ @@ -25,31 +28,42 @@ def to_bytes(self) -> bytes: """ logging.info("Creating log-in request as %s", self.user_name) - self.record[0] = Record.MAGIC_NUMBER >> 8 - self.record[1] = Record.MAGIC_NUMBER & 0xFF - self.record[2] = MessageType.LOGIN.value - self.record[3] = len(self.user_name.encode()) - self.record.extend(self.user_name.encode()) + self.packet = struct.pack( + self.struct_format, + Packet.MAGIC_NUMBER, + MessageType.LOGIN.value, + len(self.user_name.encode()), + ) + + self.packet += self.user_name.encode() - return bytes(self.record) + return self.packet @classmethod - def from_record(cls, record: bytes) -> "LoginRequest": + def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: """ - Create a login request packet from a byte array + Decode the login request packet into its individual components + :param packet: The packet to be decoded + :return: A tuple containing the username """ - message_type = Record.validate_record(record) - if message_type != MessageType.LOGIN: - raise ValueError("Received log-in request with invalid ID") + header_fields, payload = cls.split_packet(packet) + magic_number, message_type, user_name_length = header_fields - user_name_length = record[3] - user_name = record[4:4 + user_name_length].decode() + if magic_number != Packet.MAGIC_NUMBER: + raise ValueError("Invalid magic number when decoding message response") - return cls(user_name) + try: + message_type = MessageType(message_type) + except ValueError as error: + raise ValueError( + "Invalid message type when decoding message response" + ) from error + if message_type != MessageType.LOGIN: + raise ValueError( + f"Message type {message_type} found when decoding message response, " + "expected LOGIN" + ) - def decode(self) -> tuple: - """ - Decode the individual fields of the login request packet - :return: A tuple containing the username - """ - return self.user_name, + user_name = payload.decode() + + return (user_name,) From 2833f8defec905d7f806a7a2b1fdd28ee61bcd73 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 28 Nov 2023 17:52:28 +1300 Subject: [PATCH 14/48] discarding unused variable --- src/login_response.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/login_response.py b/src/login_response.py index 30013c7..c79cee1 100644 --- a/src/login_response.py +++ b/src/login_response.py @@ -43,7 +43,7 @@ def to_bytes(self) -> bytes: @classmethod def decode_packet(cls, packet: bytes) -> tuple[bool, RsaEncrypter]: - header_fields, payload = cls.split_packet(packet) + header_fields, _ = cls.split_packet(packet) magic_number, message_type, is_registered, product, exponent = header_fields if magic_number != Packet.MAGIC_NUMBER: From b7c459faa46c53c5676954a2532abd5a7aa4a0ea Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 29 Oct 2024 21:06:53 +1300 Subject: [PATCH 15/48] removed redundant mypy.ini file --- mypy.ini | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 010ed68..0000000 --- a/mypy.ini +++ /dev/null @@ -1,2 +0,0 @@ -[mypy] -python_version = 3.12 From 9f466c83472675bb921a7bd1a5bb2dbf86fdf9cf Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Thu, 31 Oct 2024 18:30:43 +1300 Subject: [PATCH 16/48] Set Ruff to use all rules (ignoring a few) and fixed all issues --- client/__main__.py | 1 + client/client.py | 46 +++++++++++++++++++-------------- logging_config.py | 15 ++++++----- pyproject.toml | 28 ++++++++++++-------- server/__main__.py | 2 +- server/server.py | 42 ++++++++++++++++++------------ src/command_line_application.py | 33 ++++++++++++----------- src/login_request.py | 37 ++++++++++++-------------- src/login_response.py | 35 +++++++++++++++---------- src/message_type.py | 5 ++-- src/packets/message.py | 13 +++++----- src/packets/message_request.py | 17 ++++++------ src/packets/message_response.py | 20 +++++++------- src/packets/packet.py | 10 ++++--- 14 files changed, 169 insertions(+), 135 deletions(-) diff --git a/client/__main__.py b/client/__main__.py index 3c3e005..56ce0f9 100644 --- a/client/__main__.py +++ b/client/__main__.py @@ -6,6 +6,7 @@ import sys from logging_config import configure_logging + from .client import Client diff --git a/client/client.py b/client/client.py index 1c81a45..faae33b 100644 --- a/client/client.py +++ b/client/client.py @@ -1,17 +1,15 @@ """The client module contains the Client class.""" -from collections import OrderedDict -from typing import Optional import logging import socket +from collections import OrderedDict from src.command_line_application import CommandLineApplication -from src.packets.message_response import MessageResponse -from src.packets.message_request import MessageRequest from src.message_type import MessageType +from src.packets.message_request import MessageRequest +from src.packets.message_response import MessageResponse from src.port_number import PortNumber - logger = logging.getLogger(__name__) @@ -20,7 +18,7 @@ class Client(CommandLineApplication): MAX_USERNAME_LENGTH = 255 - def __init__(self, arguments: list[str]): + def __init__(self, arguments: list[str]) -> None: """Initialise the client with specified arguments. :param arguments: A list containing the host name, port number @@ -32,7 +30,7 @@ def __init__(self, arguments: list[str]): port_number=PortNumber, user_name=self.parse_username, message_type=MessageType.from_str, - ) + ), ) # pylint thinks that self.parse_arguments is only capable @@ -67,11 +65,11 @@ def parse_hostname(host_name: str) -> str: try: socket.getaddrinfo(host_name, 1024) except socket.gaierror as error: - logger.error(error) - raise ValueError( - "Invalid host name, must be an IP address, domain name," - ' or "localhost"' - ) from error + message = ( + 'Invalid host name, must be an IP address, domain name, or "localhost"' + ) + logger.exception(message) + raise ValueError(message) from error return host_name @@ -93,7 +91,7 @@ def parse_username(user_name: str) -> str: return user_name - def send_message_request(self, request: MessageRequest) -> Optional[bytes]: + def send_message_request(self, request: MessageRequest) -> bytes | None: """Send a message request record to the server. :param request: The message request to be sent. @@ -110,16 +108,19 @@ def send_message_request(self, request: MessageRequest) -> Optional[bytes]: response = connection_socket.recv(4096) except ConnectionRefusedError as error: - logger.error(error) + logger.exception("Connection refused, likely due to invalid port number") print("Connection refused, likely due to invalid port number") raise SystemExit from error - except socket.timeout as error: - logger.error(error) - print("Connection timed out, likely due to invalid host name") + except TimeoutError as error: + message = "Connection timed out, likely due to invalid host name" + logger.exception(message) + print(message) raise SystemExit from error logger.info( - "%s record sent as %s", self.message_type.name.lower(), self.user_name + "%s record sent as %s", + self.message_type.name.lower(), + self.user_name, ) print(f"{self.message_type.name.lower()} record sent as {self.user_name}") @@ -150,11 +151,16 @@ def run(self) -> None: self.receiver_name = input("Enter the name of the receiver: ") self.message = input("Enter the message to be sent: ") logger.info( - 'User specified message to %s: "%s"', self.receiver_name, self.message + 'User specified message to %s: "%s"', + self.receiver_name, + self.message, ) request = MessageRequest( - self.message_type, self.user_name, self.receiver_name, self.message + self.message_type, + self.user_name, + self.receiver_name, + self.message, ) response = self.send_message_request(request) if self.message_type == MessageType.READ and response: diff --git a/logging_config.py b/logging_config.py index e40b03d..71bf5e0 100644 --- a/logging_config.py +++ b/logging_config.py @@ -4,10 +4,10 @@ for all module loggers. """ -from datetime import datetime +import datetime import logging +import pathlib import sys -import os # pylint: disable=too-few-public-methods @@ -32,21 +32,24 @@ def format(self, record: logging.LogRecord) -> str: def configure_logging(package_name: str) -> None: """Configure logging for the project.""" file_formatter = PathnameFormatter( - "%(asctime)s - %(levelname)-8s - %(pathname)-35s - %(message)s" + "%(asctime)s - %(levelname)-8s - %(pathname)-35s - %(message)s", ) file_formatter.datefmt = "%d-%m-%y - %H:%M:%S.%s" - file_name = datetime.now().strftime("%d-%m-%y %H:%M:%S") + # ruff: noqa: DTZ005 + file_name = datetime.datetime.now().strftime("%d-%m-%y %H:%M:%S") - os.makedirs(os.path.dirname(f"logs/{package_name}/"), exist_ok=True) + (pathlib.Path("logs") / package_name).parent.mkdir(parents=True) file_handler = logging.FileHandler(f"logs/{package_name}/{file_name}.log") file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(file_formatter) console_formatter = PathnameFormatter( - "%(levelname)-8s - %(pathname)-35s - %(message)s" + "%(levelname)-8s - %(pathname)-35s - %(message)s", ) + # ruff: noqa: ERA001 + # To enable printing of logs to stdout, enable below code # stdout_handler = logging.StreamHandler(sys.stdout) # stdout_handler.setLevel(logging.DEBUG) # stdout_handler.addFilter(StdoutHandlerFilter()) diff --git a/pyproject.toml b/pyproject.toml index a852466..e471231 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,15 @@ [project] name = "socket-programming" version = "1.0.0" -authors = [ - { name="Harry Parkes", email="harrydparkes@proton.me" }, -] +authors = [{ name = "Harry Parkes", email = "harrydparkes@proton.me" }] description = "Sever-Client program pair capable of delivering messages between clients" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: GNU AGPL", - "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: GNU Affero General Public License v3", + "Operating System :: OS Independent", + "Intended Audience :: Education", ] [project.urls] @@ -21,7 +20,14 @@ classifiers = [ exclude = ["docs/conf.py"] [tool.ruff.lint] -# Enable Pyflakes (`F`), a subset of the pycodestyle (`E`) codes, and pydocstyle (`D`) -select = ["E4", "E7", "E9", "F", "D", "PL"] -ignore = ["D203", "D213"] - +select = ["ALL"] +ignore = [ + "D203", + "D213", + "EM101", + "TRY003", + "T201", + "FBT001", + "ANN101", + "ANN102", +] diff --git a/server/__main__.py b/server/__main__.py index d50d396..570e181 100644 --- a/server/__main__.py +++ b/server/__main__.py @@ -7,8 +7,8 @@ import sys from logging_config import configure_logging -from .server import Server +from .server import Server logger = logging.getLogger(__name__) diff --git a/server/server.py b/server/server.py index d1f34d0..3eca439 100644 --- a/server/server.py +++ b/server/server.py @@ -1,13 +1,13 @@ """Home to the ``Server`` class.""" -from collections import OrderedDict import logging import socket +from collections import OrderedDict from src.command_line_application import CommandLineApplication -from src.packets.message_response import MessageResponse -from src.packets.message_request import MessageRequest from src.message_type import MessageType +from src.packets.message_request import MessageRequest +from src.packets.message_response import MessageResponse from src.port_number import PortNumber logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ class Server(CommandLineApplication): The server can be run with ``python3 -m server ``. """ - def __init__(self, arguments: list[str]): + def __init__(self, arguments: list[str]) -> None: """Initialise the server with a specified port number. :param arguments: The program arguments from the command line. @@ -46,7 +46,9 @@ def run(self) -> None: # A maximum, of five unprocessed connections are allowed welcoming_socket.listen(5) logger.info( - "Server started on %s port %s", self.hostname, self.port_number + "Server started on %s port %s", + self.hostname, + self.port_number, ) print(f"starting up on {self.hostname} port {self.port_number}") @@ -54,12 +56,15 @@ def run(self) -> None: self.run_server(welcoming_socket) except OSError as error: - logger.error(error) - print("Error binding socket on provided port") + message = "Error binding socket on provided port" + logger.exception(message) + print(message) raise SystemExit from error def process_read_request( - self, connection_socket: socket.socket, sender_name: str + self, + connection_socket: socket.socket, + sender_name: str, ) -> None: """Respond to read requests. @@ -79,7 +84,10 @@ def process_read_request( print(f"{response.num_messages} message(s) delivered to {sender_name}") def process_create_request( - self, sender_name: str, receiver_name: str, message: bytes + self, + sender_name: str, + receiver_name: str, + message: bytes, ) -> None: """Process `create` requests. @@ -99,7 +107,7 @@ def process_create_request( ) print( f"{sender_name} sends the message " - f'"{message.decode()}" to {receiver_name}' + f'"{message.decode()}" to {receiver_name}', ) def run_server(self, welcoming_socket: socket.socket) -> None: @@ -125,9 +133,11 @@ def run_server(self, welcoming_socket: socket.socket) -> None: elif message_type == MessageType.CREATE: self.process_create_request(sender_name, receiver_name, message) - except socket.timeout as error: - logger.error(error) - print("Timed out while waiting for message request") - except ValueError as error: - logger.error(error) - print("Message request discarded") + except TimeoutError: + error_message = "Timed out while waiting for message request" + logger.exception(error_message) + print(error_message) + except ValueError: + error_message = "Message request discarded" + logger.exception(error_message) + print(error_message) diff --git a/src/command_line_application.py b/src/command_line_application.py index c9eadcb..f6e25f9 100644 --- a/src/command_line_application.py +++ b/src/command_line_application.py @@ -1,10 +1,10 @@ """Home to the ``CommandLineApplication`` abstract class.""" -from collections import OrderedDict -from typing import Callable, Any -import logging import abc - +import logging +from collections import OrderedDict +from collections.abc import Callable +from typing import Any logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ class CommandLineApplication(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def __init__(self, parameters: OrderedDict[str, Callable[[str], Any]]): + def __init__(self, parameters: OrderedDict[str, Callable[[str], Any]]) -> None: """Initialise the command line application. :param parameters: A dictionary containing the parameters for @@ -38,18 +38,21 @@ def parse_arguments(self, arguments: list[str]) -> list[Any]: :param arguments: The command line arguments. """ - parsed_arguments = [] + if len(arguments) != len(self.parameters): + print(self.usage_prompt) + print(f"Invalid number of arguments, must be {len(self.parameters)}") + try: - if len(arguments) != len(self.parameters): - raise ValueError( - f"Invalid number of arguments, must be {len(self.parameters)}" + parsed_arguments = [ + parser(argument) + for argument, parser in zip( + arguments, + self.parameters.values(), + strict=True, ) - - for argument, parser in zip(arguments, self.parameters.values()): - parsed_argument = parser(argument) - parsed_arguments.append(parsed_argument) - except (TypeError, ValueError) as error: - logger.error(error) + ] + except TypeError as error: + logger.exception("Incorrect arguments") print(self.usage_prompt) print(error) raise SystemExit from error diff --git a/src/login_request.py b/src/login_request.py index 6a6d722..5a06ff4 100644 --- a/src/login_request.py +++ b/src/login_request.py @@ -1,31 +1,26 @@ -""" -Login request module. +"""Login request module. + Defines the LoginRequest class which is used to encode and decode login request packets. """ -from typing import Any + import logging import struct +from typing import Any from src.message_type import MessageType from src.packets.packet import Packet class LoginRequest(Packet, struct_format="!HBB"): - """ - The LoginRequest class is used to encode and decode login request packets. - """ + """The LoginRequest class is used to encode and decode login request packets.""" - def __init__(self, user_name: str): - """ - Create a login request packet - """ + def __init__(self, user_name: str) -> None: + """Create a login request packet.""" self.user_name = user_name - self.packet = bytes() + self.packet = b"" def to_bytes(self) -> bytes: - """ - Encode the login request packet into a byte array - """ + """Encode the login request packet into a byte array.""" logging.info("Creating log-in request as %s", self.user_name) self.packet = struct.pack( @@ -41,8 +36,8 @@ def to_bytes(self) -> bytes: @classmethod def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: - """ - Decode the login request packet into its individual components + """Decode the login request packet into its individual components. + :param packet: The packet to be decoded :return: A tuple containing the username """ @@ -56,14 +51,16 @@ def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: message_type = MessageType(message_type) except ValueError as error: raise ValueError( - "Invalid message type when decoding message response" + "Invalid message type when decoding message response", ) from error if message_type != MessageType.LOGIN: - raise ValueError( - f"Message type {message_type} found when decoding message response, " - "expected LOGIN" + message = ( + f"Message type {message_type} found when decoding" + " message response, expected LOGIN" ) + raise ValueError(message) + user_name = payload.decode() return (user_name,) diff --git a/src/login_response.py b/src/login_response.py index c79cee1..c59e43f 100644 --- a/src/login_response.py +++ b/src/login_response.py @@ -1,7 +1,8 @@ -""" -Login response module +"""Login response module. + Defines class for encoding and decoding login response packets. """ + import logging import struct @@ -12,22 +13,21 @@ class LoginResponse(Packet, struct_format="!HB?QQ"): - """ - The LoginResponse class is used to encode and decode login response packets. - """ + """The LoginResponse class is used to encode and decode login response packets.""" - def __init__(self, is_registered: bool, encryption_key: RsaEncrypter): - """ - Create a new login response packet. + def __init__(self, is_registered: bool, encryption_key: RsaEncrypter) -> None: + """Create a new login response packet. + + :param is_registered: ``True`` if the requesting user was registered. + :param encryption_key: An RSA public key for encrypting messages. """ + """Create a new login response packet.""" self.encryption_key = encryption_key self.is_registered = is_registered - self.packet = bytes() + self.packet = b"" def to_bytes(self) -> bytes: - """ - Encode a login response packet into a byte array. - """ + """Encode a login response packet into a byte array.""" logging.info("Creating login response") self.packet = struct.pack( @@ -43,6 +43,12 @@ def to_bytes(self) -> bytes: @classmethod def decode_packet(cls, packet: bytes) -> tuple[bool, RsaEncrypter]: + """Decode a message response packet into its individual components. + + :param packet: The packet to be decoded. + :raises ValueError: If the packet is invalid. + :return: A tuple containing a boolean indicating if the + """ header_fields, _ = cls.split_packet(packet) magic_number, message_type, is_registered, product, exponent = header_fields @@ -53,13 +59,14 @@ def decode_packet(cls, packet: bytes) -> tuple[bool, RsaEncrypter]: message_type = MessageType(message_type) except ValueError as error: raise ValueError( - "Invalid message type when decoding message response" + "Invalid message type when decoding message response", ) from error if message_type != MessageType.LOGIN: - raise ValueError( + message = ( f"Message type {message_type} found when decoding message response, " "expected LOGIN" ) + raise ValueError(message) encryption_key = RsaEncrypter(product, exponent) diff --git a/src/message_type.py b/src/message_type.py index 8e9d9cf..b193a2c 100644 --- a/src/message_type.py +++ b/src/message_type.py @@ -21,6 +21,5 @@ def from_str(string: str) -> "MessageType": try: return MessageType[string.upper()] except KeyError as error: - raise ValueError( - f'Invalid message type: {string}, must be "read" or "create"' - ) from error + message = f'Invalid message type: {string}, must be "read" or "create"' + raise ValueError(message) from error diff --git a/src/packets/message.py b/src/packets/message.py index 9d87c80..83bffbd 100644 --- a/src/packets/message.py +++ b/src/packets/message.py @@ -1,7 +1,4 @@ -"""Home to the ``Message`` class. - -which is used to encode and decode -""" +"""Home to the ``Message`` class.""" import struct @@ -15,7 +12,7 @@ class Message(Packet, struct_format="!BH"): a MessageResponse packet. """ - def __init__(self, sender_name: str, message: bytes): + def __init__(self, sender_name: str, message: bytes) -> None: """Create the Message which can be encoded into a packet. :param sender_name: The name of the user sending this message. @@ -23,7 +20,7 @@ def __init__(self, sender_name: str, message: bytes): """ self.sender_name = sender_name self.message = message - self.packet = bytes() + self.packet = b"" def to_bytes(self) -> bytes: """Encode the message into bytes for transmission through a socket. @@ -31,7 +28,9 @@ def to_bytes(self) -> bytes: :return: A ``bytes`` object encoding the message. """ self.packet += struct.pack( - self.struct_format, len(self.sender_name.encode()), len(self.message) + self.struct_format, + len(self.sender_name.encode()), + len(self.message), ) self.packet += self.sender_name.encode() self.packet += self.message diff --git a/src/packets/message_request.py b/src/packets/message_request.py index 9e63021..68d80d4 100644 --- a/src/packets/message_request.py +++ b/src/packets/message_request.py @@ -4,8 +4,8 @@ import struct from src.message_type import MessageType -from .packet import Packet +from .packet import Packet logger = logging.getLogger(__name__) @@ -14,7 +14,7 @@ class MessageRequest(Packet, struct_format="!HBBBH"): """Encoding and decoding of message request packets. Usage: - message_request = MessageRequest(MessageType.CREATE, "sender_name", "receiver_name", "Hi") + message_request = MessageRequest(MessageType.CREATE, "sender", "receiver", "Hi") record = message_request.to_bytes() message_request = MessageRequest.from_record(record) @@ -27,7 +27,7 @@ def __init__( user_name: str, receiver_name: str, message: str, - ): + ) -> None: """Encode a message request packet. :param message_type: The type of the request (READ or CREATE) @@ -39,7 +39,7 @@ def __init__( self.user_name = user_name self.receiver_name = receiver_name self.message = message - self.packet = bytes() + self.packet = b"" def to_bytes(self) -> bytes: """Return the message request packet. @@ -71,6 +71,7 @@ def to_bytes(self) -> bytes: return self.packet + # ruff: noqa: C901 @classmethod def decode_packet(cls, packet: bytes) -> tuple[MessageType, str, str, bytes]: """Decode a message request packet. @@ -99,13 +100,13 @@ def decode_packet(cls, packet: bytes) -> tuple[MessageType, str, str, bytes]: if user_name_size < 1: raise ValueError( - "Received message request with insufficient user name length" + "Received message request with insufficient user name length", ) if message_type == MessageType.READ: if receiver_name_size != 0: raise ValueError( - "Received read request with non-zero receiver name length" + "Received read request with non-zero receiver name length", ) if message_size != 0: raise ValueError("Received read request with non-zero message length") @@ -113,11 +114,11 @@ def decode_packet(cls, packet: bytes) -> tuple[MessageType, str, str, bytes]: elif message_type == MessageType.CREATE: if receiver_name_size < 1: raise ValueError( - "Received create request with insufficient receiver name length" + "Received create request with insufficient receiver name length", ) if message_size < 1: raise ValueError( - "Received create request with insufficient message length" + "Received create request with insufficient message length", ) user_name = payload[:user_name_size].decode() diff --git a/src/packets/message_response.py b/src/packets/message_response.py index 8a3b1a0..0bb8265 100644 --- a/src/packets/message_response.py +++ b/src/packets/message_response.py @@ -3,11 +3,10 @@ import logging import struct -from src.packets.message import Message from src.message_type import MessageType +from src.packets.message import Message from src.packets.packet import Packet - logger = logging.getLogger(__name__) @@ -16,8 +15,8 @@ class MessageResponse(Packet, struct_format="!HBB?"): MAX_MESSAGE_LENGTH = 255 - def __init__(self, messages: list[tuple[str, bytes]]): - """Encode a structure containing all (up to 255) messages for the specified sender. + def __init__(self, messages: list[tuple[str, bytes]]) -> None: + """Encode a structure containing up to 255 messages for a specific sender. :param messages: A list of all the messages to be put in the structure. """ @@ -25,7 +24,7 @@ def __init__(self, messages: list[tuple[str, bytes]]): self.more_messages = len(messages) > MessageResponse.MAX_MESSAGE_LENGTH self.messages = messages[: self.num_messages] - self.packet = bytes() + self.packet = b"" def to_bytes(self) -> bytes: """Return the message response packet. @@ -66,19 +65,20 @@ def decode_packet(cls, packet: bytes) -> tuple[list[tuple[str, str]], bool]: message_type = MessageType(message_type) except ValueError as error: raise ValueError( - "Invalid message type when decoding message response" + "Invalid message type when decoding message response", ) from error if message_type != MessageType.RESPONSE: - raise ValueError( - f"Message type {message_type} found when decoding message response, " - "expected RESPONSE" + message = ( + f"Message type {message_type} found when decoding message" + " response, expected RESPONSE" ) + raise ValueError(message) messages: list[tuple[str, str]] = [] remaining_messages = payload for _ in range(num_messages): sender_name, message, remaining_messages = Message.decode_packet( - remaining_messages + remaining_messages, ) logger.info('Decoded message from %s: "%s"', sender_name, message) messages.append((sender_name, message)) diff --git a/src/packets/packet.py b/src/packets/packet.py index 9e1c247..b3a1909 100644 --- a/src/packets/packet.py +++ b/src/packets/packet.py @@ -1,8 +1,8 @@ """Home to the ``Packet`` abstract class.""" -from typing import Any -import struct import abc +import struct +from typing import Any class Packet(metaclass=abc.ABCMeta): @@ -23,7 +23,7 @@ class MyPacket(Packet, struct_format="!HBBH"): struct_format: str @abc.abstractmethod - def __init__(self, *args: tuple[Any, ...]): + def __init__(self, *args: tuple[Any, ...]) -> None: """Initialise the packet. :param args: All arguments needed to initialise the packet. @@ -66,7 +66,9 @@ def split_packet(cls, packet: bytes) -> tuple[tuple[Any, ...], bytes]: @classmethod def __init_subclass__( - cls, struct_format: str | None = None, **kwargs: tuple[Any, ...] + cls, + struct_format: str | None = None, + **kwargs: tuple[Any, ...], ) -> None: """Ensure ``struct_format`` attribute is present. From 31342e0c8372f9f5eb4ea9c221237f92d6766afa Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sat, 2 Nov 2024 15:14:37 +1300 Subject: [PATCH 17/48] Fixed errors from new Ruff rules in tests directory. Also removed brute_test.py as it was not very assertive, A replacement test for the more_messages flag will replace it. With actual assertions, and without subprocess --- pyproject.toml | 1 + tests/__init__.py | 1 + tests/applications/test_client.py | 15 +-- .../test_command_line_application.py | 18 ++-- tests/applications/test_server.py | 10 +- tests/brute_test.py | 36 ------- tests/packets/test_message.py | 8 +- tests/packets/test_message_request.py | 101 ++++++++++++------ tests/packets/test_message_response.py | 16 +-- 9 files changed, 105 insertions(+), 101 deletions(-) create mode 100644 tests/__init__.py delete mode 100644 tests/brute_test.py diff --git a/pyproject.toml b/pyproject.toml index e471231..c6ca6f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,4 +30,5 @@ ignore = [ "FBT001", "ANN101", "ANN102", + "PT", ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..b069b30 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Uncategorised test suites.""" diff --git a/tests/applications/test_client.py b/tests/applications/test_client.py index adaca98..fd1dc39 100644 --- a/tests/applications/test_client.py +++ b/tests/applications/test_client.py @@ -1,11 +1,11 @@ """Client class test suite.""" -import unittest import socket +import unittest -from src.packets.message_request import MessageRequest -from src.message_type import MessageType from client import Client +from src.message_type import MessageType +from src.packets.message_request import MessageRequest class TestClient(unittest.TestCase): @@ -19,7 +19,7 @@ def test_construction(self) -> None: Client([TestClient.hostname, str(TestClient.port_number), "Alice", "create"]) def test_construction_raise_error(self) -> None: - """Tests that a Client object cannot be constructed given an invalid arguments.""" + """Tests that a Client object cannot be constructed given invalid arguments.""" self.assertRaises( SystemExit, Client, @@ -29,7 +29,7 @@ def test_construction_raise_error(self) -> None: def test_send_message_request(self) -> None: """Tests that a Client object can send a message request.""" client = Client( - [TestClient.hostname, str(TestClient.port_number), "Alice", "create"] + [TestClient.hostname, str(TestClient.port_number), "Alice", "create"], ) user_name = "Alice" receiver_name = "John" @@ -42,7 +42,7 @@ def test_send_message_request(self) -> None: # Send message request from the client client.send_message_request( - MessageRequest(MessageType.CREATE, user_name, receiver_name, message) + MessageRequest(MessageType.CREATE, user_name, receiver_name, message), ) # Accept connection from the client @@ -56,5 +56,6 @@ def test_send_message_request(self) -> None: # Check that the packet is correct request = MessageRequest.decode_packet(packet) self.assertEqual( - (MessageType.CREATE, user_name, receiver_name, message.encode()), request + (MessageType.CREATE, user_name, receiver_name, message.encode()), + request, ) diff --git a/tests/applications/test_command_line_application.py b/tests/applications/test_command_line_application.py index abba5f9..467ad66 100644 --- a/tests/applications/test_command_line_application.py +++ b/tests/applications/test_command_line_application.py @@ -10,25 +10,25 @@ class TestClientParseArguments(unittest.TestCase): """Test suite for Client class.""" def test_fail_subclass(self) -> None: - """Test that we cannot subclass from CommandLineApplication without struct format.""" - try: + """Test Packet __init_subclass__ function. + + Ensure that we cannot subclass from CommandLineApplication + without specifying a struct format. + """ + with self.assertRaises(ValueError): class NoStructFormat(Packet): """No ``struct_formatt`` passed so class will not be created.""" - def __init__(self, *args: tuple[Any, ...]): + def __init__(self, *args: tuple[Any, ...]) -> None: pass def to_bytes(self) -> bytes: - return bytes() + return b"" + # ruff: noqa: ARG003 @classmethod def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: return () NoStructFormat() - - except ValueError: - pass - else: - self.fail("Did not raise, was able to subclass from CommandLineApplication") diff --git a/tests/applications/test_server.py b/tests/applications/test_server.py index 08d637a..859c2b8 100644 --- a/tests/applications/test_server.py +++ b/tests/applications/test_server.py @@ -1,10 +1,10 @@ """Server class test suite.""" -import unittest import socket +import unittest -from src.packets.message_response import MessageResponse from server import Server +from src.packets.message_response import MessageResponse class TestServer(unittest.TestCase): @@ -18,9 +18,11 @@ def test_construction(self) -> None: Server([str(TestServer.port_number)]) def test_construction_raise_error(self) -> None: - """.Tests that a Server object cannot be constructed given an invalid arguments.""" + """Tests that a Server object cannot be constructed given invalid arguments.""" self.assertRaises( - SystemExit, Server, [str(TestServer.port_number), "Extra argument"] + SystemExit, + Server, + [str(TestServer.port_number), "Extra argument"], ) def test_process_read_request(self) -> None: diff --git a/tests/brute_test.py b/tests/brute_test.py deleted file mode 100644 index f7758dc..0000000 --- a/tests/brute_test.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Run many many clients to test message limit of 255. - -This file is used to test the server with a large number of requests. -It is not a unit test, but rather a script that executes a large number of -client requests to test the server's ability to correctly set the `more_messages` flag. -""" - -import subprocess as sp - -with open("tests/resources/names.txt", encoding="utf8") as names_file: - names = names_file.readlines() - -for name in names: - client_program = ["python3", "client.py", "localhost", "1024", name, "create"] - - with sp.Popen( - client_program, text=True, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE - ) as process: - try: - output, errors = process.communicate(input="John\nHello") - except sp.TimeoutExpired: - output, errors = process.communicate() - print(output) - print(errors) - -with sp.Popen( - ["python3", "client.py", "localhost", "1024", "John", "read"], - text=True, - stdin=sp.PIPE, - stdout=sp.PIPE, - stderr=sp.PIPE, -) as process: - output, errors = process.communicate() - -print(output) -print(errors) diff --git a/tests/packets/test_message.py b/tests/packets/test_message.py index d73df7d..3ffe1a1 100644 --- a/tests/packets/test_message.py +++ b/tests/packets/test_message.py @@ -11,7 +11,7 @@ class TestMessage(unittest.TestCase): def test_sender_length_encoding(self) -> None: """Tests that the length of the sender's name is encoded correctly.""" sender_name = "John" - message_bytes = "Hello, World!".encode() + message_bytes = b"Hello, World!" packet = Message(sender_name, message_bytes).to_bytes() expected = len(sender_name.encode()) @@ -21,7 +21,7 @@ def test_sender_length_encoding(self) -> None: def test_message_length_encoding(self) -> None: """Tests that the length of the message is encoded correctly.""" sender_name = "Jack" - message_bytes = "Hello, World!".encode() + message_bytes = b"Hello, World!" packet = Message(sender_name, message_bytes).to_bytes() expected = len(message_bytes) @@ -31,7 +31,7 @@ def test_message_length_encoding(self) -> None: def test_sender_name_encoding(self) -> None: """Tests that the sender's name is encoded correctly.""" sender_name = "Jacob" - message_bytes = "Hello, World!".encode() + message_bytes = b"Hello, World!" packet = Message(sender_name, message_bytes).to_bytes() expected = sender_name @@ -41,7 +41,7 @@ def test_sender_name_encoding(self) -> None: def test_message_encoding(self) -> None: """Tests that the message is encoded correctly.""" sender_name = "James" - message_bytes = "Hello, World!".encode() + message_bytes = b"Hello, World!" packet = Message(sender_name, message_bytes).to_bytes() starting_index = 3 + len(sender_name.encode()) diff --git a/tests/packets/test_message_request.py b/tests/packets/test_message_request.py index 5b33fad..3239737 100644 --- a/tests/packets/test_message_request.py +++ b/tests/packets/test_message_request.py @@ -2,9 +2,9 @@ import unittest -from src.packets.packet import Packet from src.message_type import MessageType from src.packets.message_request import MessageRequest +from src.packets.packet import Packet class TestMessageRequestEncoding(unittest.TestCase): @@ -17,7 +17,10 @@ def test_magic_number_encoding(self) -> None: receiver_name = "Jonty" message = "Hello, World!" packet = MessageRequest( - message_type, user_name, receiver_name, message + message_type, + user_name, + receiver_name, + message, ).to_bytes() expected = Packet.MAGIC_NUMBER @@ -31,7 +34,10 @@ def test_message_type_encoding(self) -> None: receiver_name = "Jonty" message = "Hello, World!" packet = MessageRequest( - message_type, user_name, receiver_name, message + message_type, + user_name, + receiver_name, + message, ).to_bytes() expected = message_type.value @@ -45,7 +51,10 @@ def test_user_name_length_encoding(self) -> None: receiver_name = "Jarod" message = "Hello, World!" packet = MessageRequest( - message_type, user_name, receiver_name, message + message_type, + user_name, + receiver_name, + message, ).to_bytes() expected = len(user_name.encode()) @@ -59,7 +68,10 @@ def test_receiver_name_length_encoding(self) -> None: receiver_name = "Jake" message = "Hello, World!" packet = MessageRequest( - message_type, user_name, receiver_name, message + message_type, + user_name, + receiver_name, + message, ).to_bytes() expected = len(receiver_name.encode()) @@ -73,7 +85,10 @@ def test_message_length_encoding(self) -> None: receiver_name = "Jay" message = "Hello, World!" packet = MessageRequest( - message_type, user_name, receiver_name, message + message_type, + user_name, + receiver_name, + message, ).to_bytes() expected = len(message.encode()) @@ -87,7 +102,10 @@ def test_user_name_encoding(self) -> None: receiver_name = "Jay" message = "Hello, World!" packet = MessageRequest( - message_type, user_name, receiver_name, message + message_type, + user_name, + receiver_name, + message, ).to_bytes() expected = user_name @@ -101,7 +119,10 @@ def test_receiver_name_encoding(self) -> None: receiver_name = "Jesse" message = "Hello, World!" packet = MessageRequest( - message_type, user_name, receiver_name, message + message_type, + user_name, + receiver_name, + message, ).to_bytes() start_index = 7 + len(user_name.encode()) @@ -119,7 +140,10 @@ def test_message_encoding(self) -> None: receiver_name = "Jimmy" message = "Hello, World!" packet = MessageRequest( - message_type, user_name, receiver_name, message + message_type, + user_name, + receiver_name, + message, ).to_bytes() start_index = 7 + len(user_name.encode()) + len(receiver_name.encode()) @@ -139,7 +163,10 @@ def setUp(self) -> None: self.receiver_name = "Jonty" self.message = "Hello, World!" self.packet = MessageRequest( - self.message_type, self.user_name, self.receiver_name, self.message + self.message_type, + self.user_name, + self.receiver_name, + self.message, ).to_bytes() def test_message_type_decoding(self) -> None: @@ -172,29 +199,33 @@ def test_message_decoding(self) -> None: def test_incorrect_magic_number(self) -> None: """Tests that an exception is raised if the magic number is incorrect.""" - self.packet = bytearray(self.packet) - self.packet[0] = 0 + packet = bytearray(self.packet) + packet[0] = 0 + self.packet = bytes(packet) self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) def test_invalid_message_type(self) -> None: """Tests that an exception is raised if the message type is invalid.""" - self.packet = bytearray(self.packet) - self.packet[2] = 0 + packet = bytearray(self.packet) + packet[2] = 0 + self.packet = bytes(packet) self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) def test_response_message_type(self) -> None: """Tests that an exception is raised if the message type is RESPONSE.""" - self.packet = bytearray(self.packet) - self.packet[2] = 3 + packet = bytearray(self.packet) + packet[2] = 3 + self.packet = bytes(packet) self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) def test_insufficient_user_name_length(self) -> None: - """Tests that an exception is raised if the length of the user's name is zero.""" - self.packet = bytearray(self.packet) - self.packet[3] = 0 + """Tests that an exception is raised if the user's name has a length of zero.""" + packet = bytearray(self.packet) + packet[3] = 0 + self.packet = bytes(packet) self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) @@ -203,9 +234,10 @@ def test_non_zero_receiver_name_length_for_read(self) -> None: If the length of the receiver's name is non-zero for a read request. """ - self.packet = bytearray(self.packet) - self.packet[2] = MessageType.READ.value - self.packet[4] = 1 + packet = bytearray(self.packet) + packet[2] = MessageType.READ.value + packet[4] = 1 + self.packet = bytes(packet) self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) @@ -214,10 +246,11 @@ def test_non_zero_message_length_for_read(self) -> None: If the length of the message is non-zero for a read request. """ - self.packet = bytearray(self.packet) - self.packet[2] = MessageType.READ.value - self.packet[4] = 0 - self.packet[6] = 1 + packet = bytearray(self.packet) + packet[2] = MessageType.READ.value + packet[4] = 0 + packet[6] = 1 + self.packet = bytes(packet) self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) @@ -226,9 +259,10 @@ def test_insufficient_receiver_name_length_for_create(self) -> None: If the length of the receiver's name is zero for a create request. """ - self.packet = bytearray(self.packet) - self.packet[2] = MessageType.CREATE.value - self.packet[4] = 0 + packet = bytearray(self.packet) + packet[2] = MessageType.CREATE.value + packet[4] = 0 + self.packet = bytes(packet) self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) @@ -237,9 +271,10 @@ def test_insufficient_message_length_for_create(self) -> None: If the length of the message is zero for a create request. """ - self.packet = bytearray(self.packet) - self.packet[2] = MessageType.CREATE.value - self.packet[5] = 0 - self.packet[6] = 0 + packet = bytearray(self.packet) + packet[2] = MessageType.CREATE.value + packet[5] = 0 + packet[6] = 0 + self.packet = bytes(packet) self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) diff --git a/tests/packets/test_message_response.py b/tests/packets/test_message_response.py index 26204b8..ae28796 100644 --- a/tests/packets/test_message_response.py +++ b/tests/packets/test_message_response.py @@ -2,9 +2,9 @@ import unittest -from src.packets.packet import Packet from src.message_type import MessageType from src.packets.message_response import MessageResponse +from src.packets.packet import Packet class TestMessageResponseEncoding(unittest.TestCase): @@ -31,8 +31,8 @@ def test_message_type_encoding(self) -> None: def test_num_messages_encoding(self) -> None: """Tests that the number of messages is encoded correctly.""" messages = [ - ("Harry", "Hello John!".encode()), - ("John", "Hello Harry!".encode()), + ("Harry", b"Hello John!"), + ("John", b"Hello Harry!"), ] packet = MessageResponse(messages).to_bytes() @@ -56,8 +56,8 @@ class TestMessageResponseDecoding(unittest.TestCase): def test_messages_decoding(self) -> None: """Tests that the messages are decoded correctly.""" messages = [ - ("Harry", "Hello John!".encode()), - ("John", "Hello Harry!".encode()), + ("Harry", b"Hello John!"), + ("John", b"Hello Harry!"), ] packet = MessageResponse(messages).to_bytes() @@ -71,8 +71,8 @@ def test_messages_decoding(self) -> None: def test_more_messages_decoding_false(self) -> None: """Tests that the more messages flag is decoded correctly.""" messages = [ - ("Harry", "Hello John!".encode()), - ("John", "Hello Harry!".encode()), + ("Harry", b"Hello John!"), + ("John", b"Hello Harry!"), ] packet = MessageResponse(messages).to_bytes() @@ -82,7 +82,7 @@ def test_more_messages_decoding_false(self) -> None: def test_more_messages_decoding_true(self) -> None: """Tests that the more messages flag is decoded correctly.""" - messages = [("Harry", "Hello John!".encode())] * 256 + messages = [("Harry", b"Hello John!")] * 256 packet = MessageResponse(messages).to_bytes() expected = True From df0e124b9a441b34a261519243286a20b28f61d1 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sat, 2 Nov 2024 15:19:02 +1300 Subject: [PATCH 18/48] fixed errors from new Ruff rules in logging_config.py --- logging_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/logging_config.py b/logging_config.py index 71bf5e0..ec936a7 100644 --- a/logging_config.py +++ b/logging_config.py @@ -39,7 +39,7 @@ def configure_logging(package_name: str) -> None: # ruff: noqa: DTZ005 file_name = datetime.datetime.now().strftime("%d-%m-%y %H:%M:%S") - (pathlib.Path("logs") / package_name).parent.mkdir(parents=True) + (pathlib.Path("logs") / package_name).parent.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler(f"logs/{package_name}/{file_name}.log") file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(file_formatter) From 3672aa92166c5633a6ed0d646fdf8cf72ed2a804 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sat, 2 Nov 2024 15:39:37 +1300 Subject: [PATCH 19/48] updated unit test configuration in pre-commit config --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d049ca1..869898e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: name: run unit tests pass_filenames: false entry: python3 -m unittest - args: ["discover", "--start-directory", "/home/harry/PycharmProjects/socket-programming/tests"] + args: ["discover", "."] language: system stages: [pre-push] types: [python] From 9a2cd22dfc805206a3213699c058b17e66b5dc8a Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sat, 2 Nov 2024 16:09:46 +1300 Subject: [PATCH 20/48] Added ValueError to except block of CommaneLineApplication.parse_arguments New usage of strict=True in zip function throws ValueError when iterables are of different length. --- src/command_line_application.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/command_line_application.py b/src/command_line_application.py index f6e25f9..90aa20e 100644 --- a/src/command_line_application.py +++ b/src/command_line_application.py @@ -51,7 +51,7 @@ def parse_arguments(self, arguments: list[str]) -> list[Any]: strict=True, ) ] - except TypeError as error: + except (ValueError, TypeError) as error: logger.exception("Incorrect arguments") print(self.usage_prompt) print(error) From 5bda020c6ceb24bb8c967b59e4b78d9ff35d056b Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 3 Nov 2024 17:27:24 +1300 Subject: [PATCH 21/48] refactored client and server so they can be more easily tested --- client/client.py | 43 +++++++++++++++++++++++++++++++++++++------ server/server.py | 15 +++++++++++++-- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/client/client.py b/client/client.py index faae33b..c91661d 100644 --- a/client/client.py +++ b/client/client.py @@ -53,6 +53,7 @@ def __init__(self, arguments: list[str]) -> None: self.receiver_name = "" self.message = "" + self.response: bytes | None = None @staticmethod def parse_hostname(host_name: str) -> str: @@ -145,11 +146,24 @@ def read_message_response(packet: bytes) -> None: logger.info("Server has more messages available for this user") print("More messages available, please send another request") - def run(self) -> None: + def run( + self, + *, + receiver_name: str | None = None, + message: str | None = None, + ) -> None: """Ask the user to input message and send request to server.""" if self.message_type == MessageType.CREATE: - self.receiver_name = input("Enter the name of the receiver: ") - self.message = input("Enter the message to be sent: ") + if receiver_name is None: + self.receiver_name = input("Enter the name of the receiver: ") + else: + self.receiver_name = receiver_name + + if message is None: + self.message = input("Enter the message to be sent: ") + else: + self.message = message + logger.info( 'User specified message to %s: "%s"', self.receiver_name, @@ -162,6 +176,23 @@ def run(self) -> None: self.receiver_name, self.message, ) - response = self.send_message_request(request) - if self.message_type == MessageType.READ and response: - self.read_message_response(response) + self.response = self.send_message_request(request) + if self.message_type == MessageType.READ and self.response: + self.read_message_response(self.response) + + @property + def result(self) -> bytes: + """Get the packet received from the server. + + This property must only be used after calling ``run()`` + otherwise no response will exist! + + :raises RuntimeError: When there was no response. + Will always occur if requested before call to ``run()``. + + :return: A bytes object of the server's response. + """ + if self.response is None: + raise RuntimeError("No response! Was result requested after call to run()?") + + return self.response diff --git a/server/server.py b/server/server.py index 3eca439..c94dff2 100644 --- a/server/server.py +++ b/server/server.py @@ -31,6 +31,7 @@ def __init__(self, arguments: list[str]) -> None: # pylint: disable=unbalanced-tuple-unpacking (self.port_number,) = self.parse_arguments(arguments) + self.running = True self.hostname = "localhost" self.messages: dict[str, list[tuple[str, bytes]]] = {} @@ -42,6 +43,7 @@ def run(self) -> None: try: # Create a TCP/IP socket with socket.socket() as welcoming_socket: + welcoming_socket.settimeout(1) welcoming_socket.bind((self.hostname, self.port_number)) # A maximum, of five unprocessed connections are allowed welcoming_socket.listen(5) @@ -52,7 +54,7 @@ def run(self) -> None: ) print(f"starting up on {self.hostname} port {self.port_number}") - while True: + while self.running: self.run_server(welcoming_socket) except OSError as error: @@ -115,7 +117,11 @@ def run_server(self, welcoming_socket: socket.socket) -> None: :param welcoming_socket: The welcoming socket to accept connections on """ - connection_socket, client_address = welcoming_socket.accept() + try: + connection_socket, client_address = welcoming_socket.accept() + except TimeoutError: + return + connection_socket.settimeout(1) logger.info("New client connection from %s", client_address) @@ -141,3 +147,8 @@ def run_server(self, welcoming_socket: socket.socket) -> None: error_message = "Message request discarded" logger.exception(error_message) print(error_message) + + def stop(self) -> None: + """Stop the server.""" + print("Stopping server.") + self.running = False From 54bb66e627adfaa3e94c8b35bc7fe8def3289bae Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 3 Nov 2024 17:29:24 +1300 Subject: [PATCH 22/48] created new integtration test to verify more messages functionality --- tests/integration/__init__.py | 1 + tests/integration/test_more_messages.py | 55 +++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/test_more_messages.py diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..45be013 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration test suites.""" diff --git a/tests/integration/test_more_messages.py b/tests/integration/test_more_messages.py new file mode 100644 index 0000000..da6441d --- /dev/null +++ b/tests/integration/test_more_messages.py @@ -0,0 +1,55 @@ +"""More messages functionality test suite.""" + +import pathlib +import sys +import threading +import unittest + +sys.path.insert(0, "../../") +import client +import server +from src.packets.message_response import MessageResponse + + +class TestMoreMessages(unittest.TestCase): + """Test suite for the more_messages functionality. + + When a client makes a read request to the server and has more than + 255 unread messages, the server should only send the first 255. The + server should set the ``more_messages`` flag in the read response so + that the client is informed they did not receive all of the + available messages. + """ + + HOST_NAME = "localhost" + PORT_NUMBER = "1024" + USER_NAME = "John" + + def test_more_messages(self) -> None: + """Tests that no more than 255 messages are sent in a read request.""" + names = pathlib.Path("tests/resources/names.txt").read_text().splitlines() + + server_object = server.Server([self.PORT_NUMBER]) + server_thread = threading.Thread(target=server_object.run) + server_thread.start() + + for name in names: + client.Client( + [self.HOST_NAME, self.PORT_NUMBER, name, "create"], + ).run(receiver_name=self.USER_NAME, message="Hello") + + final_client = client.Client( + [self.HOST_NAME, self.PORT_NUMBER, self.USER_NAME, "read"], + ) + final_client.run() + + server_object.stop() + server_thread.join() + + messages, more_messages = MessageResponse.decode_packet(final_client.result) + self.assertEqual(255, len(messages)) + self.assertTrue(more_messages) + + +if __name__ == "__main__": + unittest.main() From 6acc6aa656b1524cd8c0c4ada5038fdc0843b71d Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Mon, 11 Nov 2024 12:10:10 +1300 Subject: [PATCH 23/48] no longer using strict zip for argument parsing in command line application as it was redundant --- src/command_line_application.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/command_line_application.py b/src/command_line_application.py index 90aa20e..fbddabb 100644 --- a/src/command_line_application.py +++ b/src/command_line_application.py @@ -39,8 +39,10 @@ def parse_arguments(self, arguments: list[str]) -> list[Any]: :param arguments: The command line arguments. """ if len(arguments) != len(self.parameters): + message = f"Invalid number of arguments, must be {len(self.parameters)}" print(self.usage_prompt) - print(f"Invalid number of arguments, must be {len(self.parameters)}") + print(message) + raise SystemExit(message) try: parsed_arguments = [ @@ -48,10 +50,10 @@ def parse_arguments(self, arguments: list[str]) -> list[Any]: for argument, parser in zip( arguments, self.parameters.values(), - strict=True, + strict=False, ) ] - except (ValueError, TypeError) as error: + except TypeError as error: logger.exception("Incorrect arguments") print(self.usage_prompt) print(error) From f002d7458109b06591ae206392a4893568903df4 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Mon, 11 Nov 2024 16:38:16 +1300 Subject: [PATCH 24/48] removed double docstring --- src/login_response.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/login_response.py b/src/login_response.py index c59e43f..d4cc98f 100644 --- a/src/login_response.py +++ b/src/login_response.py @@ -21,7 +21,6 @@ def __init__(self, is_registered: bool, encryption_key: RsaEncrypter) -> None: :param is_registered: ``True`` if the requesting user was registered. :param encryption_key: An RSA public key for encrypting messages. """ - """Create a new login response packet.""" self.encryption_key = encryption_key self.is_registered = is_registered self.packet = b"" From 329ff87a37f37817d9b83cfe9761a9c16015e48c Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 12 Nov 2024 16:23:28 +1300 Subject: [PATCH 25/48] Reengineered packets so that the server can get the type of any packet without any information about the packet. This resulted in packet classes being assigned a MessageType, and separating MessageRequest into ReadRequest and CreateRequest. --- client/client.py | 95 ++++++++----- server/server.py | 56 +++++--- src/message_type.py | 2 + src/packets/create_request.py | 102 ++++++++++++++ src/packets/message.py | 6 +- src/packets/message_request.py | 132 ------------------ src/packets/packet.py | 44 ++++-- src/packets/read_request.py | 70 ++++++++++ .../{message_response.py => read_response.py} | 30 +--- 9 files changed, 318 insertions(+), 219 deletions(-) create mode 100644 src/packets/create_request.py delete mode 100644 src/packets/message_request.py create mode 100644 src/packets/read_request.py rename src/packets/{message_response.py => read_response.py} (64%) diff --git a/client/client.py b/client/client.py index c91661d..96b867d 100644 --- a/client/client.py +++ b/client/client.py @@ -6,8 +6,10 @@ from src.command_line_application import CommandLineApplication from src.message_type import MessageType -from src.packets.message_request import MessageRequest -from src.packets.message_response import MessageResponse +from src.packets.create_request import CreateRequest +from src.packets.packet import Packet +from src.packets.read_request import ReadRequest +from src.packets.read_response import ReadResponse from src.port_number import PortNumber logger = logging.getLogger(__name__) @@ -92,21 +94,19 @@ def parse_username(user_name: str) -> str: return user_name - def send_message_request(self, request: MessageRequest) -> bytes | None: + def send_request(self, request: Packet) -> bytes: """Send a message request record to the server. :param request: The message request to be sent. :return: The server's response if applicable, otherwise ``None``. """ packet = request.to_bytes() - response = None try: with socket.socket() as connection_socket: connection_socket.settimeout(1) connection_socket.connect((self.host_name, self.port_number)) connection_socket.send(packet) - if self.message_type == MessageType.READ: - response = connection_socket.recv(4096) + response = connection_socket.recv(4096) except ConnectionRefusedError as error: logger.exception("Connection refused, likely due to invalid port number") @@ -133,7 +133,7 @@ def read_message_response(packet: bytes) -> None: :param packet: The message response from the server. """ - messages, more_messages = MessageResponse.decode_packet(packet) + messages, more_messages = ReadResponse.decode_packet(packet) for sender, message in messages: logger.info('Received %s\'s message "%s"', sender, message) @@ -146,39 +146,68 @@ def read_message_response(packet: bytes) -> None: logger.info("Server has more messages available for this user") print("More messages available, please send another request") - def run( + def send_read_request(self) -> None: + """Send a read request to the server.""" + request = ReadRequest(self.user_name) + self.response = self.send_request(request) + + self.read_message_response(self.response) + + def send_create_request( self, - *, - receiver_name: str | None = None, - message: str | None = None, + receiver_name: str | None, + message: str | None, ) -> None: - """Ask the user to input message and send request to server.""" - if self.message_type == MessageType.CREATE: - if receiver_name is None: - self.receiver_name = input("Enter the name of the receiver: ") - else: - self.receiver_name = receiver_name - - if message is None: - self.message = input("Enter the message to be sent: ") - else: - self.message = message - - logger.info( - 'User specified message to %s: "%s"', - self.receiver_name, - self.message, - ) + """Send a create request to the server. - request = MessageRequest( - self.message_type, + :param receiver_name: The name of the person to send the messag to. + :param message: The message to be sent. + """ + if receiver_name is None: + self.receiver_name = input("Enter the name of the receiver: ") + else: + self.receiver_name = receiver_name + + if message is None: + self.message = input("Enter the message to be sent: ") + else: + self.message = message + + logger.info( + 'User specified message to %s: "%s"', + self.receiver_name, + self.message, + ) + + request = CreateRequest( self.user_name, self.receiver_name, self.message, ) - self.response = self.send_message_request(request) - if self.message_type == MessageType.READ and self.response: - self.read_message_response(self.response) + self.send_request(request) + + def run( + self, + *, + receiver_name: str | None = None, + message: str | None = None, + ) -> None: + """Ask the user to input message and send request to server. + + :param receiver_name: The name of the user to send the message to. + Will request from ``stdin`` if not present. Defaults to ``None``. + :param message: The message to send. Will request from + ``stdin`` if not present. Defaults to ``None``. + """ + match self.message_type: + case MessageType.READ: + self.send_read_request() + + case MessageType.CREATE: + self.send_create_request(receiver_name, message) + + case _: + print("Oopsies, wrong message type!") @property def result(self) -> bytes: diff --git a/server/server.py b/server/server.py index c94dff2..c004689 100644 --- a/server/server.py +++ b/server/server.py @@ -6,8 +6,10 @@ from src.command_line_application import CommandLineApplication from src.message_type import MessageType -from src.packets.message_request import MessageRequest -from src.packets.message_response import MessageResponse +from src.packets.create_request import CreateRequest +from src.packets.packet import Packet +from src.packets.read_request import ReadRequest +from src.packets.read_response import ReadResponse from src.port_number import PortNumber logger = logging.getLogger(__name__) @@ -66,7 +68,7 @@ def run(self) -> None: def process_read_request( self, connection_socket: socket.socket, - sender_name: str, + packet: bytes, ) -> None: """Respond to read requests. @@ -74,7 +76,9 @@ def process_read_request( :param connection_socket: The connection socket to send the response on. :return: The response to the read request. """ - response = MessageResponse(self.messages.get(sender_name, [])) + (sender_name,) = ReadRequest.decode_packet(packet) + + response = ReadResponse(self.messages.get(sender_name, [])) record = response.to_bytes() connection_socket.send(record) del self.messages.get(sender_name, [])[: response.num_messages] @@ -87,16 +91,20 @@ def process_read_request( def process_create_request( self, - sender_name: str, - receiver_name: str, - message: bytes, + packet: bytes, ) -> None: - """Process `create` requests. + """Process create requests. - :param sender_name: The name of the user who sent the `create` request. + :param sender_name: The name of the user who sent the create request. :param receiver_name: The name of the user who will receive the message. :param message: The message to be sent. """ + ( + sender_name, + receiver_name, + message, + ) = CreateRequest.decode_packet(packet) + if receiver_name not in self.messages: self.messages[receiver_name] = [] @@ -112,6 +120,25 @@ def process_create_request( f'"{message.decode()}" to {receiver_name}', ) + def process_request(self, packet: bytes, connection_socket: socket.socket) -> None: + """Process an incoming client request. + + :param record: The packet received from a client. + :param connection_socket: The socket to use for responding to read requests. + """ + message_type: MessageType + message_type, packet = Packet.decode_packet(packet) + + match message_type: + case MessageType.READ: + self.process_read_request(connection_socket, packet) + + case MessageType.CREATE: + self.process_create_request(packet) + + case _: + logging.error("Message of incorrect type received!") + def run_server(self, welcoming_socket: socket.socket) -> None: """Run the server side of the program. @@ -120,6 +147,8 @@ def run_server(self, welcoming_socket: socket.socket) -> None: try: connection_socket, client_address = welcoming_socket.accept() except TimeoutError: + # Prevent the server from indefinitely waiting for new + # client requests, so that the ``stop`` function works return connection_socket.settimeout(1) @@ -130,14 +159,7 @@ def run_server(self, welcoming_socket: socket.socket) -> None: try: with connection_socket: record = connection_socket.recv(4096) - request_fields = MessageRequest.decode_packet(record) - message_type, sender_name, receiver_name, message = request_fields - - if message_type == MessageType.READ: - self.process_read_request(connection_socket, sender_name) - - elif message_type == MessageType.CREATE: - self.process_create_request(sender_name, receiver_name, message) + self.process_request(record, connection_socket) except TimeoutError: error_message = "Timed out while waiting for message request" diff --git a/src/message_type.py b/src/message_type.py index b193a2c..6b62f05 100644 --- a/src/message_type.py +++ b/src/message_type.py @@ -10,6 +10,8 @@ class MessageType(Enum): CREATE = 2 RESPONSE = 3 LOGIN = 4 + REGISTER = 5 + MESSAGE = 6 @staticmethod def from_str(string: str) -> "MessageType": diff --git a/src/packets/create_request.py b/src/packets/create_request.py new file mode 100644 index 0000000..9ff7b3a --- /dev/null +++ b/src/packets/create_request.py @@ -0,0 +1,102 @@ +"""Home to the ``CreateReqeust`` class.""" + +import logging +import struct + +from src.message_type import MessageType + +from .packet import Packet + +logger = logging.getLogger(__name__) + + +class CreateRequest(Packet, struct_format="!BBH", message_type=MessageType.CREATE): + """Encoding and decoding of create request packets. + + Usage: + create_request = CreateRequest("sender name", "receiver name", "Hi") + packet = create_request.to_bytes() + + sender_name, receiver_name, message = CreateRequest.decode_packet(record) + """ + + def __init__( + self, + user_name: str, + receiver_name: str, + message: str, + ) -> None: + """Encode a create request packet. + + :param user_name: The name of the user sending the request. + :param receiver_name: The name of the message recipient. + :param message: The string message to be sent. + """ + self.user_name = user_name + self.receiver_name = receiver_name + self.message = message + self.packet = b"" + + def to_bytes(self) -> bytes: + """Return the create request packet. + + :return: A byte array holding the create request. + """ + logger.info( + 'Creating CREATE request to send %s the message "%s" from %s', + self.receiver_name, + self.message, + self.user_name, + ) + + self.packet = super().to_bytes() + + self.packet += struct.pack( + self.struct_format, + len(self.user_name.encode()), + len(self.receiver_name.encode()), + len(self.message.encode()), + ) + + self.packet += self.user_name.encode() + self.packet += self.receiver_name.encode() + self.packet += self.message.encode() + + return self.packet + + @classmethod + def decode_packet(cls, packet: bytes) -> tuple[str, str, bytes]: + """Decode a message request packet. + + :param packet: An array of bytes containing the message request + """ + header_fields, payload = cls.split_packet(packet) + ( + user_name_size, + receiver_name_size, + message_size, + ) = header_fields + + if user_name_size < 1: + raise ValueError( + "Received message request with insufficient user name length", + ) + + if receiver_name_size < 1: + raise ValueError( + "Received create request with insufficient receiver name length", + ) + if message_size < 1: + raise ValueError( + "Received create request with insufficient message length", + ) + + user_name = payload[:user_name_size].decode() + index = user_name_size + + receiver_name = payload[index : index + receiver_name_size].decode() + index += receiver_name_size + + message = payload[index : index + message_size] + + return user_name, receiver_name, message diff --git a/src/packets/message.py b/src/packets/message.py index 83bffbd..b9db3b1 100644 --- a/src/packets/message.py +++ b/src/packets/message.py @@ -2,10 +2,12 @@ import struct +from src.message_type import MessageType + from .packet import Packet -class Message(Packet, struct_format="!BH"): +class Message(Packet, struct_format="!BH", message_type=MessageType.MESSAGE): """A class for encoding and decoding message packets. Message "packets" are the encoding of a single message from within @@ -16,7 +18,7 @@ def __init__(self, sender_name: str, message: bytes) -> None: """Create the Message which can be encoded into a packet. :param sender_name: The name of the user sending this message. - :param message: The message to be sent. + :param message: The message to be sent as an array of bytes. """ self.sender_name = sender_name self.message = message diff --git a/src/packets/message_request.py b/src/packets/message_request.py deleted file mode 100644 index 68d80d4..0000000 --- a/src/packets/message_request.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Home to the ``MessageReqeust`` class.""" - -import logging -import struct - -from src.message_type import MessageType - -from .packet import Packet - -logger = logging.getLogger(__name__) - - -class MessageRequest(Packet, struct_format="!HBBBH"): - """Encoding and decoding of message request packets. - - Usage: - message_request = MessageRequest(MessageType.CREATE, "sender", "receiver", "Hi") - record = message_request.to_bytes() - - message_request = MessageRequest.from_record(record) - message_type, sender_name, receiver_name, message = message_request.decode() - """ - - def __init__( - self, - message_type: MessageType, - user_name: str, - receiver_name: str, - message: str, - ) -> None: - """Encode a message request packet. - - :param message_type: The type of the request (READ or CREATE) - :param user_name: The name of the user sending the request - :param receiver_name: The name of the message recipient - :param message: The string message to be sent - """ - self.message_type = message_type - self.user_name = user_name - self.receiver_name = receiver_name - self.message = message - self.packet = b"" - - def to_bytes(self) -> bytes: - """Return the message request packet. - - :return: A byte array holding the message request - """ - if self.message_type == MessageType.READ: - logger.info("Creating READ request from %s", self.user_name) - else: - logger.info( - 'Creating CREATE request to send %s the message "%s" from %s', - self.receiver_name, - self.message, - self.user_name, - ) - - self.packet = struct.pack( - self.struct_format, - Packet.MAGIC_NUMBER, - self.message_type.value, - len(self.user_name.encode()), - len(self.receiver_name.encode()), - len(self.message.encode()), - ) - - self.packet += self.user_name.encode() - self.packet += self.receiver_name.encode() - self.packet += self.message.encode() - - return self.packet - - # ruff: noqa: C901 - @classmethod - def decode_packet(cls, packet: bytes) -> tuple[MessageType, str, str, bytes]: - """Decode a message request packet. - - :param packet: An array of bytes containing the message request - """ - header_fields, payload = cls.split_packet(packet) - ( - magic_number, - message_type, - user_name_size, - receiver_name_size, - message_size, - ) = header_fields - - if magic_number != Packet.MAGIC_NUMBER: - raise ValueError("Received message request with incorrect magic number") - - try: - message_type = MessageType(message_type) - except ValueError as error: - raise ValueError("Received message request with invalid ID") from error - - if message_type == MessageType.RESPONSE: - raise ValueError("Recieved message request with disallowed type RESPONSE") - - if user_name_size < 1: - raise ValueError( - "Received message request with insufficient user name length", - ) - - if message_type == MessageType.READ: - if receiver_name_size != 0: - raise ValueError( - "Received read request with non-zero receiver name length", - ) - if message_size != 0: - raise ValueError("Received read request with non-zero message length") - - elif message_type == MessageType.CREATE: - if receiver_name_size < 1: - raise ValueError( - "Received create request with insufficient receiver name length", - ) - if message_size < 1: - raise ValueError( - "Received create request with insufficient message length", - ) - - user_name = payload[:user_name_size].decode() - index = user_name_size - - receiver_name = payload[index : index + receiver_name_size].decode() - index += receiver_name_size - - message = payload[index : index + message_size] - - return message_type, user_name, receiver_name, message diff --git a/src/packets/packet.py b/src/packets/packet.py index b3a1909..ef117be 100644 --- a/src/packets/packet.py +++ b/src/packets/packet.py @@ -4,6 +4,8 @@ import struct from typing import Any +from src.message_type import MessageType + class Packet(metaclass=abc.ABCMeta): """Abstract class for all packets. @@ -20,7 +22,9 @@ class MyPacket(Packet, struct_format="!HBBH"): MAGIC_NUMBER = 0xAE73 - struct_format: str + struct_format = "!HB" + + message_type: MessageType @abc.abstractmethod def __init__(self, *args: tuple[Any, ...]) -> None: @@ -34,19 +38,34 @@ def __init__(self, *args: tuple[Any, ...]) -> None: def to_bytes(self) -> bytes: """Convert the packet into a ``bytes`` object. - :return: A ``bytes`` object encoding individual fields of the packet. + :return: A ``bytes`` object encoding the packet's message type. """ - raise NotImplementedError + return struct.pack( + Packet.struct_format, + self.MAGIC_NUMBER, + self.message_type.value, + ) @classmethod @abc.abstractmethod def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: - """Decode the packet into a tuple of values. + """Decode the packet into its message type and payload. :param packet: The packet to decode. - :return: A tuple of the decoded values extracted from the packet. + :return: A tuple of the decoded message type and the payload. """ - raise NotImplementedError + header_fields: tuple[int, MessageType] + header_fields, payload = Packet.split_packet(packet) + magic_number, message_type_number = header_fields + + if magic_number != cls.MAGIC_NUMBER: + raise ValueError("Incorrect magic number found in packet") + try: + message_type = MessageType(message_type_number) + except ValueError as error: + raise ValueError("Invalid message type ID number") from error + + return message_type, payload @classmethod def split_packet(cls, packet: bytes) -> tuple[tuple[Any, ...], bytes]: @@ -67,20 +86,23 @@ def split_packet(cls, packet: bytes) -> tuple[tuple[Any, ...], bytes]: @classmethod def __init_subclass__( cls, - struct_format: str | None = None, - **kwargs: tuple[Any, ...], + message_type: MessageType, + struct_format: str, ) -> None: """Ensure ``struct_format`` attribute is present. All subclasses of ``Packet`` must specify a ``struct_format`` - in their class attributes. This is used for packing and - unpacking the data into a minimal package. + and a ``message_type`` in their class attributes. This is used + for packing and unpacking the data into a minimal package, and + to communicate what data is stored inside the packet. + :param message_type: The type of message the packet will encode. :param struct_format: The format of the packet data for the ``struct`` module. :param kwargs: No additional kwargs will be accepted. """ if not struct_format: raise ValueError("Must specify struct format") - super().__init_subclass__(**kwargs) + super().__init_subclass__() cls.struct_format = struct_format + cls.message_type = message_type diff --git a/src/packets/read_request.py b/src/packets/read_request.py new file mode 100644 index 0000000..bf24366 --- /dev/null +++ b/src/packets/read_request.py @@ -0,0 +1,70 @@ +"""Home to the ``ReadReqeust`` class.""" + +import logging +import struct + +from src.message_type import MessageType + +from .packet import Packet + +logger = logging.getLogger(__name__) + + +class ReadRequest(Packet, struct_format="!B", message_type=MessageType.READ): + """Encoding and decoding of read request packets. + + Usage: + read_request = ReadRequest("Recipient name") + packet = read_request.to_bytes() + + read_request = ReadRequest.decode_packet(packet) + (recipient_name,) = read_request.decode() + """ + + def __init__( + self, + user_name: str, + ) -> None: + """Encode a read request packet. + + :param user_name: The name of the user sending the read request. + """ + self.user_name = user_name + self.packet: bytes + + def to_bytes(self) -> bytes: + """Return the read request packet. + + :return: An array of bytes holding the message request. + """ + logger.info("Creating READ request from %s", self.user_name) + + self.packet = super().to_bytes() + + self.packet += struct.pack( + self.struct_format, + len(self.user_name.encode()), + ) + + self.packet += self.user_name.encode() + + return self.packet + + @classmethod + def decode_packet(cls, packet: bytes) -> tuple[str]: + """Decode a read request packet. + + :param packet: An array of bytes containing the read request. + :return: A tuple containing the user requesting their messages. + """ + header_fields, payload = cls.split_packet(packet) + (user_name_size,) = header_fields + + if user_name_size < 1: + raise ValueError( + "Received message request with insufficient user name length", + ) + + user_name = payload.decode() + + return (user_name,) diff --git a/src/packets/message_response.py b/src/packets/read_response.py similarity index 64% rename from src/packets/message_response.py rename to src/packets/read_response.py index 0bb8265..cfd7030 100644 --- a/src/packets/message_response.py +++ b/src/packets/read_response.py @@ -1,4 +1,4 @@ -"""Home to the ``MessageResponse`` class.""" +"""Home to the ``CreateResponse`` class.""" import logging import struct @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -class MessageResponse(Packet, struct_format="!HBB?"): +class ReadResponse(Packet, struct_format="!B?", message_type=MessageType.RESPONSE): """Enables encoding and decoding message response packets.""" MAX_MESSAGE_LENGTH = 255 @@ -20,8 +20,8 @@ def __init__(self, messages: list[tuple[str, bytes]]) -> None: :param messages: A list of all the messages to be put in the structure. """ - self.num_messages = min(len(messages), MessageResponse.MAX_MESSAGE_LENGTH) - self.more_messages = len(messages) > MessageResponse.MAX_MESSAGE_LENGTH + self.num_messages = min(len(messages), ReadResponse.MAX_MESSAGE_LENGTH) + self.more_messages = len(messages) > ReadResponse.MAX_MESSAGE_LENGTH self.messages = messages[: self.num_messages] self.packet = b"" @@ -35,8 +35,6 @@ def to_bytes(self) -> bytes: self.packet = struct.pack( self.struct_format, - Packet.MAGIC_NUMBER, - MessageType.RESPONSE.value, self.num_messages, self.more_messages, ) @@ -49,30 +47,14 @@ def to_bytes(self) -> bytes: @classmethod def decode_packet(cls, packet: bytes) -> tuple[list[tuple[str, str]], bool]: - """Decode a message response packet into its individual components. + """Decode a read response packet into its individual components. :param packet: The packet to be decoded. :return: A tuple containing a list of messages and a boolean indicating whether there are more messages to be received. """ header_fields, payload = cls.split_packet(packet) - magic_number, message_type, num_messages, more_messages = header_fields - - if magic_number != Packet.MAGIC_NUMBER: - raise ValueError("Invalid magic number when decoding message response") - - try: - message_type = MessageType(message_type) - except ValueError as error: - raise ValueError( - "Invalid message type when decoding message response", - ) from error - if message_type != MessageType.RESPONSE: - message = ( - f"Message type {message_type} found when decoding message" - " response, expected RESPONSE" - ) - raise ValueError(message) + num_messages, more_messages = header_fields messages: list[tuple[str, str]] = [] remaining_messages = payload From 5bbe4186dec5597d5283d950bfe93864a8af256c Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 12 Nov 2024 16:24:06 +1300 Subject: [PATCH 26/48] Corrected typing of CommandLineApplication.parse_arguments --- src/command_line_application.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/command_line_application.py b/src/command_line_application.py index fbddabb..f9c73b9 100644 --- a/src/command_line_application.py +++ b/src/command_line_application.py @@ -33,7 +33,7 @@ def usage_prompt(self) -> str: """ return f"Usage: python3 {' '.join(self.parameters)}" - def parse_arguments(self, arguments: list[str]) -> list[Any]: + def parse_arguments(self, arguments: list[str]) -> tuple[Any, ...]: """Parse the command line arguments, ensuring they are valid. :param arguments: The command line arguments. @@ -45,14 +45,14 @@ def parse_arguments(self, arguments: list[str]) -> list[Any]: raise SystemExit(message) try: - parsed_arguments = [ + parsed_arguments = tuple( parser(argument) for argument, parser in zip( arguments, self.parameters.values(), strict=False, ) - ] + ) except TypeError as error: logger.exception("Incorrect arguments") print(self.usage_prompt) From 7062c685aab0c5d2bf2568ffaf6b31f194b209e5 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 12 Nov 2024 16:24:41 +1300 Subject: [PATCH 27/48] Corrected typing in Client.__init__ --- client/client.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/client/client.py b/client/client.py index 96b867d..57f569e 100644 --- a/client/client.py +++ b/client/client.py @@ -23,8 +23,8 @@ class Client(CommandLineApplication): def __init__(self, arguments: list[str]) -> None: """Initialise the client with specified arguments. - :param arguments: A list containing the host name, port number - , username, and message type. + :param arguments: A list containing the host name, port number, + username, and message type. """ super().__init__( OrderedDict( @@ -35,15 +35,12 @@ def __init__(self, arguments: list[str]) -> None: ), ) - # pylint thinks that self.parse_arguments is only capable - # of returning an empty list - # pylint: disable=unbalanced-tuple-unpacking - ( - self.host_name, - self.port_number, - self.user_name, - self.message_type, - ) = self.parse_arguments(arguments) + parsed_arguments: tuple[str, PortNumber, str, MessageType] + parsed_arguments = self.parse_arguments(arguments) + + self.host_name, self.port_number, self.user_name, self.message_type = ( + parsed_arguments + ) logger.info( "Client for %s port %s created by %s to send %s request", From d8c8045f347aba396fb65362f36372137e1fe417 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 12 Nov 2024 16:25:12 +1300 Subject: [PATCH 28/48] Improved error handling in Client.send_request --- client/client.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/client/client.py b/client/client.py index 57f569e..2532b71 100644 --- a/client/client.py +++ b/client/client.py @@ -105,12 +105,12 @@ def send_request(self, request: Packet) -> bytes: connection_socket.send(packet) response = connection_socket.recv(4096) - except ConnectionRefusedError as error: - logger.exception("Connection refused, likely due to invalid port number") - print("Connection refused, likely due to invalid port number") - raise SystemExit from error - except TimeoutError as error: - message = "Connection timed out, likely due to invalid host name" + except (ConnectionRefusedError, TimeoutError) as error: + message = ( + "Connection refused, likely due to invalid port number" + if isinstance(error, ConnectionRefusedError) + else "Connection timed out, likely due to invalid host name" + ) logger.exception(message) print(message) raise SystemExit from error From a720f5508c52ecd0b848f26afc01f4974bfd7462 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Wed, 13 Nov 2024 09:14:28 +1300 Subject: [PATCH 29/48] added registration and key passing --- client/client.py | 43 +++++++++++++++ server/server.py | 83 ++++++++++++++++++++++++++--- src/login_response.py | 72 ------------------------- src/message_type.py | 4 +- src/packets/key_request.py | 53 ++++++++++++++++++ src/packets/key_response.py | 70 ++++++++++++++++++++++++ src/{ => packets}/login_request.py | 28 ++-------- src/packets/registration_request.py | 76 ++++++++++++++++++++++++++ 8 files changed, 325 insertions(+), 104 deletions(-) delete mode 100644 src/login_response.py create mode 100644 src/packets/key_request.py create mode 100644 src/packets/key_response.py rename src/{ => packets}/login_request.py (53%) create mode 100644 src/packets/registration_request.py diff --git a/client/client.py b/client/client.py index 2532b71..21235cd 100644 --- a/client/client.py +++ b/client/client.py @@ -4,12 +4,17 @@ import socket from collections import OrderedDict +from message_cipher.rsa_system import RSA + from src.command_line_application import CommandLineApplication from src.message_type import MessageType from src.packets.create_request import CreateRequest +from src.packets.key_request import KeyRequest +from src.packets.key_response import KeyResponse from src.packets.packet import Packet from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse +from src.packets.registration_request import RegistrationRequest from src.port_number import PortNumber logger = logging.getLogger(__name__) @@ -183,6 +188,35 @@ def send_create_request( ) self.send_request(request) + def send_login_request(self) -> None: + """Send a login request to the server.""" + print("logging in") + + def send_registration_request(self) -> None: + """Send a login request to the server.""" + public_key, private_key = RSA() + + print(f"Creatged product {public_key.product}") + print(f"Created exponent {public_key.exponent}") + + request = RegistrationRequest(self.user_name, public_key) + self.send_request(request) + + def send_key_request(self, receiver_name: str | None) -> None: + """Send a pubblic key request to the server.""" + if not receiver_name: + receiver_name = input("Who's key are we requesting?") + + request = KeyRequest(receiver_name) + response = self.send_request(request) + message_type, payload = Packet.decode_packet(response) + if message_type != MessageType.KEY_RESPONSE: + raise ValueError("Wrong response type!") + + (public_key,) = KeyResponse.decode_packet(payload) + print(f"Received product {public_key.product}") + print(f"Received exponent {public_key.exponent}") + def run( self, *, @@ -203,6 +237,15 @@ def run( case MessageType.CREATE: self.send_create_request(receiver_name, message) + case MessageType.LOGIN: + self.send_login_request() + + case MessageType.REGISTER: + self.send_registration_request() + + case MessageType.KEY_REQUEST: + self.send_key_request(receiver_name) + case _: print("Oopsies, wrong message type!") diff --git a/server/server.py b/server/server.py index c004689..4076fea 100644 --- a/server/server.py +++ b/server/server.py @@ -3,13 +3,21 @@ import logging import socket from collections import OrderedDict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from message_cipher.rsa_encrypter import RsaEncrypter from src.command_line_application import CommandLineApplication from src.message_type import MessageType from src.packets.create_request import CreateRequest +from src.packets.key_request import KeyRequest +from src.packets.key_response import KeyResponse +from src.packets.login_request import LoginRequest from src.packets.packet import Packet from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse +from src.packets.registration_request import RegistrationRequest from src.port_number import PortNumber logger = logging.getLogger(__name__) @@ -35,6 +43,7 @@ def __init__(self, arguments: list[str]) -> None: self.running = True self.hostname = "localhost" + self.users: dict[str, RsaEncrypter] = {} self.messages: dict[str, list[tuple[str, bytes]]] = {} def run(self) -> None: @@ -65,22 +74,31 @@ def run(self) -> None: print(message) raise SystemExit from error + @staticmethod + def send_response(response: Packet, connection_socket: socket.socket) -> None: + """Send a response to the client's request. + + :param response: The response object. + :param connection_socket: The socket to send the response over. + """ + record = response.to_bytes() + connection_socket.send(record) + def process_read_request( self, - connection_socket: socket.socket, packet: bytes, + connection_socket: socket.socket, ) -> None: """Respond to read requests. - :param sender_name: The name of the user who sent the read request. + :param packet: The read request packet to process. :param connection_socket: The connection socket to send the response on. - :return: The response to the read request. """ (sender_name,) = ReadRequest.decode_packet(packet) response = ReadResponse(self.messages.get(sender_name, [])) - record = response.to_bytes() - connection_socket.send(record) + self.send_response(response, connection_socket) + del self.messages.get(sender_name, [])[: response.num_messages] logger.info( "%s message(s) delivered to %s", @@ -120,10 +138,52 @@ def process_create_request( f'"{message.decode()}" to {receiver_name}', ) + def process_login_request(self, packet: bytes) -> None: + """Process a client requset to login. + + :param packet: A byte array containing the login request. + """ + (sender_name,) = LoginRequest.decode_packet(packet) + print("logged in", sender_name) + + def process_registration_request(self, packet: bytes) -> None: + """Process a client request to register a new name. + + :param packet: A byte array containing the registration request. + """ + sender_name, public_key = RegistrationRequest.decode_packet(packet) + if sender_name not in self.users: + self.users[sender_name] = public_key + print( + f"Registered {sender_name}", + f"with product {public_key.product}", + f"and exponent {public_key.exponent}", + ) + + else: + message = f"name {sender_name} already registered" + logger.error(message) + + def process_key_request( + self, + packet: bytes, + connection_socket: socket.socket, + ) -> None: + """Process a client request for a user's public key. + + :param packet: A byte array containing the key request. + :param connection_socket: The socket to send the response over. + """ + (requested_user,) = KeyRequest.decode_packet(packet) + print(f"Received request for {requested_user}'s key") + public_key = self.users[requested_user] + response = KeyResponse(public_key) + self.send_response(response, connection_socket) + def process_request(self, packet: bytes, connection_socket: socket.socket) -> None: """Process an incoming client request. - :param record: The packet received from a client. + :param packet: The packet received from a client. :param connection_socket: The socket to use for responding to read requests. """ message_type: MessageType @@ -131,11 +191,20 @@ def process_request(self, packet: bytes, connection_socket: socket.socket) -> No match message_type: case MessageType.READ: - self.process_read_request(connection_socket, packet) + self.process_read_request(packet, connection_socket) case MessageType.CREATE: self.process_create_request(packet) + case MessageType.LOGIN: + self.process_login_request(packet) + + case MessageType.REGISTER: + self.process_registration_request(packet) + + case MessageType.KEY_REQUEST: + self.process_key_request(packet, connection_socket) + case _: logging.error("Message of incorrect type received!") diff --git a/src/login_response.py b/src/login_response.py deleted file mode 100644 index d4cc98f..0000000 --- a/src/login_response.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Login response module. - -Defines class for encoding and decoding login response packets. -""" - -import logging -import struct - -from message_cipher.rsa_encrypter import RsaEncrypter - -from src.message_type import MessageType -from src.packets.packet import Packet - - -class LoginResponse(Packet, struct_format="!HB?QQ"): - """The LoginResponse class is used to encode and decode login response packets.""" - - def __init__(self, is_registered: bool, encryption_key: RsaEncrypter) -> None: - """Create a new login response packet. - - :param is_registered: ``True`` if the requesting user was registered. - :param encryption_key: An RSA public key for encrypting messages. - """ - self.encryption_key = encryption_key - self.is_registered = is_registered - self.packet = b"" - - def to_bytes(self) -> bytes: - """Encode a login response packet into a byte array.""" - logging.info("Creating login response") - - self.packet = struct.pack( - self.struct_format, - Packet.MAGIC_NUMBER, - MessageType.LOGIN.value, - self.is_registered, - self.encryption_key.product, - self.encryption_key.exponent, - ) - - return self.packet - - @classmethod - def decode_packet(cls, packet: bytes) -> tuple[bool, RsaEncrypter]: - """Decode a message response packet into its individual components. - - :param packet: The packet to be decoded. - :raises ValueError: If the packet is invalid. - :return: A tuple containing a boolean indicating if the - """ - header_fields, _ = cls.split_packet(packet) - magic_number, message_type, is_registered, product, exponent = header_fields - - if magic_number != Packet.MAGIC_NUMBER: - raise ValueError("Invalid magic number when decoding message response") - - try: - message_type = MessageType(message_type) - except ValueError as error: - raise ValueError( - "Invalid message type when decoding message response", - ) from error - if message_type != MessageType.LOGIN: - message = ( - f"Message type {message_type} found when decoding message response, " - "expected LOGIN" - ) - raise ValueError(message) - - encryption_key = RsaEncrypter(product, exponent) - - return is_registered, encryption_key diff --git a/src/message_type.py b/src/message_type.py index 6b62f05..02db173 100644 --- a/src/message_type.py +++ b/src/message_type.py @@ -12,6 +12,8 @@ class MessageType(Enum): LOGIN = 4 REGISTER = 5 MESSAGE = 6 + KEY_REQUEST = 7 + KEY_RESPONSE = 8 @staticmethod def from_str(string: str) -> "MessageType": @@ -23,5 +25,5 @@ def from_str(string: str) -> "MessageType": try: return MessageType[string.upper()] except KeyError as error: - message = f'Invalid message type: {string}, must be "read" or "create"' + message = f"Invalid message type: {string}" raise ValueError(message) from error diff --git a/src/packets/key_request.py b/src/packets/key_request.py new file mode 100644 index 0000000..6f681b9 --- /dev/null +++ b/src/packets/key_request.py @@ -0,0 +1,53 @@ +"""Public key request module. + +Defines the KeyRequest class which is used to encode and decode +public key request packets. +""" + +import logging +import struct + +from src.message_type import MessageType +from src.packets.packet import Packet + + +class KeyRequest( + Packet, + struct_format="!H", + message_type=MessageType.KEY_REQUEST, +): + """Encode and decode public key request packets.""" + + def __init__(self, user_name: str) -> None: + """Create a key request packet.""" + self.user_name = user_name + self.packet: bytes + + def to_bytes(self) -> bytes: + """Encode the key request packet into a byte array.""" + logging.info("Creating key request for %s", self.user_name) + + self.packet = super().to_bytes() + + self.packet += struct.pack( + self.struct_format, + len(self.user_name.encode()), + ) + + self.packet += self.user_name.encode() + + return self.packet + + @classmethod + def decode_packet(cls, packet: bytes) -> tuple[str]: + """Decode the key request packet into its individual components. + + :param packet: The packet to be decoded. + :return: A tuple containing the username of the user who's key + is being requested. + """ + _, payload = cls.split_packet(packet) + + user_name = payload.decode() + + return (user_name,) diff --git a/src/packets/key_response.py b/src/packets/key_response.py new file mode 100644 index 0000000..68b356c --- /dev/null +++ b/src/packets/key_response.py @@ -0,0 +1,70 @@ +"""Public key response module. + +Defines the KeyResponse class which is used to encode and decode +public key response packets. +""" + +import logging +import struct + +from message_cipher.rsa_encrypter import RsaEncrypter + +from src.message_type import MessageType +from src.packets.packet import Packet + + +class KeyResponse( + Packet, + struct_format="!HH", + message_type=MessageType.KEY_RESPONSE, +): + """Encode and decode public key response packets.""" + + def __init__(self, public_key: RsaEncrypter) -> None: + """Create a key response packet.""" + self.public_key = public_key + self.packet: bytes + + def to_bytes(self) -> bytes: + """Encode the key response packet into a byte array.""" + logging.info("Creating key response") + + product = self.public_key.product.to_bytes( + (self.public_key.product.bit_length() + 7) // 8, + ) + + exponent = self.public_key.exponent.to_bytes( + (self.public_key.exponent.bit_length() + 7) // 8, + ) + + self.packet = super().to_bytes() + + self.packet += struct.pack( + self.struct_format, + len(product), + len(exponent), + ) + + self.packet += product + self.packet += exponent + + return self.packet + + @classmethod + def decode_packet(cls, packet: bytes) -> tuple[RsaEncrypter]: + """Decode the key response packet into its individual components. + + :param packet: The packet to be decoded. + :return: A tuple containing the public key of the requested user. + """ + header_fields, payload = cls.split_packet(packet) + product_length, exponent_length = header_fields + + product = int.from_bytes(payload[:product_length]) + index = product_length + + exponent = int.from_bytes(payload[index : index + exponent_length]) + + public_key = RsaEncrypter(product, exponent) + + return (public_key,) diff --git a/src/login_request.py b/src/packets/login_request.py similarity index 53% rename from src/login_request.py rename to src/packets/login_request.py index 5a06ff4..712982e 100644 --- a/src/login_request.py +++ b/src/packets/login_request.py @@ -11,7 +11,7 @@ from src.packets.packet import Packet -class LoginRequest(Packet, struct_format="!HBB"): +class LoginRequest(Packet, struct_format="!HBB", message_type=MessageType.LOGIN): """The LoginRequest class is used to encode and decode login request packets.""" def __init__(self, user_name: str) -> None: @@ -25,8 +25,6 @@ def to_bytes(self) -> bytes: self.packet = struct.pack( self.struct_format, - Packet.MAGIC_NUMBER, - MessageType.LOGIN.value, len(self.user_name.encode()), ) @@ -38,28 +36,10 @@ def to_bytes(self) -> bytes: def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: """Decode the login request packet into its individual components. - :param packet: The packet to be decoded - :return: A tuple containing the username + :param packet: The packet to be decoded. + :return: A tuple containing the username. """ - header_fields, payload = cls.split_packet(packet) - magic_number, message_type, user_name_length = header_fields - - if magic_number != Packet.MAGIC_NUMBER: - raise ValueError("Invalid magic number when decoding message response") - - try: - message_type = MessageType(message_type) - except ValueError as error: - raise ValueError( - "Invalid message type when decoding message response", - ) from error - if message_type != MessageType.LOGIN: - message = ( - f"Message type {message_type} found when decoding" - " message response, expected LOGIN" - ) - - raise ValueError(message) + _, payload = cls.split_packet(packet) user_name = payload.decode() diff --git a/src/packets/registration_request.py b/src/packets/registration_request.py new file mode 100644 index 0000000..d980853 --- /dev/null +++ b/src/packets/registration_request.py @@ -0,0 +1,76 @@ +"""Registration request module. + +Defines the RegistrationRequest class which is used to encode and decode +registration request packets. +""" + +import logging +import struct + +from message_cipher.rsa_encrypter import RsaEncrypter + +from src.message_type import MessageType +from src.packets.packet import Packet + + +class RegistrationRequest( + Packet, + struct_format="!BHH", + message_type=MessageType.REGISTER, +): + """Encode and decode registration request packets.""" + + def __init__(self, user_name: str, public_key: RsaEncrypter) -> None: + """Create a login request packet.""" + self.user_name = user_name + self.public_key = public_key + self.packet: bytes + + def to_bytes(self) -> bytes: + """Encode the registration request packet into a byte array.""" + logging.info("Creating request to register as %s", self.user_name) + + user_name = self.user_name.encode() + + product = self.public_key.product.to_bytes( + (self.public_key.product.bit_length() + 7) // 8, + ) + + exponent = self.public_key.exponent.to_bytes( + (self.public_key.exponent.bit_length() + 7) // 8, + ) + + self.packet = super().to_bytes() + + self.packet += struct.pack( + self.struct_format, + len(user_name), + len(product), + len(exponent), + ) + + self.packet += user_name + self.packet += product + self.packet += exponent + + return self.packet + + @classmethod + def decode_packet(cls, packet: bytes) -> tuple[str, RsaEncrypter]: + """Decode the registration request packet into its individual components. + + :param packet: The packet to be decoded + :return: A tuple containing the username and public key + """ + header_fields, payload = cls.split_packet(packet) + user_name_length, product_length, exponent_length = header_fields + + user_name = payload[:user_name_length].decode() + index = user_name_length + + product = int.from_bytes(payload[index : index + product_length]) + index += product_length + + exponent = int.from_bytes(payload[index : index + exponent_length]) + + return user_name, RsaEncrypter(product, exponent) From d709929f56caba9657e61c9f79d277239b77b99a Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Wed, 13 Nov 2024 11:25:44 +1300 Subject: [PATCH 30/48] Catching ValueErrors when parsing commandline errors --- src/command_line_application.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/command_line_application.py b/src/command_line_application.py index f9c73b9..97ca1fc 100644 --- a/src/command_line_application.py +++ b/src/command_line_application.py @@ -53,7 +53,7 @@ def parse_arguments(self, arguments: list[str]) -> tuple[Any, ...]: strict=False, ) ) - except TypeError as error: + except (ValueError, TypeError) as error: logger.exception("Incorrect arguments") print(self.usage_prompt) print(error) From 7ea2108b70262c91bcdc3d3e0a3b2b9a1243fc5f Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Thu, 14 Nov 2024 07:31:45 +1300 Subject: [PATCH 31/48] Updated tests for new code and fixed all errors and failing tests --- client/client.py | 7 +- src/packets/login_request.py | 4 +- src/packets/packet.py | 10 +- src/packets/read_response.py | 4 +- tests/applications/test_client.py | 21 +- tests/applications/test_server.py | 14 +- tests/integration/test_more_messages.py | 8 +- tests/packets/test_create_request.py | 189 ++++++++++++ tests/packets/test_message_request.py | 280 ------------------ tests/packets/test_message_response.py | 120 -------- .../test_packet.py} | 15 +- tests/packets/test_read_request.py | 57 ++++ tests/packets/test_read_response.py | 78 +++++ 13 files changed, 381 insertions(+), 426 deletions(-) create mode 100644 tests/packets/test_create_request.py delete mode 100644 tests/packets/test_message_request.py delete mode 100644 tests/packets/test_message_response.py rename tests/{applications/test_command_line_application.py => packets/test_packet.py} (61%) create mode 100644 tests/packets/test_read_request.py create mode 100644 tests/packets/test_read_response.py diff --git a/client/client.py b/client/client.py index 21235cd..2353556 100644 --- a/client/client.py +++ b/client/client.py @@ -96,19 +96,21 @@ def parse_username(user_name: str) -> str: return user_name - def send_request(self, request: Packet) -> bytes: + def send_request(self, request: Packet, *, expect_response: bool = True) -> bytes: """Send a message request record to the server. :param request: The message request to be sent. :return: The server's response if applicable, otherwise ``None``. """ + response = b"" packet = request.to_bytes() try: with socket.socket() as connection_socket: connection_socket.settimeout(1) connection_socket.connect((self.host_name, self.port_number)) connection_socket.send(packet) - response = connection_socket.recv(4096) + if expect_response: + response = connection_socket.recv(4096) except (ConnectionRefusedError, TimeoutError) as error: message = ( @@ -135,6 +137,7 @@ def read_message_response(packet: bytes) -> None: :param packet: The message response from the server. """ + _, packet = Packet.decode_packet(packet) messages, more_messages = ReadResponse.decode_packet(packet) for sender, message in messages: diff --git a/src/packets/login_request.py b/src/packets/login_request.py index 712982e..24c2254 100644 --- a/src/packets/login_request.py +++ b/src/packets/login_request.py @@ -23,7 +23,9 @@ def to_bytes(self) -> bytes: """Encode the login request packet into a byte array.""" logging.info("Creating log-in request as %s", self.user_name) - self.packet = struct.pack( + self.packet = super().to_bytes() + + self.packet += struct.pack( self.struct_format, len(self.user_name.encode()), ) diff --git a/src/packets/packet.py b/src/packets/packet.py index ef117be..193a7f3 100644 --- a/src/packets/packet.py +++ b/src/packets/packet.py @@ -1,6 +1,7 @@ """Home to the ``Packet`` abstract class.""" import abc +import re import struct from typing import Any @@ -26,6 +27,8 @@ class MyPacket(Packet, struct_format="!HBBH"): message_type: MessageType + struct_format_regex = re.compile("^[@=<>!]?[xcbB?hHiIlLqQnNefdspP]+$") + @abc.abstractmethod def __init__(self, *args: tuple[Any, ...]) -> None: """Initialise the packet. @@ -86,6 +89,7 @@ def split_packet(cls, packet: bytes) -> tuple[tuple[Any, ...], bytes]: @classmethod def __init_subclass__( cls, + *, message_type: MessageType, struct_format: str, ) -> None: @@ -99,9 +103,11 @@ def __init_subclass__( :param message_type: The type of message the packet will encode. :param struct_format: The format of the packet data for the ``struct`` module. :param kwargs: No additional kwargs will be accepted. + + :raises ValueError: if the provided struct format is invalid. """ - if not struct_format: - raise ValueError("Must specify struct format") + if not re.match(Packet.struct_format_regex, struct_format): + raise ValueError("Invalid struct format") super().__init_subclass__() cls.struct_format = struct_format diff --git a/src/packets/read_response.py b/src/packets/read_response.py index cfd7030..ff36036 100644 --- a/src/packets/read_response.py +++ b/src/packets/read_response.py @@ -33,7 +33,9 @@ def to_bytes(self) -> bytes: """ logger.info("Creating message response for %s message(s)", self.num_messages) - self.packet = struct.pack( + self.packet = super().to_bytes() + + self.packet += struct.pack( self.struct_format, self.num_messages, self.more_messages, diff --git a/tests/applications/test_client.py b/tests/applications/test_client.py index fd1dc39..3879ede 100644 --- a/tests/applications/test_client.py +++ b/tests/applications/test_client.py @@ -5,7 +5,8 @@ from client import Client from src.message_type import MessageType -from src.packets.message_request import MessageRequest +from src.packets.create_request import CreateRequest +from src.packets.packet import Packet class TestClient(unittest.TestCase): @@ -41,8 +42,9 @@ def test_send_message_request(self) -> None: welcoming_socket.listen(1) # Send message request from the client - client.send_message_request( - MessageRequest(MessageType.CREATE, user_name, receiver_name, message), + client.send_request( + CreateRequest(user_name, receiver_name, message), + expect_response=False, ) # Accept connection from the client @@ -53,9 +55,10 @@ def test_send_message_request(self) -> None: with connection_socket: packet = connection_socket.recv(4096) - # Check that the packet is correct - request = MessageRequest.decode_packet(packet) - self.assertEqual( - (MessageType.CREATE, user_name, receiver_name, message.encode()), - request, - ) + message_type, packet = Packet.decode_packet(packet) + + self.assertEqual(MessageType.CREATE, message_type) + + expected = (user_name, receiver_name, message.encode()) + actual = CreateRequest.decode_packet(packet) + self.assertEqual(expected, actual) diff --git a/tests/applications/test_server.py b/tests/applications/test_server.py index 859c2b8..6e31bdc 100644 --- a/tests/applications/test_server.py +++ b/tests/applications/test_server.py @@ -4,7 +4,9 @@ import unittest from server import Server -from src.packets.message_response import MessageResponse +from src.packets.packet import Packet +from src.packets.read_request import ReadRequest +from src.packets.read_response import ReadResponse class TestServer(unittest.TestCase): @@ -41,12 +43,16 @@ def test_process_read_request(self) -> None: client_socket.connect((TestServer.hostname, TestServer.port_number)) server_connection_socket, _ = server_welcoming_socket.accept() + packet = ReadRequest(receiver_name).to_bytes() + _, packet = Packet.decode_packet(packet) + with server_connection_socket: - server.process_read_request(server_connection_socket, receiver_name) + server.process_read_request(packet, server_connection_socket) # Receive message from server - packet = client_socket.recv(1024) - response = MessageResponse.decode_packet(packet) + response_packet = client_socket.recv(1024) + _, response_packet = Packet.decode_packet(response_packet) + response = ReadResponse.decode_packet(response_packet) # Check that the message is correct self.assertEqual(([(sender_name, message.decode())], False), response) diff --git a/tests/integration/test_more_messages.py b/tests/integration/test_more_messages.py index da6441d..896e162 100644 --- a/tests/integration/test_more_messages.py +++ b/tests/integration/test_more_messages.py @@ -5,10 +5,12 @@ import threading import unittest +from src.packets.packet import Packet + sys.path.insert(0, "../../") import client import server -from src.packets.message_response import MessageResponse +from src.packets.read_response import ReadResponse class TestMoreMessages(unittest.TestCase): @@ -46,7 +48,9 @@ def test_more_messages(self) -> None: server_object.stop() server_thread.join() - messages, more_messages = MessageResponse.decode_packet(final_client.result) + _, packet = Packet.decode_packet(final_client.result) + + messages, more_messages = ReadResponse.decode_packet(packet) self.assertEqual(255, len(messages)) self.assertTrue(more_messages) diff --git a/tests/packets/test_create_request.py b/tests/packets/test_create_request.py new file mode 100644 index 0000000..0bbac63 --- /dev/null +++ b/tests/packets/test_create_request.py @@ -0,0 +1,189 @@ +"""``CreateRequest`` class test suite.""" + +import unittest + +from src.message_type import MessageType +from src.packets.create_request import CreateRequest +from src.packets.packet import Packet + + +class TestMessageRequestEncoding(unittest.TestCase): + """Test suite for encoding MessageRequest packets.""" + + def test_user_name_length_encoding(self) -> None: + """Tests that the length of the user's name is encoded correctly.""" + user_name = "Johnny" + receiver_name = "Jarod" + message = "Hello, World!" + packet = CreateRequest( + user_name, + receiver_name, + message, + ).to_bytes() + + _, payload = Packet.decode_packet(packet) + + expected = len(user_name.encode()) + actual = payload[0] + self.assertEqual(expected, actual) + + def test_receiver_name_length_encoding(self) -> None: + """Tests that the length of the receiver's name is encoded correctly.""" + user_name = "Jackson" + receiver_name = "Jake" + message = "Hello, World!" + packet = CreateRequest( + user_name, + receiver_name, + message, + ).to_bytes() + + _, payload = Packet.decode_packet(packet) + + expected = len(receiver_name.encode()) + actual = payload[1] + self.assertEqual(expected, actual) + + def test_message_length_encoding(self) -> None: + """Tests that the length of the message is encoded correctly.""" + user_name = "Jason" + receiver_name = "Jay" + message = "Hello, World!" + packet = CreateRequest( + user_name, + receiver_name, + message, + ).to_bytes() + + _, payload = Packet.decode_packet(packet) + + expected = len(message.encode()) + actual = (payload[2] << 8) | (payload[3] & 0xFF) + self.assertEqual(expected, actual) + + def test_user_name_encoding(self) -> None: + """Tests that the user's name is encoded correctly.""" + user_name = "Jason" + receiver_name = "Jay" + message = "Hello, World!" + packet = CreateRequest( + user_name, + receiver_name, + message, + ).to_bytes() + + _, payload = Packet.decode_packet(packet) + + expected = user_name + actual = payload[4 : 4 + len(user_name.encode())].decode() + self.assertEqual(expected, actual) + + def test_receiver_name_encoding(self) -> None: + """Tests that the receiver's name is encoded correctly.""" + user_name = "Jeff" + receiver_name = "Jesse" + message = "Hello, World!" + packet = CreateRequest( + user_name, + receiver_name, + message, + ).to_bytes() + + _, payload = Packet.decode_packet(packet) + + start_index = 4 + len(user_name.encode()) + + expected = receiver_name + actual = payload[ + start_index : start_index + len(receiver_name.encode()) + ].decode() + self.assertEqual(expected, actual) + + def test_message_encoding(self) -> None: + """Tests that the message is encoded correctly.""" + user_name = "Julian" + receiver_name = "Jimmy" + message = "Hello, World!" + packet = CreateRequest( + user_name, + receiver_name, + message, + ).to_bytes() + + _, payload = Packet.decode_packet(packet) + + start_index = 4 + len(user_name.encode()) + len(receiver_name.encode()) + + expected = message + actual = payload[start_index : start_index + len(message.encode())].decode() + self.assertEqual(expected, actual) + + +class TestCreateRequestDecoding(unittest.TestCase): + """Test suite for decoding ``CreateRequest`` packets.""" + + def setUp(self) -> None: + """Set up the testing environment.""" + self.message_type = MessageType.CREATE + self.user_name = "Jamie" + self.receiver_name = "Jonty" + self.message = "Hello, World!" + + packet = CreateRequest( + self.user_name, + self.receiver_name, + self.message, + ).to_bytes() + _, self.packet = Packet.decode_packet(packet) + + def test_user_name_decoding(self) -> None: + """Tests that the user's name is decoded correctly.""" + decoded_packet = CreateRequest.decode_packet(self.packet) + + expected = self.user_name + actual = decoded_packet[0] + self.assertEqual(expected, actual) + + def test_receiver_name_decoding(self) -> None: + """Tests that the receiver's name is decoded correctly.""" + decoded_packet = CreateRequest.decode_packet(self.packet) + + expected = self.receiver_name + actual = decoded_packet[1] + self.assertEqual(expected, actual) + + def test_message_decoding(self) -> None: + """Tests that the message is decoded correctly.""" + decoded_packet = CreateRequest.decode_packet(self.packet) + + expected = self.message + actual = decoded_packet[2].decode() + self.assertEqual(expected, actual) + + def test_insufficient_user_name_length(self) -> None: + """Tests that an exception is raised if the user's name has a length of zero.""" + packet = bytearray(self.packet) + packet[3] = 0 + + self.assertRaises(ValueError, CreateRequest.decode_packet, packet) + + def test_insufficient_receiver_name_length(self) -> None: + """Tests that an exception is raised. + + If the length of the receiver's name is zero. + """ + packet = bytearray(self.packet) + packet[1] = 0 + + self.assertRaises(ValueError, CreateRequest.decode_packet, packet) + + def test_insufficient_message_length(self) -> None: + """Tests that an exception is raised. + + If the length of the message is zero. + """ + packet = bytearray(self.packet) + packet[2] = 0 + packet[3] = 0 + + self.assertRaises(ValueError, CreateRequest.decode_packet, packet) diff --git a/tests/packets/test_message_request.py b/tests/packets/test_message_request.py deleted file mode 100644 index 3239737..0000000 --- a/tests/packets/test_message_request.py +++ /dev/null @@ -1,280 +0,0 @@ -"""``MessageRequest`` class test suite.""" - -import unittest - -from src.message_type import MessageType -from src.packets.message_request import MessageRequest -from src.packets.packet import Packet - - -class TestMessageRequestEncoding(unittest.TestCase): - """Test suite for encoding MessageRequest packets.""" - - def test_magic_number_encoding(self) -> None: - """Tests that the magic number is encoded correctly.""" - message_type = MessageType.READ - user_name = "Jamie" - receiver_name = "Jonty" - message = "Hello, World!" - packet = MessageRequest( - message_type, - user_name, - receiver_name, - message, - ).to_bytes() - - expected = Packet.MAGIC_NUMBER - actual = (packet[0] << 8) | (packet[1] & 0xFF) - self.assertEqual(expected, actual) - - def test_message_type_encoding(self) -> None: - """Tests that the message type is encoded correctly.""" - message_type = MessageType.CREATE - user_name = "Jamie" - receiver_name = "Jonty" - message = "Hello, World!" - packet = MessageRequest( - message_type, - user_name, - receiver_name, - message, - ).to_bytes() - - expected = message_type.value - actual = packet[2] - self.assertEqual(expected, actual) - - def test_user_name_length_encoding(self) -> None: - """Tests that the length of the user's name is encoded correctly.""" - message_type = MessageType.CREATE - user_name = "Johnny" - receiver_name = "Jarod" - message = "Hello, World!" - packet = MessageRequest( - message_type, - user_name, - receiver_name, - message, - ).to_bytes() - - expected = len(user_name.encode()) - actual = packet[3] - self.assertEqual(expected, actual) - - def test_receiver_name_length_encoding(self) -> None: - """Tests that the length of the receiver's name is encoded correctly.""" - message_type = MessageType.CREATE - user_name = "Jackson" - receiver_name = "Jake" - message = "Hello, World!" - packet = MessageRequest( - message_type, - user_name, - receiver_name, - message, - ).to_bytes() - - expected = len(receiver_name.encode()) - actual = packet[4] - self.assertEqual(expected, actual) - - def test_message_length_encoding(self) -> None: - """Tests that the length of the message is encoded correctly.""" - message_type = MessageType.CREATE - user_name = "Jason" - receiver_name = "Jay" - message = "Hello, World!" - packet = MessageRequest( - message_type, - user_name, - receiver_name, - message, - ).to_bytes() - - expected = len(message.encode()) - actual = (packet[5] << 8) | (packet[6] & 0xFF) - self.assertEqual(expected, actual) - - def test_user_name_encoding(self) -> None: - """Tests that the user's name is encoded correctly.""" - message_type = MessageType.CREATE - user_name = "Jason" - receiver_name = "Jay" - message = "Hello, World!" - packet = MessageRequest( - message_type, - user_name, - receiver_name, - message, - ).to_bytes() - - expected = user_name - actual = packet[7 : 7 + len(user_name.encode())].decode() - self.assertEqual(expected, actual) - - def test_receiver_name_encoding(self) -> None: - """Tests that the receiver's name is encoded correctly.""" - message_type = MessageType.CREATE - user_name = "Jeff" - receiver_name = "Jesse" - message = "Hello, World!" - packet = MessageRequest( - message_type, - user_name, - receiver_name, - message, - ).to_bytes() - - start_index = 7 + len(user_name.encode()) - - expected = receiver_name - actual = packet[ - start_index : start_index + len(receiver_name.encode()) - ].decode() - self.assertEqual(expected, actual) - - def test_message_encoding(self) -> None: - """Tests that the message is encoded correctly.""" - message_type = MessageType.CREATE - user_name = "Julian" - receiver_name = "Jimmy" - message = "Hello, World!" - packet = MessageRequest( - message_type, - user_name, - receiver_name, - message, - ).to_bytes() - - start_index = 7 + len(user_name.encode()) + len(receiver_name.encode()) - - expected = message - actual = packet[start_index : start_index + len(message.encode())].decode() - self.assertEqual(expected, actual) - - -class TestMessageRequestDecoding(unittest.TestCase): - """Test suite for decoding ``MessageRequest`` packets.""" - - def setUp(self) -> None: - """Set up the testing environment.""" - self.message_type = MessageType.CREATE - self.user_name = "Jamie" - self.receiver_name = "Jonty" - self.message = "Hello, World!" - self.packet = MessageRequest( - self.message_type, - self.user_name, - self.receiver_name, - self.message, - ).to_bytes() - - def test_message_type_decoding(self) -> None: - """Tests that the message type is decoded correctly.""" - decoded_packet = MessageRequest.decode_packet(self.packet) - expected = self.message_type - actual = decoded_packet[0] - self.assertEqual(expected, actual) - - def test_user_name_decoding(self) -> None: - """Tests that the user's name is decoded correctly.""" - decoded_packet = MessageRequest.decode_packet(self.packet) - expected = self.user_name - actual = decoded_packet[1] - self.assertEqual(expected, actual) - - def test_receiver_name_decoding(self) -> None: - """Tests that the receiver's name is decoded correctly.""" - decoded_packet = MessageRequest.decode_packet(self.packet) - expected = self.receiver_name - actual = decoded_packet[2] - self.assertEqual(expected, actual) - - def test_message_decoding(self) -> None: - """Tests that the message is decoded correctly.""" - decoded_packet = MessageRequest.decode_packet(self.packet) - expected = self.message - actual = decoded_packet[3].decode() - self.assertEqual(expected, actual) - - def test_incorrect_magic_number(self) -> None: - """Tests that an exception is raised if the magic number is incorrect.""" - packet = bytearray(self.packet) - packet[0] = 0 - self.packet = bytes(packet) - - self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) - - def test_invalid_message_type(self) -> None: - """Tests that an exception is raised if the message type is invalid.""" - packet = bytearray(self.packet) - packet[2] = 0 - self.packet = bytes(packet) - - self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) - - def test_response_message_type(self) -> None: - """Tests that an exception is raised if the message type is RESPONSE.""" - packet = bytearray(self.packet) - packet[2] = 3 - self.packet = bytes(packet) - - self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) - - def test_insufficient_user_name_length(self) -> None: - """Tests that an exception is raised if the user's name has a length of zero.""" - packet = bytearray(self.packet) - packet[3] = 0 - self.packet = bytes(packet) - - self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) - - def test_non_zero_receiver_name_length_for_read(self) -> None: - """Tests that an exception is raised. - - If the length of the receiver's name is non-zero for a read request. - """ - packet = bytearray(self.packet) - packet[2] = MessageType.READ.value - packet[4] = 1 - self.packet = bytes(packet) - - self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) - - def test_non_zero_message_length_for_read(self) -> None: - """Tests that an exception is raised. - - If the length of the message is non-zero for a read request. - """ - packet = bytearray(self.packet) - packet[2] = MessageType.READ.value - packet[4] = 0 - packet[6] = 1 - self.packet = bytes(packet) - - self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) - - def test_insufficient_receiver_name_length_for_create(self) -> None: - """Tests that an exception is raised. - - If the length of the receiver's name is zero for a create request. - """ - packet = bytearray(self.packet) - packet[2] = MessageType.CREATE.value - packet[4] = 0 - self.packet = bytes(packet) - - self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) - - def test_insufficient_message_length_for_create(self) -> None: - """Tests that an exception is raised. - - If the length of the message is zero for a create request. - """ - packet = bytearray(self.packet) - packet[2] = MessageType.CREATE.value - packet[5] = 0 - packet[6] = 0 - self.packet = bytes(packet) - - self.assertRaises(ValueError, MessageRequest.decode_packet, self.packet) diff --git a/tests/packets/test_message_response.py b/tests/packets/test_message_response.py deleted file mode 100644 index ae28796..0000000 --- a/tests/packets/test_message_response.py +++ /dev/null @@ -1,120 +0,0 @@ -"""MessageResponse class test suite.""" - -import unittest - -from src.message_type import MessageType -from src.packets.message_response import MessageResponse -from src.packets.packet import Packet - - -class TestMessageResponseEncoding(unittest.TestCase): - """Test suite for encoding MessageResponse packets.""" - - def test_magic_number_encoding(self) -> None: - """Tests that the magic number is encoded correctly.""" - messages: list[tuple[str, bytes]] = [] - packet = MessageResponse(messages).to_bytes() - - expected = Packet.MAGIC_NUMBER - actual = (packet[0] << 8) | (packet[1] & 0xFF) - self.assertEqual(expected, actual) - - def test_message_type_encoding(self) -> None: - """Tests that the message type is encoded correctly.""" - messages: list[tuple[str, bytes]] = [] - packet = MessageResponse(messages).to_bytes() - - expected = MessageType.RESPONSE.value - actual = packet[2] - self.assertEqual(expected, actual) - - def test_num_messages_encoding(self) -> None: - """Tests that the number of messages is encoded correctly.""" - messages = [ - ("Harry", b"Hello John!"), - ("John", b"Hello Harry!"), - ] - packet = MessageResponse(messages).to_bytes() - - expected = len(messages) - actual = packet[3] - self.assertEqual(expected, actual) - - def test_more_messages_encoding(self) -> None: - """Tests that the more messages flag is encoded correctly.""" - messages: list[tuple[str, bytes]] = [] - packet = MessageResponse(messages).to_bytes() - - expected = False - actual = packet[4] - self.assertEqual(expected, actual) - - -class TestMessageResponseDecoding(unittest.TestCase): - """Test suite for decoding MessageResponse packets.""" - - def test_messages_decoding(self) -> None: - """Tests that the messages are decoded correctly.""" - messages = [ - ("Harry", b"Hello John!"), - ("John", b"Hello Harry!"), - ] - packet = MessageResponse(messages).to_bytes() - - expected = [ - ("Harry", "Hello John!"), - ("John", "Hello Harry!"), - ] - actual = MessageResponse.decode_packet(packet)[0] - self.assertEqual(expected, actual) - - def test_more_messages_decoding_false(self) -> None: - """Tests that the more messages flag is decoded correctly.""" - messages = [ - ("Harry", b"Hello John!"), - ("John", b"Hello Harry!"), - ] - packet = MessageResponse(messages).to_bytes() - - expected = False - actual = MessageResponse.decode_packet(packet)[1] - self.assertEqual(expected, actual) - - def test_more_messages_decoding_true(self) -> None: - """Tests that the more messages flag is decoded correctly.""" - messages = [("Harry", b"Hello John!")] * 256 - packet = MessageResponse(messages).to_bytes() - - expected = True - actual = MessageResponse.decode_packet(packet)[1] - self.assertEqual(expected, actual) - - def test_incorrect_magic_number(self) -> None: - """Tests that a ``ValueError`` is raised when the magic number is incorrect.""" - messages: list[tuple[str, bytes]] = [] - packet = MessageResponse(messages).to_bytes() - - packet = bytearray(packet) - packet[0] = 0x00 - - self.assertRaises(ValueError, MessageResponse.decode_packet, bytes(packet)) - - def test_invalid_message_type(self) -> None: - """Tests that a ``ValueError`` is raised when the message type is invalid.""" - messages: list[tuple[str, bytes]] = [] - packet = MessageResponse(messages).to_bytes() - - packet = bytearray(packet) - packet[2] = 4 - - self.assertRaises(ValueError, MessageResponse.decode_packet, bytes(packet)) - - def test_incorrect_message_type(self) -> None: - """Tests that a ``ValueError`` is raised when the message type is incorrect.""" - messages: list[tuple[str, bytes]] = [] - packet = MessageResponse(messages).to_bytes() - - packet = bytearray(packet) - packet[2] = MessageType.CREATE.value - - self.assertRaises(ValueError, MessageResponse.decode_packet, bytes(packet)) diff --git a/tests/applications/test_command_line_application.py b/tests/packets/test_packet.py similarity index 61% rename from tests/applications/test_command_line_application.py rename to tests/packets/test_packet.py index 467ad66..849d415 100644 --- a/tests/applications/test_command_line_application.py +++ b/tests/packets/test_packet.py @@ -3,22 +3,27 @@ import unittest from typing import Any +from src.message_type import MessageType from src.packets.packet import Packet -class TestClientParseArguments(unittest.TestCase): - """Test suite for Client class.""" +class TestPacket(unittest.TestCase): + """Test suite for Packet class.""" def test_fail_subclass(self) -> None: """Test Packet __init_subclass__ function. Ensure that we cannot subclass from CommandLineApplication - without specifying a struct format. + without specifying a struct format and message type. """ with self.assertRaises(ValueError): - class NoStructFormat(Packet): - """No ``struct_formatt`` passed so class will not be created.""" + class NoStructFormat( + Packet, + struct_format="invalid format", + message_type=MessageType.LOGIN, + ): + """Invalid ``struct_format`` passed, so class will not be created.""" def __init__(self, *args: tuple[Any, ...]) -> None: pass diff --git a/tests/packets/test_read_request.py b/tests/packets/test_read_request.py new file mode 100644 index 0000000..49ee3fc --- /dev/null +++ b/tests/packets/test_read_request.py @@ -0,0 +1,57 @@ +"""``ReadRequest`` class test suite.""" + +import unittest + +from src.packets.packet import Packet +from src.packets.read_request import ReadRequest + + +class TestReadRequestEncoding(unittest.TestCase): + """Test suite for encoding ReadRequest packets.""" + + def test_user_name_length_encoding(self) -> None: + """Tests that the length of the user's name is encoded correctly.""" + user_name = "Johnny" + packet = ReadRequest(user_name).to_bytes() + _, payload = Packet.decode_packet(packet) + + expected = len(user_name.encode()) + actual = payload[0] + self.assertEqual(expected, actual) + + def test_user_name_encoding(self) -> None: + """Tests that the user's name is placed correctly in the packet.""" + user_name = "Johnny" + packet = ReadRequest(user_name).to_bytes() + _, payload = Packet.decode_packet(packet) + + expected = user_name + actual = payload[1:].decode() + self.assertEqual(expected, actual) + + +class TestReadRequestDecoding(unittest.TestCase): + """Test suite for decoding ``ReadRequest`` packets.""" + + def setUp(self) -> None: + """Set up the testing environment.""" + self.user_name = "Jamie" + + packet = ReadRequest(self.user_name).to_bytes() + _, self.packet = Packet.decode_packet(packet) + + def test_user_name_decoding(self) -> None: + """Tests that the user's name is decoded correctly.""" + decoded_packet = ReadRequest.decode_packet(self.packet) + + expected = self.user_name + actual = decoded_packet[0] + self.assertEqual(expected, actual) + + def test_insufficient_user_name_length(self) -> None: + """Tests that an exception is raised if the user's name has a length of zero.""" + packet = bytearray(self.packet) + packet[0] = 0 + self.packet = bytes(packet) + + self.assertRaises(ValueError, ReadRequest.decode_packet, self.packet) diff --git a/tests/packets/test_read_response.py b/tests/packets/test_read_response.py new file mode 100644 index 0000000..2d0ce5c --- /dev/null +++ b/tests/packets/test_read_response.py @@ -0,0 +1,78 @@ +"""MessageResponse class test suite.""" + +import unittest + +from src.packets.packet import Packet +from src.packets.read_response import ReadResponse + + +class TestMessageResponseEncoding(unittest.TestCase): + """Test suite for encoding MessageResponse packets.""" + + def test_num_messages_encoding(self) -> None: + """Tests that the number of messages is encoded correctly.""" + messages = [ + ("Harry", b"Hello John!"), + ("John", b"Hello Harry!"), + ] + packet = ReadResponse(messages).to_bytes() + _, packet = Packet.decode_packet(packet) + + expected = len(messages) + actual = packet[0] + self.assertEqual(expected, actual) + + def test_more_messages_encoding(self) -> None: + """Tests that the more messages flag is encoded correctly.""" + messages: list[tuple[str, bytes]] = [] + packet = ReadResponse(messages).to_bytes() + _, packet = Packet.decode_packet(packet) + + expected = False + actual = packet[1] + self.assertEqual(expected, actual) + + +class TestMessageResponseDecoding(unittest.TestCase): + """Test suite for decoding MessageResponse packets.""" + + def test_messages_decoding(self) -> None: + """Tests that the messages are decoded correctly.""" + messages = [ + ("Harry", b"Hello John!"), + ("John", b"Hello Harry!"), + ] + packet = ReadResponse(messages).to_bytes() + + _, packet = Packet.decode_packet(packet) + + expected = [ + ("Harry", "Hello John!"), + ("John", "Hello Harry!"), + ] + actual = ReadResponse.decode_packet(packet)[0] + self.assertEqual(expected, actual) + + def test_more_messages_decoding_false(self) -> None: + """Tests that the more messages flag is decoded correctly.""" + messages = [ + ("Harry", b"Hello John!"), + ("John", b"Hello Harry!"), + ] + packet = ReadResponse(messages).to_bytes() + _, packet = Packet.decode_packet(packet) + + expected = False + actual = ReadResponse.decode_packet(packet)[1] + self.assertEqual(expected, actual) + + def test_more_messages_decoding_true(self) -> None: + """Tests that the more messages flag is decoded correctly.""" + messages = [("Harry", b"Hello John!")] * 256 + packet = ReadResponse(messages).to_bytes() + + _, packet = Packet.decode_packet(packet) + + expected = True + actual = ReadResponse.decode_packet(packet)[1] + self.assertEqual(expected, actual) From c3f2cb03a7df9ee3003ffbe23aa88a83595b6ea7 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Thu, 14 Nov 2024 09:32:14 +1300 Subject: [PATCH 32/48] updated logging configuration so all printing is done through the logging package, allowing for removal of all print statements --- client/client.py | 42 +++++++++++++++-------------- logging_config.py | 25 ++++++++--------- pyproject.toml | 1 - server/__main__.py | 3 +-- server/server.py | 27 +++++++------------ src/command_line_application.py | 9 +++---- src/packets/create_request.py | 2 +- src/packets/key_request.py | 2 +- src/packets/read_request.py | 2 +- src/packets/read_response.py | 4 +-- src/packets/registration_request.py | 2 +- 11 files changed, 55 insertions(+), 64 deletions(-) diff --git a/client/client.py b/client/client.py index 2353556..02a6ab2 100644 --- a/client/client.py +++ b/client/client.py @@ -47,7 +47,7 @@ def __init__(self, arguments: list[str]) -> None: parsed_arguments ) - logger.info( + logger.debug( "Client for %s port %s created by %s to send %s request", self.host_name, self.port_number, @@ -73,7 +73,7 @@ def parse_hostname(host_name: str) -> str: message = ( 'Invalid host name, must be an IP address, domain name, or "localhost"' ) - logger.exception(message) + logger.debug(message, exc_info=True) raise ValueError(message) from error return host_name @@ -119,15 +119,13 @@ def send_request(self, request: Packet, *, expect_response: bool = True) -> byte else "Connection timed out, likely due to invalid host name" ) logger.exception(message) - print(message) raise SystemExit from error logger.info( "%s record sent as %s", - self.message_type.name.lower(), + self.message_type.name.capitalize(), self.user_name, ) - print(f"{self.message_type.name.lower()} record sent as {self.user_name}") return response @@ -141,15 +139,12 @@ def read_message_response(packet: bytes) -> None: messages, more_messages = ReadResponse.decode_packet(packet) for sender, message in messages: - logger.info('Received %s\'s message "%s"', sender, message) - print(f"Message from {sender}:\n{message}\n") + logger.info("\nMessage from %s:\n%s", sender, message) if len(messages) == 0: - logger.info("Response contained no messages") - print("No messages available") + logger.info("No messages available") elif more_messages: - logger.info("Server has more messages available for this user") - print("More messages available, please send another request") + logger.info("More messages available, please send another request") def send_read_request(self) -> None: """Send a read request to the server.""" @@ -178,7 +173,7 @@ def send_create_request( else: self.message = message - logger.info( + logger.debug( 'User specified message to %s: "%s"', self.receiver_name, self.message, @@ -193,14 +188,17 @@ def send_create_request( def send_login_request(self) -> None: """Send a login request to the server.""" - print("logging in") + logger.info("Logging in user") def send_registration_request(self) -> None: """Send a login request to the server.""" - public_key, private_key = RSA() + public_key, _private_key = RSA() - print(f"Creatged product {public_key.product}") - print(f"Created exponent {public_key.exponent}") + logger.info( + "Created key %s for user %s", + (public_key.product, public_key.exponent), + self.user_name, + ) request = RegistrationRequest(self.user_name, public_key) self.send_request(request) @@ -208,7 +206,7 @@ def send_registration_request(self) -> None: def send_key_request(self, receiver_name: str | None) -> None: """Send a pubblic key request to the server.""" if not receiver_name: - receiver_name = input("Who's key are we requesting?") + receiver_name = input("Who's key are we requesting? ") request = KeyRequest(receiver_name) response = self.send_request(request) @@ -217,8 +215,12 @@ def send_key_request(self, receiver_name: str | None) -> None: raise ValueError("Wrong response type!") (public_key,) = KeyResponse.decode_packet(payload) - print(f"Received product {public_key.product}") - print(f"Received exponent {public_key.exponent}") + + logger.info( + "Received %s's key:\n%s", + receiver_name, + (public_key.product, public_key.exponent), + ) def run( self, @@ -250,7 +252,7 @@ def run( self.send_key_request(receiver_name) case _: - print("Oopsies, wrong message type!") + logger.error("Oopsies, wrong message type!") @property def result(self) -> bytes: diff --git a/logging_config.py b/logging_config.py index ec936a7..a11f3a9 100644 --- a/logging_config.py +++ b/logging_config.py @@ -26,6 +26,9 @@ class PathnameFormatter(logging.Formatter): def format(self, record: logging.LogRecord) -> str: """Make the filename clickable in PyCharm.""" record.pathname = record.name.replace(".", "/") + ".py:" + str(record.lineno) + + # Remove newlines from logs being sent to a file + record.msg = record.msg.lstrip("\n").replace("\n", " ") return super().format(record) @@ -39,21 +42,19 @@ def configure_logging(package_name: str) -> None: # ruff: noqa: DTZ005 file_name = datetime.datetime.now().strftime("%d-%m-%y %H:%M:%S") - (pathlib.Path("logs") / package_name).parent.mkdir(parents=True, exist_ok=True) - file_handler = logging.FileHandler(f"logs/{package_name}/{file_name}.log") + log_folder = pathlib.Path("logs") / package_name + log_folder.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_folder / (file_name + ".log")) file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(file_formatter) - console_formatter = PathnameFormatter( - "%(levelname)-8s - %(pathname)-35s - %(message)s", - ) + console_formatter = logging.Formatter("%(message)s") - # ruff: noqa: ERA001 - # To enable printing of logs to stdout, enable below code - # stdout_handler = logging.StreamHandler(sys.stdout) - # stdout_handler.setLevel(logging.DEBUG) - # stdout_handler.addFilter(StdoutHandlerFilter()) - # stdout_handler.setFormatter(console_formatter) + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.INFO) + stdout_handler.addFilter(StdoutHandlerFilter()) + stdout_handler.setFormatter(console_formatter) stderr_handler = logging.StreamHandler(sys.stderr) stderr_handler.setLevel(logging.WARNING) @@ -61,5 +62,5 @@ def configure_logging(package_name: str) -> None: logging.basicConfig( level=logging.DEBUG, - handlers=[file_handler, stderr_handler], + handlers=[stdout_handler, stderr_handler, file_handler], ) diff --git a/pyproject.toml b/pyproject.toml index c6ca6f0..538114d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ ignore = [ "D213", "EM101", "TRY003", - "T201", "FBT001", "ANN101", "ANN102", diff --git a/server/__main__.py b/server/__main__.py index 570e181..6216cd0 100644 --- a/server/__main__.py +++ b/server/__main__.py @@ -23,8 +23,7 @@ def main() -> None: except SystemExit: sys.exit(1) except KeyboardInterrupt: - logger.info("Server shut down due to keyboard interrupt") - print("\nServer shut down") + logger.info("Server shut down") sys.exit(0) diff --git a/server/server.py b/server/server.py index 4076fea..f60abed 100644 --- a/server/server.py +++ b/server/server.py @@ -63,7 +63,6 @@ def run(self) -> None: self.hostname, self.port_number, ) - print(f"starting up on {self.hostname} port {self.port_number}") while self.running: self.run_server(welcoming_socket) @@ -71,7 +70,6 @@ def run(self) -> None: except OSError as error: message = "Error binding socket on provided port" logger.exception(message) - print(message) raise SystemExit from error @staticmethod @@ -105,7 +103,6 @@ def process_read_request( response.num_messages, sender_name, ) - print(f"{response.num_messages} message(s) delivered to {sender_name}") def process_create_request( self, @@ -133,10 +130,6 @@ def process_create_request( receiver_name, message.decode(), ) - print( - f"{sender_name} sends the message " - f'"{message.decode()}" to {receiver_name}', - ) def process_login_request(self, packet: bytes) -> None: """Process a client requset to login. @@ -144,7 +137,7 @@ def process_login_request(self, packet: bytes) -> None: :param packet: A byte array containing the login request. """ (sender_name,) = LoginRequest.decode_packet(packet) - print("logged in", sender_name) + logger.info("logged in %s", sender_name) def process_registration_request(self, packet: bytes) -> None: """Process a client request to register a new name. @@ -154,10 +147,10 @@ def process_registration_request(self, packet: bytes) -> None: sender_name, public_key = RegistrationRequest.decode_packet(packet) if sender_name not in self.users: self.users[sender_name] = public_key - print( - f"Registered {sender_name}", - f"with product {public_key.product}", - f"and exponent {public_key.exponent}", + logger.info( + "Registered %s with key %s", + sender_name, + (public_key.product, public_key.exponent), ) else: @@ -175,7 +168,8 @@ def process_key_request( :param connection_socket: The socket to send the response over. """ (requested_user,) = KeyRequest.decode_packet(packet) - print(f"Received request for {requested_user}'s key") + logger.info("Received request for %s's key", requested_user) + public_key = self.users[requested_user] response = KeyResponse(public_key) self.send_response(response, connection_socket) @@ -222,8 +216,7 @@ def run_server(self, welcoming_socket: socket.socket) -> None: connection_socket.settimeout(1) - logger.info("New client connection from %s", client_address) - print("New client connection from", client_address) + logger.info("\nNew client connection from %s", client_address) try: with connection_socket: @@ -233,13 +226,11 @@ def run_server(self, welcoming_socket: socket.socket) -> None: except TimeoutError: error_message = "Timed out while waiting for message request" logger.exception(error_message) - print(error_message) except ValueError: error_message = "Message request discarded" logger.exception(error_message) - print(error_message) def stop(self) -> None: """Stop the server.""" - print("Stopping server.") + logger.info("Stopping server.") self.running = False diff --git a/src/command_line_application.py b/src/command_line_application.py index 97ca1fc..7124f0c 100644 --- a/src/command_line_application.py +++ b/src/command_line_application.py @@ -40,8 +40,8 @@ def parse_arguments(self, arguments: list[str]) -> tuple[Any, ...]: """ if len(arguments) != len(self.parameters): message = f"Invalid number of arguments, must be {len(self.parameters)}" - print(self.usage_prompt) - print(message) + logger.error(self.usage_prompt) + logger.error(message) raise SystemExit(message) try: @@ -54,9 +54,8 @@ def parse_arguments(self, arguments: list[str]) -> tuple[Any, ...]: ) ) except (ValueError, TypeError) as error: - logger.exception("Incorrect arguments") - print(self.usage_prompt) - print(error) + logger.log(logging.ERROR, "%s\n%s", self.usage_prompt, error) + logger.debug(error, exc_info=True) raise SystemExit from error return parsed_arguments diff --git a/src/packets/create_request.py b/src/packets/create_request.py index 9ff7b3a..1d42983 100644 --- a/src/packets/create_request.py +++ b/src/packets/create_request.py @@ -42,7 +42,7 @@ def to_bytes(self) -> bytes: :return: A byte array holding the create request. """ - logger.info( + logger.debug( 'Creating CREATE request to send %s the message "%s" from %s', self.receiver_name, self.message, diff --git a/src/packets/key_request.py b/src/packets/key_request.py index 6f681b9..737aa78 100644 --- a/src/packets/key_request.py +++ b/src/packets/key_request.py @@ -25,7 +25,7 @@ def __init__(self, user_name: str) -> None: def to_bytes(self) -> bytes: """Encode the key request packet into a byte array.""" - logging.info("Creating key request for %s", self.user_name) + logging.debug("Creating key request for %s", self.user_name) self.packet = super().to_bytes() diff --git a/src/packets/read_request.py b/src/packets/read_request.py index bf24366..4754988 100644 --- a/src/packets/read_request.py +++ b/src/packets/read_request.py @@ -37,7 +37,7 @@ def to_bytes(self) -> bytes: :return: An array of bytes holding the message request. """ - logger.info("Creating READ request from %s", self.user_name) + logger.debug("Creating READ request from %s", self.user_name) self.packet = super().to_bytes() diff --git a/src/packets/read_response.py b/src/packets/read_response.py index ff36036..fb843ce 100644 --- a/src/packets/read_response.py +++ b/src/packets/read_response.py @@ -31,7 +31,7 @@ def to_bytes(self) -> bytes: :return: A byte array holding the message response. """ - logger.info("Creating message response for %s message(s)", self.num_messages) + logger.debug("Creating message response for %s message(s)", self.num_messages) self.packet = super().to_bytes() @@ -64,7 +64,7 @@ def decode_packet(cls, packet: bytes) -> tuple[list[tuple[str, str]], bool]: sender_name, message, remaining_messages = Message.decode_packet( remaining_messages, ) - logger.info('Decoded message from %s: "%s"', sender_name, message) + logger.debug('Decoded message from %s: "%s"', sender_name, message) messages.append((sender_name, message)) return messages, more_messages diff --git a/src/packets/registration_request.py b/src/packets/registration_request.py index d980853..61a0ee5 100644 --- a/src/packets/registration_request.py +++ b/src/packets/registration_request.py @@ -28,7 +28,7 @@ def __init__(self, user_name: str, public_key: RsaEncrypter) -> None: def to_bytes(self) -> bytes: """Encode the registration request packet into a byte array.""" - logging.info("Creating request to register as %s", self.user_name) + logging.debug("Creating request to register as %s", self.user_name) user_name = self.user_name.encode() From ca9952dee219c4ee4bfaf3dbf9f66d00afbf1cda Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Fri, 15 Nov 2024 13:21:57 +1300 Subject: [PATCH 33/48] Added user registration and login (sort of) --- client/client.py | 50 ++++++++++++++++++++--------- server/server.py | 41 ++++++++++++++++++----- src/packets/key_response.py | 16 ++++----- src/packets/login_request.py | 6 ++-- src/packets/login_response.py | 46 ++++++++++++++++++++++++++ src/packets/registration_request.py | 16 ++++----- 6 files changed, 133 insertions(+), 42 deletions(-) create mode 100644 src/packets/login_response.py diff --git a/client/client.py b/client/client.py index 02a6ab2..e6a9a39 100644 --- a/client/client.py +++ b/client/client.py @@ -4,13 +4,15 @@ import socket from collections import OrderedDict -from message_cipher.rsa_system import RSA +import rsa from src.command_line_application import CommandLineApplication from src.message_type import MessageType from src.packets.create_request import CreateRequest from src.packets.key_request import KeyRequest from src.packets.key_response import KeyResponse +from src.packets.login_request import LoginRequest +from src.packets.login_response import LoginResponse from src.packets.packet import Packet from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse @@ -55,9 +57,18 @@ def __init__(self, arguments: list[str]) -> None: self.message_type.name.lower(), ) + self.public_key, self.__private_key = rsa.newkeys(512) + + logger.debug( + "Created key %s for user %s", + (self.public_key.n, self.public_key.e), + self.user_name, + ) + self.receiver_name = "" self.message = "" self.response: bytes | None = None + self.session_token: bytes | None = None @staticmethod def parse_hostname(host_name: str) -> str: @@ -184,24 +195,32 @@ def send_create_request( self.receiver_name, self.message, ) - self.send_request(request) + self.send_request(request, expect_response=False) def send_login_request(self) -> None: """Send a login request to the server.""" - logger.info("Logging in user") + request = LoginRequest(self.user_name) + response = self.send_request(request) - def send_registration_request(self) -> None: - """Send a login request to the server.""" - public_key, _private_key = RSA() + message_type, packet = Packet.decode_packet(response) + if message_type != MessageType.LOGIN: + raise RuntimeError("Recieved incorrect type response from server") - logger.info( - "Created key %s for user %s", - (public_key.product, public_key.exponent), - self.user_name, - ) + (encrypted_session_token,) = LoginResponse.decode_packet(packet) + logger.debug("Received encrypted token bytes %s", encrypted_session_token) - request = RegistrationRequest(self.user_name, public_key) - self.send_request(request) + if len(encrypted_session_token) == 0: + logger.error("You are not registered! Please register before logging in") + raise SystemExit + + self.session_token = rsa.decrypt(encrypted_session_token, self.__private_key) + logger.debug("Storing provided session token %s", self.session_token) + logger.info("Now logged in as %s", self.user_name) + + def send_registration_request(self) -> None: + """Send a registration request to the server.""" + request = RegistrationRequest(self.user_name, self.public_key) + self.send_request(request, expect_response=False) def send_key_request(self, receiver_name: str | None) -> None: """Send a pubblic key request to the server.""" @@ -211,15 +230,16 @@ def send_key_request(self, receiver_name: str | None) -> None: request = KeyRequest(receiver_name) response = self.send_request(request) message_type, payload = Packet.decode_packet(response) + if message_type != MessageType.KEY_RESPONSE: - raise ValueError("Wrong response type!") + raise ValueError("Recieved incorrect type response from server") (public_key,) = KeyResponse.decode_packet(payload) logger.info( "Received %s's key:\n%s", receiver_name, - (public_key.product, public_key.exponent), + (public_key.n, public_key.e), ) def run( diff --git a/server/server.py b/server/server.py index f60abed..3f6d756 100644 --- a/server/server.py +++ b/server/server.py @@ -1,12 +1,11 @@ """Home to the ``Server`` class.""" import logging +import secrets import socket from collections import OrderedDict -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from message_cipher.rsa_encrypter import RsaEncrypter +import rsa from src.command_line_application import CommandLineApplication from src.message_type import MessageType @@ -14,6 +13,7 @@ from src.packets.key_request import KeyRequest from src.packets.key_response import KeyResponse from src.packets.login_request import LoginRequest +from src.packets.login_response import LoginResponse from src.packets.packet import Packet from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse @@ -43,7 +43,8 @@ def __init__(self, arguments: list[str]) -> None: self.running = True self.hostname = "localhost" - self.users: dict[str, RsaEncrypter] = {} + self.users: dict[str, rsa.PublicKey] = {} + self.sessions: dict[str, bytes] = {} self.messages: dict[str, list[tuple[str, bytes]]] = {} def run(self) -> None: @@ -82,6 +83,11 @@ def send_response(response: Packet, connection_socket: socket.socket) -> None: record = response.to_bytes() connection_socket.send(record) + @staticmethod + def generate_session_token() -> bytes: + """Generate a random session token.""" + return secrets.token_bytes() + def process_read_request( self, packet: bytes, @@ -131,13 +137,32 @@ def process_create_request( message.decode(), ) - def process_login_request(self, packet: bytes) -> None: + def process_login_request( + self, + packet: bytes, + connection_socket: socket.socket, + ) -> None: """Process a client requset to login. :param packet: A byte array containing the login request. """ (sender_name,) = LoginRequest.decode_packet(packet) - logger.info("logged in %s", sender_name) + + if sender_name not in self.users: + response = LoginResponse(b"") + self.send_response(response, connection_socket) + logger.info("Unregistered user %s attempted to login", sender_name) + return + + session_token = self.generate_session_token() + self.sessions[sender_name] = session_token + logger.debug("Gave %s the token %s", sender_name, session_token) + + senders_public_key = self.users[sender_name] + encrypted_session_token = rsa.encrypt(session_token, senders_public_key) + logger.debug("Encrypted token to %s", encrypted_session_token) + response = LoginResponse(encrypted_session_token) + self.send_response(response, connection_socket) def process_registration_request(self, packet: bytes) -> None: """Process a client request to register a new name. @@ -150,7 +175,7 @@ def process_registration_request(self, packet: bytes) -> None: logger.info( "Registered %s with key %s", sender_name, - (public_key.product, public_key.exponent), + (public_key.n, public_key.e), ) else: @@ -191,7 +216,7 @@ def process_request(self, packet: bytes, connection_socket: socket.socket) -> No self.process_create_request(packet) case MessageType.LOGIN: - self.process_login_request(packet) + self.process_login_request(packet, connection_socket) case MessageType.REGISTER: self.process_registration_request(packet) diff --git a/src/packets/key_response.py b/src/packets/key_response.py index 68b356c..16ea35c 100644 --- a/src/packets/key_response.py +++ b/src/packets/key_response.py @@ -7,7 +7,7 @@ import logging import struct -from message_cipher.rsa_encrypter import RsaEncrypter +import rsa from src.message_type import MessageType from src.packets.packet import Packet @@ -20,7 +20,7 @@ class KeyResponse( ): """Encode and decode public key response packets.""" - def __init__(self, public_key: RsaEncrypter) -> None: + def __init__(self, public_key: rsa.PublicKey) -> None: """Create a key response packet.""" self.public_key = public_key self.packet: bytes @@ -29,12 +29,12 @@ def to_bytes(self) -> bytes: """Encode the key response packet into a byte array.""" logging.info("Creating key response") - product = self.public_key.product.to_bytes( - (self.public_key.product.bit_length() + 7) // 8, + product = self.public_key.n.to_bytes( + (self.public_key.n.bit_length() + 7) // 8, ) - exponent = self.public_key.exponent.to_bytes( - (self.public_key.exponent.bit_length() + 7) // 8, + exponent = self.public_key.e.to_bytes( + (self.public_key.e.bit_length() + 7) // 8, ) self.packet = super().to_bytes() @@ -51,7 +51,7 @@ def to_bytes(self) -> bytes: return self.packet @classmethod - def decode_packet(cls, packet: bytes) -> tuple[RsaEncrypter]: + def decode_packet(cls, packet: bytes) -> tuple[rsa.PublicKey]: """Decode the key response packet into its individual components. :param packet: The packet to be decoded. @@ -65,6 +65,6 @@ def decode_packet(cls, packet: bytes) -> tuple[RsaEncrypter]: exponent = int.from_bytes(payload[index : index + exponent_length]) - public_key = RsaEncrypter(product, exponent) + public_key = rsa.PublicKey(product, exponent) return (public_key,) diff --git a/src/packets/login_request.py b/src/packets/login_request.py index 24c2254..28f24ea 100644 --- a/src/packets/login_request.py +++ b/src/packets/login_request.py @@ -11,17 +11,17 @@ from src.packets.packet import Packet -class LoginRequest(Packet, struct_format="!HBB", message_type=MessageType.LOGIN): +class LoginRequest(Packet, struct_format="!B", message_type=MessageType.LOGIN): """The LoginRequest class is used to encode and decode login request packets.""" def __init__(self, user_name: str) -> None: """Create a login request packet.""" self.user_name = user_name - self.packet = b"" + self.packet: bytes def to_bytes(self) -> bytes: """Encode the login request packet into a byte array.""" - logging.info("Creating log-in request as %s", self.user_name) + logging.debug("Creating log-in request as %s", self.user_name) self.packet = super().to_bytes() diff --git a/src/packets/login_response.py b/src/packets/login_response.py new file mode 100644 index 0000000..88b6353 --- /dev/null +++ b/src/packets/login_response.py @@ -0,0 +1,46 @@ +"""Login response module. + +Defines class for encoding and decoding login response packets. +""" + +import logging +import struct + +from src.message_type import MessageType +from src.packets.packet import Packet + + +class LoginResponse(Packet, struct_format="!B", message_type=MessageType.LOGIN): + """The LoginResponse class is used to encode and decode login response packets.""" + + def __init__(self, encrypted_token_bytes: bytes) -> None: + """Create a new login response packet.""" + self.token = encrypted_token_bytes + self.packet: bytes + + def to_bytes(self) -> bytes: + """Encode a login response packet into a byte array.""" + logging.info("Creating login response") + + self.packet = super().to_bytes() + + self.packet += struct.pack( + self.struct_format, + len(self.token), + ) + + self.packet += self.token + + return self.packet + + @classmethod + def decode_packet(cls, packet: bytes) -> tuple[bytes]: + """Decode a message response packet into its individual components. + + :param packet: The packet to be decoded. + :raises ValueError: If the packet is invalid. + :return: A tuple containing an encrypted session token. + """ + _, payload = cls.split_packet(packet) + + return (payload,) diff --git a/src/packets/registration_request.py b/src/packets/registration_request.py index 61a0ee5..44396bc 100644 --- a/src/packets/registration_request.py +++ b/src/packets/registration_request.py @@ -7,7 +7,7 @@ import logging import struct -from message_cipher.rsa_encrypter import RsaEncrypter +import rsa from src.message_type import MessageType from src.packets.packet import Packet @@ -20,7 +20,7 @@ class RegistrationRequest( ): """Encode and decode registration request packets.""" - def __init__(self, user_name: str, public_key: RsaEncrypter) -> None: + def __init__(self, user_name: str, public_key: rsa.PublicKey) -> None: """Create a login request packet.""" self.user_name = user_name self.public_key = public_key @@ -32,12 +32,12 @@ def to_bytes(self) -> bytes: user_name = self.user_name.encode() - product = self.public_key.product.to_bytes( - (self.public_key.product.bit_length() + 7) // 8, + product = self.public_key.n.to_bytes( + (self.public_key.n.bit_length() + 7) // 8, ) - exponent = self.public_key.exponent.to_bytes( - (self.public_key.exponent.bit_length() + 7) // 8, + exponent = self.public_key.e.to_bytes( + (self.public_key.e.bit_length() + 7) // 8, ) self.packet = super().to_bytes() @@ -56,7 +56,7 @@ def to_bytes(self) -> bytes: return self.packet @classmethod - def decode_packet(cls, packet: bytes) -> tuple[str, RsaEncrypter]: + def decode_packet(cls, packet: bytes) -> tuple[str, rsa.PublicKey]: """Decode the registration request packet into its individual components. :param packet: The packet to be decoded @@ -73,4 +73,4 @@ def decode_packet(cls, packet: bytes) -> tuple[str, RsaEncrypter]: exponent = int.from_bytes(payload[index : index + exponent_length]) - return user_name, RsaEncrypter(product, exponent) + return user_name, rsa.PublicKey(product, exponent) From 2b641bfde3fc9da9cd440d1adaa1827ed0ef1806 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Fri, 15 Nov 2024 17:05:27 +1300 Subject: [PATCH 34/48] moved parsing username and hostname functionality outside of the client class --- client/client.py | 45 ++++--------------------------------------- src/parse_hostname.py | 25 ++++++++++++++++++++++++ src/parse_username.py | 25 ++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 41 deletions(-) create mode 100644 src/parse_hostname.py create mode 100644 src/parse_username.py diff --git a/client/client.py b/client/client.py index e6a9a39..7fdfb77 100644 --- a/client/client.py +++ b/client/client.py @@ -17,6 +17,8 @@ from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse from src.packets.registration_request import RegistrationRequest +from src.parse_hostname import parse_hostname +from src.parse_username import parse_username from src.port_number import PortNumber logger = logging.getLogger(__name__) @@ -25,8 +27,6 @@ class Client(CommandLineApplication): """Send and receives messages to and from the server.""" - MAX_USERNAME_LENGTH = 255 - def __init__(self, arguments: list[str]) -> None: """Initialise the client with specified arguments. @@ -35,9 +35,9 @@ def __init__(self, arguments: list[str]) -> None: """ super().__init__( OrderedDict( - host_name=self.parse_hostname, + host_name=parse_hostname, port_number=PortNumber, - user_name=self.parse_username, + user_name=parse_username, message_type=MessageType.from_str, ), ) @@ -70,43 +70,6 @@ def __init__(self, arguments: list[str]) -> None: self.response: bytes | None = None self.session_token: bytes | None = None - @staticmethod - def parse_hostname(host_name: str) -> str: - """Parse the host name, ensuring it is valid. - - :param host_name: String representing the host name. - :return: String of the host name. - :raises ValueError: If the host name is invalid. - """ - try: - socket.getaddrinfo(host_name, 1024) - except socket.gaierror as error: - message = ( - 'Invalid host name, must be an IP address, domain name, or "localhost"' - ) - logger.debug(message, exc_info=True) - raise ValueError(message) from error - - return host_name - - @staticmethod - def parse_username(user_name: str) -> str: - """Parse the username, ensuring it is valid. - - :param user_name: String representing the username. - :return: String of the username. - :raises ValueError: If the username is invalid. - """ - if len(user_name) == 0: - logger.error("Username is empty") - raise ValueError("Username must not be empty") - - if len(user_name.encode()) > Client.MAX_USERNAME_LENGTH: - logger.error("Username consumes more than 255 bytes") - raise ValueError("Username must consume at most 255 bytes") - - return user_name - def send_request(self, request: Packet, *, expect_response: bool = True) -> bytes: """Send a message request record to the server. diff --git a/src/parse_hostname.py b/src/parse_hostname.py new file mode 100644 index 0000000..ed52279 --- /dev/null +++ b/src/parse_hostname.py @@ -0,0 +1,25 @@ +"""Host name parsing functionality to be used by server and client.""" + +import logging +import socket + +logger = logging.getLogger(__name__) + + +def parse_hostname(host_name: str) -> str: + """Parse the host name, ensuring it is valid. + + :param host_name: String representing the host name. + :return: String of the host name. + :raises ValueError: If the host name is invalid. + """ + try: + socket.getaddrinfo(host_name, 1024) + except socket.gaierror as error: + message = ( + 'Invalid host name, must be an IP address, domain name, or "localhost"' + ) + logger.debug(message, exc_info=True) + raise ValueError(message) from error + + return host_name diff --git a/src/parse_username.py b/src/parse_username.py new file mode 100644 index 0000000..97734c6 --- /dev/null +++ b/src/parse_username.py @@ -0,0 +1,25 @@ +"""username parsing functionality to be used by client.""" + +import logging + +logger = logging.getLogger(__name__) + +MAX_USERNAME_LENGTH = 255 + + +def parse_username(user_name: str) -> str: + """Parse the username, ensuring it is valid. + + :param user_name: String representing the username. + :return: String of the username. + :raises ValueError: If the username is invalid. + """ + if len(user_name) == 0: + logger.error("Username is empty") + raise ValueError("Username must not be empty") + + if len(user_name.encode()) > MAX_USERNAME_LENGTH: + logger.error("Username consumes more than 255 bytes") + raise ValueError("Username must consume at most 255 bytes") + + return user_name From 18c367814f8573d51b25afb48d457095dbc26c8f Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Fri, 15 Nov 2024 17:25:53 +1300 Subject: [PATCH 35/48] Reengineering to add session token check to read and create requests. Also removed some redundant code --- client/client.py | 118 +++++++++++++++------------- server/server.py | 45 ++++++----- src/message_type.py | 23 +++--- src/packets/create_request.py | 30 +++---- src/packets/key_request.py | 3 +- src/packets/key_response.py | 1 + src/packets/login_request.py | 1 + src/packets/login_response.py | 7 +- src/packets/packet.py | 56 +++++++++---- src/packets/read_request.py | 2 + src/packets/read_response.py | 2 +- src/packets/registration_request.py | 1 + tests/test_message_type.py | 10 ++- 13 files changed, 174 insertions(+), 125 deletions(-) diff --git a/client/client.py b/client/client.py index 7fdfb77..53b970f 100644 --- a/client/client.py +++ b/client/client.py @@ -65,9 +65,6 @@ def __init__(self, arguments: list[str]) -> None: self.user_name, ) - self.receiver_name = "" - self.message = "" - self.response: bytes | None = None self.session_token: bytes | None = None def send_request(self, request: Packet, *, expect_response: bool = True) -> bytes: @@ -109,7 +106,6 @@ def read_message_response(packet: bytes) -> None: :param packet: The message response from the server. """ - _, packet = Packet.decode_packet(packet) messages, more_messages = ReadResponse.decode_packet(packet) for sender, message in messages: @@ -120,12 +116,28 @@ def read_message_response(packet: bytes) -> None: elif more_messages: logger.info("More messages available, please send another request") - def send_read_request(self) -> None: - """Send a read request to the server.""" - request = ReadRequest(self.user_name) - self.response = self.send_request(request) + def send_read_request(self) -> bytes: + """Send a read request to the server. - self.read_message_response(self.response) + :raises RuntimeError: If the server sends an invalid response. + :return: The ReadResponse packet received from the server. + """ + if self.session_token is None: + logger.error("Please log in to request messages") + raise SystemExit + + request = ReadRequest(self.session_token, self.user_name) + response = self.send_request(request) + + message_type: MessageType + packet: bytes + message_type, packet = Packet.decode_packet(response) + + if message_type != MessageType.READ_RESPONSE: + raise RuntimeError("Incorrect type message recieved from the server.") + + self.read_message_response(packet) + return packet def send_create_request( self, @@ -137,36 +149,42 @@ def send_create_request( :param receiver_name: The name of the person to send the messag to. :param message: The message to be sent. """ + if self.session_token is None: + logger.error("Please log in before send messages") + raise SystemExit + if receiver_name is None: - self.receiver_name = input("Enter the name of the receiver: ") - else: - self.receiver_name = receiver_name + receiver_name = input("Enter the name of the receiver: ") if message is None: - self.message = input("Enter the message to be sent: ") - else: - self.message = message + message = input("Enter the message to be sent: ") logger.debug( 'User specified message to %s: "%s"', - self.receiver_name, - self.message, + receiver_name, + message, ) request = CreateRequest( - self.user_name, - self.receiver_name, - self.message, + self.session_token, + receiver_name, + message, ) self.send_request(request, expect_response=False) - def send_login_request(self) -> None: - """Send a login request to the server.""" + def send_login_request(self) -> bytes: + """Send a login request to the server. + + :raises RuntimeError: If the server sends an incorrect response. + :return: The LoginResponse packet from the server. + """ request = LoginRequest(self.user_name) response = self.send_request(request) - message_type, packet = Packet.decode_packet(response) - if message_type != MessageType.LOGIN: + message_type: MessageType + packet: bytes + message_type, _, packet = Packet.decode_packet(response) + if message_type != MessageType.LOGIN_RESPONSE: raise RuntimeError("Recieved incorrect type response from server") (encrypted_session_token,) = LoginResponse.decode_packet(packet) @@ -180,24 +198,34 @@ def send_login_request(self) -> None: logger.debug("Storing provided session token %s", self.session_token) logger.info("Now logged in as %s", self.user_name) + return packet + def send_registration_request(self) -> None: """Send a registration request to the server.""" request = RegistrationRequest(self.user_name, self.public_key) self.send_request(request, expect_response=False) - def send_key_request(self, receiver_name: str | None) -> None: - """Send a pubblic key request to the server.""" + def send_key_request(self, receiver_name: str | None) -> bytes: + """Send a pubblic key request to the server. + + :param receiver_name: The name of the user who's key should be requested. + :return: The KeyResponse packet from the server. + """ if not receiver_name: receiver_name = input("Who's key are we requesting? ") request = KeyRequest(receiver_name) response = self.send_request(request) - message_type, payload = Packet.decode_packet(response) + + message_type: MessageType + packet: bytes + message_type, _, packet = Packet.decode_packet(response) if message_type != MessageType.KEY_RESPONSE: - raise ValueError("Recieved incorrect type response from server") + logger.error("Recieved incorrect type response from server") + raise SystemExit - (public_key,) = KeyResponse.decode_packet(payload) + (public_key,) = KeyResponse.decode_packet(packet) logger.info( "Received %s's key:\n%s", @@ -205,12 +233,9 @@ def send_key_request(self, receiver_name: str | None) -> None: (public_key.n, public_key.e), ) - def run( - self, - *, - receiver_name: str | None = None, - message: str | None = None, - ) -> None: + return packet + + def run(self) -> None: """Ask the user to input message and send request to server. :param receiver_name: The name of the user to send the message to. @@ -223,7 +248,7 @@ def run( self.send_read_request() case MessageType.CREATE: - self.send_create_request(receiver_name, message) + self.send_create_request(None, None) case MessageType.LOGIN: self.send_login_request() @@ -231,25 +256,8 @@ def run( case MessageType.REGISTER: self.send_registration_request() - case MessageType.KEY_REQUEST: - self.send_key_request(receiver_name) + case MessageType.KEY: + self.send_key_request(None) case _: logger.error("Oopsies, wrong message type!") - - @property - def result(self) -> bytes: - """Get the packet received from the server. - - This property must only be used after calling ``run()`` - otherwise no response will exist! - - :raises RuntimeError: When there was no response. - Will always occur if requested before call to ``run()``. - - :return: A bytes object of the server's response. - """ - if self.response is None: - raise RuntimeError("No response! Was result requested after call to run()?") - - return self.response diff --git a/server/server.py b/server/server.py index 3f6d756..327c149 100644 --- a/server/server.py +++ b/server/server.py @@ -15,7 +15,6 @@ from src.packets.login_request import LoginRequest from src.packets.login_response import LoginResponse from src.packets.packet import Packet -from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse from src.packets.registration_request import RegistrationRequest from src.port_number import PortNumber @@ -36,15 +35,13 @@ def __init__(self, arguments: list[str]) -> None: """ super().__init__(OrderedDict(port_number=PortNumber)) - # pylint thinks that self.parse_arguments is only - # capable of returning an empty list - # pylint: disable=unbalanced-tuple-unpacking + self.port_number: PortNumber (self.port_number,) = self.parse_arguments(arguments) self.running = True self.hostname = "localhost" self.users: dict[str, rsa.PublicKey] = {} - self.sessions: dict[str, bytes] = {} + self.sessions: dict[bytes, str] = {} self.messages: dict[str, list[tuple[str, bytes]]] = {} def run(self) -> None: @@ -90,7 +87,8 @@ def generate_session_token() -> bytes: def process_read_request( self, - packet: bytes, + session_token: bytes | None, + _packet: bytes, connection_socket: socket.socket, ) -> None: """Respond to read requests. @@ -98,7 +96,10 @@ def process_read_request( :param packet: The read request packet to process. :param connection_socket: The connection socket to send the response on. """ - (sender_name,) = ReadRequest.decode_packet(packet) + if session_token not in self.sessions: + return + + sender_name = self.sessions[session_token] response = ReadResponse(self.messages.get(sender_name, [])) self.send_response(response, connection_socket) @@ -112,19 +113,21 @@ def process_read_request( def process_create_request( self, + session_token: bytes | None, packet: bytes, ) -> None: """Process create requests. - :param sender_name: The name of the user who sent the create request. - :param receiver_name: The name of the user who will receive the message. - :param message: The message to be sent. + :param session_token: The token provided by the client. + :param packet: The packet provided by the client. """ - ( - sender_name, - receiver_name, - message, - ) = CreateRequest.decode_packet(packet) + if session_token not in self.sessions: + logger.info("Received unauthenticated create request, ignoreing") + return + + sender_name = self.sessions[session_token] + + receiver_name, message = CreateRequest.decode_packet(packet) if receiver_name not in self.messages: self.messages[receiver_name] = [] @@ -155,7 +158,7 @@ def process_login_request( return session_token = self.generate_session_token() - self.sessions[sender_name] = session_token + self.sessions[session_token] = sender_name logger.debug("Gave %s the token %s", sender_name, session_token) senders_public_key = self.users[sender_name] @@ -206,14 +209,16 @@ def process_request(self, packet: bytes, connection_socket: socket.socket) -> No :param connection_socket: The socket to use for responding to read requests. """ message_type: MessageType - message_type, packet = Packet.decode_packet(packet) + session_token: bytes | None + + message_type, session_token, packet = Packet.decode_packet(packet) match message_type: case MessageType.READ: - self.process_read_request(packet, connection_socket) + self.process_read_request(session_token, packet, connection_socket) case MessageType.CREATE: - self.process_create_request(packet) + self.process_create_request(session_token, packet) case MessageType.LOGIN: self.process_login_request(packet, connection_socket) @@ -221,7 +226,7 @@ def process_request(self, packet: bytes, connection_socket: socket.socket) -> No case MessageType.REGISTER: self.process_registration_request(packet) - case MessageType.KEY_REQUEST: + case MessageType.KEY: self.process_key_request(packet, connection_socket) case _: diff --git a/src/message_type.py b/src/message_type.py index 02db173..54e4e02 100644 --- a/src/message_type.py +++ b/src/message_type.py @@ -1,19 +1,22 @@ """Home to the ``MessageType``.""" -from enum import Enum +import enum -class MessageType(Enum): +class MessageType(enum.Enum): """An enum for message types.""" - READ = 1 - CREATE = 2 - RESPONSE = 3 - LOGIN = 4 - REGISTER = 5 - MESSAGE = 6 - KEY_REQUEST = 7 - KEY_RESPONSE = 8 + REGISTER = enum.auto() + LOGIN = enum.auto() + LOGIN_RESPONSE = enum.auto() + + KEY = enum.auto() + KEY_RESPONSE = enum.auto() + CREATE = enum.auto() + + READ = enum.auto() + READ_RESPONSE = enum.auto() + MESSAGE = enum.auto() @staticmethod def from_str(string: str) -> "MessageType": diff --git a/src/packets/create_request.py b/src/packets/create_request.py index 1d42983..9ae68bf 100644 --- a/src/packets/create_request.py +++ b/src/packets/create_request.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -class CreateRequest(Packet, struct_format="!BBH", message_type=MessageType.CREATE): +class CreateRequest(Packet, struct_format="!BH", message_type=MessageType.CREATE): """Encoding and decoding of create request packets. Usage: @@ -22,7 +22,7 @@ class CreateRequest(Packet, struct_format="!BBH", message_type=MessageType.CREAT def __init__( self, - user_name: str, + session_token: bytes, receiver_name: str, message: str, ) -> None: @@ -32,7 +32,7 @@ def __init__( :param receiver_name: The name of the message recipient. :param message: The string message to be sent. """ - self.user_name = user_name + super().__init__(session_token=session_token) self.receiver_name = receiver_name self.message = message self.packet = b"" @@ -43,45 +43,36 @@ def to_bytes(self) -> bytes: :return: A byte array holding the create request. """ logger.debug( - 'Creating CREATE request to send %s the message "%s" from %s', + 'Creating create request to send %s the message "%s"', self.receiver_name, self.message, - self.user_name, ) self.packet = super().to_bytes() self.packet += struct.pack( self.struct_format, - len(self.user_name.encode()), len(self.receiver_name.encode()), len(self.message.encode()), ) - self.packet += self.user_name.encode() self.packet += self.receiver_name.encode() self.packet += self.message.encode() return self.packet @classmethod - def decode_packet(cls, packet: bytes) -> tuple[str, str, bytes]: + def decode_packet(cls, packet: bytes) -> tuple[str, bytes]: """Decode a message request packet. - :param packet: An array of bytes containing the message request + :param packet: An array of bytes containing the create request """ header_fields, payload = cls.split_packet(packet) ( - user_name_size, receiver_name_size, message_size, ) = header_fields - if user_name_size < 1: - raise ValueError( - "Received message request with insufficient user name length", - ) - if receiver_name_size < 1: raise ValueError( "Received create request with insufficient receiver name length", @@ -91,12 +82,9 @@ def decode_packet(cls, packet: bytes) -> tuple[str, str, bytes]: "Received create request with insufficient message length", ) - user_name = payload[:user_name_size].decode() - index = user_name_size - - receiver_name = payload[index : index + receiver_name_size].decode() - index += receiver_name_size + receiver_name = payload[:receiver_name_size].decode() + index = receiver_name_size message = payload[index : index + message_size] - return user_name, receiver_name, message + return receiver_name, message diff --git a/src/packets/key_request.py b/src/packets/key_request.py index 737aa78..07e1cd4 100644 --- a/src/packets/key_request.py +++ b/src/packets/key_request.py @@ -14,12 +14,13 @@ class KeyRequest( Packet, struct_format="!H", - message_type=MessageType.KEY_REQUEST, + message_type=MessageType.KEY, ): """Encode and decode public key request packets.""" def __init__(self, user_name: str) -> None: """Create a key request packet.""" + super().__init__() self.user_name = user_name self.packet: bytes diff --git a/src/packets/key_response.py b/src/packets/key_response.py index 16ea35c..8694d93 100644 --- a/src/packets/key_response.py +++ b/src/packets/key_response.py @@ -22,6 +22,7 @@ class KeyResponse( def __init__(self, public_key: rsa.PublicKey) -> None: """Create a key response packet.""" + super().__init__() self.public_key = public_key self.packet: bytes diff --git a/src/packets/login_request.py b/src/packets/login_request.py index 28f24ea..4d5b197 100644 --- a/src/packets/login_request.py +++ b/src/packets/login_request.py @@ -16,6 +16,7 @@ class LoginRequest(Packet, struct_format="!B", message_type=MessageType.LOGIN): def __init__(self, user_name: str) -> None: """Create a login request packet.""" + super().__init__() self.user_name = user_name self.packet: bytes diff --git a/src/packets/login_response.py b/src/packets/login_response.py index 88b6353..a5127fc 100644 --- a/src/packets/login_response.py +++ b/src/packets/login_response.py @@ -10,11 +10,16 @@ from src.packets.packet import Packet -class LoginResponse(Packet, struct_format="!B", message_type=MessageType.LOGIN): +class LoginResponse( + Packet, + struct_format="!B", + message_type=MessageType.LOGIN_RESPONSE, +): """The LoginResponse class is used to encode and decode login response packets.""" def __init__(self, encrypted_token_bytes: bytes) -> None: """Create a new login response packet.""" + super().__init__() self.token = encrypted_token_bytes self.packet: bytes diff --git a/src/packets/packet.py b/src/packets/packet.py index 193a7f3..6b6b22e 100644 --- a/src/packets/packet.py +++ b/src/packets/packet.py @@ -11,31 +11,45 @@ class Packet(metaclass=abc.ABCMeta): """Abstract class for all packets. - All classes inheriting ``Packet`` must specify ``struct_format`` - in their class attributes. The format of ``struct_format`` is as - described in https://docs.python.org/3/library/struct.html + All classes inheriting ``Packet`` must specify both + ``struct_format`` and ``message_type`` in their class attributes. + The format of ``struct_format`` is as described in + https://docs.python.org/3/library/struct.html. ``message_type`` must + be of type ``message_type.MessageType``. Example:: - class MyPacket(Packet, struct_format="!HBBH"): + class MyPacket(Packet, struct_format="!HBBH", message_type=MessageType.LOGIN): pass """ MAGIC_NUMBER = 0xAE73 - struct_format = "!HB" + STRUCT_FORMAT_REGEX = re.compile("^[@=<>!]?[xcbB?hHiIlLqQnNefdspP]+$") - message_type: MessageType + SESSION_TOKEN_LENGTH = 32 + + struct_format = "!HB?" - struct_format_regex = re.compile("^[@=<>!]?[xcbB?hHiIlLqQnNefdspP]+$") + message_type: MessageType @abc.abstractmethod - def __init__(self, *args: tuple[Any, ...]) -> None: + def __init__( + self, + *args: tuple[Any, ...], + session_token: bytes | None = None, + ) -> None: """Initialise the packet. :param args: All arguments needed to initialise the packet. """ - raise NotImplementedError + if ( + session_token is not None + and len(session_token) != Packet.SESSION_TOKEN_LENGTH + ): + raise ValueError("Session token is incorrect length") + + self.session_token = session_token @abc.abstractmethod def to_bytes(self) -> bytes: @@ -43,12 +57,18 @@ def to_bytes(self) -> bytes: :return: A ``bytes`` object encoding the packet's message type. """ - return struct.pack( + packet = struct.pack( Packet.struct_format, self.MAGIC_NUMBER, self.message_type.value, + self.session_token is not None, ) + if self.session_token is not None: + packet += self.session_token + + return packet + @classmethod @abc.abstractmethod def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: @@ -57,9 +77,9 @@ def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: :param packet: The packet to decode. :return: A tuple of the decoded message type and the payload. """ - header_fields: tuple[int, MessageType] + header_fields: tuple[int, MessageType, bool] header_fields, payload = Packet.split_packet(packet) - magic_number, message_type_number = header_fields + magic_number, message_type_number, has_token = header_fields if magic_number != cls.MAGIC_NUMBER: raise ValueError("Incorrect magic number found in packet") @@ -68,7 +88,15 @@ def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: except ValueError as error: raise ValueError("Invalid message type ID number") from error - return message_type, payload + session_token = None + + if has_token: + session_token, payload = ( + payload[: Packet.SESSION_TOKEN_LENGTH], + payload[Packet.SESSION_TOKEN_LENGTH :], + ) + + return message_type, session_token, payload @classmethod def split_packet(cls, packet: bytes) -> tuple[tuple[Any, ...], bytes]: @@ -106,7 +134,7 @@ def __init_subclass__( :raises ValueError: if the provided struct format is invalid. """ - if not re.match(Packet.struct_format_regex, struct_format): + if not re.match(Packet.STRUCT_FORMAT_REGEX, struct_format): raise ValueError("Invalid struct format") super().__init_subclass__() diff --git a/src/packets/read_request.py b/src/packets/read_request.py index 4754988..2e8a5b8 100644 --- a/src/packets/read_request.py +++ b/src/packets/read_request.py @@ -23,12 +23,14 @@ class ReadRequest(Packet, struct_format="!B", message_type=MessageType.READ): def __init__( self, + session_token: bytes, user_name: str, ) -> None: """Encode a read request packet. :param user_name: The name of the user sending the read request. """ + super().__init__(session_token=session_token) self.user_name = user_name self.packet: bytes diff --git a/src/packets/read_response.py b/src/packets/read_response.py index fb843ce..faed922 100644 --- a/src/packets/read_response.py +++ b/src/packets/read_response.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -class ReadResponse(Packet, struct_format="!B?", message_type=MessageType.RESPONSE): +class ReadResponse(Packet, struct_format="!B?", message_type=MessageType.READ_RESPONSE): """Enables encoding and decoding message response packets.""" MAX_MESSAGE_LENGTH = 255 diff --git a/src/packets/registration_request.py b/src/packets/registration_request.py index 44396bc..d598246 100644 --- a/src/packets/registration_request.py +++ b/src/packets/registration_request.py @@ -22,6 +22,7 @@ class RegistrationRequest( def __init__(self, user_name: str, public_key: rsa.PublicKey) -> None: """Create a login request packet.""" + super().__init__() self.user_name = user_name self.public_key = public_key self.packet: bytes diff --git a/tests/test_message_type.py b/tests/test_message_type.py index 1af3431..c03555c 100644 --- a/tests/test_message_type.py +++ b/tests/test_message_type.py @@ -26,11 +26,17 @@ def test_create_uppercase(self) -> None: def test_response_lowercase(self) -> None: """Test that response is parsed correctly.""" - self.assertEqual(MessageType.RESPONSE, MessageType.from_str("response")) + self.assertEqual( + MessageType.READ_RESPONSE, + MessageType.from_str("read_response"), + ) def test_response_uppercase(self) -> None: """Test that RESPONSE is parsed correctly.""" - self.assertEqual(MessageType.RESPONSE, MessageType.from_str("RESPONSE")) + self.assertEqual( + MessageType.READ_RESPONSE, + MessageType.from_str("READ_RESPONSE"), + ) def test_invalid(self) -> None: """Test that invalid input raises a ValueError.""" From a3696ee8959c01023099572184c27197a15ebdce Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sat, 16 Nov 2024 07:30:10 +1300 Subject: [PATCH 36/48] Made the bytes packet object a local instead of an object attribute for all packet classes --- src/packets/create_request.py | 11 +++++------ src/packets/key_request.py | 9 ++++----- src/packets/key_response.py | 11 +++++------ src/packets/login_request.py | 9 ++++----- src/packets/login_response.py | 9 ++++----- src/packets/message.py | 9 ++++----- src/packets/read_request.py | 9 ++++----- src/packets/read_response.py | 9 ++++----- src/packets/registration_request.py | 13 ++++++------- 9 files changed, 40 insertions(+), 49 deletions(-) diff --git a/src/packets/create_request.py b/src/packets/create_request.py index 9ae68bf..ec8a987 100644 --- a/src/packets/create_request.py +++ b/src/packets/create_request.py @@ -35,7 +35,6 @@ def __init__( super().__init__(session_token=session_token) self.receiver_name = receiver_name self.message = message - self.packet = b"" def to_bytes(self) -> bytes: """Return the create request packet. @@ -48,18 +47,18 @@ def to_bytes(self) -> bytes: self.message, ) - self.packet = super().to_bytes() + packet = super().to_bytes() - self.packet += struct.pack( + packet += struct.pack( self.struct_format, len(self.receiver_name.encode()), len(self.message.encode()), ) - self.packet += self.receiver_name.encode() - self.packet += self.message.encode() + packet += self.receiver_name.encode() + packet += self.message.encode() - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[str, bytes]: diff --git a/src/packets/key_request.py b/src/packets/key_request.py index 07e1cd4..561c310 100644 --- a/src/packets/key_request.py +++ b/src/packets/key_request.py @@ -22,22 +22,21 @@ def __init__(self, user_name: str) -> None: """Create a key request packet.""" super().__init__() self.user_name = user_name - self.packet: bytes def to_bytes(self) -> bytes: """Encode the key request packet into a byte array.""" logging.debug("Creating key request for %s", self.user_name) - self.packet = super().to_bytes() + packet = super().to_bytes() - self.packet += struct.pack( + packet += struct.pack( self.struct_format, len(self.user_name.encode()), ) - self.packet += self.user_name.encode() + packet += self.user_name.encode() - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[str]: diff --git a/src/packets/key_response.py b/src/packets/key_response.py index 8694d93..6f37635 100644 --- a/src/packets/key_response.py +++ b/src/packets/key_response.py @@ -24,7 +24,6 @@ def __init__(self, public_key: rsa.PublicKey) -> None: """Create a key response packet.""" super().__init__() self.public_key = public_key - self.packet: bytes def to_bytes(self) -> bytes: """Encode the key response packet into a byte array.""" @@ -38,18 +37,18 @@ def to_bytes(self) -> bytes: (self.public_key.e.bit_length() + 7) // 8, ) - self.packet = super().to_bytes() + packet = super().to_bytes() - self.packet += struct.pack( + packet += struct.pack( self.struct_format, len(product), len(exponent), ) - self.packet += product - self.packet += exponent + packet += product + packet += exponent - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[rsa.PublicKey]: diff --git a/src/packets/login_request.py b/src/packets/login_request.py index 4d5b197..74857a7 100644 --- a/src/packets/login_request.py +++ b/src/packets/login_request.py @@ -18,22 +18,21 @@ def __init__(self, user_name: str) -> None: """Create a login request packet.""" super().__init__() self.user_name = user_name - self.packet: bytes def to_bytes(self) -> bytes: """Encode the login request packet into a byte array.""" logging.debug("Creating log-in request as %s", self.user_name) - self.packet = super().to_bytes() + packet = super().to_bytes() - self.packet += struct.pack( + packet += struct.pack( self.struct_format, len(self.user_name.encode()), ) - self.packet += self.user_name.encode() + packet += self.user_name.encode() - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: diff --git a/src/packets/login_response.py b/src/packets/login_response.py index a5127fc..924e6b5 100644 --- a/src/packets/login_response.py +++ b/src/packets/login_response.py @@ -21,22 +21,21 @@ def __init__(self, encrypted_token_bytes: bytes) -> None: """Create a new login response packet.""" super().__init__() self.token = encrypted_token_bytes - self.packet: bytes def to_bytes(self) -> bytes: """Encode a login response packet into a byte array.""" logging.info("Creating login response") - self.packet = super().to_bytes() + packet = super().to_bytes() - self.packet += struct.pack( + packet += struct.pack( self.struct_format, len(self.token), ) - self.packet += self.token + packet += self.token - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[bytes]: diff --git a/src/packets/message.py b/src/packets/message.py index b9db3b1..85aac89 100644 --- a/src/packets/message.py +++ b/src/packets/message.py @@ -22,22 +22,21 @@ def __init__(self, sender_name: str, message: bytes) -> None: """ self.sender_name = sender_name self.message = message - self.packet = b"" def to_bytes(self) -> bytes: """Encode the message into bytes for transmission through a socket. :return: A ``bytes`` object encoding the message. """ - self.packet += struct.pack( + packet = struct.pack( self.struct_format, len(self.sender_name.encode()), len(self.message), ) - self.packet += self.sender_name.encode() - self.packet += self.message + packet += self.sender_name.encode() + packet += self.message - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[str, str, bytes]: diff --git a/src/packets/read_request.py b/src/packets/read_request.py index 2e8a5b8..16744e0 100644 --- a/src/packets/read_request.py +++ b/src/packets/read_request.py @@ -32,7 +32,6 @@ def __init__( """ super().__init__(session_token=session_token) self.user_name = user_name - self.packet: bytes def to_bytes(self) -> bytes: """Return the read request packet. @@ -41,16 +40,16 @@ def to_bytes(self) -> bytes: """ logger.debug("Creating READ request from %s", self.user_name) - self.packet = super().to_bytes() + packet = super().to_bytes() - self.packet += struct.pack( + packet += struct.pack( self.struct_format, len(self.user_name.encode()), ) - self.packet += self.user_name.encode() + packet += self.user_name.encode() - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[str]: diff --git a/src/packets/read_response.py b/src/packets/read_response.py index faed922..74c94fd 100644 --- a/src/packets/read_response.py +++ b/src/packets/read_response.py @@ -24,7 +24,6 @@ def __init__(self, messages: list[tuple[str, bytes]]) -> None: self.more_messages = len(messages) > ReadResponse.MAX_MESSAGE_LENGTH self.messages = messages[: self.num_messages] - self.packet = b"" def to_bytes(self) -> bytes: """Return the message response packet. @@ -33,19 +32,19 @@ def to_bytes(self) -> bytes: """ logger.debug("Creating message response for %s message(s)", self.num_messages) - self.packet = super().to_bytes() + packet = super().to_bytes() - self.packet += struct.pack( + packet += struct.pack( self.struct_format, self.num_messages, self.more_messages, ) for sender, message in self.messages: - self.packet += Message(sender, message).to_bytes() + packet += Message(sender, message).to_bytes() logger.info('Encoded message from %s: "%s"', sender, message.decode()) - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[list[tuple[str, str]], bool]: diff --git a/src/packets/registration_request.py b/src/packets/registration_request.py index d598246..62e9daa 100644 --- a/src/packets/registration_request.py +++ b/src/packets/registration_request.py @@ -25,7 +25,6 @@ def __init__(self, user_name: str, public_key: rsa.PublicKey) -> None: super().__init__() self.user_name = user_name self.public_key = public_key - self.packet: bytes def to_bytes(self) -> bytes: """Encode the registration request packet into a byte array.""" @@ -41,20 +40,20 @@ def to_bytes(self) -> bytes: (self.public_key.e.bit_length() + 7) // 8, ) - self.packet = super().to_bytes() + packet = super().to_bytes() - self.packet += struct.pack( + packet += struct.pack( self.struct_format, len(user_name), len(product), len(exponent), ) - self.packet += user_name - self.packet += product - self.packet += exponent + packet += user_name + packet += product + packet += exponent - return self.packet + return packet @classmethod def decode_packet(cls, packet: bytes) -> tuple[str, rsa.PublicKey]: From c6d3a7e82c02b15c0413a2a977405128e952c16b Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sat, 16 Nov 2024 07:45:29 +1300 Subject: [PATCH 37/48] added missing call to super().__init() --- src/packets/read_response.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/packets/read_response.py b/src/packets/read_response.py index 74c94fd..2637797 100644 --- a/src/packets/read_response.py +++ b/src/packets/read_response.py @@ -20,6 +20,7 @@ def __init__(self, messages: list[tuple[str, bytes]]) -> None: :param messages: A list of all the messages to be put in the structure. """ + super().__init__() self.num_messages = min(len(messages), ReadResponse.MAX_MESSAGE_LENGTH) self.more_messages = len(messages) > ReadResponse.MAX_MESSAGE_LENGTH From da17a9ba76c18fe4318fea85bc6a6f5375c43668 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 17 Nov 2024 10:12:01 +1300 Subject: [PATCH 38/48] refactored server.py so each process request function fits the same protocol --- server/server.py | 126 +++++++++++++++++++++++------------------------ 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/server/server.py b/server/server.py index 327c149..db3b5df 100644 --- a/server/server.py +++ b/server/server.py @@ -4,6 +4,8 @@ import secrets import socket from collections import OrderedDict +from collections.abc import Callable, Mapping +from typing import Final, TypeAlias import rsa @@ -70,16 +72,6 @@ def run(self) -> None: logger.exception(message) raise SystemExit from error - @staticmethod - def send_response(response: Packet, connection_socket: socket.socket) -> None: - """Send a response to the client's request. - - :param response: The response object. - :param connection_socket: The socket to send the response over. - """ - record = response.to_bytes() - connection_socket.send(record) - @staticmethod def generate_session_token() -> bytes: """Generate a random session token.""" @@ -87,33 +79,35 @@ def generate_session_token() -> bytes: def process_read_request( self, - session_token: bytes | None, + requestor_username: str | None, _packet: bytes, - connection_socket: socket.socket, - ) -> None: + ) -> bytes: """Respond to read requests. :param packet: The read request packet to process. :param connection_socket: The connection socket to send the response on. """ - if session_token not in self.sessions: - return - - sender_name = self.sessions[session_token] + if requestor_username is None: + logger.info( + "Received unauthenticated read request, responding without messages", + ) + return ReadResponse([]).to_bytes() - response = ReadResponse(self.messages.get(sender_name, [])) - self.send_response(response, connection_socket) + messages = self.messages.get(requestor_username, []).copy() + response = ReadResponse(messages) + del self.messages.get(requestor_username, [])[: response.num_messages] - del self.messages.get(sender_name, [])[: response.num_messages] logger.info( "%s message(s) delivered to %s", response.num_messages, - sender_name, + requestor_username, ) + return response.to_bytes() + def process_create_request( self, - session_token: bytes | None, + requestor_username: str | None, packet: bytes, ) -> None: """Process create requests. @@ -121,30 +115,28 @@ def process_create_request( :param session_token: The token provided by the client. :param packet: The packet provided by the client. """ - if session_token not in self.sessions: + if requestor_username is None: logger.info("Received unauthenticated create request, ignoreing") return - sender_name = self.sessions[session_token] - receiver_name, message = CreateRequest.decode_packet(packet) if receiver_name not in self.messages: self.messages[receiver_name] = [] - self.messages[receiver_name].append((sender_name, message)) + self.messages[receiver_name].append((requestor_username, message)) logger.info( 'Storing %s\'s message to %s: "%s"', - sender_name, + requestor_username, receiver_name, message.decode(), ) def process_login_request( self, + _requestor_username: str | None, packet: bytes, - connection_socket: socket.socket, - ) -> None: + ) -> bytes: """Process a client requset to login. :param packet: A byte array containing the login request. @@ -152,10 +144,8 @@ def process_login_request( (sender_name,) = LoginRequest.decode_packet(packet) if sender_name not in self.users: - response = LoginResponse(b"") - self.send_response(response, connection_socket) logger.info("Unregistered user %s attempted to login", sender_name) - return + return LoginResponse(b"").to_bytes() session_token = self.generate_session_token() self.sessions[session_token] = sender_name @@ -164,32 +154,37 @@ def process_login_request( senders_public_key = self.users[sender_name] encrypted_session_token = rsa.encrypt(session_token, senders_public_key) logger.debug("Encrypted token to %s", encrypted_session_token) - response = LoginResponse(encrypted_session_token) - self.send_response(response, connection_socket) - def process_registration_request(self, packet: bytes) -> None: + return LoginResponse(encrypted_session_token).to_bytes() + + def process_registration_request( + self, + _requestor_username: str | None, + packet: bytes, + ) -> None: """Process a client request to register a new name. :param packet: A byte array containing the registration request. """ sender_name, public_key = RegistrationRequest.decode_packet(packet) - if sender_name not in self.users: - self.users[sender_name] = public_key - logger.info( - "Registered %s with key %s", - sender_name, - (public_key.n, public_key.e), - ) - else: - message = f"name {sender_name} already registered" + if sender_name in self.users: + message = f"Name {sender_name} already registered" logger.error(message) + return + + self.users[sender_name] = public_key + logger.info( + "Registered %s with key %s", + sender_name, + (public_key.n, public_key.e), + ) def process_key_request( self, + _requestor_username: str | None, packet: bytes, - connection_socket: socket.socket, - ) -> None: + ) -> bytes: """Process a client request for a user's public key. :param packet: A byte array containing the key request. @@ -199,8 +194,7 @@ def process_key_request( logger.info("Received request for %s's key", requested_user) public_key = self.users[requested_user] - response = KeyResponse(public_key) - self.send_response(response, connection_socket) + return KeyResponse(public_key).to_bytes() def process_request(self, packet: bytes, connection_socket: socket.socket) -> None: """Process an incoming client request. @@ -213,24 +207,20 @@ def process_request(self, packet: bytes, connection_socket: socket.socket) -> No message_type, session_token, packet = Packet.decode_packet(packet) - match message_type: - case MessageType.READ: - self.process_read_request(session_token, packet, connection_socket) - - case MessageType.CREATE: - self.process_create_request(session_token, packet) - - case MessageType.LOGIN: - self.process_login_request(packet, connection_socket) + if session_token is not None: + requestor_username = self.sessions.get(session_token, None) + else: + requestor_username = None - case MessageType.REGISTER: - self.process_registration_request(packet) + if message_type not in PROCESS_REQUEST_MAPPING: + logging.error("Message of incorrect type received!") + return - case MessageType.KEY: - self.process_key_request(packet, connection_socket) + processor_function = PROCESS_REQUEST_MAPPING[message_type] + response = processor_function(self, requestor_username, packet) - case _: - logging.error("Message of incorrect type received!") + if response is not None: + connection_socket.send(response) def run_server(self, welcoming_socket: socket.socket) -> None: """Run the server side of the program. @@ -264,3 +254,13 @@ def stop(self) -> None: """Stop the server.""" logger.info("Stopping server.") self.running = False + + +ServerProcessFunction: TypeAlias = Callable[[Server, str | None, bytes], bytes | None] +PROCESS_REQUEST_MAPPING: Final[Mapping[MessageType, ServerProcessFunction]] = { + MessageType.REGISTER: Server.process_registration_request, + MessageType.LOGIN: Server.process_login_request, + MessageType.KEY: Server.process_key_request, + MessageType.CREATE: Server.process_create_request, + MessageType.READ: Server.process_read_request, +} From 85027802b7d6b5cb9579d6805e463a2f08a21e78 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 17 Nov 2024 10:13:29 +1300 Subject: [PATCH 39/48] fixed packet decode error in client.py --- client/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/client.py b/client/client.py index 53b970f..0477717 100644 --- a/client/client.py +++ b/client/client.py @@ -131,7 +131,7 @@ def send_read_request(self) -> bytes: message_type: MessageType packet: bytes - message_type, packet = Packet.decode_packet(response) + message_type, _, packet = Packet.decode_packet(response) if message_type != MessageType.READ_RESPONSE: raise RuntimeError("Incorrect type message recieved from the server.") @@ -150,7 +150,7 @@ def send_create_request( :param message: The message to be sent. """ if self.session_token is None: - logger.error("Please log in before send messages") + logger.error("Please log in before sending messages") raise SystemExit if receiver_name is None: From 3926aec4147ae6346c3b666a20e2c781c9f523bb Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 17 Nov 2024 10:14:29 +1300 Subject: [PATCH 40/48] Improved documentation in create_request.py --- src/packets/create_request.py | 5 ++++- src/packets/packet.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/packets/create_request.py b/src/packets/create_request.py index ec8a987..bae287a 100644 --- a/src/packets/create_request.py +++ b/src/packets/create_request.py @@ -64,7 +64,10 @@ def to_bytes(self) -> bytes: def decode_packet(cls, packet: bytes) -> tuple[str, bytes]: """Decode a message request packet. - :param packet: An array of bytes containing the create request + :param packet: An array of bytes containing the create request. + :raises ValueError: If the packet has incorrect values. + :return: A tuple containing the name of the message recipient + and the message itself. """ header_fields, payload = cls.split_packet(packet) ( diff --git a/src/packets/packet.py b/src/packets/packet.py index 6b6b22e..29694f8 100644 --- a/src/packets/packet.py +++ b/src/packets/packet.py @@ -83,6 +83,7 @@ def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: if magic_number != cls.MAGIC_NUMBER: raise ValueError("Incorrect magic number found in packet") + try: message_type = MessageType(message_type_number) except ValueError as error: From e9a79a51129fb7b75d633ad23992c6bf8c87cc2d Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Sun, 17 Nov 2024 10:15:49 +1300 Subject: [PATCH 41/48] rewrote many tests to correctly test newly reengineered code --- tests/applications/test_client.py | 10 +- tests/applications/test_server.py | 35 +++---- tests/integration/test_more_messages.py | 28 +++--- tests/packets/test_create_request.py | 98 +++++-------------- tests/packets/test_read_request.py | 14 +-- tests/packets/test_read_response.py | 10 +- tests/resources/client_input | 2 - tests/resources/names.txt | 4 +- ...e_arguments.py => test_parse_arguments.py} | 21 ++-- 9 files changed, 87 insertions(+), 135 deletions(-) delete mode 100644 tests/resources/client_input rename tests/{applications/test_client_parse_arguments.py => test_parse_arguments.py} (69%) diff --git a/tests/applications/test_client.py b/tests/applications/test_client.py index 3879ede..e68a294 100644 --- a/tests/applications/test_client.py +++ b/tests/applications/test_client.py @@ -8,6 +8,8 @@ from src.packets.create_request import CreateRequest from src.packets.packet import Packet +DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" + class TestClient(unittest.TestCase): """Test suite for Client class.""" @@ -32,7 +34,6 @@ def test_send_message_request(self) -> None: client = Client( [TestClient.hostname, str(TestClient.port_number), "Alice", "create"], ) - user_name = "Alice" receiver_name = "John" message = "Hello John" @@ -43,7 +44,7 @@ def test_send_message_request(self) -> None: # Send message request from the client client.send_request( - CreateRequest(user_name, receiver_name, message), + CreateRequest(DUMMY_SESSION_TOKEN, receiver_name, message), expect_response=False, ) @@ -55,10 +56,11 @@ def test_send_message_request(self) -> None: with connection_socket: packet = connection_socket.recv(4096) - message_type, packet = Packet.decode_packet(packet) + message_type, session_token, packet = Packet.decode_packet(packet) self.assertEqual(MessageType.CREATE, message_type) + self.assertEqual(DUMMY_SESSION_TOKEN, session_token) - expected = (user_name, receiver_name, message.encode()) + expected = (receiver_name, message.encode()) actual = CreateRequest.decode_packet(packet) self.assertEqual(expected, actual) diff --git a/tests/applications/test_server.py b/tests/applications/test_server.py index 6e31bdc..49e3e4e 100644 --- a/tests/applications/test_server.py +++ b/tests/applications/test_server.py @@ -1,13 +1,15 @@ """Server class test suite.""" -import socket import unittest from server import Server +from src.message_type import MessageType from src.packets.packet import Packet from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse +DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" + class TestServer(unittest.TestCase): """Test suite for Server class.""" @@ -30,29 +32,24 @@ def test_construction_raise_error(self) -> None: def test_process_read_request(self) -> None: """Tests that Server objects correctly responds to read requests.""" server = Server([str(TestServer.port_number)]) - receiver_name = "John" + sender_name = "Alice" + receiver_name = "John" message = b"Hello John" - server.messages[receiver_name] = [(sender_name, message)] - with socket.socket() as server_welcoming_socket: - server_welcoming_socket.bind((TestServer.hostname, TestServer.port_number)) - server_welcoming_socket.listen(1) + server.sessions = {DUMMY_SESSION_TOKEN: receiver_name} + server.messages[receiver_name] = [(sender_name, message)] - with socket.socket() as client_socket: - client_socket.connect((TestServer.hostname, TestServer.port_number)) - server_connection_socket, _ = server_welcoming_socket.accept() + request = ReadRequest(DUMMY_SESSION_TOKEN, receiver_name).to_bytes() - packet = ReadRequest(receiver_name).to_bytes() - _, packet = Packet.decode_packet(packet) + _, _, packet = Packet.decode_packet(request) - with server_connection_socket: - server.process_read_request(packet, server_connection_socket) + response = server.process_read_request(receiver_name, packet) + message_type, session_token, payload = Packet.decode_packet(response) - # Receive message from server - response_packet = client_socket.recv(1024) - _, response_packet = Packet.decode_packet(response_packet) - response = ReadResponse.decode_packet(response_packet) + messages, more_messages = ReadResponse.decode_packet(payload) - # Check that the message is correct - self.assertEqual(([(sender_name, message.decode())], False), response) + self.assertEqual(MessageType.READ_RESPONSE, message_type) + self.assertEqual(None, session_token) + self.assertEqual([(sender_name, message.decode())], messages) + self.assertEqual(False, more_messages) diff --git a/tests/integration/test_more_messages.py b/tests/integration/test_more_messages.py index 896e162..137f100 100644 --- a/tests/integration/test_more_messages.py +++ b/tests/integration/test_more_messages.py @@ -5,8 +5,6 @@ import threading import unittest -from src.packets.packet import Packet - sys.path.insert(0, "../../") import client import server @@ -25,7 +23,7 @@ class TestMoreMessages(unittest.TestCase): HOST_NAME = "localhost" PORT_NUMBER = "1024" - USER_NAME = "John" + RECIPIENT_NAME = "Recipient" def test_more_messages(self) -> None: """Tests that no more than 255 messages are sent in a read request.""" @@ -35,21 +33,25 @@ def test_more_messages(self) -> None: server_thread = threading.Thread(target=server_object.run) server_thread.start() - for name in names: - client.Client( - [self.HOST_NAME, self.PORT_NUMBER, name, "create"], - ).run(receiver_name=self.USER_NAME, message="Hello") - - final_client = client.Client( - [self.HOST_NAME, self.PORT_NUMBER, self.USER_NAME, "read"], + recipient_client = client.Client( + [self.HOST_NAME, self.PORT_NUMBER, self.RECIPIENT_NAME, "read"], ) - final_client.run() + recipient_client.send_registration_request() + + for sender_name in names: + sender_client = client.Client( + [self.HOST_NAME, self.PORT_NUMBER, sender_name, "create"], + ) + sender_client.send_registration_request() + sender_client.send_login_request() + sender_client.send_create_request(self.RECIPIENT_NAME, "Hello") + + recipient_client.send_login_request() + packet = recipient_client.send_read_request() server_object.stop() server_thread.join() - _, packet = Packet.decode_packet(final_client.result) - messages, more_messages = ReadResponse.decode_packet(packet) self.assertEqual(255, len(messages)) self.assertTrue(more_messages) diff --git a/tests/packets/test_create_request.py b/tests/packets/test_create_request.py index 0bbac63..093f0ee 100644 --- a/tests/packets/test_create_request.py +++ b/tests/packets/test_create_request.py @@ -1,118 +1,83 @@ """``CreateRequest`` class test suite.""" +import struct import unittest from src.message_type import MessageType from src.packets.create_request import CreateRequest from src.packets.packet import Packet +DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" +CREATE_PACKET_HEADER_SIZE = struct.calcsize(CreateRequest.struct_format) -class TestMessageRequestEncoding(unittest.TestCase): - """Test suite for encoding MessageRequest packets.""" - def test_user_name_length_encoding(self) -> None: - """Tests that the length of the user's name is encoded correctly.""" - user_name = "Johnny" - receiver_name = "Jarod" - message = "Hello, World!" - packet = CreateRequest( - user_name, - receiver_name, - message, - ).to_bytes() - - _, payload = Packet.decode_packet(packet) - - expected = len(user_name.encode()) - actual = payload[0] - self.assertEqual(expected, actual) +class TestCreateRequestEncoding(unittest.TestCase): + """Test suite for encoding MessageRequest packets.""" def test_receiver_name_length_encoding(self) -> None: """Tests that the length of the receiver's name is encoded correctly.""" - user_name = "Jackson" receiver_name = "Jake" message = "Hello, World!" packet = CreateRequest( - user_name, + DUMMY_SESSION_TOKEN, receiver_name, message, ).to_bytes() - _, payload = Packet.decode_packet(packet) + _, _, payload = Packet.decode_packet(packet) expected = len(receiver_name.encode()) - actual = payload[1] + actual = payload[0] self.assertEqual(expected, actual) def test_message_length_encoding(self) -> None: """Tests that the length of the message is encoded correctly.""" - user_name = "Jason" receiver_name = "Jay" message = "Hello, World!" packet = CreateRequest( - user_name, + DUMMY_SESSION_TOKEN, receiver_name, message, ).to_bytes() - _, payload = Packet.decode_packet(packet) + _, _, payload = Packet.decode_packet(packet) expected = len(message.encode()) - actual = (payload[2] << 8) | (payload[3] & 0xFF) - self.assertEqual(expected, actual) - - def test_user_name_encoding(self) -> None: - """Tests that the user's name is encoded correctly.""" - user_name = "Jason" - receiver_name = "Jay" - message = "Hello, World!" - packet = CreateRequest( - user_name, - receiver_name, - message, - ).to_bytes() - - _, payload = Packet.decode_packet(packet) - - expected = user_name - actual = payload[4 : 4 + len(user_name.encode())].decode() + actual = (payload[1] << 8) | (payload[2] & 0xFF) self.assertEqual(expected, actual) def test_receiver_name_encoding(self) -> None: """Tests that the receiver's name is encoded correctly.""" - user_name = "Jeff" receiver_name = "Jesse" message = "Hello, World!" packet = CreateRequest( - user_name, + DUMMY_SESSION_TOKEN, receiver_name, message, ).to_bytes() - _, payload = Packet.decode_packet(packet) - - start_index = 4 + len(user_name.encode()) + _, _, payload = Packet.decode_packet(packet) expected = receiver_name actual = payload[ - start_index : start_index + len(receiver_name.encode()) + CREATE_PACKET_HEADER_SIZE : CREATE_PACKET_HEADER_SIZE + + len(receiver_name.encode()) ].decode() self.assertEqual(expected, actual) def test_message_encoding(self) -> None: """Tests that the message is encoded correctly.""" - user_name = "Julian" receiver_name = "Jimmy" message = "Hello, World!" packet = CreateRequest( - user_name, + DUMMY_SESSION_TOKEN, receiver_name, message, ).to_bytes() - _, payload = Packet.decode_packet(packet) + _, _, payload = Packet.decode_packet(packet) - start_index = 4 + len(user_name.encode()) + len(receiver_name.encode()) + start_index = CREATE_PACKET_HEADER_SIZE + len(receiver_name.encode()) expected = message actual = payload[start_index : start_index + len(message.encode())].decode() @@ -130,26 +95,18 @@ def setUp(self) -> None: self.message = "Hello, World!" packet = CreateRequest( - self.user_name, + DUMMY_SESSION_TOKEN, self.receiver_name, self.message, ).to_bytes() - _, self.packet = Packet.decode_packet(packet) - - def test_user_name_decoding(self) -> None: - """Tests that the user's name is decoded correctly.""" - decoded_packet = CreateRequest.decode_packet(self.packet) - - expected = self.user_name - actual = decoded_packet[0] - self.assertEqual(expected, actual) + _, _, self.packet = Packet.decode_packet(packet) def test_receiver_name_decoding(self) -> None: """Tests that the receiver's name is decoded correctly.""" decoded_packet = CreateRequest.decode_packet(self.packet) expected = self.receiver_name - actual = decoded_packet[1] + actual = decoded_packet[0] self.assertEqual(expected, actual) def test_message_decoding(self) -> None: @@ -157,23 +114,16 @@ def test_message_decoding(self) -> None: decoded_packet = CreateRequest.decode_packet(self.packet) expected = self.message - actual = decoded_packet[2].decode() + actual = decoded_packet[1].decode() self.assertEqual(expected, actual) - def test_insufficient_user_name_length(self) -> None: - """Tests that an exception is raised if the user's name has a length of zero.""" - packet = bytearray(self.packet) - packet[3] = 0 - - self.assertRaises(ValueError, CreateRequest.decode_packet, packet) - def test_insufficient_receiver_name_length(self) -> None: """Tests that an exception is raised. If the length of the receiver's name is zero. """ packet = bytearray(self.packet) - packet[1] = 0 + packet[0] = 0 self.assertRaises(ValueError, CreateRequest.decode_packet, packet) @@ -183,7 +133,7 @@ def test_insufficient_message_length(self) -> None: If the length of the message is zero. """ packet = bytearray(self.packet) + packet[1] = 0 packet[2] = 0 - packet[3] = 0 self.assertRaises(ValueError, CreateRequest.decode_packet, packet) diff --git a/tests/packets/test_read_request.py b/tests/packets/test_read_request.py index 49ee3fc..73c5f3d 100644 --- a/tests/packets/test_read_request.py +++ b/tests/packets/test_read_request.py @@ -5,6 +5,8 @@ from src.packets.packet import Packet from src.packets.read_request import ReadRequest +DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" + class TestReadRequestEncoding(unittest.TestCase): """Test suite for encoding ReadRequest packets.""" @@ -12,8 +14,8 @@ class TestReadRequestEncoding(unittest.TestCase): def test_user_name_length_encoding(self) -> None: """Tests that the length of the user's name is encoded correctly.""" user_name = "Johnny" - packet = ReadRequest(user_name).to_bytes() - _, payload = Packet.decode_packet(packet) + packet = ReadRequest(DUMMY_SESSION_TOKEN, user_name).to_bytes() + _, _, payload = Packet.decode_packet(packet) expected = len(user_name.encode()) actual = payload[0] @@ -22,8 +24,8 @@ def test_user_name_length_encoding(self) -> None: def test_user_name_encoding(self) -> None: """Tests that the user's name is placed correctly in the packet.""" user_name = "Johnny" - packet = ReadRequest(user_name).to_bytes() - _, payload = Packet.decode_packet(packet) + packet = ReadRequest(DUMMY_SESSION_TOKEN, user_name).to_bytes() + _, _, payload = Packet.decode_packet(packet) expected = user_name actual = payload[1:].decode() @@ -37,8 +39,8 @@ def setUp(self) -> None: """Set up the testing environment.""" self.user_name = "Jamie" - packet = ReadRequest(self.user_name).to_bytes() - _, self.packet = Packet.decode_packet(packet) + packet = ReadRequest(DUMMY_SESSION_TOKEN, self.user_name).to_bytes() + _, _, self.packet = Packet.decode_packet(packet) def test_user_name_decoding(self) -> None: """Tests that the user's name is decoded correctly.""" diff --git a/tests/packets/test_read_response.py b/tests/packets/test_read_response.py index 2d0ce5c..b254fa1 100644 --- a/tests/packets/test_read_response.py +++ b/tests/packets/test_read_response.py @@ -16,7 +16,7 @@ def test_num_messages_encoding(self) -> None: ("John", b"Hello Harry!"), ] packet = ReadResponse(messages).to_bytes() - _, packet = Packet.decode_packet(packet) + _, _, packet = Packet.decode_packet(packet) expected = len(messages) actual = packet[0] @@ -26,7 +26,7 @@ def test_more_messages_encoding(self) -> None: """Tests that the more messages flag is encoded correctly.""" messages: list[tuple[str, bytes]] = [] packet = ReadResponse(messages).to_bytes() - _, packet = Packet.decode_packet(packet) + _, _, packet = Packet.decode_packet(packet) expected = False actual = packet[1] @@ -44,7 +44,7 @@ def test_messages_decoding(self) -> None: ] packet = ReadResponse(messages).to_bytes() - _, packet = Packet.decode_packet(packet) + _, _, packet = Packet.decode_packet(packet) expected = [ ("Harry", "Hello John!"), @@ -60,7 +60,7 @@ def test_more_messages_decoding_false(self) -> None: ("John", b"Hello Harry!"), ] packet = ReadResponse(messages).to_bytes() - _, packet = Packet.decode_packet(packet) + _, _, packet = Packet.decode_packet(packet) expected = False actual = ReadResponse.decode_packet(packet)[1] @@ -71,7 +71,7 @@ def test_more_messages_decoding_true(self) -> None: messages = [("Harry", b"Hello John!")] * 256 packet = ReadResponse(messages).to_bytes() - _, packet = Packet.decode_packet(packet) + _, _, packet = Packet.decode_packet(packet) expected = True actual = ReadResponse.decode_packet(packet)[1] diff --git a/tests/resources/client_input b/tests/resources/client_input deleted file mode 100644 index 7e5a168..0000000 --- a/tests/resources/client_input +++ /dev/null @@ -1,2 +0,0 @@ -John -Hello John \ No newline at end of file diff --git a/tests/resources/names.txt b/tests/resources/names.txt index d3d7af2..4013055 100644 --- a/tests/resources/names.txt +++ b/tests/resources/names.txt @@ -97,7 +97,7 @@ Kimberly Kingsley Kingston Kira -Kaleb +Kris Kurt Kyle Kylie @@ -253,4 +253,4 @@ Richard Riley River Robert -Robin \ No newline at end of file +Robin diff --git a/tests/applications/test_client_parse_arguments.py b/tests/test_parse_arguments.py similarity index 69% rename from tests/applications/test_client_parse_arguments.py rename to tests/test_parse_arguments.py index 1e5d3db..d4994f5 100644 --- a/tests/applications/test_client_parse_arguments.py +++ b/tests/test_parse_arguments.py @@ -2,7 +2,8 @@ import unittest -from client import Client +from src.parse_hostname import parse_hostname +from src.parse_username import parse_username class TestClientParseArguments(unittest.TestCase): @@ -10,36 +11,36 @@ class TestClientParseArguments(unittest.TestCase): def test_parse_host_name_localhost(self) -> None: """Test parsing localhost as hostname.""" - Client.parse_hostname("localhost") + parse_hostname("localhost") def test_parse_host_name_ip_address(self) -> None: """Test parsing a valid IP address as hostname.""" - Client.parse_hostname("1.1.1.1") + parse_hostname("1.1.1.1") def test_parse_host_name_domain_name(self) -> None: """Test parsing a domain name as hostname.""" - Client.parse_hostname("www.duckduckgo.com") + parse_hostname("www.duckduckgo.com") def test_parse_host_name_invalid(self) -> None: """Test parsing an invalid hostname.""" - self.assertRaises(ValueError, Client.parse_hostname, "invalid") + self.assertRaises(ValueError, parse_hostname, "invalid") def test_parse_host_name_invalid_ip(self) -> None: """Test parsing an invalid IP address as hostname.""" - self.assertRaises(ValueError, Client.parse_hostname, "256.0.0.1") + self.assertRaises(ValueError, parse_hostname, "256.0.0.1") def test_parse_username_min_length(self) -> None: """Test parsing a valid username with minimum length.""" - Client.parse_username("J") + parse_username("J") def test_parse_username_max_length(self) -> None: """Test parsing a valid username with maximum length.""" - Client.parse_username("J" * 255) + parse_username("J" * 255) def test_parse_username_empty(self) -> None: """Test that parsing an empty username raises a ValueError.""" - self.assertRaises(ValueError, Client.parse_username, "") + self.assertRaises(ValueError, parse_username, "") def test_parse_username_too_long(self) -> None: """Test that parsing a username that is too long raises a ValueError.""" - self.assertRaises(ValueError, Client.parse_username, "a" * 256) + self.assertRaises(ValueError, parse_username, "a" * 256) From fb2a744f5d167c55af700d7f4a22b92d3819e024 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Mon, 18 Nov 2024 07:42:17 +1300 Subject: [PATCH 42/48] Added new tests for the server, and changed code to pass them --- client/client.py | 13 +- server/server.py | 3 + src/packets/key_response.py | 15 +- tests/applications/test_server.py | 180 +++++++++++++++++++++++- tests/integration/test_more_messages.py | 1 + 5 files changed, 197 insertions(+), 15 deletions(-) diff --git a/client/client.py b/client/client.py index 0477717..f2ce21b 100644 --- a/client/client.py +++ b/client/client.py @@ -227,11 +227,14 @@ def send_key_request(self, receiver_name: str | None) -> bytes: (public_key,) = KeyResponse.decode_packet(packet) - logger.info( - "Received %s's key:\n%s", - receiver_name, - (public_key.n, public_key.e), - ) + if public_key is None: + logger.warning("The requested user is not registered") + else: + logger.info( + "Received %s's key:\n%s", + receiver_name, + (public_key.n, public_key.e), + ) return packet diff --git a/server/server.py b/server/server.py index db3b5df..1041dfa 100644 --- a/server/server.py +++ b/server/server.py @@ -193,6 +193,9 @@ def process_key_request( (requested_user,) = KeyRequest.decode_packet(packet) logger.info("Received request for %s's key", requested_user) + if requested_user not in self.users: + return KeyResponse(None).to_bytes() + public_key = self.users[requested_user] return KeyResponse(public_key).to_bytes() diff --git a/src/packets/key_response.py b/src/packets/key_response.py index 6f37635..adeb14e 100644 --- a/src/packets/key_response.py +++ b/src/packets/key_response.py @@ -20,7 +20,7 @@ class KeyResponse( ): """Encode and decode public key response packets.""" - def __init__(self, public_key: rsa.PublicKey) -> None: + def __init__(self, public_key: rsa.PublicKey | None) -> None: """Create a key response packet.""" super().__init__() self.public_key = public_key @@ -29,6 +29,12 @@ def to_bytes(self) -> bytes: """Encode the key response packet into a byte array.""" logging.info("Creating key response") + packet = super().to_bytes() + + if self.public_key is None: + packet += struct.pack(self.struct_format, 0, 0) + return packet + product = self.public_key.n.to_bytes( (self.public_key.n.bit_length() + 7) // 8, ) @@ -37,8 +43,6 @@ def to_bytes(self) -> bytes: (self.public_key.e.bit_length() + 7) // 8, ) - packet = super().to_bytes() - packet += struct.pack( self.struct_format, len(product), @@ -51,7 +55,7 @@ def to_bytes(self) -> bytes: return packet @classmethod - def decode_packet(cls, packet: bytes) -> tuple[rsa.PublicKey]: + def decode_packet(cls, packet: bytes) -> tuple[rsa.PublicKey | None]: """Decode the key response packet into its individual components. :param packet: The packet to be decoded. @@ -60,6 +64,9 @@ def decode_packet(cls, packet: bytes) -> tuple[rsa.PublicKey]: header_fields, payload = cls.split_packet(packet) product_length, exponent_length = header_fields + if product_length == 0 or exponent_length == 0: + return (None,) + product = int.from_bytes(payload[:product_length]) index = product_length diff --git a/tests/applications/test_server.py b/tests/applications/test_server.py index 49e3e4e..dc7d836 100644 --- a/tests/applications/test_server.py +++ b/tests/applications/test_server.py @@ -2,11 +2,19 @@ import unittest +import rsa + from server import Server from src.message_type import MessageType +from src.packets.create_request import CreateRequest +from src.packets.key_request import KeyRequest +from src.packets.key_response import KeyResponse +from src.packets.login_request import LoginRequest +from src.packets.login_response import LoginResponse from src.packets.packet import Packet from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse +from src.packets.registration_request import RegistrationRequest DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" @@ -29,20 +37,157 @@ def test_construction_raise_error(self) -> None: [str(TestServer.port_number), "Extra argument"], ) - def test_process_read_request(self) -> None: - """Tests that Server objects correctly responds to read requests.""" + def test_process_register_request_unused_name(self) -> None: + """Tests that the server correctly registers new users.""" + server = Server([str(TestServer.port_number)]) + + username = "John" + public_key, _ = rsa.newkeys(512) + + packet = RegistrationRequest(username, public_key).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + server.process_registration_request(None, packet) + + self.assertEqual({username: public_key}, server.users) + + def test_process_register_request_used_name(self) -> None: + """Tests that the server correctly ignores re-registers.""" + server = Server([str(TestServer.port_number)]) + + username = "John" + existing_key, _ = rsa.newkeys(512) + public_key, _ = rsa.newkeys(512) + + server.users = {username: existing_key} + + packet = RegistrationRequest(username, public_key).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + server.process_registration_request(None, packet) + + self.assertEqual({username: existing_key}, server.users) + + def test_process_login_request_registered_user(self) -> None: + """Tests that the server correctly responds to login requests.""" + server = Server([str(TestServer.port_number)]) + + username = "John" + public_key, private_key = rsa.newkeys(512) + + server.users = {username: public_key} + + packet = LoginRequest(username).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + response = server.process_login_request(None, packet) + message_type, _, payload = Packet.decode_packet(response) + + (encrypted_session_token,) = LoginResponse.decode_packet(payload) + session_token = rsa.decrypt(encrypted_session_token, private_key) + + self.assertEqual(MessageType.LOGIN_RESPONSE, message_type) + self.assertEqual(Packet.SESSION_TOKEN_LENGTH, len(session_token)) + + def test_process_login_request_unknown_user(self) -> None: + """Tests the the server responds correctly to login requests.""" + server = Server([str(TestServer.port_number)]) + + username = "John" + + packet = LoginRequest(username).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + response = server.process_login_request(None, packet) + message_type, session_token, payload = Packet.decode_packet(response) + + (encrypted_session_token,) = LoginResponse.decode_packet(payload) + + self.assertEqual(MessageType.LOGIN_RESPONSE, message_type) + self.assertEqual(None, session_token) + self.assertEqual(b"", encrypted_session_token) + + def test_process_key_request_registered_user(self) -> None: + """Tests that the server correctly responds to key request.""" + server = Server([str(TestServer.port_number)]) + + receiver_name = "John" + recipients_key, _ = rsa.newkeys(512) + + server.users = {receiver_name: recipients_key} + + packet = KeyRequest(receiver_name).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + response = server.process_key_request(None, packet) + message_type, session_token, payload = Packet.decode_packet(response) + (public_key,) = KeyResponse.decode_packet(payload) + + self.assertEqual(MessageType.KEY_RESPONSE, message_type) + self.assertEqual(None, session_token) + self.assertEqual(recipients_key, public_key) + + def test_process_key_request_unknown_user(self) -> None: + """Tests that the server correctly responds to key request.""" + server = Server([str(TestServer.port_number)]) + + receiver_name = "John" + + packet = KeyRequest(receiver_name).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + response = server.process_key_request(None, packet) + message_type, session_token, payload = Packet.decode_packet(response) + (public_key,) = KeyResponse.decode_packet(payload) + + self.assertEqual(MessageType.KEY_RESPONSE, message_type) + self.assertEqual(None, session_token) + self.assertEqual(None, public_key) + + def test_process_create_request_authorised(self) -> None: + """Tests that the server correctly stores messages in create requests.""" + server = Server([str(TestServer.port_number)]) + + sender_name = "Alice" + receiver_name = "John" + message = "Hello John" + + packet = CreateRequest(DUMMY_SESSION_TOKEN, receiver_name, message).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + server.process_create_request(sender_name, packet) + + self.assertEqual( + [(sender_name, message.encode())], + server.messages[receiver_name], + ) + + def test_process_create_request_unauthorised(self) -> None: + """Tests that the server ignores messages in unauthorised create requests.""" + server = Server([str(TestServer.port_number)]) + + receiver_name = "John" + message = "Hello John" + + packet = CreateRequest(DUMMY_SESSION_TOKEN, receiver_name, message).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + server.process_create_request(None, packet) + + self.assertEqual(None, server.messages.get(receiver_name, None)) + + def test_process_read_request_authorised(self) -> None: + """Tests that the server correctly responds to authorised read requests.""" server = Server([str(TestServer.port_number)]) sender_name = "Alice" receiver_name = "John" message = b"Hello John" - server.sessions = {DUMMY_SESSION_TOKEN: receiver_name} server.messages[receiver_name] = [(sender_name, message)] - request = ReadRequest(DUMMY_SESSION_TOKEN, receiver_name).to_bytes() - - _, _, packet = Packet.decode_packet(request) + packet = ReadRequest(DUMMY_SESSION_TOKEN, receiver_name).to_bytes() + _, _, packet = Packet.decode_packet(packet) response = server.process_read_request(receiver_name, packet) message_type, session_token, payload = Packet.decode_packet(response) @@ -53,3 +198,26 @@ def test_process_read_request(self) -> None: self.assertEqual(None, session_token) self.assertEqual([(sender_name, message.decode())], messages) self.assertEqual(False, more_messages) + + def test_process_read_request_unauthorised(self) -> None: + """Tests that the server responds correctly to unauthorised read reqeusts.""" + server = Server([str(TestServer.port_number)]) + + sender_name = "Alice" + receiver_name = "John" + message = b"Hello John" + + server.messages[receiver_name] = [(sender_name, message)] + + packet = ReadRequest(DUMMY_SESSION_TOKEN, receiver_name).to_bytes() + _, _, packet = Packet.decode_packet(packet) + + response = server.process_read_request(None, packet) + message_type, session_token, payload = Packet.decode_packet(response) + + messages, more_messages = ReadResponse.decode_packet(payload) + + self.assertEqual(MessageType.READ_RESPONSE, message_type) + self.assertEqual(None, session_token) + self.assertEqual([], messages) + self.assertEqual(False, more_messages) diff --git a/tests/integration/test_more_messages.py b/tests/integration/test_more_messages.py index 137f100..9a461fe 100644 --- a/tests/integration/test_more_messages.py +++ b/tests/integration/test_more_messages.py @@ -44,6 +44,7 @@ def test_more_messages(self) -> None: ) sender_client.send_registration_request() sender_client.send_login_request() + sender_client.send_key_request(self.RECIPIENT_NAME) sender_client.send_create_request(self.RECIPIENT_NAME, "Hello") recipient_client.send_login_request() From a43edc326d5e4adff95c453b26464011d747bb18 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Mon, 18 Nov 2024 08:11:29 +1300 Subject: [PATCH 43/48] Rearanged client functions to keep order consistent and added function map --- client/client.py | 186 ++++++++++++++++++++++++----------------------- 1 file changed, 95 insertions(+), 91 deletions(-) diff --git a/client/client.py b/client/client.py index f2ce21b..11e8971 100644 --- a/client/client.py +++ b/client/client.py @@ -3,6 +3,8 @@ import logging import socket from collections import OrderedDict +from collections.abc import Callable, Mapping +from typing import Final, TypeAlias import rsa @@ -66,6 +68,7 @@ def __init__(self, arguments: list[str]) -> None: ) self.session_token: bytes | None = None + self.key_cache: dict[str, rsa.PublicKey] = {} def send_request(self, request: Packet, *, expect_response: bool = True) -> bytes: """Send a message request record to the server. @@ -100,49 +103,76 @@ def send_request(self, request: Packet, *, expect_response: bool = True) -> byte return response - @staticmethod - def read_message_response(packet: bytes) -> None: - """Read a message response from the server. + def send_registration_request(self) -> None: + """Send a registration request to the server.""" + request = RegistrationRequest(self.user_name, self.public_key) + self.send_request(request, expect_response=False) - :param packet: The message response from the server. + def send_login_request(self) -> bytes: + """Send a login request to the server. + + :raises RuntimeError: If the server sends an incorrect response. + :return: The LoginResponse packet from the server. """ - messages, more_messages = ReadResponse.decode_packet(packet) + request = LoginRequest(self.user_name) + response = self.send_request(request) - for sender, message in messages: - logger.info("\nMessage from %s:\n%s", sender, message) + message_type: MessageType + packet: bytes + message_type, _, packet = Packet.decode_packet(response) + if message_type != MessageType.LOGIN_RESPONSE: + raise RuntimeError("Recieved incorrect type response from server") - if len(messages) == 0: - logger.info("No messages available") - elif more_messages: - logger.info("More messages available, please send another request") + (encrypted_session_token,) = LoginResponse.decode_packet(packet) + logger.debug("Received encrypted token bytes %s", encrypted_session_token) - def send_read_request(self) -> bytes: - """Send a read request to the server. + if len(encrypted_session_token) == 0: + logger.error("You are not registered! Please register before logging in") + raise SystemExit - :raises RuntimeError: If the server sends an invalid response. - :return: The ReadResponse packet received from the server. + self.session_token = rsa.decrypt(encrypted_session_token, self.__private_key) + logger.debug("Storing provided session token %s", self.session_token) + logger.info("Now logged in as %s", self.user_name) + + return packet + + def send_key_request(self, receiver_name: str | None = None) -> bytes: + """Send a pubblic key request to the server. + + :param receiver_name: The name of the user who's key should be requested. + :return: The KeyResponse packet from the server. """ - if self.session_token is None: - logger.error("Please log in to request messages") - raise SystemExit + if not receiver_name: + receiver_name = input("Who's key are we requesting? ") - request = ReadRequest(self.session_token, self.user_name) + request = KeyRequest(receiver_name) response = self.send_request(request) message_type: MessageType packet: bytes message_type, _, packet = Packet.decode_packet(response) - if message_type != MessageType.READ_RESPONSE: - raise RuntimeError("Incorrect type message recieved from the server.") + if message_type != MessageType.KEY_RESPONSE: + logger.error("Recieved incorrect type response from server") + raise SystemExit + + (public_key,) = KeyResponse.decode_packet(packet) + + if public_key is None: + logger.warning("The requested user is not registered") + else: + logger.info( + "Received %s's key:\n%s", + receiver_name, + (public_key.n, public_key.e), + ) - self.read_message_response(packet) return packet def send_create_request( self, - receiver_name: str | None, - message: str | None, + receiver_name: str | None = None, + message: str | None = None, ) -> None: """Send a create request to the server. @@ -165,6 +195,9 @@ def send_create_request( message, ) + # encrypted_message = rsa.encrypt(message.encode(), + # self.key_cache[receiver_name]) + request = CreateRequest( self.session_token, receiver_name, @@ -172,70 +205,43 @@ def send_create_request( ) self.send_request(request, expect_response=False) - def send_login_request(self) -> bytes: - """Send a login request to the server. + @staticmethod + def read_message_response(packet: bytes) -> None: + """Read a message response from the server. - :raises RuntimeError: If the server sends an incorrect response. - :return: The LoginResponse packet from the server. + :param packet: The message response from the server. """ - request = LoginRequest(self.user_name) - response = self.send_request(request) - - message_type: MessageType - packet: bytes - message_type, _, packet = Packet.decode_packet(response) - if message_type != MessageType.LOGIN_RESPONSE: - raise RuntimeError("Recieved incorrect type response from server") - - (encrypted_session_token,) = LoginResponse.decode_packet(packet) - logger.debug("Received encrypted token bytes %s", encrypted_session_token) - - if len(encrypted_session_token) == 0: - logger.error("You are not registered! Please register before logging in") - raise SystemExit - - self.session_token = rsa.decrypt(encrypted_session_token, self.__private_key) - logger.debug("Storing provided session token %s", self.session_token) - logger.info("Now logged in as %s", self.user_name) + messages, more_messages = ReadResponse.decode_packet(packet) - return packet + for sender, message in messages: + logger.info("\nMessage from %s:\n%s", sender, message) - def send_registration_request(self) -> None: - """Send a registration request to the server.""" - request = RegistrationRequest(self.user_name, self.public_key) - self.send_request(request, expect_response=False) + if len(messages) == 0: + logger.info("No messages available") + elif more_messages: + logger.info("More messages available, please send another request") - def send_key_request(self, receiver_name: str | None) -> bytes: - """Send a pubblic key request to the server. + def send_read_request(self) -> bytes: + """Send a read request to the server. - :param receiver_name: The name of the user who's key should be requested. - :return: The KeyResponse packet from the server. + :raises RuntimeError: If the server sends an invalid response. + :return: The ReadResponse packet received from the server. """ - if not receiver_name: - receiver_name = input("Who's key are we requesting? ") + if self.session_token is None: + logger.error("Please log in to request messages") + raise SystemExit - request = KeyRequest(receiver_name) + request = ReadRequest(self.session_token, self.user_name) response = self.send_request(request) message_type: MessageType packet: bytes message_type, _, packet = Packet.decode_packet(response) - if message_type != MessageType.KEY_RESPONSE: - logger.error("Recieved incorrect type response from server") - raise SystemExit - - (public_key,) = KeyResponse.decode_packet(packet) - - if public_key is None: - logger.warning("The requested user is not registered") - else: - logger.info( - "Received %s's key:\n%s", - receiver_name, - (public_key.n, public_key.e), - ) + if message_type != MessageType.READ_RESPONSE: + raise RuntimeError("Incorrect type message recieved from the server.") + self.read_message_response(packet) return packet def run(self) -> None: @@ -246,21 +252,19 @@ def run(self) -> None: :param message: The message to send. Will request from ``stdin`` if not present. Defaults to ``None``. """ - match self.message_type: - case MessageType.READ: - self.send_read_request() - - case MessageType.CREATE: - self.send_create_request(None, None) - - case MessageType.LOGIN: - self.send_login_request() - - case MessageType.REGISTER: - self.send_registration_request() - - case MessageType.KEY: - self.send_key_request(None) - - case _: - logger.error("Oopsies, wrong message type!") + if self.message_type not in SEND_REQUEST_MAPPING: + logging.error("Given message type is not a valid request!") + return + + send_function = SEND_REQUEST_MAPPING[self.message_type] + send_function(self) + + +ClientSendFunction: TypeAlias = Callable[[Client], bytes | None] +SEND_REQUEST_MAPPING: Final[Mapping[MessageType, ClientSendFunction]] = { + MessageType.REGISTER: Client.send_registration_request, + MessageType.LOGIN: Client.send_login_request, + MessageType.KEY: Client.send_key_request, + MessageType.CREATE: Client.send_create_request, + MessageType.READ: Client.send_read_request, +} From 69620cd108a67f232a074bd9f3f00226f1ddb507 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 19 Nov 2024 09:02:10 +1300 Subject: [PATCH 44/48] Moved session token code to SessionWrapper and magic number and message type code to TypeWrapper --- client/client.py | 69 +++++++++++++---------- server/server.py | 48 ++++++++++------ src/packets/create_request.py | 12 +--- src/packets/key_request.py | 12 +--- src/packets/key_response.py | 15 +---- src/packets/login_request.py | 8 +-- src/packets/login_response.py | 12 +--- src/packets/message.py | 4 +- src/packets/packet.py | 84 +++++----------------------- src/packets/read_request.py | 12 +--- src/packets/read_response.py | 8 +-- src/packets/registration_request.py | 12 +--- src/packets/session_wrapper.py | 63 +++++++++++++++++++++ src/packets/type_wrapper.py | 56 +++++++++++++++++++ tests/applications/test_client.py | 15 +++-- tests/applications/test_server.py | 64 ++++++--------------- tests/packets/test_create_request.py | 26 ++------- tests/packets/test_packet.py | 2 - tests/packets/test_read_request.py | 16 ++---- tests/packets/test_read_response.py | 8 --- 20 files changed, 259 insertions(+), 287 deletions(-) create mode 100644 src/packets/session_wrapper.py create mode 100644 src/packets/type_wrapper.py diff --git a/client/client.py b/client/client.py index 11e8971..1550c62 100644 --- a/client/client.py +++ b/client/client.py @@ -19,6 +19,8 @@ from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse from src.packets.registration_request import RegistrationRequest +from src.packets.session_wrapper import SessionWrapper +from src.packets.type_wrapper import TypeWrapper from src.parse_hostname import parse_hostname from src.parse_username import parse_username from src.port_number import PortNumber @@ -70,21 +72,33 @@ def __init__(self, arguments: list[str]) -> None: self.session_token: bytes | None = None self.key_cache: dict[str, rsa.PublicKey] = {} - def send_request(self, request: Packet, *, expect_response: bool = True) -> bytes: + def send_request( + self, + request: Packet, + message_type: MessageType, + *, + expect_response: bool = True, + ) -> tuple[MessageType, bytes] | tuple[None, None]: """Send a message request record to the server. :param request: The message request to be sent. :return: The server's response if applicable, otherwise ``None``. """ - response = b"" - packet = request.to_bytes() + response: tuple[MessageType, bytes] | tuple[None, None] = None, None + + packet = TypeWrapper( + message_type, + SessionWrapper(self.session_token, request), + ).to_bytes() + try: with socket.socket() as connection_socket: connection_socket.settimeout(1) connection_socket.connect((self.host_name, self.port_number)) connection_socket.send(packet) if expect_response: - response = connection_socket.recv(4096) + response_packet = connection_socket.recv(4096) + response = TypeWrapper.decode_packet(response_packet) except (ConnectionRefusedError, TimeoutError) as error: message = ( @@ -106,7 +120,7 @@ def send_request(self, request: Packet, *, expect_response: bool = True) -> byte def send_registration_request(self) -> None: """Send a registration request to the server.""" request = RegistrationRequest(self.user_name, self.public_key) - self.send_request(request, expect_response=False) + self.send_request(request, MessageType.REGISTER, expect_response=False) def send_login_request(self) -> bytes: """Send a login request to the server. @@ -115,15 +129,15 @@ def send_login_request(self) -> bytes: :return: The LoginResponse packet from the server. """ request = LoginRequest(self.user_name) - response = self.send_request(request) + message_type, payload = self.send_request(request, MessageType.LOGIN) + + if payload is None: + raise RuntimeError("No response received from server") - message_type: MessageType - packet: bytes - message_type, _, packet = Packet.decode_packet(response) if message_type != MessageType.LOGIN_RESPONSE: raise RuntimeError("Recieved incorrect type response from server") - (encrypted_session_token,) = LoginResponse.decode_packet(packet) + (encrypted_session_token,) = LoginResponse.decode_packet(payload) logger.debug("Received encrypted token bytes %s", encrypted_session_token) if len(encrypted_session_token) == 0: @@ -134,7 +148,7 @@ def send_login_request(self) -> bytes: logger.debug("Storing provided session token %s", self.session_token) logger.info("Now logged in as %s", self.user_name) - return packet + return payload def send_key_request(self, receiver_name: str | None = None) -> bytes: """Send a pubblic key request to the server. @@ -146,17 +160,16 @@ def send_key_request(self, receiver_name: str | None = None) -> bytes: receiver_name = input("Who's key are we requesting? ") request = KeyRequest(receiver_name) - response = self.send_request(request) + message_type, payload = self.send_request(request, MessageType.KEY) - message_type: MessageType - packet: bytes - message_type, _, packet = Packet.decode_packet(response) + if payload is None: + raise RuntimeError("No response received from the server") if message_type != MessageType.KEY_RESPONSE: logger.error("Recieved incorrect type response from server") raise SystemExit - (public_key,) = KeyResponse.decode_packet(packet) + (public_key,) = KeyResponse.decode_packet(payload) if public_key is None: logger.warning("The requested user is not registered") @@ -167,7 +180,7 @@ def send_key_request(self, receiver_name: str | None = None) -> bytes: (public_key.n, public_key.e), ) - return packet + return payload def send_create_request( self, @@ -198,12 +211,8 @@ def send_create_request( # encrypted_message = rsa.encrypt(message.encode(), # self.key_cache[receiver_name]) - request = CreateRequest( - self.session_token, - receiver_name, - message, - ) - self.send_request(request, expect_response=False) + request = CreateRequest(receiver_name, message) + self.send_request(request, MessageType.CREATE, expect_response=False) @staticmethod def read_message_response(packet: bytes) -> None: @@ -231,18 +240,18 @@ def send_read_request(self) -> bytes: logger.error("Please log in to request messages") raise SystemExit - request = ReadRequest(self.session_token, self.user_name) - response = self.send_request(request) + request = ReadRequest(self.user_name) + message_type, payload = self.send_request(request, MessageType.READ) - message_type: MessageType - packet: bytes - message_type, _, packet = Packet.decode_packet(response) + if payload is None: + raise RuntimeError("No response received from the server") if message_type != MessageType.READ_RESPONSE: raise RuntimeError("Incorrect type message recieved from the server.") - self.read_message_response(packet) - return packet + self.read_message_response(payload) + + return payload def run(self) -> None: """Ask the user to input message and send request to server. diff --git a/server/server.py b/server/server.py index 1041dfa..848c85b 100644 --- a/server/server.py +++ b/server/server.py @@ -19,6 +19,8 @@ from src.packets.packet import Packet from src.packets.read_response import ReadResponse from src.packets.registration_request import RegistrationRequest +from src.packets.session_wrapper import SessionWrapper +from src.packets.type_wrapper import TypeWrapper from src.port_number import PortNumber logger = logging.getLogger(__name__) @@ -81,7 +83,7 @@ def process_read_request( self, requestor_username: str | None, _packet: bytes, - ) -> bytes: + ) -> ReadResponse: """Respond to read requests. :param packet: The read request packet to process. @@ -91,7 +93,7 @@ def process_read_request( logger.info( "Received unauthenticated read request, responding without messages", ) - return ReadResponse([]).to_bytes() + return ReadResponse([]) messages = self.messages.get(requestor_username, []).copy() response = ReadResponse(messages) @@ -103,7 +105,7 @@ def process_read_request( requestor_username, ) - return response.to_bytes() + return response def process_create_request( self, @@ -136,7 +138,7 @@ def process_login_request( self, _requestor_username: str | None, packet: bytes, - ) -> bytes: + ) -> LoginResponse: """Process a client requset to login. :param packet: A byte array containing the login request. @@ -145,7 +147,7 @@ def process_login_request( if sender_name not in self.users: logger.info("Unregistered user %s attempted to login", sender_name) - return LoginResponse(b"").to_bytes() + return LoginResponse(b"") session_token = self.generate_session_token() self.sessions[session_token] = sender_name @@ -155,7 +157,7 @@ def process_login_request( encrypted_session_token = rsa.encrypt(session_token, senders_public_key) logger.debug("Encrypted token to %s", encrypted_session_token) - return LoginResponse(encrypted_session_token).to_bytes() + return LoginResponse(encrypted_session_token) def process_registration_request( self, @@ -184,7 +186,7 @@ def process_key_request( self, _requestor_username: str | None, packet: bytes, - ) -> bytes: + ) -> KeyResponse: """Process a client request for a user's public key. :param packet: A byte array containing the key request. @@ -194,21 +196,23 @@ def process_key_request( logger.info("Received request for %s's key", requested_user) if requested_user not in self.users: - return KeyResponse(None).to_bytes() + return KeyResponse(None) public_key = self.users[requested_user] - return KeyResponse(public_key).to_bytes() + return KeyResponse(public_key) - def process_request(self, packet: bytes, connection_socket: socket.socket) -> None: + def process_request( + self, + packet: bytes, + connection_socket: socket.socket, + ) -> None: """Process an incoming client request. :param packet: The packet received from a client. :param connection_socket: The socket to use for responding to read requests. """ - message_type: MessageType - session_token: bytes | None - - message_type, session_token, packet = Packet.decode_packet(packet) + message_type, packet = TypeWrapper.decode_packet(packet) + session_token, packet = SessionWrapper.decode_packet(packet) if session_token is not None: requestor_username = self.sessions.get(session_token, None) @@ -223,7 +227,9 @@ def process_request(self, packet: bytes, connection_socket: socket.socket) -> No response = processor_function(self, requestor_username, packet) if response is not None: - connection_socket.send(response) + response_type = REQUEST_RESPONSE_MAPPING[message_type] + packet = TypeWrapper(response_type, response).to_bytes() + connection_socket.send(packet) def run_server(self, welcoming_socket: socket.socket) -> None: """Run the server side of the program. @@ -243,8 +249,8 @@ def run_server(self, welcoming_socket: socket.socket) -> None: try: with connection_socket: - record = connection_socket.recv(4096) - self.process_request(record, connection_socket) + packet = connection_socket.recv(4096) + self.process_request(packet, connection_socket) except TimeoutError: error_message = "Timed out while waiting for message request" @@ -259,7 +265,7 @@ def stop(self) -> None: self.running = False -ServerProcessFunction: TypeAlias = Callable[[Server, str | None, bytes], bytes | None] +ServerProcessFunction: TypeAlias = Callable[[Server, str | None, bytes], Packet | None] PROCESS_REQUEST_MAPPING: Final[Mapping[MessageType, ServerProcessFunction]] = { MessageType.REGISTER: Server.process_registration_request, MessageType.LOGIN: Server.process_login_request, @@ -267,3 +273,9 @@ def stop(self) -> None: MessageType.CREATE: Server.process_create_request, MessageType.READ: Server.process_read_request, } + +REQUEST_RESPONSE_MAPPING: Final[Mapping[MessageType, MessageType]] = { + MessageType.LOGIN: MessageType.LOGIN_RESPONSE, + MessageType.KEY: MessageType.KEY_RESPONSE, + MessageType.READ: MessageType.READ_RESPONSE, +} diff --git a/src/packets/create_request.py b/src/packets/create_request.py index bae287a..03e1f96 100644 --- a/src/packets/create_request.py +++ b/src/packets/create_request.py @@ -3,14 +3,12 @@ import logging import struct -from src.message_type import MessageType - -from .packet import Packet +from src.packets.packet import Packet logger = logging.getLogger(__name__) -class CreateRequest(Packet, struct_format="!BH", message_type=MessageType.CREATE): +class CreateRequest(Packet, struct_format="!BH"): """Encoding and decoding of create request packets. Usage: @@ -22,7 +20,6 @@ class CreateRequest(Packet, struct_format="!BH", message_type=MessageType.CREATE def __init__( self, - session_token: bytes, receiver_name: str, message: str, ) -> None: @@ -32,7 +29,6 @@ def __init__( :param receiver_name: The name of the message recipient. :param message: The string message to be sent. """ - super().__init__(session_token=session_token) self.receiver_name = receiver_name self.message = message @@ -47,9 +43,7 @@ def to_bytes(self) -> bytes: self.message, ) - packet = super().to_bytes() - - packet += struct.pack( + packet = struct.pack( self.struct_format, len(self.receiver_name.encode()), len(self.message.encode()), diff --git a/src/packets/key_request.py b/src/packets/key_request.py index 561c310..2b11ed8 100644 --- a/src/packets/key_request.py +++ b/src/packets/key_request.py @@ -7,29 +7,21 @@ import logging import struct -from src.message_type import MessageType from src.packets.packet import Packet -class KeyRequest( - Packet, - struct_format="!H", - message_type=MessageType.KEY, -): +class KeyRequest(Packet, struct_format="!H"): """Encode and decode public key request packets.""" def __init__(self, user_name: str) -> None: """Create a key request packet.""" - super().__init__() self.user_name = user_name def to_bytes(self) -> bytes: """Encode the key request packet into a byte array.""" logging.debug("Creating key request for %s", self.user_name) - packet = super().to_bytes() - - packet += struct.pack( + packet = struct.pack( self.struct_format, len(self.user_name.encode()), ) diff --git a/src/packets/key_response.py b/src/packets/key_response.py index adeb14e..2f4f0e5 100644 --- a/src/packets/key_response.py +++ b/src/packets/key_response.py @@ -9,31 +9,22 @@ import rsa -from src.message_type import MessageType from src.packets.packet import Packet -class KeyResponse( - Packet, - struct_format="!HH", - message_type=MessageType.KEY_RESPONSE, -): +class KeyResponse(Packet, struct_format="!HH"): """Encode and decode public key response packets.""" def __init__(self, public_key: rsa.PublicKey | None) -> None: """Create a key response packet.""" - super().__init__() self.public_key = public_key def to_bytes(self) -> bytes: """Encode the key response packet into a byte array.""" logging.info("Creating key response") - packet = super().to_bytes() - if self.public_key is None: - packet += struct.pack(self.struct_format, 0, 0) - return packet + return struct.pack(self.struct_format, 0, 0) product = self.public_key.n.to_bytes( (self.public_key.n.bit_length() + 7) // 8, @@ -43,7 +34,7 @@ def to_bytes(self) -> bytes: (self.public_key.e.bit_length() + 7) // 8, ) - packet += struct.pack( + packet = struct.pack( self.struct_format, len(product), len(exponent), diff --git a/src/packets/login_request.py b/src/packets/login_request.py index 74857a7..7ac63de 100644 --- a/src/packets/login_request.py +++ b/src/packets/login_request.py @@ -7,25 +7,21 @@ import struct from typing import Any -from src.message_type import MessageType from src.packets.packet import Packet -class LoginRequest(Packet, struct_format="!B", message_type=MessageType.LOGIN): +class LoginRequest(Packet, struct_format="!B"): """The LoginRequest class is used to encode and decode login request packets.""" def __init__(self, user_name: str) -> None: """Create a login request packet.""" - super().__init__() self.user_name = user_name def to_bytes(self) -> bytes: """Encode the login request packet into a byte array.""" logging.debug("Creating log-in request as %s", self.user_name) - packet = super().to_bytes() - - packet += struct.pack( + packet = struct.pack( self.struct_format, len(self.user_name.encode()), ) diff --git a/src/packets/login_response.py b/src/packets/login_response.py index 924e6b5..d596ca6 100644 --- a/src/packets/login_response.py +++ b/src/packets/login_response.py @@ -6,29 +6,21 @@ import logging import struct -from src.message_type import MessageType from src.packets.packet import Packet -class LoginResponse( - Packet, - struct_format="!B", - message_type=MessageType.LOGIN_RESPONSE, -): +class LoginResponse(Packet, struct_format="!B"): """The LoginResponse class is used to encode and decode login response packets.""" def __init__(self, encrypted_token_bytes: bytes) -> None: """Create a new login response packet.""" - super().__init__() self.token = encrypted_token_bytes def to_bytes(self) -> bytes: """Encode a login response packet into a byte array.""" logging.info("Creating login response") - packet = super().to_bytes() - - packet += struct.pack( + packet = struct.pack( self.struct_format, len(self.token), ) diff --git a/src/packets/message.py b/src/packets/message.py index 85aac89..54e41a4 100644 --- a/src/packets/message.py +++ b/src/packets/message.py @@ -2,12 +2,10 @@ import struct -from src.message_type import MessageType - from .packet import Packet -class Message(Packet, struct_format="!BH", message_type=MessageType.MESSAGE): +class Message(Packet, struct_format="!BH"): """A class for encoding and decoding message packets. Message "packets" are the encoding of a single message from within diff --git a/src/packets/packet.py b/src/packets/packet.py index 29694f8..2b19584 100644 --- a/src/packets/packet.py +++ b/src/packets/packet.py @@ -5,51 +5,31 @@ import struct from typing import Any -from src.message_type import MessageType - class Packet(metaclass=abc.ABCMeta): """Abstract class for all packets. - All classes inheriting ``Packet`` must specify both - ``struct_format`` and ``message_type`` in their class attributes. - The format of ``struct_format`` is as described in - https://docs.python.org/3/library/struct.html. ``message_type`` must - be of type ``message_type.MessageType``. + All classes inheriting ``Packet`` must specify ``struct_format`` in + their class attributes. The format of ``struct_format`` is as described in + https://docs.python.org/3/library/struct.html. Example:: - class MyPacket(Packet, struct_format="!HBBH", message_type=MessageType.LOGIN): + class MyPacket(Packet, struct_format="!HBBH"): pass """ - MAGIC_NUMBER = 0xAE73 - STRUCT_FORMAT_REGEX = re.compile("^[@=<>!]?[xcbB?hHiIlLqQnNefdspP]+$") - SESSION_TOKEN_LENGTH = 32 - - struct_format = "!HB?" - - message_type: MessageType + struct_format: str @abc.abstractmethod - def __init__( - self, - *args: tuple[Any, ...], - session_token: bytes | None = None, - ) -> None: + def __init__(self, *args: tuple[Any, ...]) -> None: """Initialise the packet. :param args: All arguments needed to initialise the packet. """ - if ( - session_token is not None - and len(session_token) != Packet.SESSION_TOKEN_LENGTH - ): - raise ValueError("Session token is incorrect length") - - self.session_token = session_token + raise NotImplementedError @abc.abstractmethod def to_bytes(self) -> bytes: @@ -57,47 +37,17 @@ def to_bytes(self) -> bytes: :return: A ``bytes`` object encoding the packet's message type. """ - packet = struct.pack( - Packet.struct_format, - self.MAGIC_NUMBER, - self.message_type.value, - self.session_token is not None, - ) - - if self.session_token is not None: - packet += self.session_token - - return packet + raise NotImplementedError @classmethod @abc.abstractmethod def decode_packet(cls, packet: bytes) -> tuple[Any, ...]: - """Decode the packet into its message type and payload. + """Decode the packet into its individual fields. :param packet: The packet to decode. :return: A tuple of the decoded message type and the payload. """ - header_fields: tuple[int, MessageType, bool] - header_fields, payload = Packet.split_packet(packet) - magic_number, message_type_number, has_token = header_fields - - if magic_number != cls.MAGIC_NUMBER: - raise ValueError("Incorrect magic number found in packet") - - try: - message_type = MessageType(message_type_number) - except ValueError as error: - raise ValueError("Invalid message type ID number") from error - - session_token = None - - if has_token: - session_token, payload = ( - payload[: Packet.SESSION_TOKEN_LENGTH], - payload[Packet.SESSION_TOKEN_LENGTH :], - ) - - return message_type, session_token, payload + raise NotImplementedError @classmethod def split_packet(cls, packet: bytes) -> tuple[tuple[Any, ...], bytes]: @@ -116,22 +66,15 @@ def split_packet(cls, packet: bytes) -> tuple[tuple[Any, ...], bytes]: return header_fields, payload @classmethod - def __init_subclass__( - cls, - *, - message_type: MessageType, - struct_format: str, - ) -> None: + def __init_subclass__(cls, *, struct_format: str) -> None: """Ensure ``struct_format`` attribute is present. All subclasses of ``Packet`` must specify a ``struct_format`` - and a ``message_type`` in their class attributes. This is used - for packing and unpacking the data into a minimal package, and - to communicate what data is stored inside the packet. + in their class attributes. This is used for packing and + unpacking the data into a minimal package. :param message_type: The type of message the packet will encode. :param struct_format: The format of the packet data for the ``struct`` module. - :param kwargs: No additional kwargs will be accepted. :raises ValueError: if the provided struct format is invalid. """ @@ -140,4 +83,3 @@ def __init_subclass__( super().__init_subclass__() cls.struct_format = struct_format - cls.message_type = message_type diff --git a/src/packets/read_request.py b/src/packets/read_request.py index 16744e0..2e80b56 100644 --- a/src/packets/read_request.py +++ b/src/packets/read_request.py @@ -3,14 +3,12 @@ import logging import struct -from src.message_type import MessageType - -from .packet import Packet +from src.packets.packet import Packet logger = logging.getLogger(__name__) -class ReadRequest(Packet, struct_format="!B", message_type=MessageType.READ): +class ReadRequest(Packet, struct_format="!B"): """Encoding and decoding of read request packets. Usage: @@ -23,14 +21,12 @@ class ReadRequest(Packet, struct_format="!B", message_type=MessageType.READ): def __init__( self, - session_token: bytes, user_name: str, ) -> None: """Encode a read request packet. :param user_name: The name of the user sending the read request. """ - super().__init__(session_token=session_token) self.user_name = user_name def to_bytes(self) -> bytes: @@ -40,9 +36,7 @@ def to_bytes(self) -> bytes: """ logger.debug("Creating READ request from %s", self.user_name) - packet = super().to_bytes() - - packet += struct.pack( + packet = struct.pack( self.struct_format, len(self.user_name.encode()), ) diff --git a/src/packets/read_response.py b/src/packets/read_response.py index 2637797..7dd4384 100644 --- a/src/packets/read_response.py +++ b/src/packets/read_response.py @@ -3,14 +3,13 @@ import logging import struct -from src.message_type import MessageType from src.packets.message import Message from src.packets.packet import Packet logger = logging.getLogger(__name__) -class ReadResponse(Packet, struct_format="!B?", message_type=MessageType.READ_RESPONSE): +class ReadResponse(Packet, struct_format="!B?"): """Enables encoding and decoding message response packets.""" MAX_MESSAGE_LENGTH = 255 @@ -20,7 +19,6 @@ def __init__(self, messages: list[tuple[str, bytes]]) -> None: :param messages: A list of all the messages to be put in the structure. """ - super().__init__() self.num_messages = min(len(messages), ReadResponse.MAX_MESSAGE_LENGTH) self.more_messages = len(messages) > ReadResponse.MAX_MESSAGE_LENGTH @@ -33,9 +31,7 @@ def to_bytes(self) -> bytes: """ logger.debug("Creating message response for %s message(s)", self.num_messages) - packet = super().to_bytes() - - packet += struct.pack( + packet = struct.pack( self.struct_format, self.num_messages, self.more_messages, diff --git a/src/packets/registration_request.py b/src/packets/registration_request.py index 62e9daa..155b39f 100644 --- a/src/packets/registration_request.py +++ b/src/packets/registration_request.py @@ -9,20 +9,14 @@ import rsa -from src.message_type import MessageType from src.packets.packet import Packet -class RegistrationRequest( - Packet, - struct_format="!BHH", - message_type=MessageType.REGISTER, -): +class RegistrationRequest(Packet, struct_format="!BHH"): """Encode and decode registration request packets.""" def __init__(self, user_name: str, public_key: rsa.PublicKey) -> None: """Create a login request packet.""" - super().__init__() self.user_name = user_name self.public_key = public_key @@ -40,9 +34,7 @@ def to_bytes(self) -> bytes: (self.public_key.e.bit_length() + 7) // 8, ) - packet = super().to_bytes() - - packet += struct.pack( + packet = struct.pack( self.struct_format, len(user_name), len(product), diff --git a/src/packets/session_wrapper.py b/src/packets/session_wrapper.py new file mode 100644 index 0000000..326fcf4 --- /dev/null +++ b/src/packets/session_wrapper.py @@ -0,0 +1,63 @@ +"""Home to the ``SessionWrapper`` class.""" + +import struct + +from src.packets.packet import Packet + + +class SessionWrapper(Packet, struct_format="!?"): + """Wrapper class to prepend the client session token to packets.""" + + SESSION_TOKEN_LENGTH = 32 + + def __init__(self, session_token: bytes | None, payload: Packet) -> None: + """Initialise the packet. + + :param session_token: The client's current session token. + """ + if ( + session_token is not None + and len(session_token) != SessionWrapper.SESSION_TOKEN_LENGTH + ): + raise ValueError("Session token is incorrect length") + + self.session_token = session_token + self.payload = payload + + def to_bytes(self) -> bytes: + """Convert the packet into a ``bytes`` object. + + :return: A ``bytes`` object encoding the packet and it's session. + """ + packet = struct.pack( + SessionWrapper.struct_format, + self.session_token is not None, + ) + + if self.session_token is not None: + packet += self.session_token + + packet += self.payload.to_bytes() + + return packet + + @classmethod + def decode_packet(cls, packet: bytes) -> tuple[bytes | None, bytes]: + """Decode the packet into its token and payload. + + :param packet: The packet to decode. + :return: A tuple of the session token and payload. + """ + header_fields: tuple[bool] + header_fields, payload = cls.split_packet(packet) + (has_token,) = header_fields + + session_token = None + + if has_token: + session_token, payload = ( + payload[: SessionWrapper.SESSION_TOKEN_LENGTH], + payload[SessionWrapper.SESSION_TOKEN_LENGTH :], + ) + + return session_token, payload diff --git a/src/packets/type_wrapper.py b/src/packets/type_wrapper.py new file mode 100644 index 0000000..125849f --- /dev/null +++ b/src/packets/type_wrapper.py @@ -0,0 +1,56 @@ +"""Home to the ``TypeWrapper`` class.""" + +import struct + +from src.message_type import MessageType +from src.packets.packet import Packet + + +class TypeWrapper(Packet, struct_format="!HB"): + """Wrapper class to prepend the message type to packets.""" + + MAGIC_NUMBER = 0xAE73 + + def __init__(self, message_type: MessageType, payload: Packet) -> None: + """Initialise the packet. + + :param message_type: The type of the payload packet. + """ + self.payload = payload + self.message_type = message_type + + def to_bytes(self) -> bytes: + """Convert the packet into a ``bytes`` object. + + :return: A ``bytes`` object encoding the packet and it's type. + """ + packet = struct.pack( + self.struct_format, + self.MAGIC_NUMBER, + self.message_type.value, + ) + + packet += self.payload.to_bytes() + + return packet + + @classmethod + def decode_packet(cls, packet: bytes) -> tuple[MessageType, bytes]: + """Decode the packet into its type and payload. + + :param packet: The packet to decode. + :return: A tuple of the message type and payload. + """ + header_fields: tuple[int, MessageType] + header_fields, payload = cls.split_packet(packet) + magic_number, message_type_number = header_fields + + if magic_number != TypeWrapper.MAGIC_NUMBER: + raise ValueError("Incorrect magic number in packet") + + try: + message_type = MessageType(message_type_number) + except ValueError as error: + raise ValueError("Invalid message type ID number") from error + + return message_type, payload diff --git a/tests/applications/test_client.py b/tests/applications/test_client.py index e68a294..1af47d8 100644 --- a/tests/applications/test_client.py +++ b/tests/applications/test_client.py @@ -6,7 +6,8 @@ from client import Client from src.message_type import MessageType from src.packets.create_request import CreateRequest -from src.packets.packet import Packet +from src.packets.session_wrapper import SessionWrapper +from src.packets.type_wrapper import TypeWrapper DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" @@ -29,8 +30,8 @@ def test_construction_raise_error(self) -> None: ["invalid", str(TestClient.port_number), "Alice", "create"], ) - def test_send_message_request(self) -> None: - """Tests that a Client object can send a message request.""" + def test_send_create_request(self) -> None: + """Tests that a Client object can send a create request.""" client = Client( [TestClient.hostname, str(TestClient.port_number), "Alice", "create"], ) @@ -42,9 +43,12 @@ def test_send_message_request(self) -> None: welcoming_socket.bind((TestClient.hostname, TestClient.port_number)) welcoming_socket.listen(1) + client.session_token = DUMMY_SESSION_TOKEN + # Send message request from the client client.send_request( - CreateRequest(DUMMY_SESSION_TOKEN, receiver_name, message), + CreateRequest(receiver_name, message), + MessageType.CREATE, expect_response=False, ) @@ -56,7 +60,8 @@ def test_send_message_request(self) -> None: with connection_socket: packet = connection_socket.recv(4096) - message_type, session_token, packet = Packet.decode_packet(packet) + message_type, packet = TypeWrapper.decode_packet(packet) + session_token, packet = SessionWrapper.decode_packet(packet) self.assertEqual(MessageType.CREATE, message_type) self.assertEqual(DUMMY_SESSION_TOKEN, session_token) diff --git a/tests/applications/test_server.py b/tests/applications/test_server.py index dc7d836..617a177 100644 --- a/tests/applications/test_server.py +++ b/tests/applications/test_server.py @@ -5,16 +5,15 @@ import rsa from server import Server -from src.message_type import MessageType from src.packets.create_request import CreateRequest from src.packets.key_request import KeyRequest from src.packets.key_response import KeyResponse from src.packets.login_request import LoginRequest from src.packets.login_response import LoginResponse -from src.packets.packet import Packet from src.packets.read_request import ReadRequest from src.packets.read_response import ReadResponse from src.packets.registration_request import RegistrationRequest +from src.packets.session_wrapper import SessionWrapper DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" @@ -45,7 +44,6 @@ def test_process_register_request_unused_name(self) -> None: public_key, _ = rsa.newkeys(512) packet = RegistrationRequest(username, public_key).to_bytes() - _, _, packet = Packet.decode_packet(packet) server.process_registration_request(None, packet) @@ -62,7 +60,6 @@ def test_process_register_request_used_name(self) -> None: server.users = {username: existing_key} packet = RegistrationRequest(username, public_key).to_bytes() - _, _, packet = Packet.decode_packet(packet) server.process_registration_request(None, packet) @@ -78,16 +75,13 @@ def test_process_login_request_registered_user(self) -> None: server.users = {username: public_key} packet = LoginRequest(username).to_bytes() - _, _, packet = Packet.decode_packet(packet) - response = server.process_login_request(None, packet) - message_type, _, payload = Packet.decode_packet(response) + response = server.process_login_request(None, packet).to_bytes() - (encrypted_session_token,) = LoginResponse.decode_packet(payload) + (encrypted_session_token,) = LoginResponse.decode_packet(response) session_token = rsa.decrypt(encrypted_session_token, private_key) - self.assertEqual(MessageType.LOGIN_RESPONSE, message_type) - self.assertEqual(Packet.SESSION_TOKEN_LENGTH, len(session_token)) + self.assertEqual(SessionWrapper.SESSION_TOKEN_LENGTH, len(session_token)) def test_process_login_request_unknown_user(self) -> None: """Tests the the server responds correctly to login requests.""" @@ -96,15 +90,11 @@ def test_process_login_request_unknown_user(self) -> None: username = "John" packet = LoginRequest(username).to_bytes() - _, _, packet = Packet.decode_packet(packet) - response = server.process_login_request(None, packet) - message_type, session_token, payload = Packet.decode_packet(response) + response_packet = server.process_login_request(None, packet).to_bytes() - (encrypted_session_token,) = LoginResponse.decode_packet(payload) + (encrypted_session_token,) = LoginResponse.decode_packet(response_packet) - self.assertEqual(MessageType.LOGIN_RESPONSE, message_type) - self.assertEqual(None, session_token) self.assertEqual(b"", encrypted_session_token) def test_process_key_request_registered_user(self) -> None: @@ -117,14 +107,10 @@ def test_process_key_request_registered_user(self) -> None: server.users = {receiver_name: recipients_key} packet = KeyRequest(receiver_name).to_bytes() - _, _, packet = Packet.decode_packet(packet) - response = server.process_key_request(None, packet) - message_type, session_token, payload = Packet.decode_packet(response) - (public_key,) = KeyResponse.decode_packet(payload) + response_packet = server.process_key_request(None, packet).to_bytes() + (public_key,) = KeyResponse.decode_packet(response_packet) - self.assertEqual(MessageType.KEY_RESPONSE, message_type) - self.assertEqual(None, session_token) self.assertEqual(recipients_key, public_key) def test_process_key_request_unknown_user(self) -> None: @@ -134,14 +120,10 @@ def test_process_key_request_unknown_user(self) -> None: receiver_name = "John" packet = KeyRequest(receiver_name).to_bytes() - _, _, packet = Packet.decode_packet(packet) - response = server.process_key_request(None, packet) - message_type, session_token, payload = Packet.decode_packet(response) - (public_key,) = KeyResponse.decode_packet(payload) + response_packet = server.process_key_request(None, packet).to_bytes() + (public_key,) = KeyResponse.decode_packet(response_packet) - self.assertEqual(MessageType.KEY_RESPONSE, message_type) - self.assertEqual(None, session_token) self.assertEqual(None, public_key) def test_process_create_request_authorised(self) -> None: @@ -152,8 +134,7 @@ def test_process_create_request_authorised(self) -> None: receiver_name = "John" message = "Hello John" - packet = CreateRequest(DUMMY_SESSION_TOKEN, receiver_name, message).to_bytes() - _, _, packet = Packet.decode_packet(packet) + packet = CreateRequest(receiver_name, message).to_bytes() server.process_create_request(sender_name, packet) @@ -169,8 +150,7 @@ def test_process_create_request_unauthorised(self) -> None: receiver_name = "John" message = "Hello John" - packet = CreateRequest(DUMMY_SESSION_TOKEN, receiver_name, message).to_bytes() - _, _, packet = Packet.decode_packet(packet) + packet = CreateRequest(receiver_name, message).to_bytes() server.process_create_request(None, packet) @@ -186,16 +166,12 @@ def test_process_read_request_authorised(self) -> None: server.messages[receiver_name] = [(sender_name, message)] - packet = ReadRequest(DUMMY_SESSION_TOKEN, receiver_name).to_bytes() - _, _, packet = Packet.decode_packet(packet) + packet = ReadRequest(receiver_name).to_bytes() - response = server.process_read_request(receiver_name, packet) - message_type, session_token, payload = Packet.decode_packet(response) + response_packet = server.process_read_request(receiver_name, packet).to_bytes() - messages, more_messages = ReadResponse.decode_packet(payload) + messages, more_messages = ReadResponse.decode_packet(response_packet) - self.assertEqual(MessageType.READ_RESPONSE, message_type) - self.assertEqual(None, session_token) self.assertEqual([(sender_name, message.decode())], messages) self.assertEqual(False, more_messages) @@ -209,15 +185,11 @@ def test_process_read_request_unauthorised(self) -> None: server.messages[receiver_name] = [(sender_name, message)] - packet = ReadRequest(DUMMY_SESSION_TOKEN, receiver_name).to_bytes() - _, _, packet = Packet.decode_packet(packet) + packet = ReadRequest(receiver_name).to_bytes() - response = server.process_read_request(None, packet) - message_type, session_token, payload = Packet.decode_packet(response) + response_packet = server.process_read_request(None, packet).to_bytes() - messages, more_messages = ReadResponse.decode_packet(payload) + messages, more_messages = ReadResponse.decode_packet(response_packet) - self.assertEqual(MessageType.READ_RESPONSE, message_type) - self.assertEqual(None, session_token) self.assertEqual([], messages) self.assertEqual(False, more_messages) diff --git a/tests/packets/test_create_request.py b/tests/packets/test_create_request.py index 093f0ee..8d06603 100644 --- a/tests/packets/test_create_request.py +++ b/tests/packets/test_create_request.py @@ -5,9 +5,7 @@ from src.message_type import MessageType from src.packets.create_request import CreateRequest -from src.packets.packet import Packet -DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" CREATE_PACKET_HEADER_SIZE = struct.calcsize(CreateRequest.struct_format) @@ -19,15 +17,12 @@ def test_receiver_name_length_encoding(self) -> None: receiver_name = "Jake" message = "Hello, World!" packet = CreateRequest( - DUMMY_SESSION_TOKEN, receiver_name, message, ).to_bytes() - _, _, payload = Packet.decode_packet(packet) - expected = len(receiver_name.encode()) - actual = payload[0] + actual = packet[0] self.assertEqual(expected, actual) def test_message_length_encoding(self) -> None: @@ -35,15 +30,12 @@ def test_message_length_encoding(self) -> None: receiver_name = "Jay" message = "Hello, World!" packet = CreateRequest( - DUMMY_SESSION_TOKEN, receiver_name, message, ).to_bytes() - _, _, payload = Packet.decode_packet(packet) - expected = len(message.encode()) - actual = (payload[1] << 8) | (payload[2] & 0xFF) + actual = (packet[1] << 8) | (packet[2] & 0xFF) self.assertEqual(expected, actual) def test_receiver_name_encoding(self) -> None: @@ -51,15 +43,12 @@ def test_receiver_name_encoding(self) -> None: receiver_name = "Jesse" message = "Hello, World!" packet = CreateRequest( - DUMMY_SESSION_TOKEN, receiver_name, message, ).to_bytes() - _, _, payload = Packet.decode_packet(packet) - expected = receiver_name - actual = payload[ + actual = packet[ CREATE_PACKET_HEADER_SIZE : CREATE_PACKET_HEADER_SIZE + len(receiver_name.encode()) ].decode() @@ -70,17 +59,14 @@ def test_message_encoding(self) -> None: receiver_name = "Jimmy" message = "Hello, World!" packet = CreateRequest( - DUMMY_SESSION_TOKEN, receiver_name, message, ).to_bytes() - _, _, payload = Packet.decode_packet(packet) - start_index = CREATE_PACKET_HEADER_SIZE + len(receiver_name.encode()) expected = message - actual = payload[start_index : start_index + len(message.encode())].decode() + actual = packet[start_index : start_index + len(message.encode())].decode() self.assertEqual(expected, actual) @@ -94,12 +80,10 @@ def setUp(self) -> None: self.receiver_name = "Jonty" self.message = "Hello, World!" - packet = CreateRequest( - DUMMY_SESSION_TOKEN, + self.packet = CreateRequest( self.receiver_name, self.message, ).to_bytes() - _, _, self.packet = Packet.decode_packet(packet) def test_receiver_name_decoding(self) -> None: """Tests that the receiver's name is decoded correctly.""" diff --git a/tests/packets/test_packet.py b/tests/packets/test_packet.py index 849d415..ef9858d 100644 --- a/tests/packets/test_packet.py +++ b/tests/packets/test_packet.py @@ -3,7 +3,6 @@ import unittest from typing import Any -from src.message_type import MessageType from src.packets.packet import Packet @@ -21,7 +20,6 @@ def test_fail_subclass(self) -> None: class NoStructFormat( Packet, struct_format="invalid format", - message_type=MessageType.LOGIN, ): """Invalid ``struct_format`` passed, so class will not be created.""" diff --git a/tests/packets/test_read_request.py b/tests/packets/test_read_request.py index 73c5f3d..db52df9 100644 --- a/tests/packets/test_read_request.py +++ b/tests/packets/test_read_request.py @@ -2,11 +2,8 @@ import unittest -from src.packets.packet import Packet from src.packets.read_request import ReadRequest -DUMMY_SESSION_TOKEN = b"01234567890123456789012345678901" - class TestReadRequestEncoding(unittest.TestCase): """Test suite for encoding ReadRequest packets.""" @@ -14,21 +11,19 @@ class TestReadRequestEncoding(unittest.TestCase): def test_user_name_length_encoding(self) -> None: """Tests that the length of the user's name is encoded correctly.""" user_name = "Johnny" - packet = ReadRequest(DUMMY_SESSION_TOKEN, user_name).to_bytes() - _, _, payload = Packet.decode_packet(packet) + packet = ReadRequest(user_name).to_bytes() expected = len(user_name.encode()) - actual = payload[0] + actual = packet[0] self.assertEqual(expected, actual) def test_user_name_encoding(self) -> None: """Tests that the user's name is placed correctly in the packet.""" user_name = "Johnny" - packet = ReadRequest(DUMMY_SESSION_TOKEN, user_name).to_bytes() - _, _, payload = Packet.decode_packet(packet) + packet = ReadRequest(user_name).to_bytes() expected = user_name - actual = payload[1:].decode() + actual = packet[1:].decode() self.assertEqual(expected, actual) @@ -39,8 +34,7 @@ def setUp(self) -> None: """Set up the testing environment.""" self.user_name = "Jamie" - packet = ReadRequest(DUMMY_SESSION_TOKEN, self.user_name).to_bytes() - _, _, self.packet = Packet.decode_packet(packet) + self.packet = ReadRequest(self.user_name).to_bytes() def test_user_name_decoding(self) -> None: """Tests that the user's name is decoded correctly.""" diff --git a/tests/packets/test_read_response.py b/tests/packets/test_read_response.py index b254fa1..b04ad17 100644 --- a/tests/packets/test_read_response.py +++ b/tests/packets/test_read_response.py @@ -2,7 +2,6 @@ import unittest -from src.packets.packet import Packet from src.packets.read_response import ReadResponse @@ -16,7 +15,6 @@ def test_num_messages_encoding(self) -> None: ("John", b"Hello Harry!"), ] packet = ReadResponse(messages).to_bytes() - _, _, packet = Packet.decode_packet(packet) expected = len(messages) actual = packet[0] @@ -26,7 +24,6 @@ def test_more_messages_encoding(self) -> None: """Tests that the more messages flag is encoded correctly.""" messages: list[tuple[str, bytes]] = [] packet = ReadResponse(messages).to_bytes() - _, _, packet = Packet.decode_packet(packet) expected = False actual = packet[1] @@ -44,8 +41,6 @@ def test_messages_decoding(self) -> None: ] packet = ReadResponse(messages).to_bytes() - _, _, packet = Packet.decode_packet(packet) - expected = [ ("Harry", "Hello John!"), ("John", "Hello Harry!"), @@ -60,7 +55,6 @@ def test_more_messages_decoding_false(self) -> None: ("John", b"Hello Harry!"), ] packet = ReadResponse(messages).to_bytes() - _, _, packet = Packet.decode_packet(packet) expected = False actual = ReadResponse.decode_packet(packet)[1] @@ -71,8 +65,6 @@ def test_more_messages_decoding_true(self) -> None: messages = [("Harry", b"Hello John!")] * 256 packet = ReadResponse(messages).to_bytes() - _, _, packet = Packet.decode_packet(packet) - expected = True actual = ReadResponse.decode_packet(packet)[1] self.assertEqual(expected, actual) From 1c200cb36e683638c9222973a10baaf8f79806f1 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 19 Nov 2024 09:22:23 +1300 Subject: [PATCH 45/48] updated codecov-action version --- .github/workflows/unittests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 2ebdef9..c651825 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -35,6 +35,6 @@ jobs: coverage xml - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v5.0.2 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} From 63c20926cb78b27c513f19030fe5b6fc09ab61d8 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 19 Nov 2024 11:36:52 +1300 Subject: [PATCH 46/48] removed message type from client's command line arguments and enabled infinite loop of making requests --- client/client.py | 52 ++++++++++++++++++------- tests/applications/test_client.py | 10 ++--- tests/integration/test_more_messages.py | 4 +- 3 files changed, 43 insertions(+), 23 deletions(-) diff --git a/client/client.py b/client/client.py index 1550c62..adbaee6 100644 --- a/client/client.py +++ b/client/client.py @@ -35,30 +35,26 @@ def __init__(self, arguments: list[str]) -> None: """Initialise the client with specified arguments. :param arguments: A list containing the host name, port number, - username, and message type. + and username. """ super().__init__( OrderedDict( host_name=parse_hostname, port_number=PortNumber, user_name=parse_username, - message_type=MessageType.from_str, ), ) - parsed_arguments: tuple[str, PortNumber, str, MessageType] + parsed_arguments: tuple[str, PortNumber, str] parsed_arguments = self.parse_arguments(arguments) - self.host_name, self.port_number, self.user_name, self.message_type = ( - parsed_arguments - ) + self.host_name, self.port_number, self.user_name = parsed_arguments logger.debug( - "Client for %s port %s created by %s to send %s request", + "Client for %s port %s created by %s", self.host_name, self.port_number, self.user_name, - self.message_type.name.lower(), ) self.public_key, self.__private_key = rsa.newkeys(512) @@ -111,7 +107,7 @@ def send_request( logger.info( "%s record sent as %s", - self.message_type.name.capitalize(), + message_type.name.capitalize(), self.user_name, ) @@ -261,12 +257,40 @@ def run(self) -> None: :param message: The message to send. Will request from ``stdin`` if not present. Defaults to ``None``. """ - if self.message_type not in SEND_REQUEST_MAPPING: - logging.error("Given message type is not a valid request!") - return + help_text = ( + "'register': Register your name and public key with the server.\n" + "'login': Get a token from the server for sending and receiving messages\n" + "'key': Request a user's public key. Currently not useful.\n" + "'create': Send a message to another user.\n" + "'read': Get all messages sent to you.\n" + "'help': Show this message.\n" + "'exit': Quit the application.\n" + ) + + logger.info(help_text) + + while True: + user_input = input("Please enter a request type: ") + + if user_input == "exit": + return + + if user_input == "help": + logger.info(help_text) + continue + + try: + message_type = MessageType.from_str(user_input) + except ValueError: + logger.warning("Invalid message type") + continue + + if message_type not in SEND_REQUEST_MAPPING: + logging.warning("Given message type is not a valid request!") + continue - send_function = SEND_REQUEST_MAPPING[self.message_type] - send_function(self) + send_function = SEND_REQUEST_MAPPING[message_type] + send_function(self) ClientSendFunction: TypeAlias = Callable[[Client], bytes | None] diff --git a/tests/applications/test_client.py b/tests/applications/test_client.py index 1af47d8..58bce28 100644 --- a/tests/applications/test_client.py +++ b/tests/applications/test_client.py @@ -20,7 +20,7 @@ class TestClient(unittest.TestCase): def test_construction(self) -> None: """Tests that a Client object can be constructed given correct arguments.""" - Client([TestClient.hostname, str(TestClient.port_number), "Alice", "create"]) + Client([TestClient.hostname, str(TestClient.port_number), "Alice"]) def test_construction_raise_error(self) -> None: """Tests that a Client object cannot be constructed given invalid arguments.""" @@ -33,7 +33,7 @@ def test_construction_raise_error(self) -> None: def test_send_create_request(self) -> None: """Tests that a Client object can send a create request.""" client = Client( - [TestClient.hostname, str(TestClient.port_number), "Alice", "create"], + [TestClient.hostname, str(TestClient.port_number), "Alice"], ) receiver_name = "John" message = "Hello John" @@ -46,11 +46,7 @@ def test_send_create_request(self) -> None: client.session_token = DUMMY_SESSION_TOKEN # Send message request from the client - client.send_request( - CreateRequest(receiver_name, message), - MessageType.CREATE, - expect_response=False, - ) + client.send_create_request(receiver_name, message) # Accept connection from the client connection_socket, _ = welcoming_socket.accept() diff --git a/tests/integration/test_more_messages.py b/tests/integration/test_more_messages.py index 9a461fe..b5a3925 100644 --- a/tests/integration/test_more_messages.py +++ b/tests/integration/test_more_messages.py @@ -34,13 +34,13 @@ def test_more_messages(self) -> None: server_thread.start() recipient_client = client.Client( - [self.HOST_NAME, self.PORT_NUMBER, self.RECIPIENT_NAME, "read"], + [self.HOST_NAME, self.PORT_NUMBER, self.RECIPIENT_NAME], ) recipient_client.send_registration_request() for sender_name in names: sender_client = client.Client( - [self.HOST_NAME, self.PORT_NUMBER, sender_name, "create"], + [self.HOST_NAME, self.PORT_NUMBER, sender_name], ) sender_client.send_registration_request() sender_client.send_login_request() From 949a3276bc26607e8fdb5e58235e9e11286e24a7 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 19 Nov 2024 11:42:27 +1300 Subject: [PATCH 47/48] added rsa to dependancies in pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 538114d..9f326fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ authors = [{ name = "Harry Parkes", email = "harrydparkes@proton.me" }] description = "Sever-Client program pair capable of delivering messages between clients" readme = "README.md" requires-python = ">=3.10" +dependencies = ["rsa"] classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: GNU Affero General Public License v3", From bfc2850d0aa496e87f11c70592e921c2ea04f1a6 Mon Sep 17 00:00:00 2001 From: Harry Parkes Date: Tue, 19 Nov 2024 11:44:33 +1300 Subject: [PATCH 48/48] unittest CI script now installs rsa --- .github/workflows/unittests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index c651825..62ec5ec 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -27,7 +27,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install coverage + pip install coverage rsa - name: Run all unit tests run: |