Skip to content

Commit

Permalink
Code improvements for example notebooks
Browse files Browse the repository at this point in the history
For Splunk ingestion don't use the raw endpoint, but normal JSON endpoint.
  • Loading branch information
alexott committed Nov 30, 2023
1 parent da0da80 commit 30d8c51
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 51 deletions.
5 changes: 3 additions & 2 deletions notebooks/source/cloudtrail_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
checkpointPath=dbutils.widgets.get("Checkpoint Path")
tableName=dbutils.widgets.get("Table Name")
regionName=dbutils.widgets.get("Region Name")
if ((cloudTrailLogsPath==None or cloudTrailLogsPath=="")or(deltaOutputPath==None or deltaOutputPath=="")or(checkpointPath==None or checkpointPath=="")or(tableName==None or tableName=="")or(regionName==None or regionName=="")):

if ((cloudTrailLogsPath==None or cloudTrailLogsPath=="") or (deltaOutputPath==None or deltaOutputPath=="") or (checkpointPath==None or checkpointPath=="") or (tableName==None or tableName=="") or (regionName==None or regionName=="")):
dbutils.notebook.exit("All parameters are mandatory. Ensure correct values of all parameters are specified.")

# COMMAND ----------
Expand Down Expand Up @@ -170,5 +171,5 @@

# COMMAND ----------

create_table_query="CREATE TABLE IF NOT EXISTS "+tableName+" USING DELTA LOCATION '"+ deltaOutputPath +"'"
create_table_query=f"CREATE TABLE IF NOT EXISTS {tableName} USING DELTA LOCATION '{deltaOutputPath}'"
spark.sql(create_table_query)
4 changes: 2 additions & 2 deletions notebooks/source/cloudtrail_insights_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
checkpointPath=dbutils.widgets.get("Checkpoint Path")
tableName=dbutils.widgets.get("Table Name")
regionName=dbutils.widgets.get("Region Name")
if ((cloudTrailInsightsPath==None or cloudTrailInsightsPath=="")or(deltaOutputPath==None or deltaOutputPath=="")or(checkpointPath==None or checkpointPath=="")or(tableName==None or tableName=="")or(regionName==None or regionName=="")):
if ((cloudTrailInsightsPath==None or cloudTrailInsightsPath=="") or (deltaOutputPath==None or deltaOutputPath=="") or (checkpointPath==None or checkpointPath=="") or (tableName==None or tableName=="") or (regionName==None or regionName=="")):
dbutils.notebook.exit("All parameters are mandatory. Ensure correct values of all parameters are specified.")

# COMMAND ----------
Expand Down Expand Up @@ -144,5 +144,5 @@

# COMMAND ----------

create_table_query="CREATE TABLE IF NOT EXISTS "+tableName+" USING DELTA LOCATION '"+ deltaOutputPath +"'"
create_table_query=f"CREATE TABLE IF NOT EXISTS {tableName} USING DELTA LOCATION '{deltaOutputPath}'"
spark.sql(create_table_query)
10 changes: 5 additions & 5 deletions notebooks/source/pull_from_splunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,25 @@ def __init__(self,splunk_address,splunk_port,splunk_username,splunk_namespace,ss
self.splunk_namespace = splunk_namespace
@property
def auth_url(self):
auth_url = "https://{}:{}/services/auth/login".format(self.splunk_address,self.splunk_port)
auth_url = f"https://{self.splunk_address}:{self.splunk_port}/services/auth/login"
return (auth_url)

@property
def mgmt_segment(self):
mgmt_segment_part = "https://{}:{}/servicesNS/{}/{}/".format(self.splunk_address,self.splunk_port,self.splunk_username,self.splunk_namespace)
mgmt_segment_part = f"https://{self.splunk_address}:{self.splunk_port}/servicesNS/{self.splunk_username}/{self.splunk_namespace}/"
return (mgmt_segment_part)

def connect(self,splunk_password):
try:
response = requests.post(
self.auth_url,
data={"username":self.splunk_username,
"password":splunk_password},verify=self.ssl_verify)
data={"username": self.splunk_username, "password": splunk_password},
verify=self.ssl_verify)
session = XML(response.text).findtext("./sessionKey")
if (session=="" or session==None):
dbutils.notebook.exit("Issue in Authentication : Type - "+XML(response.text).find("./messages/msg").attrib["type"]+"\n Message - "+XML(response.text).findtext("./messages/msg"))
else :
self.token = "Splunk {}".format(session)
self.token = f"Splunk {session}"
except HTTPError as e:
if e.status == 401:
raise AuthenticationError("Login failed.", e)
Expand Down
100 changes: 58 additions & 42 deletions notebooks/source/push_to_splunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Databricks notebook source
# MAGIC %md
# MAGIC
# MAGIC ## Introduction
# MAGIC
# MAGIC This notebook demonstrates how to push data to Splunk using Splun's HTTP Event Collector. ***Please note that it's intended only for sending a limited number of events (dozens/hundreds) to Splunk, not to send big amounts of data, as it collects all data to the driver node!***
# MAGIC
# MAGIC ## Input parameters from the user
# MAGIC
# MAGIC This gives a brief description of each parameter.
# MAGIC <br>Before using the notebook, please go through the user documentation of this notebook to use the notebook effectively.
# MAGIC 1. **Protocol** ***(Mandatory Field)*** : The protocol on which Splunk HTTP Event Collector(HEC) runs. Splunk HEC runs on `https` if Enable SSL checkbox is selected while configuring Splunk HEC Token in Splunk, else it runs on `http` protocol. If you do not have access to the Splunk HEC Configuration page, you can ask your Splunk Admin if the `Enable SSL checkbox` is selected or not.
Expand Down Expand Up @@ -119,13 +125,23 @@
import traceback
import os
import uuid
import copy
from datetime import date, datetime
requests.packages.urllib3.disable_warnings()




