Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NAS-132488 / 25.04 / Fix websocket connection crash #17

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions truenas_api_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,25 @@ def _recv(self, message: JSONRPCMessage):
call.returned.set()
self._unregister_call(call)
else:
logger.error('Received a response for non-registered method call %r', message['id'])
if 'result' in message:
logger.error('Received a success response for non-registered method call %r', message['id'])
elif 'error' in message:
try:
error = self._parse_error_and_unpickle_exception(message['error'])[0]
except Exception:
logger.error('Unhandled exception in JSONRPCClient._parse_error', exc_info=True)
error = None

if message['id'] is None:
logger.error('Received a global connection error: %r', error)
else:
logger.error('Received an error response for non-registered method call %r: %r',
message['id'], error)

if error:
self._broadcast_error(error)
else:
logger.error('Received a response for non-registered method call %r', message['id'])
else:
logger.error('Received unknown message %r', message)
except Exception:
Expand All @@ -544,24 +562,35 @@ def _parse_error(self, error: ErrorObj, call: Call):
Args:
error: The JSON object received in an error Response.
call: The associated `Call` object with which to store the `ClientException`.
"""
call.error, call.py_exception = self._parse_error_and_unpickle_exception(error)

def _parse_error_and_unpickle_exception(self, error: ErrorObj) -> tuple[ClientException, Exception | None]:
"""Convert an error received from the server into a `ClientException` and, possibly, unpickle original
exception.

Args:
error: The JSON object received in an error Response.
"""
code = JSONRPCError(error['code'])
py_exception = None
if self._py_exceptions and code in [JSONRPCError.INVALID_PARAMS, JSONRPCError.TRUENAS_CALL_ERROR]:
data = error['data']
call.error = ClientException(data['reason'], data['error'], data['trace'], data['extra'])
error = ClientException(data['reason'], data['error'], data['trace'], data['extra'])
if 'py_exception' in data:
try:
call.py_exception = pickle.loads(b64decode(data['py_exception']))
py_exception = pickle.loads(b64decode(data['py_exception']))
except Exception as e:
logger.warning("Error unpickling call exception: %r", e)
elif code == JSONRPCError.INVALID_PARAMS:
call.error = ValidationErrors(error['data']['extra'])
error = ValidationErrors(error['data']['extra'])
elif code == JSONRPCError.TRUENAS_CALL_ERROR:
data = error['data']
call.error = ClientException(data['reason'], data['error'], data['trace'], data['extra'])
error = ClientException(data['reason'], data['error'], data['trace'], data['extra'])
else:
call.error = ClientException(code.name)
error = ClientException(error.get('message') or code.name)

return error, py_exception

def _run_callback(self, event: _Payload, args: Iterable[str], kwargs: CollectionUpdateParams):
"""Call the passed `_Payload`'s callback function.
Expand All @@ -585,7 +614,7 @@ def on_open(self):
"""Make an API call to `core.set_options` to configure how middlewared sends its responses."""
self._set_options_call = self.call("core.set_options", {"py_exceptions": self._py_exceptions}, background=True)

def on_close(self, code: int, reason: str | None=None):
def on_close(self, code: int, reason: str | None = None):
"""Close this `JSONRPCClient` in response to the `WebSocketApp` closing.

End all unanswered calls and unreturned jobs with an error.
Expand All @@ -600,9 +629,14 @@ def on_close(self, code: int, reason: str | None=None):
self._connection_error = error
self._connected.set()

self._broadcast_error(ClientException(error, errno.ECONNABORTED))

self._closed.set()

def _broadcast_error(self, error: ClientException):
for call in self._calls.values():
if not call.returned.is_set():
call.error = ClientException(error, errno.ECONNABORTED)
call.error = error
call.returned.set()

for job in self._jobs.values():
Expand All @@ -611,17 +645,16 @@ def on_close(self, code: int, reason: str | None=None):
event = job['__ready'] = Event()

if not event.is_set():
job['error'] = error
job['exception'] = error
error_repr = repr(error)
job['error'] = error_repr
job['exception'] = error_repr
job['exc_info'] = {
'type': 'Exception',
'repr': error,
'repr': error_repr,
'extra': None,
}
event.set()

self._closed.set()

def _register_call(self, call: Call):
"""Save a `Call` and index it by its id."""
self._calls[call.id] = call
Expand Down
40 changes: 25 additions & 15 deletions truenas_api_client/ejson.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import json


class EJSONDecodeError(ValueError):
pass


class JSONEncoder(json.JSONEncoder):
"""Custom JSON encoder that extends the default encoder to handle more types.

Expand Down Expand Up @@ -63,21 +67,27 @@ def object_hook(obj: dict):
Passed as a kwarg to a JSON deserialization function like `json.dump()`.

"""
obj_len = len(obj)
if obj_len == 1:
if '$date' in obj:
return datetime.fromtimestamp(obj['$date'] / 1000, tz=timezone.utc) + timedelta(milliseconds=obj['$date'] % 1000)
if '$time' in obj:
return time(*[int(i) for i in obj['$time'].split(':')[:4]]) # type: ignore
if '$set' in obj:
return set(obj['$set'])
if '$ipv4_interface' in obj:
return IPv4Interface(obj['$ipv4_interface'])
if '$ipv6_interface' in obj:
return IPv6Interface(obj['$ipv6_interface'])
if obj_len == 2 and '$type' in obj and '$value' in obj:
if obj['$type'] == 'date':
return date(*[int(i) for i in obj['$value'].split('-')])
error_key = '<unknown>'
try:
obj_len = len(obj)
if obj_len == 1:
error_key = list(obj.keys())[0]
if '$date' in obj:
return datetime.fromtimestamp(obj['$date'] / 1000, tz=timezone.utc) + timedelta(milliseconds=obj['$date'] % 1000)
if '$time' in obj:
return time(*[int(i) for i in obj['$time'].split(':')[:4]]) # type: ignore
if '$set' in obj:
return set(obj['$set'])
if '$ipv4_interface' in obj:
return IPv4Interface(obj['$ipv4_interface'])
if '$ipv6_interface' in obj:
return IPv6Interface(obj['$ipv6_interface'])
if obj_len == 2 and '$type' in obj and '$value' in obj:
error_key = obj['$type']
if obj['$type'] == 'date':
return date(*[int(i) for i in obj['$value'].split('-')])
except Exception as e:
raise EJSONDecodeError(f'Error parsing {error_key}: {e}')
return obj


Expand Down
Loading