Skip to content

Commit

Permalink
Add remote_server capability.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakeichikawasalesforce committed Nov 20, 2024
1 parent d8207c0 commit ff526b2
Showing 1 changed file with 82 additions and 1 deletion.
83 changes: 82 additions & 1 deletion tabpy/tabpy_tools/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import inspect
from re import compile
import time
import requests
Expand Down Expand Up @@ -49,7 +50,7 @@ def _check_endpoint_name(name):


class Client:
def __init__(self, endpoint, query_timeout=1000):
def __init__(self, endpoint, query_timeout=1000, remote_server=False, localhost_endpoint=None):
"""
Connects to a running server.

Expand All @@ -63,10 +64,19 @@ def __init__(self, endpoint, query_timeout=1000):

query_timeout : float, optional
The timeout for query operations.

remote_server : bool, optional
Whether client is a remote TabPy server.

localhost_endpoint : str, optional
The localhost endpoint with potentially different protocol and
port compared to the main endpoint parameter.
"""
_check_hostname(endpoint)

self._endpoint = endpoint
self._remote_server = remote_server
self._localhost_endpoint = localhost_endpoint

session = requests.session()
session.verify = False
Expand Down Expand Up @@ -232,6 +242,13 @@ def deploy(self, name, obj, description="", schema=None, override=False, is_publ
--------
remove, get_endpoints
"""
if self._remote_server:
self._remote_deploy(
name, obj,
description=description, schema=schema, override=override, is_public=is_public
)
return

endpoint = self.get_endpoints().get(name)
version = 1
if endpoint:
Expand Down Expand Up @@ -379,6 +396,7 @@ def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_publi
description = obj.__doc__.strip() or "" if isinstance(obj.__doc__, str) else ""

endpoint_object = CustomQueryObject(query=obj, description=description,)
docstring = inspect.getdoc(obj) or "-- no docstring found in query function --"

return {
"name": name,
Expand All @@ -390,6 +408,7 @@ def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_publi
"methods": endpoint_object.get_methods(),
"required_files": [],
"required_packages": [],
"docstring": docstring,
"schema": copy.copy(schema),
"is_public": is_public,
}
Expand Down Expand Up @@ -419,6 +438,7 @@ def _wait_for_endpoint_deployment(
logger.info(
f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
)
time.sleep(interval)
start = time.time()
while True:
ep_status = self.get_status()
Expand Down Expand Up @@ -447,6 +467,67 @@ def _wait_for_endpoint_deployment(
logger.info(f"Sleeping {interval}...")
time.sleep(interval)

def _remote_deploy(self, name, obj, description="", schema=None, override=False, is_public=False):
"""
Remotely deploy a Python function using the /evaluate endpoint. Takes the same inputs
as deploy.
"""
remote_script = self._gen_remote_script()
remote_script += f"{inspect.getsource(obj)}\n"

remote_script += (
f"client.deploy("
f"'{name}', {obj.__name__}, '{description}', "
f"override={override}, is_public={is_public}, schema={schema}"
f")"
)

self._evaluate_remote_script(remote_script)

def _gen_remote_script(self):
"""
Generates a remote script for TabPy client connection with credential handling.

Returns:
str: A Python script to establish a TabPy client connection
"""
remote_script = [
"from tabpy.tabpy_tools.client import Client",
f"client = Client('{self._localhost_endpoint or self._endpoint}')"
]

remote_script.append(
f"client.set_credentials('{auth.username}', '{auth.password}')"
) if (auth := self._service.service_client.network_wrapper.auth) else None

return "\n".join(remote_script) + "\n"

def _evaluate_remote_script(self, remote_script):
"""
Uses TabPy /evaluate endpoint to execute a remote TabPy client script.

Parameters
----------
remote_script : str
The script to execute remotely.
"""
print(f"Remote script:\n{remote_script}")
url = f"{self._endpoint}evaluate"
headers = {"Content-Type": "application/json"}
payload = {"data": {}, "script": remote_script}

response = requests.post(
url,
headers=headers,
auth=self._service.service_client.network_wrapper.auth,
json=payload
)

log_message = response.text.replace('null', 'success')
if "Ad-hoc scripts have been disabled" in log_message:
log_message += "\n[Connecting to this TabPy server with remote_server=True is not allowed.]"
print(f"\n{response.status_code} - {log_message}\n")

def set_credentials(self, username, password):
"""
Set credentials for all the TabPy client-server communication
Expand Down

0 comments on commit ff526b2

Please sign in to comment.