diff --git a/assemblyline_service_client/task_handler.py b/assemblyline_service_client/task_handler.py index d30b938..71e7c6a 100644 --- a/assemblyline_service_client/task_handler.py +++ b/assemblyline_service_client/task_handler.py @@ -173,7 +173,7 @@ def cleanup_working_directory(self, folder_path): except Exception: pass - def request_with_retries(self, method: str, url: str, max_retry=None, **kwargs): + def request_with_retries(self, method: str, url: str, get_api_response=True, max_retry=None, **kwargs): if 'headers' in kwargs: self.session.headers.update(kwargs['headers']) kwargs.pop('headers') @@ -188,13 +188,16 @@ def request_with_retries(self, method: str, url: str, max_retry=None, **kwargs): func = getattr(self.session, method) resp = func(url, **kwargs) - if resp.status_code == 400 and resp.json(): - self.log.exception(resp.json()['api_error_message']) - raise ServiceServerException(resp.json()['api_error_message']) - else: - resp.raise_for_status() + if get_api_response: + if resp.status_code == 400 and resp.json(): + self.log.exception(resp.json()['api_error_message']) + raise ServiceServerException(resp.json()['api_error_message']) + else: + resp.raise_for_status() - return resp.json()['api_response'] + return resp.json()['api_response'] + else: + return resp except requests.ConnectionError: msg = f"Cannot reach service server. Retrying after {back_off_time}s." if retry < 2: @@ -349,7 +352,8 @@ def download_file(self, sha256, sid) -> Optional[str]: received_file_sha256 = '' file_path = None self.log.info(f"[{sid}] Downloading file: {sha256}") - r = self.request_with_retries('get', self._path('file', sha256), max_retry=3, headers=self.headers) + r = self.request_with_retries('get', self._path('file', sha256), + get_api_response=False, max_retry=3, headers=self.headers) if r is not None: if r.status_code == 404: self.log.error(f"[{sid}] Requested file not found in the system: {sha256}")