diff --git a/tabpy/tabpy_tools/client.py b/tabpy/tabpy_tools/client.py index 684fdc0e..2fb09e54 100644 --- a/tabpy/tabpy_tools/client.py +++ b/tabpy/tabpy_tools/client.py @@ -1,4 +1,5 @@ import copy +import inspect from re import compile import time import requests @@ -49,7 +50,7 @@ def _check_endpoint_name(name): class Client: - def __init__(self, endpoint, query_timeout=1000): + def __init__(self, endpoint, query_timeout=1000, remote_server=False, localhost_endpoint=None): """ Connects to a running server. @@ -63,10 +64,19 @@ def __init__(self, endpoint, query_timeout=1000): query_timeout : float, optional The timeout for query operations. + + remote_server : bool, optional + Whether client is a remote TabPy server. + + localhost_endpoint : str, optional + The localhost endpoint with potentially different protocol and + port compared to the main endpoint parameter. """ _check_hostname(endpoint) self._endpoint = endpoint + self._remote_server = remote_server + self._localhost_endpoint = localhost_endpoint session = requests.session() session.verify = False @@ -232,6 +242,13 @@ def deploy(self, name, obj, description="", schema=None, override=False, is_publ -------- remove, get_endpoints """ + if self._remote_server: + self._remote_deploy( + name, obj, + description=description, schema=schema, override=override, is_public=is_public + ) + return + endpoint = self.get_endpoints().get(name) version = 1 if endpoint: @@ -379,6 +396,7 @@ def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_publi description = obj.__doc__.strip() or "" if isinstance(obj.__doc__, str) else "" endpoint_object = CustomQueryObject(query=obj, description=description,) + docstring = inspect.getdoc(obj) or "-- no docstring found in query function --" return { "name": name, @@ -390,6 +408,7 @@ def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_publi "methods": endpoint_object.get_methods(), "required_files": [], "required_packages": [], + "docstring": docstring, "schema": copy.copy(schema), "is_public": is_public, } @@ -419,6 +438,7 @@ def _wait_for_endpoint_deployment( logger.info( f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}" ) + time.sleep(interval) start = time.time() while True: ep_status = self.get_status() @@ -447,6 +467,67 @@ def _wait_for_endpoint_deployment( logger.info(f"Sleeping {interval}...") time.sleep(interval) + def _remote_deploy(self, name, obj, description="", schema=None, override=False, is_public=False): + """ + Remotely deploy a Python function using the /evaluate endpoint. Takes the same inputs + as deploy. + """ + remote_script = self._gen_remote_script() + remote_script += f"{inspect.getsource(obj)}\n" + + remote_script += ( + f"client.deploy(" + f"'{name}', {obj.__name__}, '{description}', " + f"override={override}, is_public={is_public}, schema={schema}" + f")" + ) + + self._evaluate_remote_script(remote_script) + + def _gen_remote_script(self): + """ + Generates a remote script for TabPy client connection with credential handling. + + Returns: + str: A Python script to establish a TabPy client connection + """ + remote_script = [ + "from tabpy.tabpy_tools.client import Client", + f"client = Client('{self._localhost_endpoint or self._endpoint}')" + ] + + remote_script.append( + f"client.set_credentials('{auth.username}', '{auth.password}')" + ) if (auth := self._service.service_client.network_wrapper.auth) else None + + return "\n".join(remote_script) + "\n" + + def _evaluate_remote_script(self, remote_script): + """ + Uses TabPy /evaluate endpoint to execute a remote TabPy client script. + + Parameters + ---------- + remote_script : str + The script to execute remotely. + """ + print(f"Remote script:\n{remote_script}") + url = f"{self._endpoint}evaluate" + headers = {"Content-Type": "application/json"} + payload = {"data": {}, "script": remote_script} + + response = requests.post( + url, + headers=headers, + auth=self._service.service_client.network_wrapper.auth, + json=payload + ) + + log_message = response.text.replace('null', 'success') + if "Ad-hoc scripts have been disabled" in log_message: + log_message += "\n[Connecting to this TabPy server with remote_server=True is not allowed.]" + print(f"\n{response.status_code} - {log_message}\n") + def set_credentials(self, username, password): """ Set credentials for all the TabPy client-server communication