Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type hints around crypto module #1714

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions asyncua/client/ua_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _process_received_data(self, data: bytes) -> None:
return
msg = self._connection.receive_from_header_and_body(header, buf)
self._process_received_message(msg)
if header.MessageType == ua.MessageType.SecureOpen:
if header.MessageType == ua.MessageType.SecureOpen and isinstance(msg,ua.Message):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that check for isinstance cannot be correct. if that is because of mypy then annotate the "recevie_from_header_xx" over

params: ua.OpenSecureChannelParameters = self._open_secure_channel_exchange
response: ua.OpenSecureChannelResponse = struct_from_binary(ua.OpenSecureChannelResponse, msg.body())
response.ResponseHeader.ServiceResult.check()
Expand All @@ -107,7 +107,7 @@ def _process_received_data(self, data: bytes) -> None:
self.disconnect_socket()
return

def _process_received_message(self, msg: Union[ua.Message, ua.Acknowledge, ua.ErrorMessage]):
def _process_received_message(self, msg: Union[None,ua.Message, ua.Acknowledge, ua.ErrorMessage]):
if msg is None:
pass
elif isinstance(msg, ua.Message):
Expand Down
59 changes: 32 additions & 27 deletions asyncua/common/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
import hashlib
from datetime import datetime, timedelta, timezone
from typing import Optional, List, TYPE_CHECKING, Union
import logging
import copy

Expand All @@ -14,6 +15,10 @@
class InvalidSignature(Exception): # type: ignore
pass

if TYPE_CHECKING:
from asyncua.common.utils import Buffer
from asyncua.ua.uaprotocol_hand import SecurityPolicy, SecurityPolicyFactory

_logger = logging.getLogger('asyncua.uaprotocol')


Expand Down Expand Up @@ -51,7 +56,7 @@ def is_chunk_count_within_limit(self, sz: int) -> bool:
_logger.error("Number of message chunks: %s is > configured max chunk count: %s", sz, self.max_chunk_count)
return within_limit

def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge:
def create_acknowledge_and_set_limits(self, msg: "ua.Hello") -> "ua.Acknowledge":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that too cannot be correct. should be possible to use from __future__ import annotations stuff to avoid adding "

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be no reason to use " for typing

ack = ua.Acknowledge()
ack.ReceiveBufferSize = min(msg.ReceiveBufferSize, self.max_send_buffer)
ack.SendBufferSize = min(msg.SendBufferSize, self.max_recv_buffer)
Expand All @@ -64,14 +69,14 @@ def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge:
_logger.info("updating server limits to: %s", self)
return ack

def create_hello_limits(self, msg: ua.Hello) -> ua.Hello:
def create_hello_limits(self, msg: "ua.Hello") -> "ua.Hello":
msg.ReceiveBufferSize = self.max_recv_buffer
msg.SendBufferSize = self.max_send_buffer
msg.MaxChunkCount = self.max_chunk_count
msg.MaxMessageSize = self.max_chunk_count
return msg

def update_client_limits(self, msg: ua.Acknowledge) -> None:
def update_client_limits(self, msg: "ua.Acknowledge") -> None:
self.max_chunk_count = msg.MaxChunkCount
self.max_recv_buffer = msg.ReceiveBufferSize
self.max_send_buffer = msg.SendBufferSize
Expand Down Expand Up @@ -105,7 +110,7 @@ def from_binary(security_policy, data):
return MessageChunk.from_header_and_body(security_policy, h, data, use_prev_key=True)

@staticmethod
def from_header_and_body(security_policy, header, buf, use_prev_key=False):
def from_header_and_body(security_policy: "SecurityPolicy", header, buf, use_prev_key=False):
if not len(buf) >= header.body_size:
raise ValueError('Full body expected here')
data = buf.copy(header.body_size)
Expand Down Expand Up @@ -156,7 +161,7 @@ def max_body_size(crypto, max_chunk_size):
return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size()

