From 30d8c5114879d00d4a8cf58c1274555398483765 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Thu, 30 Nov 2023 16:28:17 +0100 Subject: [PATCH] Code improvements for example notebooks For Splunk ingestion don't use the raw endpoint, but normal JSON endpoint. --- notebooks/source/cloudtrail_ingest.py | 5 +- .../source/cloudtrail_insights_ingest.py | 4 +- notebooks/source/pull_from_splunk.py | 10 +- notebooks/source/push_to_splunk.py | 100 ++++++++++-------- 4 files changed, 68 insertions(+), 51 deletions(-) diff --git a/notebooks/source/cloudtrail_ingest.py b/notebooks/source/cloudtrail_ingest.py index 08eefca..6b6e650 100644 --- a/notebooks/source/cloudtrail_ingest.py +++ b/notebooks/source/cloudtrail_ingest.py @@ -38,7 +38,8 @@ checkpointPath=dbutils.widgets.get("Checkpoint Path") tableName=dbutils.widgets.get("Table Name") regionName=dbutils.widgets.get("Region Name") -if ((cloudTrailLogsPath==None or cloudTrailLogsPath=="")or(deltaOutputPath==None or deltaOutputPath=="")or(checkpointPath==None or checkpointPath=="")or(tableName==None or tableName=="")or(regionName==None or regionName=="")): + +if ((cloudTrailLogsPath==None or cloudTrailLogsPath=="") or (deltaOutputPath==None or deltaOutputPath=="") or (checkpointPath==None or checkpointPath=="") or (tableName==None or tableName=="") or (regionName==None or regionName=="")): dbutils.notebook.exit("All parameters are mandatory. Ensure correct values of all parameters are specified.") # COMMAND ---------- @@ -170,5 +171,5 @@ # COMMAND ---------- -create_table_query="CREATE TABLE IF NOT EXISTS "+tableName+" USING DELTA LOCATION '"+ deltaOutputPath +"'" +create_table_query=f"CREATE TABLE IF NOT EXISTS {tableName} USING DELTA LOCATION '{deltaOutputPath}'" spark.sql(create_table_query) diff --git a/notebooks/source/cloudtrail_insights_ingest.py b/notebooks/source/cloudtrail_insights_ingest.py index ca3d0a0..2be2ecc 100644 --- a/notebooks/source/cloudtrail_insights_ingest.py +++ b/notebooks/source/cloudtrail_insights_ingest.py @@ -38,7 +38,7 @@ checkpointPath=dbutils.widgets.get("Checkpoint Path") tableName=dbutils.widgets.get("Table Name") regionName=dbutils.widgets.get("Region Name") -if ((cloudTrailInsightsPath==None or cloudTrailInsightsPath=="")or(deltaOutputPath==None or deltaOutputPath=="")or(checkpointPath==None or checkpointPath=="")or(tableName==None or tableName=="")or(regionName==None or regionName=="")): +if ((cloudTrailInsightsPath==None or cloudTrailInsightsPath=="") or (deltaOutputPath==None or deltaOutputPath=="") or (checkpointPath==None or checkpointPath=="") or (tableName==None or tableName=="") or (regionName==None or regionName=="")): dbutils.notebook.exit("All parameters are mandatory. Ensure correct values of all parameters are specified.") # COMMAND ---------- @@ -144,5 +144,5 @@ # COMMAND ---------- -create_table_query="CREATE TABLE IF NOT EXISTS "+tableName+" USING DELTA LOCATION '"+ deltaOutputPath +"'" +create_table_query=f"CREATE TABLE IF NOT EXISTS {tableName} USING DELTA LOCATION '{deltaOutputPath}'" spark.sql(create_table_query) diff --git a/notebooks/source/pull_from_splunk.py b/notebooks/source/pull_from_splunk.py index 3f7bdf1..9d3194d 100644 --- a/notebooks/source/pull_from_splunk.py +++ b/notebooks/source/pull_from_splunk.py @@ -117,25 +117,25 @@ def __init__(self,splunk_address,splunk_port,splunk_username,splunk_namespace,ss self.splunk_namespace = splunk_namespace @property def auth_url(self): - auth_url = "https://{}:{}/services/auth/login".format(self.splunk_address,self.splunk_port) + auth_url = f"https://{self.splunk_address}:{self.splunk_port}/services/auth/login" return (auth_url) @property def mgmt_segment(self): - mgmt_segment_part = "https://{}:{}/servicesNS/{}/{}/".format(self.splunk_address,self.splunk_port,self.splunk_username,self.splunk_namespace) + mgmt_segment_part = f"https://{self.splunk_address}:{self.splunk_port}/servicesNS/{self.splunk_username}/{self.splunk_namespace}/" return (mgmt_segment_part) def connect(self,splunk_password): try: response = requests.post( self.auth_url, - data={"username":self.splunk_username, - "password":splunk_password},verify=self.ssl_verify) + data={"username": self.splunk_username, "password": splunk_password}, + verify=self.ssl_verify) session = XML(response.text).findtext("./sessionKey") if (session=="" or session==None): dbutils.notebook.exit("Issue in Authentication : Type - "+XML(response.text).find("./messages/msg").attrib["type"]+"\n Message - "+XML(response.text).findtext("./messages/msg")) else : - self.token = "Splunk {}".format(session) + self.token = f"Splunk {session}" except HTTPError as e: if e.status == 401: raise AuthenticationError("Login failed.", e) diff --git a/notebooks/source/push_to_splunk.py b/notebooks/source/push_to_splunk.py index 4b2622c..5efffb4 100644 --- a/notebooks/source/push_to_splunk.py +++ b/notebooks/source/push_to_splunk.py @@ -1,6 +1,12 @@ # Databricks notebook source # MAGIC %md +# MAGIC +# MAGIC ## Introduction +# MAGIC +# MAGIC This notebook demonstrates how to push data to Splunk using Splun's HTTP Event Collector. ***Please note that it's intended only for sending a limited number of events (dozens/hundreds) to Splunk, not to send big amounts of data, as it collects all data to the driver node!*** +# MAGIC # MAGIC ## Input parameters from the user +# MAGIC # MAGIC This gives a brief description of each parameter. # MAGIC
Before using the notebook, please go through the user documentation of this notebook to use the notebook effectively. # MAGIC 1. **Protocol** ***(Mandatory Field)*** : The protocol on which Splunk HTTP Event Collector(HEC) runs. Splunk HEC runs on `https` if Enable SSL checkbox is selected while configuring Splunk HEC Token in Splunk, else it runs on `http` protocol. If you do not have access to the Splunk HEC Configuration page, you can ask your Splunk Admin if the `Enable SSL checkbox` is selected or not. @@ -119,13 +125,23 @@ import traceback import os import uuid +import copy +from datetime import date, datetime requests.packages.urllib3.disable_warnings() - + + + +def json_serializer(obj): + if isinstance(obj, (date, datetime)): + return obj.isoformat() + return str(obj) + class HttpEventCollector: maxByteLength = 1000000 - def __init__(self,protocol,splunk_address,splunk_port,splunk_hec_token,index,source,sourcetype,host,ssl_verify="false"): + def __init__(self, protocol, splunk_address, splunk_port, splunk_hec_token, + index, source, sourcetype, host, ssl_verify="false"): self.protocol = protocol if (splunk_address=="" or splunk_address==None): @@ -153,7 +169,15 @@ def __init__(self,protocol,splunk_address,splunk_port,splunk_hec_token,index,sou self.host = host self.batch_events = [] self.current_byte_length = 0 - + self.parameters = {} + if self.sourcetype: + parameters["sourcetype"] = self.sourcetype + if self.source: + parameters["source"] = self.source + if self.index: + parameters["index"] = self.index + if self.host: + parameters["host"] = self.host def requests_retry_session(self,retries=3): session = requests.Session() @@ -165,48 +189,40 @@ def requests_retry_session(self,retries=3): @property def server_uri(self): # splunk HEC url used to push data - endpoint="/raw?channel="+str(uuid.uuid1()) - server_uri = '%s://%s:%s/services/collector%s' % (self.protocol, self.splunk_address, self.splunk_port, endpoint) - return (server_uri) - - @property - def parameters(self): - params={} - if not( self.sourcetype == None or self.sourcetype == ""): - params.update({"sourcetype":self.sourcetype}) - if not( self.source == None or self.source == ""): - params.update({"source":self.source}) - if not( self.index == None or self.index == ""): - params.update({"index":self.index}) - if not( self.host == None or self.host == ""): - params.update({"host":self.host}) - return (params) + endpoint=f"?channel={uuid.uuid1()}" + server_uri = f'{self.protocol}://{self.splunk_address}:{self.splunk_port}/services/collector{endpoint}' + return (server_uri) def batch_and_push_event(self,event): # divide the resut payload into batches and push to splunk HEC - payload_string = str(event) - if not payload_string.endswith("\n"): - payload_string=payload_string+"\n" + data = copy.copy(parameters) + data["event"] = event.asDict(True) + payload_string = json.dumps(data, default=json_serializer) payload_length = len(payload_string) if ((self.current_byte_length+payload_length) > self.maxByteLength ): - self.push_event() - self.batch_events = [] - self.current_byte_length = 0 + self.push_events_and_cleanup() self.batch_events.append(payload_string) self.current_byte_length += payload_length - def push_event(self): + def push_events_and_cleanup(self): + self.push_events() + self.batch_events = [] + self.current_byte_length = 0 + + def push_events(self): + if len(self.batch_events) == 0: + # Nothing to push + return + # Function to push data to splunk - payload = " ".join(self.batch_events) + payload = "\n".join(self.batch_events) headers = {'Authorization':'Splunk '+self.token} - response = self.requests_retry_session().post(self.server_uri, data=payload, headers=headers,params=self.parameters, verify=self.ssl_verify) + response = self.requests_retry_session().post(self.server_uri, data=payload, headers=headers, verify=self.ssl_verify) if not (response.status_code==200 or response.status_code==201) : raise Exception("Response status : {} .Response message : {}".format(str(response.status_code),response.text)) - - # COMMAND ---------- from pyspark.sql.functions import * @@ -214,33 +230,33 @@ def push_event(self): if(advancedQuery): full_query=advancedQuery elif (table and database): - basic_query="select * from "+database+"."+table+" " + basic_query = f"select * from {database}.{table} " if (filterQuery == None or filterQuery == "" ) : - full_query=basic_query + full_query = basic_query else : full_query = basic_query+filterQuery else: dbutils.notebook.exit("Advanced Query or Table name and Database name are required.Please check input values.") try : read_data=spark.sql(full_query) - events_list=read_data.toJSON().collect() + events_list=read_data.collect() + except Exception as e: - print ("Some error occurred while running query. The filter may be incorrect : ".format(e)) + print(f"Some error occurred while running query. The filter may be incorrect : {e}") traceback.print_exc() - exit() + raise ex try : - http_event_collector_instance=HttpEventCollector(protocol,splunkAddress,splunkPort,splunkHecToken,index,source,sourcetype,host,ssl_verify=sslVerify) + http_event_collector_instance=HttpEventCollector(protocol,splunkAddress,splunkPort, + splunkHecToken,index,source,sourcetype, + host,ssl_verify=sslVerify) for each in events_list: http_event_collector_instance.batch_and_push_event(each) - if(len(http_event_collector_instance.batch_events)>0): - http_event_collector_instance.push_event() - http_event_collector_instance.batch_events = [] - http_event_collector_instance.current_byte_length = 0 - + + http_event_collector_instance.push_events_and_cleanup() except Exception as ex: print ("Some error occurred.") traceback.print_exc() - exit() + raise ex