Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code improvements for example notebooks #47

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions notebooks/source/cloudtrail_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
checkpointPath=dbutils.widgets.get("Checkpoint Path")
tableName=dbutils.widgets.get("Table Name")
regionName=dbutils.widgets.get("Region Name")
if ((cloudTrailLogsPath==None or cloudTrailLogsPath=="")or(deltaOutputPath==None or deltaOutputPath=="")or(checkpointPath==None or checkpointPath=="")or(tableName==None or tableName=="")or(regionName==None or regionName=="")):

if ((cloudTrailLogsPath==None or cloudTrailLogsPath=="") or (deltaOutputPath==None or deltaOutputPath=="") or (checkpointPath==None or checkpointPath=="") or (tableName==None or tableName=="") or (regionName==None or regionName=="")):
dbutils.notebook.exit("All parameters are mandatory. Ensure correct values of all parameters are specified.")

# COMMAND ----------
Expand Down Expand Up @@ -170,5 +171,5 @@

# COMMAND ----------

create_table_query="CREATE TABLE IF NOT EXISTS "+tableName+" USING DELTA LOCATION '"+ deltaOutputPath +"'"
create_table_query=f"CREATE TABLE IF NOT EXISTS {tableName} USING DELTA LOCATION '{deltaOutputPath}'"
spark.sql(create_table_query)
4 changes: 2 additions & 2 deletions notebooks/source/cloudtrail_insights_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
checkpointPath=dbutils.widgets.get("Checkpoint Path")
tableName=dbutils.widgets.get("Table Name")
regionName=dbutils.widgets.get("Region Name")
if ((cloudTrailInsightsPath==None or cloudTrailInsightsPath=="")or(deltaOutputPath==None or deltaOutputPath=="")or(checkpointPath==None or checkpointPath=="")or(tableName==None or tableName=="")or(regionName==None or regionName=="")):
if ((cloudTrailInsightsPath==None or cloudTrailInsightsPath=="") or (deltaOutputPath==None or deltaOutputPath=="") or (checkpointPath==None or checkpointPath=="") or (tableName==None or tableName=="") or (regionName==None or regionName=="")):
dbutils.notebook.exit("All parameters are mandatory. Ensure correct values of all parameters are specified.")

# COMMAND ----------
Expand Down Expand Up @@ -144,5 +144,5 @@

# COMMAND ----------

create_table_query="CREATE TABLE IF NOT EXISTS "+tableName+" USING DELTA LOCATION '"+ deltaOutputPath +"'"
create_table_query=f"CREATE TABLE IF NOT EXISTS {tableName} USING DELTA LOCATION '{deltaOutputPath}'"
spark.sql(create_table_query)
10 changes: 5 additions & 5 deletions notebooks/source/pull_from_splunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,25 @@ def __init__(self,splunk_address,splunk_port,splunk_username,splunk_namespace,ss
self.splunk_namespace = splunk_namespace
@property
def auth_url(self):
auth_url = "https://{}:{}/services/auth/login".format(self.splunk_address,self.splunk_port)
auth_url = f"https://{self.splunk_address}:{self.splunk_port}/services/auth/login"
return (auth_url)

@property
def mgmt_segment(self):
mgmt_segment_part = "https://{}:{}/servicesNS/{}/{}/".format(self.splunk_address,self.splunk_port,self.splunk_username,self.splunk_namespace)
mgmt_segment_part = f"https://{self.splunk_address}:{self.splunk_port}/servicesNS/{self.splunk_username}/{self.splunk_namespace}/"
return (mgmt_segment_part)

def connect(self,splunk_password):
try:
response = requests.post(
self.auth_url,
data={"username":self.splunk_username,
"password":splunk_password},verify=self.ssl_verify)
data={"username": self.splunk_username, "password": splunk_password},
verify=self.ssl_verify)
session = XML(response.text).findtext("./sessionKey")
if (session=="" or session==None):
dbutils.notebook.exit("Issue in Authentication : Type - "+XML(response.text).find("./messages/msg").attrib["type"]+"\n Message - "+XML(response.text).findtext("./messages/msg"))
else :
self.token = "Splunk {}".format(session)
self.token = f"Splunk {session}"
except HTTPError as e:
if e.status == 401:
raise AuthenticationError("Login failed.", e)
Expand Down
100 changes: 58 additions & 42 deletions notebooks/source/push_to_splunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Databricks notebook source
# MAGIC %md
# MAGIC
# MAGIC ## Introduction
# MAGIC
# MAGIC This notebook demonstrates how to push data to Splunk using Splun's HTTP Event Collector. ***Please note that it's intended only for sending a limited number of events (dozens/hundreds) to Splunk, not to send big amounts of data, as it collects all data to the driver node!***
# MAGIC
# MAGIC ## Input parameters from the user
# MAGIC
# MAGIC This gives a brief description of each parameter.
# MAGIC <br>Before using the notebook, please go through the user documentation of this notebook to use the notebook effectively.
# MAGIC 1. **Protocol** ***(Mandatory Field)*** : The protocol on which Splunk HTTP Event Collector(HEC) runs. Splunk HEC runs on `https` if Enable SSL checkbox is selected while configuring Splunk HEC Token in Splunk, else it runs on `http` protocol. If you do not have access to the Splunk HEC Configuration page, you can ask your Splunk Admin if the `Enable SSL checkbox` is selected or not.
Expand Down Expand Up @@ -119,13 +125,23 @@
import traceback
import os
import uuid
import copy
from datetime import date, datetime
requests.packages.urllib3.disable_warnings()