@staticmethod
def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1):
def message_to_chunks(security_policy: "SecurityPolicy", body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1):
"""
Pack message body (as binary string) into one or more chunks.
Size of each chunk will not exceed max_chunk_size.
Expand All @@ -179,7 +184,7 @@ def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.Mes
crypto = security_policy.symmetric_cryptography
max_size = MessageChunk.max_body_size(crypto, max_chunk_size)

chunks = []
chunks: List[MessageChunk] = []
for i in range(0, len(body), max_size):
part = body[i:i + max_size]
if i + max_size >= len(body):
Expand All @@ -204,22 +209,22 @@ class SecureConnection:
"""
Common logic for client and server
"""
def __init__(self, security_policy, limits: TransportLimits):
self._sequence_number = 0
self._peer_sequence_number = None
self._incoming_parts = []
self.security_policy = security_policy
self._policies = []
self._open = False
def __init__(self, security_policy: "SecurityPolicy", limits: TransportLimits) -> None:
self._sequence_number: int = 0
self._peer_sequence_number: Optional[int] = None
self._incoming_parts: List[MessageChunk] = []
self.security_policy: SecurityPolicy = security_policy
self._policies: List[SecurityPolicyFactory] = []
self._open: bool = False
self.security_token = ua.ChannelSecurityToken()
self.next_security_token = ua.ChannelSecurityToken()
self.prev_security_token = ua.ChannelSecurityToken()
self.local_nonce = 0
self.remote_nonce = 0
self._allow_prev_token = False
self._limits = limits
self.local_nonce: int = 0
self.remote_nonce:int = 0
self._allow_prev_token: bool = False
self._limits: TransportLimits = limits

def set_channel(self, params, request_type, client_nonce):
def set_channel(self, params, request_type, client_nonce) -> None:
"""
Called on client side when getting secure channel data from server.
"""
Expand All @@ -241,7 +246,7 @@ def set_channel(self, params, request_type, client_nonce):

self._allow_prev_token = True

def open(self, params, server):
def open(self, params, server) -> ua.OpenSecureChannelResult:
"""
Called on server side to open secure channel.
"""
Expand Down Expand Up @@ -276,32 +281,32 @@ def open(self, params, server):

return response

def close(self):
def close(self) -> None:
self._open = False

def is_open(self):
def is_open(self) -> bool:
return self._open

def set_policy_factories(self, policies):
def set_policy_factories(self, policies: "List[SecurityPolicyFactory]") -> None:
"""
Set a list of available security policies.
Use this in servers with multiple endpoints with different security.
"""
self._policies = policies

@staticmethod
def _policy_matches(policy, uri, mode=None):
def _policy_matches(policy: "SecurityPolicy", uri, mode=None) -> bool:
return policy.URI == uri and (mode is None or policy.Mode == mode)

def select_policy(self, uri, peer_certificate, mode=None):
def select_policy(self, uri: str, peer_certificate, mode=None):
for policy in self._policies:
if policy.matches(uri, mode):
self.security_policy = policy.create(peer_certificate)
return
if self.security_policy.URI != uri or (mode is not None and self.security_policy.Mode != mode):
raise ua.UaError(f"No matching policy: {uri}, {mode}")

def revolve_tokens(self):
def revolve_tokens(self) -> None:
"""
Revolve security tokens of the security channel. Start using the
next security token negotiated during the renewal of the channel and
Expand Down Expand Up @@ -389,7 +394,7 @@ def _check_incoming_chunk(self, chunk):
raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection")
self._peer_sequence_number = seq_num

