diff --git a/setup.cfg b/setup.cfg index eff21ae..a708ea4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = sas-airflow-provider -version = 0.0.7 +version = 0.0.8 author = SAS author_email = andrew.shakinovsky@sas.com description = Enables execution of Studio Flows and Jobs from Airflow diff --git a/src/sas_airflow_provider/hooks/sas.py b/src/sas_airflow_provider/hooks/sas.py index 5fc9978..fce53a2 100644 --- a/src/sas_airflow_provider/hooks/sas.py +++ b/src/sas_airflow_provider/hooks/sas.py @@ -100,4 +100,7 @@ def _create_session_for_connection(self): session.put = lambda *args, **kwargs: requests.Session.put( # type: ignore session, urllib.parse.urljoin(root_url, args[0]), *args[1:], **kwargs ) + session.delete = lambda *args, **kwargs: requests.Session.delete( # type: ignore + session, urllib.parse.urljoin(root_url, args[0]), *args[1:], **kwargs + ) return session diff --git a/src/sas_airflow_provider/operators/sas_create_session.py b/src/sas_airflow_provider/operators/sas_create_session.py index b481767..9469c10 100644 --- a/src/sas_airflow_provider/operators/sas_create_session.py +++ b/src/sas_airflow_provider/operators/sas_create_session.py @@ -66,7 +66,7 @@ def execute(self, context): self.xcom_push(context, 'compute_session_id', self.compute_session_id) # support retry if API-calls fails for whatever reason except Exception as e: - raise AirflowException(f"SASComputeCodeExecOperator error: {str(e)}") + raise AirflowException(f"SASComputeCreateSession error: {str(e)}") return 1 diff --git a/src/sas_airflow_provider/operators/sas_delete_session.py b/src/sas_airflow_provider/operators/sas_delete_session.py new file mode 100644 index 0000000..7b0a8a4 --- /dev/null +++ b/src/sas_airflow_provider/operators/sas_delete_session.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from sas_airflow_provider.hooks.sas import SasHook +from sas_airflow_provider.util.util import \ + create_or_connect_to_session, find_named_compute_session, end_compute_session + + +class SASComputeDeleteSession(BaseOperator): + """ + Delete a Compute session. either a session_name or a session_id should be provided. + The result is pushed as a True/False xcom named disconnect_succeeded + + :param connection_name: (optional) name of the connection to use. The connection should be defined + as an HTTP connection in Airflow. If not specified then the default is used + :param session_name: (optional) name of the session to delete + :param session_id: (optiona) id of the session to delete + """ + + ui_color = "#CCE5FF" + ui_fgcolor = "black" + + # template fields are fields which can be templated out in the Airflow task using {{ }} + template_fields: Sequence[str] = ("compute_session_id", "compute_session_name") + + def __init__( + self, + connection_name=None, + compute_session_name="", + compute_session_id="", + **kwargs, + ) -> None: + if not compute_session_id and not compute_session_name: + raise AirflowException(f"Either session_name or session_id must be provided") + super().__init__(**kwargs) + self.connection = None + self.connection_name = connection_name + self.compute_session_name = compute_session_name + self.compute_session_id = compute_session_id + self.success=False + + def execute(self, context): + try: + self.log.info("Authenticate connection") + h = SasHook(self.connection_name) + self.connection = h.get_conn() + self._delete_compute() + self.xcom_push(context, 'disconnect_succeeded', self.success) + # support retry if API-calls fails for whatever reason + except Exception as e: + raise AirflowException(f"SASComputeDeleteSession error: {str(e)}") + + return 1 + + def _delete_compute(self): + if self.compute_session_name: + self.log.info(f"Find session named {self.compute_session_name}") + sesh = find_named_compute_session(self.connection, self.compute_session_name) + if sesh: + self.compute_session_id = sesh["id"] + else: + self.log.info(f"Session named {self.compute_session_name} not found") + return + self.log.info(f"Delete session with id {self.compute_session_id}") + self.success = end_compute_session(self.connection, self.compute_session_id) + + + + diff --git a/src/sas_airflow_provider/util/util.py b/src/sas_airflow_provider/util/util.py index c85adbf..706f7dd 100644 --- a/src/sas_airflow_provider/util/util.py +++ b/src/sas_airflow_provider/util/util.py @@ -129,6 +129,16 @@ def dump_logs(session, job): if t != "title": print(f'{line["line"]}') +def find_named_compute_session(session: requests.Session, name: str) -> dict: + # find session with given name + response = session.get(f"/compute/sessions?filter=eq(name, {name})") + if not response.ok: + raise RuntimeError(f"Find sessions failed: {response.status_code}") + sessions = response.json() + if sessions["count"] > 0: + print(f"Existing session named '{name}' was found") + return sessions["items"][0] + return {} def create_or_connect_to_session(session: requests.Session, context_name: str, name: str) -> dict: """ @@ -139,14 +149,9 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n :param name: name of session to find :return: session object """ - # find session with given name - response = session.get(f"/compute/sessions?filter=eq(name, {name})") - if not response.ok: - raise RuntimeError(f"Find sessions failed: {response.status_code}") - sessions = response.json() - if sessions["count"] > 0: - print(f"Existing session named '{name}' was found") - return sessions["items"][0] + compute_session = find_named_compute_session(session, name) + if compute_session: + return compute_session print(f"Compute session named '{name}' does not exist, a new one will be created") # find compute context @@ -171,3 +176,10 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n raise RuntimeError(f"Failed to create session: {response.text}") return response.json() + +def end_compute_session(session: requests.Session, id): + uri = f'/compute/sessions/{id}' + response = session.delete(uri) + if not response.ok: + return False + return True diff --git a/tests/system/sas_create_delete_session.py b/tests/system/sas_create_delete_session.py new file mode 100644 index 0000000..21e85cf --- /dev/null +++ b/tests/system/sas_create_delete_session.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from airflow import DAG +from sas_airflow_provider.operators.sas_create_session import SASComputeCreateSession +from sas_airflow_provider.operators.sas_delete_session import SASComputeDeleteSession + +dag = DAG('demo_create_delete', description='Create and delete sessions', + schedule="@once", + start_date=datetime(2022, 6, 1), catchup=False) + +task0 = SASComputeCreateSession(task_id="create_sess", dag=dag) + +task1 = SASComputeDeleteSession(task_id='delete_sess', + compute_session_id="{{ ti.xcom_pull(key='compute_session_id', task_ids=[" + "'create_sess'])|first }}", + dag=dag) + + +task2 = SASComputeDeleteSession(task_id='delete_sess_named', + compute_session_name="Airflow-Session", + dag=dag) + +task0 >> task1 >> task2 +if __name__ == '__main__': + dag.test()