def json_serializer(obj):
if isinstance(obj, (date, datetime)):
return obj.isoformat()
return str(obj)

class HttpEventCollector:

maxByteLength = 1000000

def __init__(self,protocol,splunk_address,splunk_port,splunk_hec_token,index,source,sourcetype,host,ssl_verify="false"):
def __init__(self, protocol, splunk_address, splunk_port, splunk_hec_token,
index, source, sourcetype, host, ssl_verify="false"):

self.protocol = protocol
if (splunk_address=="" or splunk_address==None):
Expand Down Expand Up @@ -153,7 +169,15 @@ def __init__(self,protocol,splunk_address,splunk_port,splunk_hec_token,index,sou
self.host = host
self.batch_events = []
self.current_byte_length = 0

self.parameters = {}
if self.sourcetype:
parameters["sourcetype"] = self.sourcetype
if self.source:
parameters["source"] = self.source
if self.index:
parameters["index"] = self.index
if self.host:
parameters["host"] = self.host

def requests_retry_session(self,retries=3):
session = requests.Session()
Expand All @@ -165,82 +189,74 @@ def requests_retry_session(self,retries=3):
@property
def server_uri(self):
# splunk HEC url used to push data
endpoint="/raw?channel="+str(uuid.uuid1())
server_uri = '%s://%s:%s/services/collector%s' % (self.protocol, self.splunk_address, self.splunk_port, endpoint)
return (server_uri)

@property
def parameters(self):
params={}
if not( self.sourcetype == None or self.sourcetype == ""):
params.update({"sourcetype":self.sourcetype})
if not( self.source == None or self.source == ""):
params.update({"source":self.source})
if not( self.index == None or self.index == ""):
params.update({"index":self.index})
if not( self.host == None or self.host == ""):
params.update({"host":self.host})
return (params)
endpoint=f"?channel={uuid.uuid1()}"
server_uri = f'{self.protocol}://{self.splunk_address}:{self.splunk_port}/services/collector{endpoint}'
return (server_uri)

def batch_and_push_event(self,event):
# divide the resut payload into batches and push to splunk HEC
payload_string = str(event)
if not payload_string.endswith("\n"):
payload_string=payload_string+"\n"
data = copy.copy(parameters)
data["event"] = event.asDict(True)
payload_string = json.dumps(data, default=json_serializer)
payload_length = len(payload_string)

if ((self.current_byte_length+payload_length) > self.maxByteLength ):
self.push_event()
self.batch_events = []
self.current_byte_length = 0
self.push_events_and_cleanup()

self.batch_events.append(payload_string)
self.current_byte_length += payload_length

def push_event(self):
def push_events_and_cleanup(self):
self.push_events()
self.batch_events = []
self.current_byte_length = 0

def push_events(self):
if len(self.batch_events) == 0:
# Nothing to push
return

# Function to push data to splunk
payload = " ".join(self.batch_events)
payload = "\n".join(self.batch_events)
headers = {'Authorization':'Splunk '+self.token}
response = self.requests_retry_session().post(self.server_uri, data=payload, headers=headers,params=self.parameters, verify=self.ssl_verify)
response = self.requests_retry_session().post(self.server_uri, data=payload, headers=headers, verify=self.ssl_verify)
if not (response.status_code==200 or response.status_code==201) :
raise Exception("Response status : {} .Response message : {}".format(str(response.status_code),response.text))



# COMMAND ----------

from pyspark.sql.functions import *

if(advancedQuery):
full_query=advancedQuery
elif (table and database):
basic_query="select * from "+database+"."+table+" "
basic_query = f"select * from {database}.{table} "
if (filterQuery == None or filterQuery == "" ) :
full_query=basic_query
full_query = basic_query
else :
full_query = basic_query+filterQuery
else:
dbutils.notebook.exit("Advanced Query or Table name and Database name are required.Please check input values.")
try :
read_data=spark.sql(full_query)
events_list=read_data.toJSON().collect()
events_list=read_data.collect()

except Exception as e:
print ("Some error occurred while running query. The filter may be incorrect : ".format(e))
print(f"Some error occurred while running query. The filter may be incorrect : {e}")
traceback.print_exc()
exit()
raise ex

try :
http_event_collector_instance=HttpEventCollector(protocol,splunkAddress,splunkPort,splunkHecToken,index,source,sourcetype,host,ssl_verify=sslVerify)
http_event_collector_instance=HttpEventCollector(protocol,splunkAddress,splunkPort,
splunkHecToken,index,source,sourcetype,
host,ssl_verify=sslVerify)

for each in events_list:
http_event_collector_instance.batch_and_push_event(each)
if(len(http_event_collector_instance.batch_events)>0):
http_event_collector_instance.push_event()
http_event_collector_instance.batch_events = []
http_event_collector_instance.current_byte_length = 0


http_event_collector_instance.push_events_and_cleanup()

except Exception as ex:
print ("Some error occurred.")
traceback.print_exc()
exit()
raise ex

0 comments on commit 30d8c51

Please sign in to comment.