def receive_from_header_and_body(self, header, body):
def receive_from_header_and_body(self, header: ua.Header, body: "Buffer") -> Union[None,ua.Message,"ua.Hello",ua.Acknowledge,ua.ErrorMessage]:
"""
Convert MessageHeader and binary body to OPC UA TCP message (see OPC UA
specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message
Expand Down Expand Up @@ -430,7 +435,7 @@ def receive_from_header_and_body(self, header, body):
return msg
raise ua.UaError(f"Unsupported message type {header.MessageType}")

def _receive(self, msg):
def _receive(self, msg: MessageChunk) -> Optional[ua.Message]:
if msg.MessageHeader.packet_size > self._limits.max_recv_buffer:
self._incoming_parts = []
_logger.error("Message size: %s is > chunk max size: %s", msg.MessageHeader.packet_size, self._limits.max_recv_buffer)
Expand Down
29 changes: 19 additions & 10 deletions asyncua/crypto/permission_rules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from abc import abstractmethod
from asyncua import ua
from asyncua.server.users import UserRole
from abc import ABC

WRITE_TYPES = [
from typing import TYPE_CHECKING, Tuple, Dict, Set

if TYPE_CHECKING:
from asyncua.server.users import User
from asyncua.common.utils import Buffer

WRITE_TYPES: Tuple[int,...] = (
ua.ObjectIds.WriteRequest_Encoding_DefaultBinary,
ua.ObjectIds.RegisterServerRequest_Encoding_DefaultBinary,
ua.ObjectIds.RegisterServer2Request_Encoding_DefaultBinary,
Expand All @@ -11,9 +19,9 @@
ua.ObjectIds.DeleteReferencesRequest_Encoding_DefaultBinary,
ua.ObjectIds.RegisterNodesRequest_Encoding_DefaultBinary,
ua.ObjectIds.UnregisterNodesRequest_Encoding_DefaultBinary
]
)

READ_TYPES = [
READ_TYPES: Tuple[int,...] = (
ua.ObjectIds.CreateSessionRequest_Encoding_DefaultBinary,
ua.ObjectIds.CloseSessionRequest_Encoding_DefaultBinary,
ua.ObjectIds.ActivateSessionRequest_Encoding_DefaultBinary,
Expand All @@ -33,16 +41,17 @@
ua.ObjectIds.CloseSecureChannelRequest_Encoding_DefaultBinary,
ua.ObjectIds.CallRequest_Encoding_DefaultBinary,
ua.ObjectIds.SetMonitoringModeRequest_Encoding_DefaultBinary,
ua.ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary
]
ua.ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary,
)


class PermissionRuleset:
class PermissionRuleset(ABC):
"""
Base class for permission ruleset
"""

def check_validity(self, user, action_type, body):
@abstractmethod
def check_validity(self, user: "User", action_type_id: ua.NodeId, body: "Buffer") -> bool:
raise NotImplementedError


Expand All @@ -52,16 +61,16 @@ class SimpleRoleRuleset(PermissionRuleset):
Admins alone can write, admins and users can read, and anonymous users can't do anything.
"""

def __init__(self):
def __init__(self) -> None:
write_ids = list(map(ua.NodeId, WRITE_TYPES))
read_ids = list(map(ua.NodeId, READ_TYPES))
self._permission_dict = {
self._permission_dict: Dict[UserRole, Set[ua.NodeId]] = {
UserRole.Admin: set().union(write_ids, read_ids),
UserRole.User: set().union(read_ids),
UserRole.Anonymous: set()
}

def check_validity(self, user, action_type_id, body):
def check_validity(self, user: "User", action_type_id: ua.NodeId, body: "Buffer") -> bool:
if action_type_id in self._permission_dict[user.role]:
return True
else:
Expand Down
16 changes: 8 additions & 8 deletions asyncua/crypto/security_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ class Verifier:
__metaclass__ = ABCMeta

@abstractmethod
def signature_size(self):
def signature_size(self) -> None:
pass

@abstractmethod
def verify(self, data, signature):
def verify(self, data, signature) -> None:
pass

def reset(self):
def reset(self) -> None:
attrs = self.__dict__
for k in attrs:
attrs[k] = None
Expand All @@ -70,11 +70,11 @@ class Encryptor:
__metaclass__ = ABCMeta

@abstractmethod
def plain_block_size(self):
def plain_block_size(self) -> int:
pass

@abstractmethod
def encrypted_block_size(self):
def encrypted_block_size(self) -> int:
pass

@abstractmethod
Expand All @@ -90,18 +90,18 @@ class Decryptor:
__metaclass__ = ABCMeta

@abstractmethod
def plain_block_size(self):
def plain_block_size(self) -> int:
pass

@abstractmethod
def encrypted_block_size(self):
def encrypted_block_size(self) -> int:
pass

@abstractmethod
def decrypt(self, data):
pass

def reset(self):
def reset(self) -> None:
attrs = self.__dict__
for k in attrs:
attrs[k] = None
Expand Down
Loading