def json_serializer(obj):
if isinstance(obj, (date, datetime)):
return obj.isoformat()
return str(obj)

class HttpEventCollector:

maxByteLength = 1000000

def __init__(self,protocol,splunk_address,splunk_port,splunk_hec_token,index,source,sourcetype,host,ssl_verify="false"):
def __init__(self, protocol, splunk_address, splunk_port, splunk_hec_token,
index, source, sourcetype, host, ssl_verify="false"):

self.protocol = protocol
if (splunk_address=="" or splunk_address==None):
Expand Down Expand Up @@ -153,7 +169,15 @@ def __init__(self,protocol,splunk_address,splunk_port,splunk_hec_token,index,sou
self.host = host
self.batch_events = []
self.current_byte_length = 0

self.parameters = {}
if self.sourcetype:
parameters["sourcetype"] = self.sourcetype
if self.source:
parameters["source"] = self.source
if self.index:
parameters["index"] = self.index
if self.host:
parameters["host"] = self.host

def requests_retry_session(self,retries=3):
session = requests.Session()
Expand All @@ -165,82 +189,74 @@ def requests_retry_session(self,retries=3):
@property
def server_uri(self):
# splunk HEC url used to push data
endpoint="/raw?channel="+str(uuid.uuid1())
server_uri = '%s://%s:%s/services/collector%s' % (self.protocol, self.splunk_address, self.splunk_port, endpoint)
return (server_uri)

@property
def parameters(self):
params={}
if not( self.sourcetype == None or self.sourcetype == ""):
params.update({"sourcetype":self.sourcetype})
if not( self.source == None or self.source == ""):
params.update({"source":self.source})
if not( self.index == None or self.index == ""):
params.update({"index":self.index})
if not( self.host == None or self.host == ""):
params.update({"host":self.host})
return (params)
endpoint=f"?channel={uuid.uuid1()}"
server_uri = f'{self.protocol}://{self.splunk_address}:{self.splunk_port}/services/collector{endpoint}'
return (server_uri)

def batch_and_push_event(self,event):
# divide the resut payload into batches and push to splunk HEC
payload_string = str(event)
if not payload_string.endswith("\n"):
payload_string=payload_string+"\n"
data = copy.copy(parameters)
data["event"] = event.asDict(True)
payload_string = json.dumps(data, default=json_serializer)
payload_length = len(payload_string)

if ((self.current_byte_length+payload_length) > self.maxByteLength ):
self.push_event()
self.batch_events = []
self.current_byte_length = 0
self.push_events_and_cleanup()

self.batch_events.append(payload_string)
self.current_byte_length += payload_length

def push_event(self):
def push_events_and_cleanup(self):
self.push_events()
self.batch_events = []
self.current_byte_length = 0

def push_events(self):
if len(self.batch_events) == 0:
# Nothing to push
return

# Function to push data to splunk
payload = " ".join(self.batch_events)
payload = "\n".join(self.batch_events)
headers = {'Authorization':'Splunk '+self.token}
response = self.requests_retry_session().post(self.server_uri, data=payload, headers=headers,params=self.parameters, verify=self.ssl_verify)
response = self.requests_retry_session().post(self.server_uri, data=payload, headers=headers, verify=self.ssl_verify)
if not (response.status_code==200 or response.status_code==201) :
raise Exception("Response status : {} .Response message : {}".format(str(response.status_code),response.text))



# COMMAND ----------

from pyspark.sql.functions import *

if(advancedQuery):
full_query=advancedQuery
elif (table and database):
basic_query="select * from "+database+"."+table+" "
basic_query = f"select * from {database}.{table} "
if (filterQuery == None or filterQuery == "" ) :
full_query=basic_query
full_query = basic_query
else :
full_query = basic_query+filterQuery
else:
dbutils.notebook.exit("Advanced Query or Table name and Database name are required.Please check input values.")
try :
read_data=spark.sql(full_query)
events_list=read_data.toJSON().collect()
events_list=read_data.collect()

except Exception as e:
print ("Some error occurred while running query. The filter may be incorrect : ".format(e))
print(f"Some error occurred while running query. The filter may be incorrect : {e}")
traceback.print_exc()
exit()
raise ex

try :
http_event_collector_instance=HttpEventCollector(protocol,splunkAddress,splunkPort,splunkHecToken,index,source,sourcetype,host,ssl_verify=sslVerify)
http_event_collector_instance=HttpEventCollector(protocol,splunkAddress,splunkPort,
splunkHecToken,index,source,sourcetype,
host,ssl_verify=sslVerify)

for each in events_list:
http_event_collector_instance.batch_and_push_event(each)
if(len(http_event_collector_instance.batch_events)>0):
http_event_collector_instance.push_event()
http_event_collector_instance.batch_events = []
http_event_collector_instance.current_byte_length = 0


http_event_collector_instance.push_events_and_cleanup()

except Exception as ex:
print ("Some error occurred.")
traceback.print_exc()
exit()
raise ex
Loading