From eb3eb4cb76702349e83d39b90d6915c63da325cd Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Fri, 11 Oct 2024 03:58:09 +0100 Subject: [PATCH 01/35] refactor provider form component --- .../app/providers/provider-form-scopes.tsx | 29 +- keep-ui/app/providers/provider-form.tsx | 1040 ++++++++++------- keep-ui/app/providers/providers-tiles.tsx | 49 +- keep-ui/app/providers/providers.tsx | 13 +- keep-ui/app/workflows/workflow-tile.tsx | 33 - 5 files changed, 642 insertions(+), 522 deletions(-) diff --git a/keep-ui/app/providers/provider-form-scopes.tsx b/keep-ui/app/providers/provider-form-scopes.tsx index 6820872b4..d5b60130e 100644 --- a/keep-ui/app/providers/provider-form-scopes.tsx +++ b/keep-ui/app/providers/provider-form-scopes.tsx @@ -22,29 +22,24 @@ import "./provider-form-scopes.css"; const ProviderFormScopes = ({ provider, validatedScopes, - installedProvidersMode = false, refreshLoading, - triggerRevalidateScope, + onRevalidate, }: { provider: Provider; validatedScopes: { [key: string]: string | boolean }; - installedProvidersMode?: boolean; refreshLoading: boolean; - triggerRevalidateScope: any; + onRevalidate: () => void; }) => { return ( Scopes - {installedProvidersMode && ( + {provider.installed && ( - - handleDictInputChange(configKey, value)} - error={Object.keys(inputErrors).includes(configKey)} - disabled={provider.provisioned} - /> - - ); - case "file": - return ( - <> - {renderFieldHeader()} - - { - if (e.target.files && e.target.files[0]) { - setSelectedFile(e.target.files[0].name); - } - handleInputChange(e); - }} - disabled={provider.provisioned} - /> - - ); - default: - return ( - <> - {renderFieldHeader()} - - - ); - } - }; - - - - const requiredConfigs = Object.entries(provider.config) - .filter(([_, config]) => config.required && !config.config_main_group) - .reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {}); - - const optionalConfigs = Object.entries(provider.config) - .filter(([_, config]) => !config.required && !config.hidden && !config.config_main_group) - .reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {}); - - const groupConfigsByMainGroup = (configs) => { - return Object.entries(configs).reduce((acc, [key, config]) => { - const mainGroup = config.config_main_group; - if (mainGroup) { - if (!acc[mainGroup]) { - acc[mainGroup] = {}; - } - acc[mainGroup][key] = config; - } - return acc; - }, {}); - }; - - const groupConfigsBySubGroup = (configs) => { - return Object.entries(configs).reduce((acc, [key, config]) => { - const subGroup = config.config_sub_group || 'default'; - if (!acc[subGroup]) { - acc[subGroup] = {}; - } - acc[subGroup][key] = config; - return acc; - }, {}); - }; - - const getSubGroups = (configs) => { - return [...new Set(Object.values(configs).map(config => config.config_sub_group))].filter(Boolean); - }; - - const renderGroupFields = (groupName, groupConfigs) => { - const subGroups = groupConfigsBySubGroup(groupConfigs); - const subGroupNames = getSubGroups(groupConfigs); - - if (subGroupNames.length === 0) { - // If no subgroups, render fields directly - return ( - - {groupName.charAt(0).toUpperCase() + groupName.slice(1)} - {Object.entries(groupConfigs).map(([configKey, config]) => ( -
- {renderFormField(configKey, config)} -
- ))} -
- ); - } - - return ( - - {groupName.charAt(0).toUpperCase() + groupName.slice(1)} - setActiveTabsState(prev => ({...prev, [groupName]: subGroupNames[index]}))} - > - - {subGroupNames.map((subGroup) => ( - {subGroup.replace('_', ' ').toUpperCase()} - ))} - - - {subGroupNames.map((subGroup) => ( - - {Object.entries(subGroups[subGroup] || {}).map(([configKey, config]) => ( -
- {renderFormField(configKey, config)} -
- ))} -
- ))} -
-
-
- ); - }; - - const groupedConfigs = groupConfigsByMainGroup(provider.config); - console.log("ProviderForm component loaded"); return (
@@ -721,9 +475,10 @@ const ProviderForm = ({
- {provider.provisioned && + {provider.provisioned && (
- } + )} {provider.provider_description && ( {provider.provider_description} @@ -770,13 +525,12 @@ const ProviderForm = ({ />
)} - {provider.scopes?.length > 0 && ( + {provider.scopes && provider.scopes.length > 0 && ( )}
@@ -784,6 +538,7 @@ const ProviderForm = ({ {provider.oauth2_url && !provider.installed ? ( <> {installedProvidersMode && Object.keys(provider.config).length > 0 && ( <> -
@@ -985,4 +762,407 @@ const ProviderForm = ({ ); }; +function GroupFields({ + groupName, + fields, + data, + errors, + disabled, + onChange, +}: { + groupName: string; + fields: Provider["config"]; + data: ProviderFormData; + errors: InputErrors; + disabled: boolean; + onChange: (key: string, value: ProviderFormValue) => void; +}) { + const subGroups = useMemo(() => getConfigBySubGroup(fields), [fields]); + + if (Object.keys(subGroups).length === 0) { + // If no subgroups, render fields directly + return ( + + {groupName} + {Object.entries(fields).map(([field, config]) => ( +
+ +
+ ))} +
+ ); + } + + return ( + + {groupName} + + + {Object.keys(subGroups).map((name) => ( + + {name} + + ))} + + + {Object.entries(subGroups).map(([name, subGroup]) => ( + + {Object.entries(subGroup).map(([field, config]) => ( +
+ +
+ ))} +
+ ))} +
+
+
+ ); +} + +function FormField({ + id, + config, + value, + error, + disabled, + title, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + error?: string; + disabled: boolean; + title?: string; + onChange: (key: string, value: ProviderFormValue) => void; +}) { + function handleInputChange(event: React.ChangeEvent) { + let value; + const files = event.target.files; + const name = event.target.name; + + // If the input is a file, retrieve the file object, otherwise retrieve the value + if (files && files.length > 0) { + value = files[0]; // Assumes single file upload + } else { + value = event.target.value; + } + + onChange(name, value); + } + + switch (config.type) { + case "select": + return ( + onChange(id, value)} + /> + ); + case "form": + return ( + onChange(id, data)} + onChange={(value) => onChange(id, value)} + /> + ); + case "file": + return ( + + ); + default: + return ( + + ); + } +} + +function TextField({ + id, + config, + value, + error, + disabled, + title, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + error?: string; + disabled: boolean; + title?: string; + onChange: (e: React.ChangeEvent) => void; +}) { + return ( + <> + + + + ); +} + +function SelectField({ + id, + config, + value, + error, + disabled, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + error?: string; + disabled: boolean; + onChange: (value: string) => void; +}) { + return ( + <> + + + + ); +} + +function FileField({ + id, + config, + disabled, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + disabled: boolean; + onChange: (e: React.ChangeEvent) => void; +}) { + const [selected, setSelected] = useState(); + const ref = useRef(null); + + function handleClick(e: React.MouseEvent) { + e.preventDefault(); + if (ref.current) ref.current.click(); + } + return ( + <> + + + { + if (e.target.files && e.target.files[0]) { + setSelected(e.target.files[0].name); + } + onChange(e); + }} + disabled={disabled} + /> + + ); +} + +function KVForm({ + id, + config, + value, + error, + disabled, + onAdd, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + error?: string; + disabled: boolean; + onAdd: (data: KVFormData) => void; + onChange: (value: KVFormData) => void; +}) { + function handleAdd() { + const newData = Array.isArray(value) + ? [...value, { key: "", value: "" }] + : [{ key: "", value: "" }]; + onAdd(newData); + } + + return ( +
+
+ + +
+ {Array.isArray(value) && ( + + )} +
+ ); +} + +const KVInput = ({ + name, + data, + onChange, + error, +}: { + name: string; + data: KVFormData; + onChange: (entries: KVFormData) => void; + error?: string; +}) => { + const handleEntryChange = (index: number, name: string, value: string) => { + const newEntries = data.map((entry, i) => + i === index ? { ...entry, [name]: value } : entry + ); + onChange(newEntries); + }; + + const removeEntry = (index: number) => { + const newEntries = data.filter((_, i) => i !== index); + onChange(newEntries); + }; + + return ( +
+ {data.map((entry, index) => ( +
+ handleEntryChange(index, "key", e.target.value)} + placeholder="Key" + className="mr-2" + /> + handleEntryChange(index, "value", e.target.value)} + placeholder="Value" + className="mr-2" + /> +
+ ))} +
+ ); +}; + +function FieldLabel({ + id, + config, +}: { + id: string; + config: ProviderAuthConfig; +}) { + return ( + + ); +} + export default ProviderForm; diff --git a/keep-ui/app/providers/providers-tiles.tsx b/keep-ui/app/providers/providers-tiles.tsx index 2a6088ecf..d09e9e964 100644 --- a/keep-ui/app/providers/providers-tiles.tsx +++ b/keep-ui/app/providers/providers-tiles.tsx @@ -26,12 +26,9 @@ const ProvidersTiles = ({ }) => { const searchParams = useSearchParams(); const [openPanel, setOpenPanel] = useState(false); - const [panelSize, setPanelSize] = useState(40); const [selectedProvider, setSelectedProvider] = useState( null ); - const [formValues, setFormValues] = useState<{ [key: string]: string }>({}); - const [formErrors, setFormErrors] = useState<{ [key: string]: string }>({}); const providerType = searchParams?.get("provider_type"); const providerName = searchParams?.get("provider_name"); @@ -45,53 +42,21 @@ const ProvidersTiles = ({ if (provider) { setSelectedProvider(provider); - if (providerName) { - setFormValues({ - provider_name: providerName, - }); - } setOpenPanel(true); } } }, [providerType, providerName, providers]); - useEffect(() => { - const pageWidth = window.innerWidth; - - if (pageWidth < 640) { - setPanelSize(100); - } else { - setPanelSize(40); - } - }, [openPanel]); - - const handleFormChange = ( - updatedFormValues: Record, - updatedFormErrors: Record - ) => { - setFormValues(updatedFormValues); - setFormErrors(updatedFormErrors); - }; - const handleConnectProvider = (provider: Provider) => { // on linked providers, don't open the modal if (provider.linked) return; - setSelectedProvider(provider); - if (installedProvidersMode) { - setFormValues({ - provider_name: provider.details.name!, - ...provider.details?.authentication, - }); - } setOpenPanel(true); }; const handleCloseModal = () => { setOpenPanel(false); setSelectedProvider(null); - setFormValues({}); - setFormErrors({}); }; const handleConnecting = (isConnecting: boolean, isConnected: boolean) => { @@ -111,12 +76,17 @@ const ProvidersTiles = ({ }; const sortedProviders = providers - .filter(provider => Object.keys(provider.config || {}).length > 0 || (provider.tags && provider.tags.includes('alert'))) + .filter( + (provider) => + Object.keys(provider.config || {}).length > 0 || + (provider.tags && provider.tags.includes("alert")) + ) .sort( (a, b) => Number(b.can_setup_webhook) - Number(a.can_setup_webhook) || Number(b.supports_webhook) - Number(a.supports_webhook) || - Number(b.oauth2_url ? true : false) - Number(a.oauth2_url ? true : false) + Number(b.oauth2_url ? true : false) - + Number(a.oauth2_url ? true : false) ); return ( @@ -148,16 +118,13 @@ const ProvidersTiles = ({ {selectedProvider && ( ; sensitive?: boolean; hidden?: boolean; type?: string; file_type?: string; + config_main_group?: string; + config_sub_group?: string; } export interface ProviderMethodParam { @@ -51,7 +54,13 @@ interface AlertDistritbuionData { number: number; } -export type TProviderLabels = 'alert' | 'topology' | 'messaging' | 'ticketing' | 'data' | 'queue'; +export type TProviderLabels = + | "alert" + | "topology" + | "messaging" + | "ticketing" + | "data" + | "queue"; export interface Provider { // key value pair of auth method name and auth method config diff --git a/keep-ui/app/workflows/workflow-tile.tsx b/keep-ui/app/workflows/workflow-tile.tsx index 7e3da37ba..b29ba7de8 100644 --- a/keep-ui/app/workflows/workflow-tile.tsx +++ b/keep-ui/app/workflows/workflow-tile.tsx @@ -272,8 +272,6 @@ function WorkflowTile({ workflow }: { workflow: Workflow }) { const [selectedProvider, setSelectedProvider] = useState( null ); - const [formValues, setFormValues] = useState<{ [key: string]: string }>({}); - const [formErrors, setFormErrors] = useState<{ [key: string]: string }>({}); const [openTriggerModal, setOpenTriggerModal] = useState(false); const alertSource = workflow?.triggers @@ -293,23 +291,12 @@ function WorkflowTile({ workflow }: { workflow: Workflow }) { const handleConnectProvider = (provider: FullProvider) => { setSelectedProvider(provider); // prepopulate it with the name - setFormValues({ provider_name: provider.details.name || "" }); setOpenPanel(true); }; const handleCloseModal = () => { setOpenPanel(false); setSelectedProvider(null); - setFormValues({}); - setFormErrors({}); - }; - // Function to handle form change - const handleFormChange = ( - updatedFormValues: Record, - updatedFormErrors: Record - ) => { - setFormValues(updatedFormValues); - setFormErrors(updatedFormErrors); }; const handleDeleteClick = async () => { @@ -672,9 +659,6 @@ function WorkflowTile({ workflow }: { workflow: Workflow }) { {selectedProvider && ( ( null ); - const [formValues, setFormValues] = useState<{ [key: string]: string }>({}); - const [formErrors, setFormErrors] = useState<{ [key: string]: string }>({}); const { providers } = useFetchProviders(); const { @@ -741,24 +723,12 @@ export function WorkflowTileOld({ workflow }: { workflow: Workflow }) { const handleConnectProvider = (provider: FullProvider) => { setSelectedProvider(provider); - // prepopulate it with the name - setFormValues({ provider_name: provider.details.name || "" }); setOpenPanel(true); }; const handleCloseModal = () => { setOpenPanel(false); setSelectedProvider(null); - setFormValues({}); - setFormErrors({}); - }; - // Function to handle form change - const handleFormChange = ( - updatedFormValues: Record, - updatedFormErrors: Record - ) => { - setFormValues(updatedFormValues); - setFormErrors(updatedFormErrors); }; const handleDeleteClick = async () => { @@ -1011,9 +981,6 @@ export function WorkflowTileOld({ workflow }: { workflow: Workflow }) { {selectedProvider && ( Date: Tue, 15 Oct 2024 12:06:03 +0100 Subject: [PATCH 02/35] add backend provider input validations --- .../appdynamics_provider.py | 3 ++- keep/providers/base/base_provider.py | 20 +++++++++++++- .../datadog_provider/datadog_provider.py | 21 +++++++-------- .../grafana_provider/grafana_provider.py | 8 +++--- .../kibana_provider/kibana_provider.py | 27 ++++++++++++++----- .../newrelic_provider/newrelic_provider.py | 16 +++-------- keep/validation/__init__.py | 0 keep/validation/fields.py | 10 +++++++ 8 files changed, 69 insertions(+), 36 deletions(-) create mode 100644 keep/validation/__init__.py create mode 100644 keep/validation/fields.py diff --git a/keep/providers/appdynamics_provider/appdynamics_provider.py b/keep/providers/appdynamics_provider/appdynamics_provider.py index d0f5e8867..bacee4e40 100644 --- a/keep/providers/appdynamics_provider/appdynamics_provider.py +++ b/keep/providers/appdynamics_provider/appdynamics_provider.py @@ -58,11 +58,12 @@ class AppdynamicsProviderAuthConfig: "hint": "the app instance in which the webhook should be installed", }, ) - host: str = dataclasses.field( + host: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "AppDynamics host", "hint": "e.g. https://baseball202404101029219.saas.appdynamics.com", + "validation": "any_http_url" }, ) diff --git a/keep/providers/base/base_provider.py b/keep/providers/base/base_provider.py index 758c848de..2c9560b12 100644 --- a/keep/providers/base/base_provider.py +++ b/keep/providers/base/base_provider.py @@ -19,12 +19,13 @@ import requests from keep.api.bl.enrichments_bl import EnrichmentsBl -from keep.api.core.db import get_custom_deduplication_rule, get_enrichments +from keep.api.core.db import get_custom_deduplication_rule, get_enrichments, get_provider_by_name from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.api.models.db.alert import AlertActionType from keep.api.models.db.topology import TopologyServiceInDto from keep.api.utils.enrichment_helpers import parse_and_enrich_deleted_and_assignees from keep.contextmanager.contextmanager import ContextManager +from keep.parser.parser import Parser from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.models.provider_method import ProviderMethod @@ -683,6 +684,23 @@ def simulate_alert(cls) -> dict: return simulated_alert + @property + def is_installed(self) -> bool: + """ + Check if provider has been recorded in the database. + """ + provider = get_provider_by_name(self.context_manager.tenant_id, self.config.name) + return provider is not None + + @property + def is_provisioned(self) -> bool: + """ + Check if provider exist in env provisioning. + """ + parser = Parser() + parser._parse_providers_from_env(self.context_manager) + return self.config.name in self.context_manager.providers_context + class BaseTopologyProvider(BaseProvider): def pull_topology(self) -> list[TopologyServiceInDto]: diff --git a/keep/providers/datadog_provider/datadog_provider.py b/keep/providers/datadog_provider/datadog_provider.py index e19881f5d..03d91e3ab 100644 --- a/keep/providers/datadog_provider/datadog_provider.py +++ b/keep/providers/datadog_provider/datadog_provider.py @@ -13,30 +13,28 @@ import requests from datadog_api_client import ApiClient, Configuration from datadog_api_client.api_client import Endpoint -from datadog_api_client.exceptions import ( - ApiException, - ForbiddenException, - NotFoundException, -) +from datadog_api_client.exceptions import (ApiException, ForbiddenException, + NotFoundException) from datadog_api_client.v1.api.events_api import EventsApi from datadog_api_client.v1.api.logs_api import LogsApi from datadog_api_client.v1.api.metrics_api import MetricsApi from datadog_api_client.v1.api.monitors_api import MonitorsApi -from datadog_api_client.v1.api.webhooks_integration_api import WebhooksIntegrationApi +from datadog_api_client.v1.api.webhooks_integration_api import \ + WebhooksIntegrationApi from datadog_api_client.v1.model.monitor import Monitor from datadog_api_client.v1.model.monitor_options import MonitorOptions from datadog_api_client.v1.model.monitor_thresholds import MonitorThresholds from datadog_api_client.v1.model.monitor_type import MonitorType -from datadog_api_client.v2.api.service_definition_api import ServiceDefinitionApi +from datadog_api_client.v2.api.service_definition_api import \ + ServiceDefinitionApi from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.api.models.db.topology import TopologyServiceInDto from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseTopologyProvider from keep.providers.base.provider_exceptions import GetAlertException -from keep.providers.datadog_provider.datadog_alert_format_description import ( - DatadogAlertFormatDescription, -) +from keep.providers.datadog_provider.datadog_alert_format_description import \ + DatadogAlertFormatDescription from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.models.provider_method import ProviderMethod from keep.providers.providers_factory import ProvidersFactory @@ -70,12 +68,13 @@ class DatadogProviderAuthConfig: }, default="", ) - domain: str = dataclasses.field( + domain: pydantic.HttpUrl = dataclasses.field( metadata={ "required": False, "description": "Datadog API domain", "sensitive": False, "hint": "https://api.datadoghq.com", + "validation": "http_url" }, default="https://api.datadoghq.com", ) diff --git a/keep/providers/grafana_provider/grafana_provider.py b/keep/providers/grafana_provider/grafana_provider.py index b3b9cf1f6..607f60a0d 100644 --- a/keep/providers/grafana_provider/grafana_provider.py +++ b/keep/providers/grafana_provider/grafana_provider.py @@ -15,9 +15,8 @@ from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.base.provider_exceptions import GetAlertException -from keep.providers.grafana_provider.grafana_alert_format_description import ( - GrafanaAlertFormatDescription, -) +from keep.providers.grafana_provider.grafana_alert_format_description import \ + GrafanaAlertFormatDescription from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.providers_factory import ProvidersFactory @@ -36,11 +35,12 @@ class GrafanaProviderAuthConfig: "sensitive": True, }, ) - host: str = dataclasses.field( + host: pydantic.HttpUrl = dataclasses.field( metadata={ "required": True, "description": "Grafana host", "hint": "e.g. https://keephq.grafana.net", + "validation": "http_url" }, ) diff --git a/keep/providers/kibana_provider/kibana_provider.py b/keep/providers/kibana_provider/kibana_provider.py index 98a7273e8..425bd0bb6 100644 --- a/keep/providers/kibana_provider/kibana_provider.py +++ b/keep/providers/kibana_provider/kibana_provider.py @@ -12,6 +12,7 @@ import pydantic import requests from fastapi import HTTPException +from pydantic import AnyHttpUrl, conint from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager @@ -31,15 +32,21 @@ class KibanaProviderAuthConfig: "sensitive": True, } ) - kibana_host: str = dataclasses.field( + kibana_host: AnyHttpUrl = dataclasses.field( metadata={ "required": True, - "description": "Kibana Host (e.g. keep.kb.us-central1.gcp.cloud.es.io)", + "description": "Kibana Host", + "hint": "https://keep.kb.us-central1.gcp.cloud.es.io", + "validation": "any_http_url" } ) - kibana_port: str = dataclasses.field( - metadata={"required": False, "description": "Kibana Port (defaults to 9243)"}, - default="9243", + kibana_port: conint(ge=1, le=65_535) = dataclasses.field( + metadata={ + "required": False, + "description": "Kibana Port (defaults to 9243)", + "validation": "port" + }, + default=9243, ) @@ -212,7 +219,7 @@ def request( headers["Authorization"] = f"ApiKey {self.authentication_config.api_key}" headers["kbn-xsrf"] = "reporting" response: requests.Response = getattr(requests, method.lower())( - f"https://{self.authentication_config.kibana_host}:{self.authentication_config.kibana_port}/{uri}", + f"{self.authentication_config.kibana_host}:{self.authentication_config.kibana_port}/{uri}", headers=headers, **kwargs, ) @@ -434,6 +441,14 @@ def setup_webhook( self.logger.info("Done setting up webhooks") def validate_config(self): + # In order not to prepend the url scheme while making a request, + # we added proper url validation but, we also have to handle previously + # installed and provisioned providers to avoid validation errors. + if self.is_installed or self.is_provisioned: + host = self.config.authentication['kibana_host'] + host = "https://" + host if not (host.starts_with("http://") or host.starts_with("https://")) else host + self.config.authentication['kibana_host'] = host + self.authentication_config = KibanaProviderAuthConfig( **self.config.authentication ) diff --git a/keep/providers/newrelic_provider/newrelic_provider.py b/keep/providers/newrelic_provider/newrelic_provider.py index 49ecaa52b..e05e40855 100644 --- a/keep/providers/newrelic_provider/newrelic_provider.py +++ b/keep/providers/newrelic_provider/newrelic_provider.py @@ -16,6 +16,7 @@ from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass @@ -37,10 +38,11 @@ class NewrelicProviderAuthConfig: account_id: str = dataclasses.field( metadata={"required": True, "description": "New Relic account ID"} ) - new_relic_api_url: str = dataclasses.field( + new_relic_api_url: HttpsUrl = dataclasses.field( metadata={ "required": False, "description": "New Relic API URL", + "validation": "https_url" }, default="https://api.newrelic.com", ) @@ -120,20 +122,8 @@ def dispose(self): def validate_config(self): """ Validates required configuration for New-Relic provider. - - Raises: - ProviderConfigException: user or account is missing in authentication. - ProviderConfigException: private key - ProviderConfigException: new_relic_api_url must start with https """ self.newrelic_config = NewrelicProviderAuthConfig(**self.config.authentication) - if ( - self.newrelic_config.new_relic_api_url - and not self.newrelic_config.new_relic_api_url.startswith("https") - ): - raise ProviderConfigException( - "New Relic API URL must start with https", self.provider_id - ) def __make_add_webhook_destination_query(self, url: str, name: str) -> dict: query = f"""mutation {{ diff --git a/keep/validation/__init__.py b/keep/validation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/keep/validation/fields.py b/keep/validation/fields.py new file mode 100644 index 000000000..1f476959e --- /dev/null +++ b/keep/validation/fields.py @@ -0,0 +1,10 @@ +from pydantic import HttpUrl + + +class HttpsUrl(HttpUrl): + scheme = {'https'} + + @staticmethod + def get_default_parts(parts): + return {'port': '443'} + From 929dea108f327b51be4fc6de009c90c2d9a95378 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Wed, 16 Oct 2024 18:29:41 +0100 Subject: [PATCH 03/35] fix circular import errors --- keep-ui/next.config.js | 7 ++++--- keep/parser/parser.py | 3 +-- keep/providers/base/base_provider.py | 8 +++++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/keep-ui/next.config.js b/keep-ui/next.config.js index bf8da95e9..708a43c6f 100644 --- a/keep-ui/next.config.js +++ b/keep-ui/next.config.js @@ -66,8 +66,9 @@ const nextConfig = { }, }; -const withBundleAnalyzer = require("@next/bundle-analyzer")({ - enabled: process.env.ANALYZE === "true", -}); +const withBundleAnalyzer = + process.env.ANALYZE === "true" + ? require("@next/bundle-analyzer")({ enabled: true }) + : (config) => config; module.exports = withBundleAnalyzer(nextConfig); diff --git a/keep/parser/parser.py b/keep/parser/parser.py index 9e2f2f215..98de0e2fa 100644 --- a/keep/parser/parser.py +++ b/keep/parser/parser.py @@ -9,7 +9,6 @@ from keep.actions.actions_factory import ActionsCRUD from keep.api.core.db import get_workflow_id from keep.contextmanager.contextmanager import ContextManager -from keep.providers.base.base_provider import BaseProvider from keep.providers.providers_factory import ProvidersFactory from keep.step.step import Step, StepType from keep.step.step_provider_parameter import StepProviderParameter @@ -314,7 +313,7 @@ def _inject_env_variables(self, config): def _parse_providers_from_workflow( self, context_manager: ContextManager, workflow: dict - ) -> typing.List[BaseProvider]: + ): context_manager.providers_context.update(workflow.get("providers")) self.logger.debug("Workflow providers parsed successfully") diff --git a/keep/providers/base/base_provider.py b/keep/providers/base/base_provider.py index 2c9560b12..f3abfca41 100644 --- a/keep/providers/base/base_provider.py +++ b/keep/providers/base/base_provider.py @@ -19,13 +19,14 @@ import requests from keep.api.bl.enrichments_bl import EnrichmentsBl -from keep.api.core.db import get_custom_deduplication_rule, get_enrichments, get_provider_by_name +from keep.api.core.db import (get_custom_deduplication_rule, get_enrichments, + get_provider_by_name) from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.api.models.db.alert import AlertActionType from keep.api.models.db.topology import TopologyServiceInDto -from keep.api.utils.enrichment_helpers import parse_and_enrich_deleted_and_assignees +from keep.api.utils.enrichment_helpers import \ + parse_and_enrich_deleted_and_assignees from keep.contextmanager.contextmanager import ContextManager -from keep.parser.parser import Parser from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.models.provider_method import ProviderMethod @@ -697,6 +698,7 @@ def is_provisioned(self) -> bool: """ Check if provider exist in env provisioning. """ + from keep.parser.parser import Parser parser = Parser() parser._parse_providers_from_env(self.context_manager) return self.config.name in self.context_manager.providers_context From 16f925543bd20fe909c8bb2fa2c7899b52d09365 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Fri, 18 Oct 2024 16:25:04 +0100 Subject: [PATCH 04/35] add client-side validation with zod --- keep-ui/app/providers/provider-form.tsx | 223 +++++++++++++----- keep-ui/app/providers/providers.tsx | 2 +- .../bigquery_provider/bigquery_provider.py | 2 +- keep/providers/gke_provider/gke_provider.py | 2 +- keep/providers/ssh_provider/ssh_provider.py | 3 +- 5 files changed, 165 insertions(+), 67 deletions(-) diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index 07739db66..b4d868777 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -51,6 +51,7 @@ import { useSearchParams } from "next/navigation"; import "./provider-form.css"; import { toast } from "react-toastify"; import { useProviders } from "@/utils/hooks/useProviders"; +import { z } from "zod"; type ProviderFormProps = { provider: Provider; @@ -132,11 +133,141 @@ function getConfigGroup(type: "config_main_group" | "config_sub_group") { const getConfigByMainGroup = getConfigGroup("config_main_group"); const getConfigBySubGroup = getConfigGroup("config_sub_group"); +function getInitialFormValues(provider: Provider) { + const initialValues: ProviderFormData = { + provider_id: provider.id, + install_webhook: provider.can_setup_webhook ?? false, + }; + if (provider.installed) + Object.assign(initialValues, { + provider_name: provider.details.name, + ...provider.details.authentication, + }); + + // Set default values for select inputs + Object.entries(provider.config).forEach(([field, config]) => { + if (config.type === "select" && config.default && !initialValues[field]) { + initialValues[field] = config.default; + } + }); + + return initialValues; +} + +function getZodSchema(fields: Provider["config"]) { + const required_error = "This field is required"; + const emptyStringToNull = z + .string() + .transform((val) => (val.length === 0 ? null : val)); + const kvPairs = Object.entries(fields).map(([field, config]) => { + if (config.type === "form") { + const baseFormSchema = z.record(z.string(), z.string()).array(); + const formSchema = config.required + ? baseFormSchema.nonempty({ + message: "At least one key-value entry should be provided.", + }) + : baseFormSchema.optional(); + return [field, formSchema]; + } + + if (config.type === "file") { + const baseFileSchema = z + .instanceof(File, { message: "Please upload a file here." }) + .refine( + (file) => { + if (config.file_type == undefined) return true; + if (config.file_type.length <= 1) return true; + return config.file_type.includes(file.type); + }, + { + message: + config.file_type && config.file_type?.split(",").length > 1 + ? `File type should be one of ${config.file_type}.` + : `File should be of type ${config.file_type}.`, + } + ); + const fileSchema = config.required + ? baseFileSchema + : baseFileSchema.optional(); + return [field, fileSchema]; + } + + const urlSchema = z + .string({ required_error }) + .url({ message: "Please provide a valid url, e.g https://example.com" }); + const urlTldSchema = z.string().regex(new RegExp(/\.[a-z]{2,63}$/), { + message: "Url must contain a valid TLD e.g .com, .io, .dev, .co.uk", + }); + const baseAnyHttpSchema = urlSchema.refine( + (url) => url.startsWith("http://") || url.startsWith("https://"), + { message: "A url with `http` or `https` protocol is reuquired." } + ); + const baseHttpSchema = baseAnyHttpSchema.and(urlTldSchema); + const baseHttpsSchema = urlSchema + .refine((url) => url.startsWith("https://"), { + message: "A url with `https` protocol is required.", + }) + .and(urlTldSchema); + + if (config.validation === "any_http_url") { + const anyHttpSchema = config.required + ? baseAnyHttpSchema + : emptyStringToNull.pipe(baseAnyHttpSchema.nullish()); + return [field, anyHttpSchema]; + } + + if (config.validation === "http_url") { + const httpSchema = config.required + ? baseHttpSchema + : emptyStringToNull.pipe(baseHttpSchema.nullish()); + return [field, httpSchema]; + } + if (config.validation === "https_url") { + const httpsSchema = config.required + ? baseHttpsSchema + : emptyStringToNull.pipe(baseHttpsSchema.nullish()); + return [field, httpsSchema]; + } + if (config.validation === "tld") { + const baseTldSchema = z + .string({ required_error }) + .regex(new RegExp(/\.[a-z]{2,63}$/), { + message: "Please provide a valid TLD e.g .com, .io, .dev, .net", + }); + const tldSchema = config.required + ? baseTldSchema + : baseTldSchema.optional(); + return [field, tldSchema]; + } + if (config.validation === "port") { + const basePortSchema = z + .number({ required_error }) + .min(1, { message: "Invalid port number" }) + .max(65_535, { message: "Invalid port number" }); + const portSchema = config.required + ? basePortSchema + : basePortSchema.optional(); + return [field, portSchema]; + } + return [ + field, + config.required + ? z.string({ required_error }).min(1, { message: required_error }) + : z.string().optional(), + ]; + }); + return z.object({ + provider_name: z + .string({ required_error }) + .min(1, { message: required_error }), + ...Object.fromEntries(kvPairs), + }); +} + const providerNameFieldConfig: ProviderAuthConfig = { required: true, description: "Provider Name", placeholder: "Enter provider name", - validation: "", default: null, }; @@ -151,31 +282,9 @@ const ProviderForm = ({ console.log("Loading the ProviderForm component"); const { mutate } = useProviders(); const searchParams = useSearchParams(); - const [formValues, setFormValues] = useState(() => { - const initialValues: ProviderFormData = { - provider_id: provider.id, - install_webhook: provider.can_setup_webhook ?? false, - }; - if (provider.installed) - Object.assign(initialValues, { - provider_name: provider.details.name, - ...provider.details.authentication, - }); - - // Set default values for select inputs - Object.entries(provider.config).forEach(([configKey, method]) => { - if ( - method.type === "select" && - method.default && - !initialValues[configKey] - ) { - initialValues[configKey] = method.default; - } - }); - - return initialValues; - }); - + const [formValues, setFormValues] = useState(() => + getInitialFormValues(provider) + ); const [formErrors, setFormErrors] = useState(null); const [inputErrors, setInputErrors] = useState({}); // Related to scopes @@ -197,6 +306,7 @@ const ProviderForm = ({ () => getConfigByMainGroup(provider.config), [provider] ); + const zodSchema = useMemo(() => getZodSchema(provider.config), [provider]); const { data: session } = useSession(); const accessToken = session?.accessToken; @@ -267,27 +377,6 @@ const ProviderForm = ({ } } - const validateForm = (updatedFormValues: ProviderFormData) => { - const errors: InputErrors = {}; - for (const [configKey, method] of Object.entries(provider.config)) { - if (!formValues[configKey] && method.required) { - errors[configKey] = "This field is required"; - } - if ( - "validation" in method && - formValues[configKey] - // TODO:add form validation here - ) { - errors[configKey] = ""; - } - if (!formValues.provider_name) { - errors["provider_name"] = "This field is required"; - } - } - setInputErrors(errors); - return errors; - }; - function handleFormChange(key: string, value: ProviderFormValue) { setFormValues((prev) => { const prevValue = prev[key]; @@ -314,21 +403,18 @@ const ProviderForm = ({ })); }; - const validate = () => { - const errors = validateForm(formValues); - if (Object.keys(errors).length === 0) { - return true; - } else { - setFormErrors( - `Missing required fields: ${JSON.stringify( - Object.keys(errors), - null, - 4 - )}` - ); - return false; - } - }; + function validate() { + const validation = zodSchema.safeParse(formValues); + if (validation.success) return true; + const errors: InputErrors = {}; + Object.entries(validation.error.format()).forEach(([field, err]) => { + err && typeof err === "object" && !Array.isArray(err) + ? (errors[field] = err._errors[0]) + : null; + }); + setInputErrors(errors); + return false; + } const submit = ( requestUrl: string, @@ -368,7 +454,9 @@ const ProviderForm = ({ // If the response is not okay, throw the error message return response_json.then((errorData) => { if (response.status === 400) { - throw `${errorData.detail}`; + if ("detail" in errorData) throw `${errorData.detail}`; + if ("message" in errorData) throw `${errorData.messsage}`; + throw `${errorData}`; } if (response.status === 409) { throw `Provider with name ${formValues.provider_name} already exists`; @@ -889,6 +977,7 @@ function FormField({ @@ -986,11 +1075,13 @@ function FileField({ id, config, disabled, + error, onChange, }: { id: string; config: ProviderAuthConfig; disabled: boolean; + error?: string; onChange: (e: React.ChangeEvent) => void; }) { const [selected, setSelected] = useState(); @@ -1028,6 +1119,9 @@ function FileField({ }} disabled={disabled} /> + {error && error?.length > 0 && ( +

{error}

+ )} ); } @@ -1076,6 +1170,9 @@ function KVForm({ {Array.isArray(value) && ( )} + {error && error?.length > 0 && ( +

{error}

+ )}
); } diff --git a/keep-ui/app/providers/providers.tsx b/keep-ui/app/providers/providers.tsx index ff2cf37a4..d4a156f9f 100644 --- a/keep-ui/app/providers/providers.tsx +++ b/keep-ui/app/providers/providers.tsx @@ -9,7 +9,7 @@ export interface ProviderAuthConfig { options?: Array; sensitive?: boolean; hidden?: boolean; - type?: string; + type?: "select" | "form" | "file"; file_type?: string; config_main_group?: string; config_sub_group?: string; diff --git a/keep/providers/bigquery_provider/bigquery_provider.py b/keep/providers/bigquery_provider/bigquery_provider.py index 45e186963..96491bc69 100644 --- a/keep/providers/bigquery_provider/bigquery_provider.py +++ b/keep/providers/bigquery_provider/bigquery_provider.py @@ -26,7 +26,7 @@ class BigqueryProviderAuthConfig: "sensitive": True, "type": "file", "name": "service_account_json", - "file_type": ".json", # this is used to filter the file type in the UI + "file_type": "application/json", }, ) project_id: Optional[str] = dataclasses.field( diff --git a/keep/providers/gke_provider/gke_provider.py b/keep/providers/gke_provider/gke_provider.py index 2b44742e2..8efa2938f 100644 --- a/keep/providers/gke_provider/gke_provider.py +++ b/keep/providers/gke_provider/gke_provider.py @@ -25,7 +25,7 @@ class GkeProviderAuthConfig: "sensitive": True, "type": "file", "name": "service_account_json", - "file_type": ".json", # this is used to filter the file type in the UI + "file_type": "application/json", } ) cluster_name: str = dataclasses.field( diff --git a/keep/providers/ssh_provider/ssh_provider.py b/keep/providers/ssh_provider/ssh_provider.py index e48f9d0ad..2b430a738 100644 --- a/keep/providers/ssh_provider/ssh_provider.py +++ b/keep/providers/ssh_provider/ssh_provider.py @@ -36,7 +36,8 @@ class SshProviderAuthConfig: "sensitive": True, "type": "file", "name": "pkey", - "file_type": "*", + "file_type": "text/plain, application/x-pem-file, application/x-putty-private-key, "+ + "application/x-ed25519-key, application/pkcs8, application/octet-stream", "config_sub_group": "private_key", "config_main_group": "authentication", }, From c6844fc6b23f6c55a8c7e28ddad789485d8fff64 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sun, 20 Oct 2024 23:54:52 +0100 Subject: [PATCH 05/35] add backend validation for providers --- keep-ui/app/providers/provider-form.tsx | 152 ++++++++---------- keep-ui/app/providers/providers.tsx | 2 +- .../kibana_provider/kibana_provider.py | 8 +- .../openobserve_provider.py | 27 ++-- .../prometheus_provider.py | 2 +- .../sentry_provider/sentry_provider.py | 4 +- .../site24x7_provider/site24x7_provider.py | 7 +- .../splunk_provider/splunk_provider.py | 10 +- .../zabbix_provider/zabbix_provider.py | 3 +- keep/validation/fields.py | 3 +- 10 files changed, 103 insertions(+), 115 deletions(-) diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index b4d868777..08fb13681 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -156,6 +156,7 @@ function getInitialFormValues(provider: Provider) { function getZodSchema(fields: Provider["config"]) { const required_error = "This field is required"; + const portError = "Invalid port number"; const emptyStringToNull = z .string() .transform((val) => (val.length === 0 ? null : val)); @@ -196,7 +197,7 @@ function getZodSchema(fields: Provider["config"]) { .string({ required_error }) .url({ message: "Please provide a valid url, e.g https://example.com" }); const urlTldSchema = z.string().regex(new RegExp(/\.[a-z]{2,63}$/), { - message: "Url must contain a valid TLD e.g .com, .io, .dev, .co.uk", + message: "Url must contain a valid TLD e.g .com, .io, .dev, .net", }); const baseAnyHttpSchema = urlSchema.refine( (url) => url.startsWith("http://") || url.startsWith("https://"), @@ -240,25 +241,29 @@ function getZodSchema(fields: Provider["config"]) { return [field, tldSchema]; } if (config.validation === "port") { - const basePortSchema = z - .number({ required_error }) - .min(1, { message: "Invalid port number" }) - .max(65_535, { message: "Invalid port number" }); + const basePortSchema = z.coerce + .number({ required_error, invalid_type_error: portError }) + .min(1, { message: portError }) + .max(65_535, { message: portError }); const portSchema = config.required ? basePortSchema - : basePortSchema.optional(); + : emptyStringToNull.pipe(basePortSchema.nullish()); return [field, portSchema]; } return [ field, config.required - ? z.string({ required_error }).min(1, { message: required_error }) + ? z + .string({ required_error }) + .trim() + .min(1, { message: required_error }) : z.string().optional(), ]; }); return z.object({ provider_name: z .string({ required_error }) + .trim() .min(1, { message: required_error }), ...Object.fromEntries(kvPairs), }); @@ -416,10 +421,7 @@ const ProviderForm = ({ return false; } - const submit = ( - requestUrl: string, - method: string = "POST" - ): Promise => { + async function submit(requestUrl: string, method: string = "POST") { let headers = { Authorization: `Bearer ${accessToken}`, "Content-Type": "application/json", @@ -447,85 +449,65 @@ const ProviderForm = ({ method: method, headers: headers, body: body, - }) - .then((response) => { - const response_json = response.json(); - if (!response.ok) { - // If the response is not okay, throw the error message - return response_json.then((errorData) => { - if (response.status === 400) { - if ("detail" in errorData) throw `${errorData.detail}`; - if ("message" in errorData) throw `${errorData.messsage}`; - throw `${errorData}`; - } - if (response.status === 409) { - throw `Provider with name ${formValues.provider_name} already exists`; - } - const errorDetail = errorData.detail; - if (response.status === 412) { - setProviderValidatedScopes(errorDetail); - } - throw `${provider.type} scopes are invalid: ${JSON.stringify( - errorDetail, - null, - 4 - )}`; - }); - } - return response_json; - }) - .then((data) => { - setFormErrors(""); - return data; - }); - }; + }); + } - const handleUpdateClick = (e: any) => { + async function handleSubmitError(response: Response) { + const status = response.status; + const data = await response.json(); + const error = + "detail" in data ? data.detail : "message" in data ? data.message : data; + if (status === 400) setFormErrors(error); + if (response.status === 409) + setFormErrors( + `Provider with name ${formValues.provider_name} already exists` + ); + if (response.status === 412) setProviderValidatedScopes(error); + } + + async function handleUpdateClick() { if (provider.webhook_required) callInstallWebhook(); - e.preventDefault(); - if (validate()) { - setIsLoading(true); - submit(`${getApiURL()}/providers/${provider.id}`, "PUT") - .then((data) => { - setIsLoading(false); - mutate(); - }) - .catch((error) => { - const updatedFormErrors = error.toString(); - setFormErrors(updatedFormErrors); - setIsLoading(false); - }); + if (!validate()) return; + setIsLoading(true); + const response = await submit( + `${getApiURL()}/providers/${provider.id}`, + "PUT" + ); + if (response.ok) { + setIsLoading(false); + mutate(); + } else { + handleSubmitError(response); + setIsLoading(false); } - }; + } - const handleConnectClick = async () => { - if (validate()) { - setIsLoading(true); - onConnectChange?.(true, false); - submit(`${getApiURL()}/providers/install`) - .then(async (data) => { - console.log("Connect Result:", data); - setIsLoading(false); - onConnectChange?.(false, true); - if ( - formValues.install_webhook && - provider.can_setup_webhook && - accessToken && - !isLocalhost - ) { - // mutate after webhook installation - await installWebhook(data as Provider, accessToken); - } - mutate(); - }) - .catch((error) => { - const updatedFormErrors = error.toString(); - setFormErrors(updatedFormErrors); - setIsLoading(false); - onConnectChange?.(false, false); - }); + async function handleConnectClick() { + if (!validate()) return; + setIsLoading(true); + onConnectChange?.(true, false); + const response = await submit(`${getApiURL()}/providers/install`); + if (response.ok) { + const data = await response.json(); + console.log("Connect Result:", data); + setIsLoading(false); + onConnectChange?.(false, true); + if ( + formValues.install_webhook && + provider.can_setup_webhook && + accessToken && + !isLocalhost + ) { + // mutate after webhook installation + await installWebhook(data as Provider, accessToken); + } + mutate(); + } else { + handleSubmitError(response); + setIsLoading(false); + onConnectChange?.(false, false); } - }; + } const installOrUpdateWebhookEnabled = provider.scopes ?.filter((scope) => scope.mandatory_for_webhook) diff --git a/keep-ui/app/providers/providers.tsx b/keep-ui/app/providers/providers.tsx index d4a156f9f..3911cfe2a 100644 --- a/keep-ui/app/providers/providers.tsx +++ b/keep-ui/app/providers/providers.tsx @@ -2,7 +2,7 @@ export interface ProviderAuthConfig { description: string; hint?: string; placeholder?: string; - validation: string; // regex + validation?: "any_http_url" | "http_url" | "https_url" | "port" | "tld"; required?: boolean; value?: string; default: string | number | boolean | null; diff --git a/keep/providers/kibana_provider/kibana_provider.py b/keep/providers/kibana_provider/kibana_provider.py index 425bd0bb6..1a290c28f 100644 --- a/keep/providers/kibana_provider/kibana_provider.py +++ b/keep/providers/kibana_provider/kibana_provider.py @@ -12,13 +12,14 @@ import pydantic import requests from fastapi import HTTPException -from pydantic import AnyHttpUrl, conint +from pydantic import AnyHttpUrl from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.providers_factory import ProvidersFactory +from keep.validation.fields import UrlPort @pydantic.dataclasses.dataclass @@ -40,7 +41,7 @@ class KibanaProviderAuthConfig: "validation": "any_http_url" } ) - kibana_port: conint(ge=1, le=65_535) = dataclasses.field( + kibana_port: UrlPort = dataclasses.field( metadata={ "required": False, "description": "Kibana Port (defaults to 9243)", @@ -441,9 +442,6 @@ def setup_webhook( self.logger.info("Done setting up webhooks") def validate_config(self): - # In order not to prepend the url scheme while making a request, - # we added proper url validation but, we also have to handle previously - # installed and provisioned providers to avoid validation errors. if self.is_installed or self.is_provisioned: host = self.config.authentication['kibana_host'] host = "https://" + host if not (host.starts_with("http://") or host.starts_with("https://")) else host diff --git a/keep/providers/openobserve_provider/openobserve_provider.py b/keep/providers/openobserve_provider/openobserve_provider.py index 14a0d72e1..834a4bdc7 100644 --- a/keep/providers/openobserve_provider/openobserve_provider.py +++ b/keep/providers/openobserve_provider/openobserve_provider.py @@ -17,6 +17,7 @@ from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import UrlPort class ResourceAlreadyExists(Exception): @@ -45,19 +46,21 @@ class OpenobserveProviderAuthConfig: "sensitive": True, }, ) - openObserveHost: str = dataclasses.field( + openObserveHost: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, - "description": "OpenObserve host url || default: localhost", - "hint": "Eg. localhost", + "description": "OpenObserve host url", + "hint": "e.g. http://localhost", + "validation": "any_http_url" }, ) - openObservePort: str = dataclasses.field( + openObservePort: UrlPort = dataclasses.field( metadata={ "required": True, - "description": "OpenObserve Host|| default: 5080", - "hint": "Eg. 5080", + "description": "OpenObserve Port", + "hint": "e.g. 5080", + "validation": "port" }, ) organisationID: str = dataclasses.field( @@ -104,17 +107,15 @@ def dispose(self): def validate_config(self): """ Validates required configuration for OpenObserve provider. - """ + if self.is_installed or self.is_provisioned: + host = self.config.authentication['openObserveHost'] + host = "https://" + host if not (host.starts_with("http://") or host.starts_with("https://")) else host + self.config.authentication['openObserveHost'] = host + self.authentication_config = OpenobserveProviderAuthConfig( **self.config.authentication ) - if not self.authentication_config.openObserveHost.startswith( - "https://" - ) and not self.authentication_config.openObserveHost.startswith("http://"): - self.authentication_config.openObserveHost = ( - f"https://{self.authentication_config.openObserveHost}" - ) def __get_url(self, paths: List[str] = [], query_params: dict = None, **kwargs): """ diff --git a/keep/providers/prometheus_provider/prometheus_provider.py b/keep/providers/prometheus_provider/prometheus_provider.py index 19d57f9df..8841c111b 100644 --- a/keep/providers/prometheus_provider/prometheus_provider.py +++ b/keep/providers/prometheus_provider/prometheus_provider.py @@ -18,7 +18,7 @@ @pydantic.dataclasses.dataclass class PrometheusProviderAuthConfig: - url: str = dataclasses.field( + url: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Prometheus server URL", diff --git a/keep/providers/sentry_provider/sentry_provider.py b/keep/providers/sentry_provider/sentry_provider.py index 4f82e90ca..927c919c3 100644 --- a/keep/providers/sentry_provider/sentry_provider.py +++ b/keep/providers/sentry_provider/sentry_provider.py @@ -15,6 +15,7 @@ from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.providers_factory import ProvidersFactory +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass @@ -32,12 +33,13 @@ class SentryProviderAuthConfig: organization_slug: str = dataclasses.field( metadata={"required": True, "description": "Sentry organization slug"} ) - api_url: str = dataclasses.field( + api_url: HttpsUrl = dataclasses.field( metadata={ "required": False, "description": "Sentry API URL", "hint": "https://sentry.io/api/0 (see https://docs.sentry.io/api/)", "sensitive": False, + "validation": "https_url" }, default="https://sentry.io/api/0", ) diff --git a/keep/providers/site24x7_provider/site24x7_provider.py b/keep/providers/site24x7_provider/site24x7_provider.py index c474155e5..b28bc30b8 100644 --- a/keep/providers/site24x7_provider/site24x7_provider.py +++ b/keep/providers/site24x7_provider/site24x7_provider.py @@ -29,7 +29,7 @@ class Site24X7ProviderAuthConfig: zohoRefreshToken: str = dataclasses.field( metadata={ "required": True, - "description": "ZohoRefreshToken", + "description": "Zoho Refresh Token", "hint": "Refresh token for Zoho authentication", "sensitive": True, }, @@ -37,7 +37,7 @@ class Site24X7ProviderAuthConfig: zohoClientId: str = dataclasses.field( metadata={ "required": True, - "description": "ZohoClientId", + "description": "Zoho Client Id", "hint": "Client Secret for Zoho authentication.", "sensitive": True, }, @@ -45,7 +45,7 @@ class Site24X7ProviderAuthConfig: zohoClientSecret: str = dataclasses.field( metadata={ "required": True, - "description": "ZohoClientSecret", + "description": "Zoho Client Secret", "hint": "Password associated with yur account", "sensitive": True, }, @@ -55,6 +55,7 @@ class Site24X7ProviderAuthConfig: "required": True, "description": "Zoho Account's TLD (.com | .eu | .com.cn | .in | .au | .jp)", "hint": "Possible: .com | .eu | .com.cn | .in | .com.au | .jp", + "validation": "tld" }, ) diff --git a/keep/providers/splunk_provider/splunk_provider.py b/keep/providers/splunk_provider/splunk_provider.py index b33e3342b..d3d9d8943 100644 --- a/keep/providers/splunk_provider/splunk_provider.py +++ b/keep/providers/splunk_provider/splunk_provider.py @@ -2,17 +2,18 @@ import datetime import json import logging +from xml.etree.ElementTree import ParseError import pydantic -from splunklib.client import connect from splunklib.binding import AuthenticationError, HTTPError -from xml.etree.ElementTree import ParseError +from splunklib.client import connect from keep.api.models.alert import AlertDto, AlertSeverity from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.providers_factory import ProvidersFactory +from keep.validation.fields import UrlPort @pydantic.dataclasses.dataclass @@ -31,9 +32,10 @@ class SplunkProviderAuthConfig: }, default="localhost", ) - port: int = dataclasses.field( + port: UrlPort = dataclasses.field( metadata={ "description": "Splunk Port (default is 8089)", + "validation": "port" }, default=8089, ) @@ -75,8 +77,8 @@ def __init__( def __debug_fetch_users_response(self): try: - from splunklib.client import PATH_USERS import requests + from splunklib.client import PATH_USERS response = requests.get( f"https://{self.authentication_config.host}:{self.authentication_config.port}/services/{PATH_USERS}", diff --git a/keep/providers/zabbix_provider/zabbix_provider.py b/keep/providers/zabbix_provider/zabbix_provider.py index b7bca0ad3..7b21f6339 100644 --- a/keep/providers/zabbix_provider/zabbix_provider.py +++ b/keep/providers/zabbix_provider/zabbix_provider.py @@ -30,12 +30,13 @@ class ZabbixProviderAuthConfig: Zabbix authentication configuration. """ - zabbix_frontend_url: str = dataclasses.field( + zabbix_frontend_url: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Zabbix Frontend URL", "hint": "https://zabbix.example.com", "sensitive": False, + "validation": "any_http_url" } ) auth_token: str = dataclasses.field( diff --git a/keep/validation/fields.py b/keep/validation/fields.py index 1f476959e..3ac977d84 100644 --- a/keep/validation/fields.py +++ b/keep/validation/fields.py @@ -1,4 +1,4 @@ -from pydantic import HttpUrl +from pydantic import HttpUrl, conint class HttpsUrl(HttpUrl): @@ -8,3 +8,4 @@ class HttpsUrl(HttpUrl): def get_default_parts(parts): return {'port': '443'} +UrlPort = conint(ge=1, le=65_535) From d6eaa40a58b637f30d4623ba1f13f1a39156f890 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Mon, 21 Oct 2024 23:55:05 +0100 Subject: [PATCH 06/35] add backend validation for 7 providers --- .../auth0_provider/auth0_provider.py | 15 +- .../centreon_provider/centreon_provider.py | 397 +++++++++--------- .../clickhouse_provider.py | 45 +- .../discord_provider/discord_provider.py | 4 +- .../elastic_provider/elastic_provider.py | 35 +- .../slack_provider/slack_provider.py | 5 +- .../victoriametrics_provider.py | 4 +- 7 files changed, 269 insertions(+), 236 deletions(-) diff --git a/keep/providers/auth0_provider/auth0_provider.py b/keep/providers/auth0_provider/auth0_provider.py index e61598ad0..32e08ad65 100644 --- a/keep/providers/auth0_provider/auth0_provider.py +++ b/keep/providers/auth0_provider/auth0_provider.py @@ -1,6 +1,7 @@ """ Auth0 provider. """ + import dataclasses import datetime import os @@ -10,6 +11,7 @@ from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig +from keep.validation.fields import HttpsUrl @dataclasses.dataclass @@ -26,12 +28,13 @@ class Auth0ProviderAuthConfig: "hint": "https://manage.auth0.com/dashboard/us/YOUR_ACCOUNT/apis/management/explorer", }, ) - domain: str = dataclasses.field( + domain: HttpsUrl = dataclasses.field( default=None, metadata={ "required": True, "description": "Auth0 Domain", "hint": "tenantname.us.auth0.com", + "validation": "https_url", }, ) @@ -52,10 +55,12 @@ def __init__( def validate_config(self): """ Validates required configuration for Auth0 provider. - """ - if self.config.authentication is None: - self.config.authentication = {} + if self.is_installed or self.is_provisioned: + host = self.config.authentication["domain"] + host = "https://" + host if not (host.starts_with("https://")) else host + self.config.authentication["domain"] = host + self.authentication_config = Auth0ProviderAuthConfig( **self.config.authentication ) @@ -74,7 +79,7 @@ def _query(self, log_type: str, from_: str = None, **kwargs: dict): Returns: _type_: _description_ """ - url = f"https://{self.authentication_config.domain}/api/v2/logs" + url = f"{self.authentication_config.domain}/api/v2/logs" headers = { "content-type": "application/json", "Authorization": f"Bearer {self.authentication_config.token}", diff --git a/keep/providers/centreon_provider/centreon_provider.py b/keep/providers/centreon_provider/centreon_provider.py index 7d2f3e4fe..3c02111d6 100644 --- a/keep/providers/centreon_provider/centreon_provider.py +++ b/keep/providers/centreon_provider/centreon_provider.py @@ -3,217 +3,234 @@ """ import dataclasses +import datetime import pydantic import requests -import datetime -from keep.api.models.alert import AlertDto, AlertStatus, AlertSeverity -from keep.exceptions.provider_exception import ProviderException +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager +from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope + @pydantic.dataclasses.dataclass class CentreonProviderAuthConfig: - """ - CentreonProviderAuthConfig is a class that holds the authentication information for the CentreonProvider. - """ + """ + CentreonProviderAuthConfig is a class that holds the authentication information for the CentreonProvider. + """ + + host_url: pydantic.HttpUrl = dataclasses.field( + metadata={ + "required": True, + "description": "Centreon Host URL", + "sensitive": False, + "validation": "http_url", + }, + default=None, + ) + + api_token: str = dataclasses.field( + metadata={ + "required": True, + "description": "Centreon API Token", + "sensitive": True, + }, + default=None, + ) - host_url: str = dataclasses.field( - metadata={ - "required": True, - "description": "Centreon Host URL", - "sensitive": False, - }, - default=None, - ) - - api_token: str = dataclasses.field( - metadata={ - "required": True, - "description": "Centreon API Token", - "sensitive": True, - }, - default=None, - ) class CentreonProvider(BaseProvider): - PROVIDER_DISPLAY_NAME = "Centreon" - PROVIDER_TAGS = ["alert"] + PROVIDER_DISPLAY_NAME = "Centreon" + PROVIDER_TAGS = ["alert"] - PROVIDER_SCOPES = [ - ProviderScope( - name="authenticated", - description="User is authenticated" - ), - ] + PROVIDER_SCOPES = [ + ProviderScope(name="authenticated", description="User is authenticated"), + ] - """ + """ Centreon only supports the following host state (UP = 0, DOWN = 2, UNREA = 3) https://docs.centreon.com/docs/api/rest-api-v1/#realtime-information """ - STATUS_MAP = { - 2: AlertStatus.FIRING, - 3: AlertStatus.FIRING, - 0: AlertStatus.RESOLVED, - } - - SEVERITY_MAP = { - "CRITICAL": AlertSeverity.CRITICAL, - "WARNING": AlertSeverity.WARNING, - "UNKNOWN": AlertSeverity.INFO, - "OK": AlertSeverity.LOW, - "PENDING": AlertSeverity.INFO, - } - - def __init__( - self, context_manager: ContextManager, provider_id: str,config: ProviderConfig + STATUS_MAP = { + 2: AlertStatus.FIRING, + 3: AlertStatus.FIRING, + 0: AlertStatus.RESOLVED, + } + + SEVERITY_MAP = { + "CRITICAL": AlertSeverity.CRITICAL, + "WARNING": AlertSeverity.WARNING, + "UNKNOWN": AlertSeverity.INFO, + "OK": AlertSeverity.LOW, + "PENDING": AlertSeverity.INFO, + } + + def __init__( + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): - super().__init__(context_manager, provider_id, config) + super().__init__(context_manager, provider_id, config) + + def dispose(self): + pass + + def validate_config(self): + """ + Validates the configuration of the Centreon provider. + """ + self.authentication_config = CentreonProviderAuthConfig( + **self.config.authentication + ) + + def __get_url(self, params: str): + url = self.authentication_config.host_url + "/centreon/api/index.php?" + params + return url + + def __get_headers(self): + return { + "Content-Type": "application/json", + "centreon-auth-token": self.authentication_config.api_token, + } + + def validate_scopes(self) -> dict[str, bool | str]: + """ + Validate the scopes of the provider. + """ + try: + response = requests.get( + self.__get_url("object=centreon_realtime_hosts&action=list"), + headers=self.__get_headers(), + ) + if response.ok: + scopes = {"authenticated": True} + else: + scopes = { + "authenticated": f"Error validating scopes: {response.status_code} {response.text}" + } + except Exception as e: + scopes = { + "authenticated": f"Error validating scopes: {e}", + } + + return scopes + + def __get_host_status(self) -> list[AlertDto]: + try: + url = self.__get_url("object=centreon_realtime_hosts&action=list") + response = requests.get(url, headers=self.__get_headers()) + + if not response.ok: + self.logger.error( + "Failed to get host status from Centreon: %s", response.json() + ) + raise ProviderException("Failed to get host status from Centreon") + + return [ + AlertDto( + id=host["id"], + name=host["name"], + address=host["address"], + description=host["output"], + status=host["state"], + severity=host["output"].split()[0], + instance_name=host["instance_name"], + acknowledged=host["acknowledged"], + max_check_attempts=host["max_check_attempts"], + lastReceived=datetime.datetime.fromtimestamp( + host["last_check"] + ).isoformat(), + source=["centreon"], + ) + for host in response.json() + ] + + except Exception as e: + self.logger.error("Error getting host status from Centreon: %s", e) + raise ProviderException(f"Error getting host status from Centreon: {e}") + + def __get_service_status(self) -> list[AlertDto]: + try: + url = self.__get_url("object=centreon_realtime_services&action=list") + response = requests.get(url, headers=self.__get_headers()) + + if not response.ok: + self.logger.error( + "Failed to get service status from Centreon: %s", response.json() + ) + raise ProviderException("Failed to get service status from Centreon") + + return [ + AlertDto( + id=service["service_id"], + host_id=service["host_id"], + name=service["name"], + description=service["description"], + status=service["state"], + severity=service["output"].split(":")[0], + acknowledged=service["acknowledged"], + max_check_attempts=service["max_check_attempts"], + lastReceived=datetime.datetime.fromtimestamp( + service["last_check"] + ).isoformat(), + source=["centreon"], + ) + for service in response.json() + ] + + except Exception as e: + self.logger.error("Error getting service status from Centreon: %s", e) + raise ProviderException(f"Error getting service status from Centreon: {e}") + + def _get_alerts(self) -> list[AlertDto]: + alerts = [] + try: + self.logger.info("Collecting alerts (host status) from Centreon") + host_status_alerts = self.__get_host_status() + alerts.extend(host_status_alerts) + except Exception as e: + self.logger.error("Error getting host status from Centreon: %s", e) + + try: + self.logger.info("Collecting alerts (service status) from Centreon") + service_status_alerts = self.__get_service_status() + alerts.extend(service_status_alerts) + except Exception as e: + self.logger.error("Error getting service status from Centreon: %s", e) + + return alerts - def dispose(self): - pass - def validate_config(self): - """ - Validates the configuration of the Centreon provider. - """ - self.authentication_config = CentreonProviderAuthConfig(**self.config.authentication) - - def __get_url(self, params: str): - url = self.authentication_config.host_url + "/centreon/api/index.php?" + params - return url - - def __get_headers(self): - return { - "Content-Type": "application/json", - "centreon-auth-token": self.authentication_config.api_token, - } - - def validate_scopes(self) -> dict[str, bool | str]: - """ - Validate the scopes of the provider. - """ - try: - response = requests.get(self.__get_url("object=centreon_realtime_hosts&action=list"), headers=self.__get_headers()) - if response.ok: - scopes = { - "authenticated": True - } - else: - scopes = { - "authenticated": f"Error validating scopes: {response.status_code} {response.text}" - } - except Exception as e: - scopes = { - "authenticated": f"Error validating scopes: {e}", - } - - return scopes - - def __get_host_status(self) -> list[AlertDto]: - try: - url = self.__get_url("object=centreon_realtime_hosts&action=list") - response = requests.get(url, headers=self.__get_headers()) - - if not response.ok: - self.logger.error("Failed to get host status from Centreon: %s", response.json()) - raise ProviderException("Failed to get host status from Centreon") - - return [AlertDto( - id=host["id"], - name=host["name"], - address=host["address"], - description=host["output"], - status=host["state"], - severity=host["output"].split()[0], - instance_name=host["instance_name"], - acknowledged=host["acknowledged"], - max_check_attempts=host["max_check_attempts"], - lastReceived=datetime.datetime.fromtimestamp(host["last_check"]).isoformat(), - source=["centreon"] - ) for host in response.json()] - - except Exception as e: - self.logger.error("Error getting host status from Centreon: %s", e) - raise ProviderException(f"Error getting host status from Centreon: {e}") - - def __get_service_status(self) -> list[AlertDto]: - try: - url = self.__get_url("object=centreon_realtime_services&action=list") - response = requests.get(url, headers=self.__get_headers()) - - if not response.ok: - self.logger.error("Failed to get service status from Centreon: %s", response.json()) - raise ProviderException("Failed to get service status from Centreon") - - return [AlertDto( - id=service["service_id"], - host_id=service["host_id"], - name=service["name"], - description=service["description"], - status=service["state"], - severity=service["output"].split(":")[0], - acknowledged=service["acknowledged"], - max_check_attempts=service["max_check_attempts"], - lastReceived=datetime.datetime.fromtimestamp(service["last_check"]).isoformat(), - source=["centreon"] - ) for service in response.json()] - - except Exception as e: - self.logger.error("Error getting service status from Centreon: %s", e) - raise ProviderException(f"Error getting service status from Centreon: {e}") - - def _get_alerts(self) -> list[AlertDto]: - alerts = [] - try: - self.logger.info("Collecting alerts (host status) from Centreon") - host_status_alerts = self.__get_host_status() - alerts.extend(host_status_alerts) - except Exception as e: - self.logger.error("Error getting host status from Centreon: %s", e) - - try: - self.logger.info("Collecting alerts (service status) from Centreon") - service_status_alerts = self.__get_service_status() - alerts.extend(service_status_alerts) - except Exception as e: - self.logger.error("Error getting service status from Centreon: %s", e) - - return alerts - if __name__ == "__main__": - import logging - - logging.basicConfig(level=logging.DEBUG, handlers=[logging.StreamHandler()]) - context_manager = ContextManager( - tenant_id="singletenant", - workflow_id="test", - ) - - import os - - host_url = os.environ.get("CENTREON_HOST_URL") - api_token = os.environ.get("CENTREON_API_TOKEN") - - if host_url is None: - raise ProviderException("CENTREON_HOST_URL is not set") - - config = ProviderConfig( - description="Centreon Provider", - authentication={ - "host_url": host_url, - "api_token": api_token, - }, - ) - - provider = CentreonProvider( - context_manager, - provider_id="centreon", - config=config, - ) - - provider._get_alerts() + import logging + + logging.basicConfig(level=logging.DEBUG, handlers=[logging.StreamHandler()]) + context_manager = ContextManager( + tenant_id="singletenant", + workflow_id="test", + ) + + import os + + host_url = os.environ.get("CENTREON_HOST_URL") + api_token = os.environ.get("CENTREON_API_TOKEN") + + if host_url is None: + raise ProviderException("CENTREON_HOST_URL is not set") + + config = ProviderConfig( + description="Centreon Provider", + authentication={ + "host_url": host_url, + "api_token": api_token, + }, + ) + + provider = CentreonProvider( + context_manager, + provider_id="centreon", + config=config, + ) + provider._get_alerts() diff --git a/keep/providers/clickhouse_provider/clickhouse_provider.py b/keep/providers/clickhouse_provider/clickhouse_provider.py index 4b35cefee..307e0be5a 100644 --- a/keep/providers/clickhouse_provider/clickhouse_provider.py +++ b/keep/providers/clickhouse_provider/clickhouse_provider.py @@ -6,13 +6,13 @@ import os import pydantic - from clickhouse_driver import connect from clickhouse_driver.dbapi.extras import DictCursor from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import UrlPort @pydantic.dataclasses.dataclass @@ -21,16 +21,25 @@ class ClickhouseProviderAuthConfig: metadata={"required": True, "description": "Clickhouse username"} ) password: str = dataclasses.field( - metadata={"required": True, "description": "Clickhouse password", "sensitive": True} + metadata={ + "required": True, + "description": "Clickhouse password", + "sensitive": True, + } ) host: str = dataclasses.field( metadata={"required": True, "description": "Clickhouse hostname"} ) - port: str = dataclasses.field( - metadata={"required": True, "description": "Clickhouse port"} + port: UrlPort = dataclasses.field( + metadata={ + "required": True, + "description": "Clickhouse port", + "validation": "port", + } ) database: str | None = dataclasses.field( - metadata={"required": False, "description": "Clickhouse database name"}, default=None + metadata={"required": False, "description": "Clickhouse database name"}, + default=None, ) @@ -60,13 +69,13 @@ def validate_scopes(self): """ try: client = self.__generate_client() - + cursor = client.cursor() - cursor.execute('SHOW TABLES') - + cursor.execute("SHOW TABLES") + tables = cursor.fetchall() self.logger.info(f"Tables: {tables}") - + cursor.close() client.close() @@ -88,11 +97,11 @@ def __generate_client(self): clickhouse_driver.Connection: Clickhouse connection object """ - user=self.authentication_config.username - password=self.authentication_config.password - host=self.authentication_config.host - database=self.authentication_config.database - port=self.authentication_config.port + user = self.authentication_config.username + password = self.authentication_config.password + host = self.authentication_config.host + database = self.authentication_config.database + port = self.authentication_config.port dsn = f"clickhouse://{user}:{password}@{host}:{port}/{database}" @@ -121,9 +130,7 @@ def _query(self, query="", single_row=False, **kwargs: dict) -> list | tuple: """ return self._notify(query=query, single_row=single_row, **kwargs) - def _notify( - self, query="", single_row=False, **kwargs: dict - ) -> list | tuple: + def _notify(self, query="", single_row=False, **kwargs: dict) -> list | tuple: """ Executes a query against the Clickhouse database. @@ -160,5 +167,7 @@ def _notify( workflow_id="test", ) clickhouse_provider = ClickhouseProvider(context_manager, "clickhouse-prod", config) - results = clickhouse_provider.query(query="SELECT * FROM logs_table ORDER BY timestamp DESC LIMIT 1") + results = clickhouse_provider.query( + query="SELECT * FROM logs_table ORDER BY timestamp DESC LIMIT 1" + ) print(results) diff --git a/keep/providers/discord_provider/discord_provider.py b/keep/providers/discord_provider/discord_provider.py index a9bd7dfb3..67f72263e 100644 --- a/keep/providers/discord_provider/discord_provider.py +++ b/keep/providers/discord_provider/discord_provider.py @@ -11,17 +11,19 @@ from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass class DiscordProviderAuthConfig: """Discord authentication configuration.""" - webhook_url: str = dataclasses.field( + webhook_url: HttpsUrl = dataclasses.field( metadata={ "required": True, "description": "Discord Webhook Url", "sensitive": True, + "validation": "https_url", } ) diff --git a/keep/providers/elastic_provider/elastic_provider.py b/keep/providers/elastic_provider/elastic_provider.py index f6b4dae24..06d8c50c9 100644 --- a/keep/providers/elastic_provider/elastic_provider.py +++ b/keep/providers/elastic_provider/elastic_provider.py @@ -1,6 +1,7 @@ """ Elasticsearch provider. """ + import dataclasses import json @@ -8,7 +9,6 @@ from elasticsearch import Elasticsearch from keep.contextmanager.contextmanager import ContextManager -from keep.exceptions.provider_config_exception import ProviderConfigException from keep.exceptions.provider_connection_failed import ProviderConnectionFailed from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig @@ -26,8 +26,13 @@ class ElasticProviderAuthConfig: "sensitive": True, } ) - host: str = dataclasses.field( - default="", metadata={"required": False, "description": "Elasticsearch host"} + host: pydantic.HttpUrl = dataclasses.field( + default="", + metadata={ + "required": False, + "description": "Elasticsearch host", + "validation": "http_url", + }, ) cloud_id: str = dataclasses.field( default="", @@ -37,8 +42,8 @@ class ElasticProviderAuthConfig: @pydantic.root_validator def check_host_or_cloud_id(cls, values): host, cloud_id = values.get("host"), values.get("cloud_id") - if host == "" and cloud_id == "": - raise ValueError("either host or cloud_id must be provided") + if host is None and cloud_id is None: + raise ValueError("Missing host or cloud_id in provider config") return values @@ -63,9 +68,9 @@ def __initialize_client(self) -> Elasticsearch: """ Initialize the ElasticSearch client for the provider. """ - api_key = self.config.authentication.get("api_key") - host = self.config.authentication.get("host") - cloud_id = self.config.authentication.get("cloud_id") + api_key = self.authentication_config.api_key + host = self.authentication_config.host + cloud_id = self.authentication_config.cloud_id # Elastic.co requires you to connect with cloud_id if cloud_id: @@ -84,17 +89,9 @@ def validate_config(self): """ Validate the provider config. """ - if not self.config.authentication.get( - "host" - ) and not self.config.authentication.get("cloud_id"): - raise ProviderConfigException( - "Missing host or cloud_id in provider config", - provider_id=self.provider_id, - ) - if "api_key" not in self.config.authentication: - raise ProviderConfigException( - "Missing api_key in provider config", provider_id=self.provider_id - ) + self.authentication_config = ElasticProviderAuthConfig( + **self.config.authentication + ) @staticmethod def get_neccessary_config_keys(): diff --git a/keep/providers/slack_provider/slack_provider.py b/keep/providers/slack_provider/slack_provider.py index e5cafb8d4..5cb875764 100644 --- a/keep/providers/slack_provider/slack_provider.py +++ b/keep/providers/slack_provider/slack_provider.py @@ -13,17 +13,18 @@ from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass class SlackProviderAuthConfig: """Slack authentication configuration.""" - webhook_url: str = dataclasses.field( + webhook_url: HttpsUrl = dataclasses.field( metadata={ "required": True, "description": "Slack Webhook Url", - "sensitive": True, + "validation": "https_url", }, default="", ) diff --git a/keep/providers/victoriametrics_provider/victoriametrics_provider.py b/keep/providers/victoriametrics_provider/victoriametrics_provider.py index 18e0eaec3..6f93e171e 100644 --- a/keep/providers/victoriametrics_provider/victoriametrics_provider.py +++ b/keep/providers/victoriametrics_provider/victoriametrics_provider.py @@ -12,6 +12,7 @@ from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import UrlPort class ResourceAlreadyExists(Exception): @@ -33,11 +34,12 @@ class VictoriametricsProviderAuthConfig: }, ) - VMAlertPort: int = dataclasses.field( + VMAlertPort: UrlPort = dataclasses.field( metadata={ "required": True, "description": "The port number on which VMAlert is listening. This should match the port configured in your VMAlert setup.", "hint": "Example: 8880 (if VMAlert is set to listen on port 8880)", + "validation": "port" }, ) From d28ba1a9a400d1768c242b33184ac7f03e180d1d Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Tue, 22 Oct 2024 23:50:59 +0100 Subject: [PATCH 07/35] add backend provider validation & switch input --- keep-ui/app/providers/provider-form.tsx | 53 ++++++++++++++- keep-ui/app/providers/providers.tsx | 2 +- .../auth0_provider/auth0_provider.py | 18 ++--- .../elastic_provider/elastic_provider.py | 9 +-- .../google_chat_provider.py | 13 ++-- .../grafana_incident_provider.py | 66 ++++++++++++------- .../grafana_oncall_provider.py | 4 +- .../ilert_provider/ilert_provider.py | 18 ++++- keep/providers/jira_provider/jira_provider.py | 13 +++- .../kibana_provider/kibana_provider.py | 5 +- .../kubernetes_provider.py | 9 +-- .../openobserve_provider.py | 2 +- .../slack_provider/slack_provider.py | 1 - 13 files changed, 148 insertions(+), 65 deletions(-) diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index 08fb13681..829b92b6f 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -26,6 +26,7 @@ import { AccordionHeader, AccordionBody, Badge, + Switch, } from "@tremor/react"; import { ExclamationCircleIcon, @@ -144,9 +145,14 @@ function getInitialFormValues(provider: Provider) { ...provider.details.authentication, }); - // Set default values for select inputs + // Set default values for select & switch inputs Object.entries(provider.config).forEach(([field, config]) => { - if (config.type === "select" && config.default && !initialValues[field]) { + if ( + config.type && + ["select", "switch"].includes(config.type) && + config.default !== null && + !initialValues[field] + ) { initialValues[field] = config.default; } }); @@ -159,7 +165,8 @@ function getZodSchema(fields: Provider["config"]) { const portError = "Invalid port number"; const emptyStringToNull = z .string() - .transform((val) => (val.length === 0 ? null : val)); + .optional() + .transform((val) => (val?.length === 0 ? null : val)); const kvPairs = Object.entries(fields).map(([field, config]) => { if (config.type === "form") { const baseFormSchema = z.record(z.string(), z.string()).array(); @@ -193,6 +200,13 @@ function getZodSchema(fields: Provider["config"]) { return [field, fileSchema]; } + if (config.type === "switch") { + const switchSchema = config.required + ? z.boolean() + : z.boolean().optional(); + return [field, switchSchema]; + } + const urlSchema = z .string({ required_error }) .url({ message: "Please provide a valid url, e.g https://example.com" }); @@ -964,6 +978,16 @@ function FormField({ onChange={handleInputChange} /> ); + case "switch": + return ( + onChange(id, value)} + /> + ); default: return ( void; +}) { + if (typeof value !== "boolean") return null; + + return ( +
+ + +
+ ); +} + function FieldLabel({ id, config, diff --git a/keep-ui/app/providers/providers.tsx b/keep-ui/app/providers/providers.tsx index 3911cfe2a..4c06b24a6 100644 --- a/keep-ui/app/providers/providers.tsx +++ b/keep-ui/app/providers/providers.tsx @@ -9,7 +9,7 @@ export interface ProviderAuthConfig { options?: Array; sensitive?: boolean; hidden?: boolean; - type?: "select" | "form" | "file"; + type?: "select" | "form" | "file" | "switch"; file_type?: string; config_main_group?: string; config_sub_group?: string; diff --git a/keep/providers/auth0_provider/auth0_provider.py b/keep/providers/auth0_provider/auth0_provider.py index 32e08ad65..9b9eff405 100644 --- a/keep/providers/auth0_provider/auth0_provider.py +++ b/keep/providers/auth0_provider/auth0_provider.py @@ -20,21 +20,21 @@ class Auth0ProviderAuthConfig: Auth0 authentication configuration. """ - token: str = dataclasses.field( - default=None, + domain: HttpsUrl = dataclasses.field( metadata={ "required": True, - "description": "Auth0 API Token", - "hint": "https://manage.auth0.com/dashboard/us/YOUR_ACCOUNT/apis/management/explorer", + "description": "Auth0 Domain", + "hint": "https://tenantname.us.auth0.com", + "validation": "https_url", }, ) - domain: HttpsUrl = dataclasses.field( + + token: str = dataclasses.field( default=None, metadata={ "required": True, - "description": "Auth0 Domain", - "hint": "tenantname.us.auth0.com", - "validation": "https_url", + "description": "Auth0 API Token", + "hint": "https://manage.auth0.com/dashboard/us/YOUR_ACCOUNT/apis/management/explorer", }, ) @@ -58,7 +58,7 @@ def validate_config(self): """ if self.is_installed or self.is_provisioned: host = self.config.authentication["domain"] - host = "https://" + host if not (host.starts_with("https://")) else host + host = "https://" + host if not host.startswith("https://") else host self.config.authentication["domain"] = host self.authentication_config = Auth0ProviderAuthConfig( diff --git a/keep/providers/elastic_provider/elastic_provider.py b/keep/providers/elastic_provider/elastic_provider.py index 06d8c50c9..d414c578a 100644 --- a/keep/providers/elastic_provider/elastic_provider.py +++ b/keep/providers/elastic_provider/elastic_provider.py @@ -4,6 +4,7 @@ import dataclasses import json +import typing import pydantic from elasticsearch import Elasticsearch @@ -26,16 +27,16 @@ class ElasticProviderAuthConfig: "sensitive": True, } ) - host: pydantic.HttpUrl = dataclasses.field( - default="", + host: typing.Optional[pydantic.HttpUrl] = dataclasses.field( + default=None, metadata={ "required": False, "description": "Elasticsearch host", "validation": "http_url", }, ) - cloud_id: str = dataclasses.field( - default="", + cloud_id: typing.Optional[str] = dataclasses.field( + default=None, metadata={"required": False, "description": "Elasticsearch cloud id"}, ) diff --git a/keep/providers/google_chat_provider/google_chat_provider.py b/keep/providers/google_chat_provider/google_chat_provider.py index ba27d4624..a09215dd6 100644 --- a/keep/providers/google_chat_provider/google_chat_provider.py +++ b/keep/providers/google_chat_provider/google_chat_provider.py @@ -1,26 +1,28 @@ +import dataclasses import os + import pydantic -import dataclasses import requests from keep.contextmanager.contextmanager import ContextManager from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass class GoogleChatProviderAuthConfig: """Google Chat authentication configuration.""" - webhook_url: str = dataclasses.field( + webhook_url: HttpsUrl = dataclasses.field( metadata={ "name": "webhook_url", "description": "Google Chat Webhook Url", "required": True, "sensitive": True, + "validation": "https_url", }, - default="", ) @@ -31,7 +33,7 @@ class GoogleChatProvider(BaseProvider): PROVIDER_TAGS = ["messaging"] def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): super().__init__(context_manager, provider_id, config) @@ -40,9 +42,6 @@ def validate_config(self): **self.config.authentication ) - if not self.authentication_config.webhook_url: - raise ProviderException("Google Chat webhook URL is required") - def dispose(self): """ No need to dispose of anything, so just do nothing. diff --git a/keep/providers/grafana_incident_provider/grafana_incident_provider.py b/keep/providers/grafana_incident_provider/grafana_incident_provider.py index c8474e9b0..68a5c6b1e 100644 --- a/keep/providers/grafana_incident_provider/grafana_incident_provider.py +++ b/keep/providers/grafana_incident_provider/grafana_incident_provider.py @@ -3,27 +3,30 @@ """ import dataclasses -import pydantic +from urllib.parse import urljoin +import pydantic import requests -from urllib.parse import urljoin - from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import HttpsUrl + @pydantic.dataclasses.dataclass class GrafanaIncidentProviderAuthConfig: """ GrafanaIncidentProviderAuthConfig is a class that allows to authenticate in Grafana Incident. """ - host_url: str = dataclasses.field( + + host_url: HttpsUrl = dataclasses.field( metadata={ "required": True, "description": "Grafana Host URL", "sensitive": False, + "validation": "https_url", }, default=None, ) @@ -37,6 +40,7 @@ class GrafanaIncidentProviderAuthConfig: default=None, ) + class GrafanaIncidentProvider(BaseProvider): PROVIDER_DISPLAY_NAME = "Grafana Incident" PROVIDER_TAGS = ["alert"] @@ -55,19 +59,16 @@ class GrafanaIncidentProvider(BaseProvider): "Minor": AlertSeverity.LOW, } - STATUS_MAP = { - "active": AlertStatus.FIRING, - "resolved": AlertStatus.RESOLVED - } + STATUS_MAP = {"active": AlertStatus.FIRING, "resolved": AlertStatus.RESOLVED} def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): super().__init__(context_manager, provider_id, config) def dispose(self): pass - + def validate_config(self): """ Validate the configuration of the provider. @@ -91,49 +92,61 @@ def validate_scopes(self) -> dict[str, bool | str]: """ try: response = requests.post( - urljoin(self.authentication_config.host_url, "/api/plugins/grafana-incident-app/resources/api/v1/IncidentsService.QueryIncidentPreviews"), + urljoin( + self.authentication_config.host_url, + "/api/plugins/grafana-incident-app/resources/api/v1/IncidentsService.QueryIncidentPreviews", + ), headers=self.__get_headers(), json={ "query": { "limit": 10, "orderDirection": "DESC", - "orderField": "createdTime" + "orderField": "createdTime", } - } + }, ) if response.status_code == 200: return {"authenticated": True} else: self.logger.error(f"Failed to validate scopes: {response.status_code}") - scopes = {"authenticated": f"Unable to query incidents: {response.status_code}"} + scopes = { + "authenticated": f"Unable to query incidents: {response.status_code}" + } except Exception as e: self.logger.error(f"Failed to validate scopes: {e}") scopes = {"authenticated": f"Unable to query incidents: {e}"} return scopes - + def _get_alerts(self) -> list[AlertDto]: """ Get the alerts from Grafana Incident. """ try: response = requests.post( - urljoin(self.authentication_config.host_url, "/api/plugins/grafana-incident-app/resources/api/v1/IncidentsService.QueryIncidentPreviews"), + urljoin( + self.authentication_config.host_url, + "/api/plugins/grafana-incident-app/resources/api/v1/IncidentsService.QueryIncidentPreviews", + ), headers=self.__get_headers(), json={ "query": { "limit": 10, "orderDirection": "DESC", - "orderField": "createdTime" + "orderField": "createdTime", } - } + }, ) if not response.ok: - self.logger.error(f"Failed to get incidents from grafana incident: {response.status_code}") - raise Exception(f"Failed to get incidents from grafana incident: {response.status_code} - {response.text}") - + self.logger.error( + f"Failed to get incidents from grafana incident: {response.status_code}" + ) + raise Exception( + f"Failed to get incidents from grafana incident: {response.status_code} - {response.text}" + ) + return [ AlertDto( id=incident["incidentID"], @@ -158,7 +171,7 @@ def _get_alerts(self) -> list[AlertDto]: incidentMembershipPreview=incident["incidentMembershipPreview"], fieldValues=incident["fieldValues"], version=incident["version"], - source=["grafana_incident"] + source=["grafana_incident"], ) for incident in response.json()["incidentPreviews"] ] @@ -166,7 +179,8 @@ def _get_alerts(self) -> list[AlertDto]: except Exception as e: self.logger.error(f"Failed to get incidents from grafana incident: {e}") raise Exception(f"Failed to get incidents from grafana incident: {e}") - + + if __name__ == "__main__": import logging @@ -182,8 +196,10 @@ def _get_alerts(self) -> list[AlertDto]: api_token = os.getenv("GRAFANA_SERVICE_ACCOUNT_TOKEN") if host_url is None or api_token is None: - raise Exception("GRAFANA_HOST_URL and GRAFANA_SERVICE_ACCOUNT_TOKEN environment variables are required") - + raise Exception( + "GRAFANA_HOST_URL and GRAFANA_SERVICE_ACCOUNT_TOKEN environment variables are required" + ) + config = ProviderConfig( description="Grafana Incident Provider", authentication={ diff --git a/keep/providers/grafana_oncall_provider/grafana_oncall_provider.py b/keep/providers/grafana_oncall_provider/grafana_oncall_provider.py index d264f77cb..fec1db322 100644 --- a/keep/providers/grafana_oncall_provider/grafana_oncall_provider.py +++ b/keep/providers/grafana_oncall_provider/grafana_oncall_provider.py @@ -1,6 +1,7 @@ """ Grafana Provider is a class that allows to ingest/digest data from Grafana. """ + import dataclasses import random from typing import Literal @@ -27,11 +28,12 @@ class GrafanaOncallProviderAuthConfig: "hint": "Grafana OnCall API Token", }, ) - host: str = dataclasses.field( + host: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Grafana OnCall Host", "hint": "E.g. https://keephq.grafana.net", + "validation": "any_http_url", }, ) diff --git a/keep/providers/ilert_provider/ilert_provider.py b/keep/providers/ilert_provider/ilert_provider.py index 748e67c63..f5a5071cd 100644 --- a/keep/providers/ilert_provider/ilert_provider.py +++ b/keep/providers/ilert_provider/ilert_provider.py @@ -16,6 +16,7 @@ from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.providers_factory import ProvidersFactory +from keep.validation.fields import HttpsUrl class IlertIncidentStatus(str, enum.Enum): @@ -43,11 +44,12 @@ class IlertProviderAuthConfig: "sensitive": True, } ) - ilert_host: str = dataclasses.field( + ilert_host: HttpsUrl = dataclasses.field( metadata={ "required": False, "description": "ILert API host", "hint": "https://api.ilert.com/api", + "validation": "https_url" }, default="https://api.ilert.com/api", ) @@ -103,15 +105,25 @@ def validate_scopes(self): for scope in self.PROVIDER_SCOPES: try: if scope.name == "read_permission": - requests.get( + res = requests.get( f"{self.authentication_config.ilert_host}/incidents", headers={ "Authorization": self.authentication_config.ilert_token }, ) + res.raise_for_status() scopes[scope.name] = True elif scope.name == "write_permission": - # TODO: find a way to validate write_permissions, for now it is always "validated" sucessfully. + res = requests.get( + f"{self.authentication_config.ilert_host}/users/current", + headers={ + "Authorization": self.authentication_config.ilert_token + }, + ) + res.raise_for_status() + data = res.json() + if data['role'] not in ["USER", "ADMIN"]: + scopes[scope.name] = "User role & permisisions may be limited." scopes[scope.name] = True except Exception as e: self.logger.warning( diff --git a/keep/providers/jira_provider/jira_provider.py b/keep/providers/jira_provider/jira_provider.py index b79216594..cef86fd4f 100644 --- a/keep/providers/jira_provider/jira_provider.py +++ b/keep/providers/jira_provider/jira_provider.py @@ -14,6 +14,7 @@ from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass @@ -37,13 +38,14 @@ class JiraProviderAuthConfig: "documentation_url": "https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/#Create-an-API-token", } ) - host: str = dataclasses.field( + host: HttpsUrl = dataclasses.field( metadata={ "required": True, "description": "Atlassian Jira Host", "sensitive": False, "documentation_url": "https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/#Create-an-API-token", - "hint": "keephq.atlassian.net", + "hint": "https://keephq.atlassian.net", + "validation": "https_url" } ) @@ -151,6 +153,11 @@ def validate_scopes(self): return scopes def validate_config(self): + if self.is_installed or self.is_provisioned: + host = self.config.authentication['host'] + host = "https://" + host if not host.startswith("https://") else host + self.config.authentication['host'] = host + self.authentication_config = JiraProviderAuthConfig( **self.config.authentication ) @@ -483,7 +490,7 @@ def _query(self, ticket_id="", board_id="", **kwargs: dict): """ if not ticket_id: request_url = ( - f"https://{self.jira_host}/rest/agile/1.0/board/{board_id}/issue" + f"{self.jira_host}/rest/agile/1.0/board/{board_id}/issue" ) response = requests.get(request_url, auth=self.__get_auth(), verify=False) if not response.ok: diff --git a/keep/providers/kibana_provider/kibana_provider.py b/keep/providers/kibana_provider/kibana_provider.py index 1a290c28f..578dba332 100644 --- a/keep/providers/kibana_provider/kibana_provider.py +++ b/keep/providers/kibana_provider/kibana_provider.py @@ -12,7 +12,6 @@ import pydantic import requests from fastapi import HTTPException -from pydantic import AnyHttpUrl from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager @@ -33,7 +32,7 @@ class KibanaProviderAuthConfig: "sensitive": True, } ) - kibana_host: AnyHttpUrl = dataclasses.field( + kibana_host: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Kibana Host", @@ -444,7 +443,7 @@ def setup_webhook( def validate_config(self): if self.is_installed or self.is_provisioned: host = self.config.authentication['kibana_host'] - host = "https://" + host if not (host.starts_with("http://") or host.starts_with("https://")) else host + host = "https://" + host if not (host.startswith("http://") or host.startswith("https://")) else host self.config.authentication['kibana_host'] = host self.authentication_config = KibanaProviderAuthConfig( diff --git a/keep/providers/kubernetes_provider/kubernetes_provider.py b/keep/providers/kubernetes_provider/kubernetes_provider.py index 2dd85709b..703bdecd7 100644 --- a/keep/providers/kubernetes_provider/kubernetes_provider.py +++ b/keep/providers/kubernetes_provider/kubernetes_provider.py @@ -1,13 +1,13 @@ -import pydantic import dataclasses +import datetime +import pydantic from kubernetes import client from kubernetes.client.rest import ApiException -import datetime from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider -from keep.providers.models.provider_config import ProviderScope, ProviderConfig +from keep.providers.models.provider_config import ProviderConfig, ProviderScope @pydantic.dataclasses.dataclass @@ -36,9 +36,10 @@ class KubernetesProviderAuthConfig: default=True, metadata={ "name": "insecure", - "description": "Whether to skip tls verification (default: True)", + "description": "Skip TLS verification", "required": False, "sensitive": False, + "type": "switch" }, ) diff --git a/keep/providers/openobserve_provider/openobserve_provider.py b/keep/providers/openobserve_provider/openobserve_provider.py index 834a4bdc7..2ef298e9a 100644 --- a/keep/providers/openobserve_provider/openobserve_provider.py +++ b/keep/providers/openobserve_provider/openobserve_provider.py @@ -110,7 +110,7 @@ def validate_config(self): """ if self.is_installed or self.is_provisioned: host = self.config.authentication['openObserveHost'] - host = "https://" + host if not (host.starts_with("http://") or host.starts_with("https://")) else host + host = "https://" + host if not (host.startswith("http://") or host.startswith("https://")) else host self.config.authentication['openObserveHost'] = host self.authentication_config = OpenobserveProviderAuthConfig( diff --git a/keep/providers/slack_provider/slack_provider.py b/keep/providers/slack_provider/slack_provider.py index 5cb875764..573958d49 100644 --- a/keep/providers/slack_provider/slack_provider.py +++ b/keep/providers/slack_provider/slack_provider.py @@ -26,7 +26,6 @@ class SlackProviderAuthConfig: "description": "Slack Webhook Url", "validation": "https_url", }, - default="", ) access_token: str = dataclasses.field( metadata={ From 83d7ed0233aec8f6a47b641a732d2c2ca2b48e4b Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Thu, 24 Oct 2024 18:29:44 +0100 Subject: [PATCH 08/35] add validation for `any_url` - url with any scheme --- keep-ui/app/providers/provider-form.tsx | 27 +++++++++++++++---------- keep-ui/app/providers/providers.tsx | 8 +++++++- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index 829b92b6f..d8ed04c8b 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -146,16 +146,13 @@ function getInitialFormValues(provider: Provider) { }); // Set default values for select & switch inputs - Object.entries(provider.config).forEach(([field, config]) => { - if ( - config.type && - ["select", "switch"].includes(config.type) && - config.default !== null && - !initialValues[field] - ) { + for (const [field, config] of Object.entries(provider.config)) { + if (field in initialValues) continue; + if (config.default === null) continue; + if (config.type && ["select", "switch"].includes(config.type)) { initialValues[field] = config.default; } - }); + } return initialValues; } @@ -207,9 +204,10 @@ function getZodSchema(fields: Provider["config"]) { return [field, switchSchema]; } - const urlSchema = z - .string({ required_error }) - .url({ message: "Please provide a valid url, e.g https://example.com" }); + const urlSchema = z.string({ required_error }).url({ + message: + "Please provide a valid url, with a scheme & hostname as required.", + }); const urlTldSchema = z.string().regex(new RegExp(/\.[a-z]{2,63}$/), { message: "Url must contain a valid TLD e.g .com, .io, .dev, .net", }); @@ -224,6 +222,13 @@ function getZodSchema(fields: Provider["config"]) { }) .and(urlTldSchema); + if (config.validation === "any_url") { + const anyUrlSchema = config.required + ? urlSchema + : emptyStringToNull.pipe(urlSchema.nullish()); + return [field, anyUrlSchema]; + } + if (config.validation === "any_http_url") { const anyHttpSchema = config.required ? baseAnyHttpSchema diff --git a/keep-ui/app/providers/providers.tsx b/keep-ui/app/providers/providers.tsx index 4c06b24a6..1e03bf640 100644 --- a/keep-ui/app/providers/providers.tsx +++ b/keep-ui/app/providers/providers.tsx @@ -2,7 +2,13 @@ export interface ProviderAuthConfig { description: string; hint?: string; placeholder?: string; - validation?: "any_http_url" | "http_url" | "https_url" | "port" | "tld"; + validation?: + | "any_url" + | "any_http_url" + | "http_url" + | "https_url" + | "port" + | "tld"; required?: boolean; value?: string; default: string | number | boolean | null; From f415058835285ff35185fea5e0939b3b77a7b342 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Thu, 24 Oct 2024 18:30:16 +0100 Subject: [PATCH 09/35] add backend validation for 7 providers --- .../kafka_provider/kafka_provider.py | 2 -- .../kubernetes_provider.py | 3 ++- .../mattermost_provider.py | 7 +++--- .../mongodb_provider/mongodb_provider.py | 23 +++++++++++++++---- .../mysql_provider/mysql_provider.py | 8 +++++-- keep/providers/ntfy_provider/ntfy_provider.py | 3 ++- .../openshift_provider/openshift_provider.py | 15 ++++++------ 7 files changed, 40 insertions(+), 21 deletions(-) diff --git a/keep/providers/kafka_provider/kafka_provider.py b/keep/providers/kafka_provider/kafka_provider.py index d79a1c293..3277b6aaa 100644 --- a/keep/providers/kafka_provider/kafka_provider.py +++ b/keep/providers/kafka_provider/kafka_provider.py @@ -6,7 +6,6 @@ import logging import pydantic - # from confluent_kafka import Consumer, KafkaError, KafkaException from kafka import KafkaConsumer from kafka.errors import KafkaError, NoBrokersAvailable @@ -167,7 +166,6 @@ def dispose(self): def validate_config(self): """ Validates required configuration for Kafka provider. - """ self.authentication_config = KafkaProviderAuthConfig( **self.config.authentication diff --git a/keep/providers/kubernetes_provider/kubernetes_provider.py b/keep/providers/kubernetes_provider/kubernetes_provider.py index 703bdecd7..13a519612 100644 --- a/keep/providers/kubernetes_provider/kubernetes_provider.py +++ b/keep/providers/kubernetes_provider/kubernetes_provider.py @@ -14,13 +14,14 @@ class KubernetesProviderAuthConfig: """Kubernetes authentication configuration.""" - api_server: str = dataclasses.field( + api_server: pydantic.AnyHttpUrl = dataclasses.field( default=None, metadata={ "name": "api_server", "description": "The kubernetes api server url", "required": True, "sensitive": False, + "validation": "any_http_url" }, ) token: str = dataclasses.field( diff --git a/keep/providers/mattermost_provider/mattermost_provider.py b/keep/providers/mattermost_provider/mattermost_provider.py index 4a2d6b312..5fe6ee41f 100644 --- a/keep/providers/mattermost_provider/mattermost_provider.py +++ b/keep/providers/mattermost_provider/mattermost_provider.py @@ -13,11 +13,12 @@ class MattermostProviderAuthConfig: """Mattermost authentication configuration.""" - webhook_url: str = dataclasses.field( + webhook_url: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Mattermost Webhook Url", "sensitive": True, + "validation": "any_http_url", } ) @@ -36,8 +37,6 @@ def validate_config(self): self.authentication_config = MattermostProviderAuthConfig( **self.config.authentication ) - if not self.authentication_config.webhook_url: - raise Exception("Mattermost webhook URL is required") def dispose(self): """ @@ -59,7 +58,7 @@ def _notify(self, message="", blocks=[], channel="", **kwargs: dict): webhook_url = self.authentication_config.webhook_url payload = {"text": message, "blocks": blocks} # channel is currently bugged (and unnecessary, as a webhook url is already one per channel) and so it is ignored for now - #if channel: + # if channel: # payload["channel"] = channel response = requests.post(webhook_url, json=payload, verify=False) diff --git a/keep/providers/mongodb_provider/mongodb_provider.py b/keep/providers/mongodb_provider/mongodb_provider.py index 50bd7d939..76e7fc14e 100644 --- a/keep/providers/mongodb_provider/mongodb_provider.py +++ b/keep/providers/mongodb_provider/mongodb_provider.py @@ -10,17 +10,19 @@ from pymongo import MongoClient from keep.contextmanager.contextmanager import ContextManager +from keep.exceptions.provider_config_exception import ProviderConfigException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope @pydantic.dataclasses.dataclass class MongodbProviderAuthConfig: - host: str = dataclasses.field( + host: pydantic.AnyUrl = dataclasses.field( metadata={ "required": True, "description": "Mongo host_uri", - "hint": "any valid mongo host_uri like host:port, user:paassword@host:port?authSource", + "hint": "any valid mongo host_uri like mongodb://host:port, user:paassword@host:port?authSource", + "validation": "any_url", } ) username: str = dataclasses.field( @@ -77,7 +79,9 @@ def validate_scopes(self): """ try: client = self.__generate_client() - client.admin.command('ping') # will raise an exception if the server is not available + client.admin.command( + "ping" + ) # will raise an exception if the server is not available client.close() scopes = { "connect_to_server": True, @@ -118,7 +122,9 @@ def __generate_client(self): and k != "additional_options" # additional_options will go seperately and k != "database" } # database is not a valid mongo option - client = MongoClient(**client_conf, **additional_options, serverSelectionTimeoutMS=10000) # 10 seconds timeout + client = MongoClient( + **client_conf, **additional_options, serverSelectionTimeoutMS=10000 + ) # 10 seconds timeout return client def dispose(self): @@ -131,6 +137,15 @@ def validate_config(self): """ Validates required configuration for MongoDB's provider. """ + host = self.config.authentication["host"] + if host is None: + raise ProviderConfigException("Please provide a value for `host`") + host = ( + "mongodb://" + host + if not (host.startswith("mongodb://") or host.startwith("mongodb+srv://")) + else host + ) + self.authentication_config = MongodbProviderAuthConfig( **self.config.authentication ) diff --git a/keep/providers/mysql_provider/mysql_provider.py b/keep/providers/mysql_provider/mysql_provider.py index 5ace58a75..a96b335c2 100644 --- a/keep/providers/mysql_provider/mysql_provider.py +++ b/keep/providers/mysql_provider/mysql_provider.py @@ -21,8 +21,12 @@ class MysqlProviderAuthConfig: password: str = dataclasses.field( metadata={"required": True, "description": "MySQL password", "sensitive": True} ) - host: str = dataclasses.field( - metadata={"required": True, "description": "MySQL hostname"} + host: pydantic.AnyUrl = dataclasses.field( + metadata={ + "required": True, + "description": "MySQL hostname", + "validation": "any_url", + } ) database: str | None = dataclasses.field( metadata={"required": False, "description": "MySQL database name"}, default=None diff --git a/keep/providers/ntfy_provider/ntfy_provider.py b/keep/providers/ntfy_provider/ntfy_provider.py index 992f750cb..383b56657 100644 --- a/keep/providers/ntfy_provider/ntfy_provider.py +++ b/keep/providers/ntfy_provider/ntfy_provider.py @@ -30,12 +30,13 @@ class NtfyProviderAuthConfig: default=None, ) - host: str = dataclasses.field( + host: pydantic.AnyHttpUrl | None = dataclasses.field( metadata={ "required": False, "description": "Ntfy Host URL (For self-hosted Ntfy only)", "sensitive": False, "hint": "http://localhost:80", + "validation": "any_http_url", }, default=None, ) diff --git a/keep/providers/openshift_provider/openshift_provider.py b/keep/providers/openshift_provider/openshift_provider.py index acfd90f6f..5ba46faa9 100644 --- a/keep/providers/openshift_provider/openshift_provider.py +++ b/keep/providers/openshift_provider/openshift_provider.py @@ -1,9 +1,10 @@ -import pydantic -import openshift_client as oc -from openshift_client import OpenShiftPythonException, Context import dataclasses import traceback +import openshift_client as oc +import pydantic +from openshift_client import Context, OpenShiftPythonException + from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope @@ -13,17 +14,16 @@ class OpenshiftProviderAuthConfig: """Openshift authentication configuration.""" - api_server: str = dataclasses.field( - default=None, + api_server: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "name": "api_server", "description": "The openshift api server url", "required": True, "sensitive": False, + "validation": "any_http_url" }, ) token: str = dataclasses.field( - default=None, metadata={ "name": "token", "description": "The openshift token", @@ -35,9 +35,10 @@ class OpenshiftProviderAuthConfig: default=False, metadata={ "name": "insecure", - "description": "Whether to skip tls verification", + "description": "Skip TLS verification", "required": False, "sensitive": False, + "type": "switch" }, ) From 7518ad2280ed0eb11296bab6108984a74404653c Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sat, 26 Oct 2024 23:42:26 +0100 Subject: [PATCH 10/35] add backend validation for 7 providers --- .../opsgenie_provider/opsgenie_provider.py | 3 ++- .../postgres_provider/postgres_provider.py | 22 +++++++++---------- .../servicenow_provider.py | 11 +++------- keep/providers/smtp_provider/smtp_provider.py | 10 +++++---- .../squadcast_provider/squadcast_provider.py | 7 ++++-- keep/providers/ssh_provider/ssh_provider.py | 5 +++-- .../teams_provider/teams_provider.py | 4 +++- 7 files changed, 32 insertions(+), 30 deletions(-) diff --git a/keep/providers/opsgenie_provider/opsgenie_provider.py b/keep/providers/opsgenie_provider/opsgenie_provider.py index 231673dd5..512575954 100644 --- a/keep/providers/opsgenie_provider/opsgenie_provider.py +++ b/keep/providers/opsgenie_provider/opsgenie_provider.py @@ -16,7 +16,8 @@ class OpsgenieProviderAuthConfig: api_key: str = dataclasses.field( metadata={ "required": True, - "description": "Ops genie api key (https://support.atlassian.com/opsgenie/docs/api-key-management/)", + "description": "Ops genie api key", + "hint": "https://support.atlassian.com/opsgenie/docs/api-key-management/", "sensitive": True, }, ) diff --git a/keep/providers/postgres_provider/postgres_provider.py b/keep/providers/postgres_provider/postgres_provider.py index 1191f80db..1b188fbe7 100644 --- a/keep/providers/postgres_provider/postgres_provider.py +++ b/keep/providers/postgres_provider/postgres_provider.py @@ -11,6 +11,7 @@ from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import UrlPort @pydantic.dataclasses.dataclass @@ -32,8 +33,13 @@ class PostgresProviderAuthConfig: metadata={"required": False, "description": "Postgres database name"}, default=None, ) - port: str | None = dataclasses.field( - default="5432", metadata={"required": False, "description": "Postgres port"} + port: UrlPort | None = dataclasses.field( + default=5432, + metadata={ + "required": False, + "description": "Postgres port", + "validation": "port", + }, ) @@ -104,11 +110,7 @@ def validate_config(self): **self.config.authentication ) - def _query( - self, - query: str, - **kwargs: dict - ) -> list | tuple: + def _query(self, query: str, **kwargs: dict) -> list | tuple: """ Executes a query against the Postgres database. @@ -135,11 +137,7 @@ def _query( # Close the database connection conn.close() - def _notify( - self, - query: str, - **kwargs - ): + def _notify(self, query: str, **kwargs): """ Notifies the Postgres database. """ diff --git a/keep/providers/servicenow_provider/servicenow_provider.py b/keep/providers/servicenow_provider/servicenow_provider.py index 46ed94447..eb243e42f 100644 --- a/keep/providers/servicenow_provider/servicenow_provider.py +++ b/keep/providers/servicenow_provider/servicenow_provider.py @@ -14,18 +14,20 @@ from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseTopologyProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass class ServicenowProviderAuthConfig: """ServiceNow authentication configuration.""" - service_now_base_url: str = dataclasses.field( + service_now_base_url: HttpsUrl = dataclasses.field( metadata={ "required": True, "description": "The base URL of the ServiceNow instance", "sensitive": False, "hint": "https://dev12345.service-now.com", + "validation": "https_url" } ) @@ -66,13 +68,6 @@ def __init__( ): super().__init__(context_manager, provider_id, config) - @property - def service_now_base_url(self): - # if not starts with http: - if not self.authentication_config.service_now_base_url.startswith("http"): - return f"https://{self.authentication_config.service_now_base_url}" - return self.authentication_config.service_now_base_url - def validate_scopes(self): """ Validates that the user has the required scopes to use the provider. diff --git a/keep/providers/smtp_provider/smtp_provider.py b/keep/providers/smtp_provider/smtp_provider.py index c0606349a..9b61a608f 100644 --- a/keep/providers/smtp_provider/smtp_provider.py +++ b/keep/providers/smtp_provider/smtp_provider.py @@ -4,15 +4,16 @@ import dataclasses import typing - -import pydantic -from smtplib import SMTP, SMTP_SSL from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from smtplib import SMTP, SMTP_SSL + +import pydantic from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import UrlPort @pydantic.dataclasses.dataclass @@ -42,11 +43,12 @@ class SmtpProviderAuthConfig: } ) - smtp_port: int = dataclasses.field( + smtp_port: UrlPort = dataclasses.field( metadata={ "required": True, "description": "SMTP port", "config_main_group": "authentication", + "validation": "port" } ) diff --git a/keep/providers/squadcast_provider/squadcast_provider.py b/keep/providers/squadcast_provider/squadcast_provider.py index 412b40d40..a2675a461 100644 --- a/keep/providers/squadcast_provider/squadcast_provider.py +++ b/keep/providers/squadcast_provider/squadcast_provider.py @@ -1,6 +1,7 @@ """ SquadcastProvider is a class that implements the Squadcast API and allows creating incidents and notes. """ + import dataclasses import json @@ -12,6 +13,7 @@ from keep.exceptions.provider_config_exception import ProviderConfigException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass @@ -33,12 +35,13 @@ class SquadcastProviderAuthConfig: }, default=None, ) - webhook_url: str | None = dataclasses.field( + webhook_url: HttpsUrl | None = dataclasses.field( metadata={ "required": False, "description": "Incident webhook url", "hint": "https://support.squadcast.com/integrations/incident-webhook-incident-webhook-api", "sensitive": True, + "validation": "https_url", }, default=None, ) @@ -134,7 +137,7 @@ def _create_incidents( # append body to additional_json we are doing this way because we don't want to override the core body fields body = json.dumps({**json.loads(additional_json), **json.loads(body)}) - + return requests.post( self.authentication_config.webhook_url, data=body, headers=headers ) diff --git a/keep/providers/ssh_provider/ssh_provider.py b/keep/providers/ssh_provider/ssh_provider.py index 2b430a738..14dc56356 100644 --- a/keep/providers/ssh_provider/ssh_provider.py +++ b/keep/providers/ssh_provider/ssh_provider.py @@ -13,6 +13,7 @@ from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig from keep.providers.providers_factory import ProvidersFactory +from keep.validation.fields import UrlPort @pydantic.dataclasses.dataclass @@ -26,8 +27,8 @@ class SshProviderAuthConfig: user: str = dataclasses.field( metadata={"required": True, "description": "SSH user"} ) - port: int = dataclasses.field( - default=22, metadata={"required": False, "description": "SSH port"} + port: UrlPort = dataclasses.field( + default=22, metadata={"required": False, "description": "SSH port", "validation": "port"} ) pkey: typing.Optional[str] = dataclasses.field( default=None, diff --git a/keep/providers/teams_provider/teams_provider.py b/keep/providers/teams_provider/teams_provider.py index 1749677e2..739d57da8 100644 --- a/keep/providers/teams_provider/teams_provider.py +++ b/keep/providers/teams_provider/teams_provider.py @@ -10,17 +10,19 @@ from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig +from keep.validation.fields import HttpsUrl @pydantic.dataclasses.dataclass class TeamsProviderAuthConfig: """Teams authentication configuration.""" - webhook_url: str = dataclasses.field( + webhook_url: HttpsUrl = dataclasses.field( metadata={ "required": True, "description": "Teams Webhook Url", "sensitive": True, + "validation": "https_url" } ) From fa727060e3dc199e0c8e0574802f21d87d1a0f9a Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sun, 27 Oct 2024 22:59:33 +0100 Subject: [PATCH 11/35] fix form value empty string in state bug --- keep-ui/app/providers/provider-form.tsx | 30 ++++++++++++------- .../uptimekuma_provider.py | 13 ++------ .../webhook_provider/webhook_provider.py | 3 +- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index d8ed04c8b..15e913709 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -402,15 +402,25 @@ const ProviderForm = ({ } function handleFormChange(key: string, value: ProviderFormValue) { - setFormValues((prev) => { - const prevValue = prev[key]; - const updatedValues = { - ...prev, - [key]: - Array.isArray(value) && Array.isArray(prevValue) ? [...value] : value, - }; - return updatedValues; - }); + if (typeof value === "string" && value.trim().length === 0) { + setFormValues((prev) => { + const updated = structuredClone(prev); + delete updated[key]; + return updated; + }); + } else { + setFormValues((prev) => { + const prevValue = prev[key]; + const updatedValues = { + ...prev, + [key]: + Array.isArray(value) && Array.isArray(prevValue) + ? [...value] + : value, + }; + return updatedValues; + }); + } if (Object.keys(inputErrors).includes(key) && value !== "") { const updatedInputErrors = { ...inputErrors }; @@ -1032,7 +1042,7 @@ function TextField({ type={config.sensitive ? "password" : "text"} id={id} name={id} - value={value?.toString()} + value={value?.toString() ?? ""} onChange={onChange} autoComplete="off" error={Boolean(error)} diff --git a/keep/providers/uptimekuma_provider/uptimekuma_provider.py b/keep/providers/uptimekuma_provider/uptimekuma_provider.py index 0179db1b2..865c468d2 100644 --- a/keep/providers/uptimekuma_provider/uptimekuma_provider.py +++ b/keep/providers/uptimekuma_provider/uptimekuma_provider.py @@ -9,7 +9,6 @@ from keep.api.models.alert import AlertDto, AlertStatus from keep.contextmanager.contextmanager import ContextManager -from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope @@ -20,13 +19,13 @@ class UptimekumaProviderAuthConfig: UptimekumaProviderAuthConfig is a class that holds the authentication information for the UptimekumaProvider. """ - host_url: str = dataclasses.field( + host_url: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "UptimeKuma Host URL", "sensitive": False, + "validation": "any_http_url" }, - default=None, ) username: str = dataclasses.field( @@ -35,7 +34,6 @@ class UptimekumaProviderAuthConfig: "description": "UptimeKuma Username", "sensitive": False, }, - default=None, ) password: str = dataclasses.field( @@ -44,7 +42,6 @@ class UptimekumaProviderAuthConfig: "description": "UptimeKuma Password", "sensitive": True, }, - default=None, ) @@ -89,12 +86,6 @@ def validate_config(self): self.authentication_config = UptimekumaProviderAuthConfig( **self.config.authentication ) - if self.authentication_config.host_url is None: - raise ProviderException("UptimeKuma Host URL is required") - if self.authentication_config.username is None: - raise ProviderException("UptimeKuma Username is required") - if self.authentication_config.password is None: - raise ProviderException("UptimeKuma Password is required") def _get_heartbeats(self): try: diff --git a/keep/providers/webhook_provider/webhook_provider.py b/keep/providers/webhook_provider/webhook_provider.py index aa6976579..97f41604e 100644 --- a/keep/providers/webhook_provider/webhook_provider.py +++ b/keep/providers/webhook_provider/webhook_provider.py @@ -22,10 +22,11 @@ class WebhookProviderAuthConfig: Webhook authentication configuration. """ - url: str = dataclasses.field( + url: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Webhook URL", + "validation": "any_http_url" } ) From 990ff831e64076b3fb91f62fc4a0e9ae02ae48fd Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Thu, 31 Oct 2024 19:55:27 +0100 Subject: [PATCH 12/35] remove `http_url` validation type --- keep/providers/centreon_provider/centreon_provider.py | 4 ++-- keep/providers/datadog_provider/datadog_provider.py | 5 +++-- keep/providers/elastic_provider/elastic_provider.py | 4 ++-- keep/providers/grafana_provider/grafana_provider.py | 11 ++--------- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/keep/providers/centreon_provider/centreon_provider.py b/keep/providers/centreon_provider/centreon_provider.py index 3c02111d6..039c59b67 100644 --- a/keep/providers/centreon_provider/centreon_provider.py +++ b/keep/providers/centreon_provider/centreon_provider.py @@ -21,12 +21,12 @@ class CentreonProviderAuthConfig: CentreonProviderAuthConfig is a class that holds the authentication information for the CentreonProvider. """ - host_url: pydantic.HttpUrl = dataclasses.field( + host_url: pydantic.AnyHttpUrl | None = dataclasses.field( metadata={ "required": True, "description": "Centreon Host URL", "sensitive": False, - "validation": "http_url", + "validation": "any_http_url", }, default=None, ) diff --git a/keep/providers/datadog_provider/datadog_provider.py b/keep/providers/datadog_provider/datadog_provider.py index 03d91e3ab..c9dcdb3ce 100644 --- a/keep/providers/datadog_provider/datadog_provider.py +++ b/keep/providers/datadog_provider/datadog_provider.py @@ -38,6 +38,7 @@ from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.models.provider_method import ProviderMethod from keep.providers.providers_factory import ProvidersFactory +from keep.validation.fields import HttpsUrl logger = logging.getLogger(__name__) @@ -68,13 +69,13 @@ class DatadogProviderAuthConfig: }, default="", ) - domain: pydantic.HttpUrl = dataclasses.field( + domain: HttpsUrl = dataclasses.field( metadata={ "required": False, "description": "Datadog API domain", "sensitive": False, "hint": "https://api.datadoghq.com", - "validation": "http_url" + "validation": "https_url" }, default="https://api.datadoghq.com", ) diff --git a/keep/providers/elastic_provider/elastic_provider.py b/keep/providers/elastic_provider/elastic_provider.py index d414c578a..541929499 100644 --- a/keep/providers/elastic_provider/elastic_provider.py +++ b/keep/providers/elastic_provider/elastic_provider.py @@ -27,12 +27,12 @@ class ElasticProviderAuthConfig: "sensitive": True, } ) - host: typing.Optional[pydantic.HttpUrl] = dataclasses.field( + host: pydantic.AnyHttpUrl | None = dataclasses.field( default=None, metadata={ "required": False, "description": "Elasticsearch host", - "validation": "http_url", + "validation": "any_http_url", }, ) cloud_id: typing.Optional[str] = dataclasses.field( diff --git a/keep/providers/grafana_provider/grafana_provider.py b/keep/providers/grafana_provider/grafana_provider.py index 607f60a0d..d966f8a31 100644 --- a/keep/providers/grafana_provider/grafana_provider.py +++ b/keep/providers/grafana_provider/grafana_provider.py @@ -35,12 +35,12 @@ class GrafanaProviderAuthConfig: "sensitive": True, }, ) - host: pydantic.HttpUrl = dataclasses.field( + host: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Grafana host", "hint": "e.g. https://keephq.grafana.net", - "validation": "http_url" + "validation": "any_http_url" }, ) @@ -110,17 +110,10 @@ def dispose(self): def validate_config(self): """ Validates required configuration for Grafana provider. - """ self.authentication_config = GrafanaProviderAuthConfig( **self.config.authentication ) - if not self.authentication_config.host.startswith( - "https://" - ) and not self.authentication_config.host.startswith("http://"): - self.authentication_config.host = ( - f"https://{self.authentication_config.host}" - ) def validate_scopes(self) -> dict[str, bool | str]: headers = {"Authorization": f"Bearer {self.authentication_config.token}"} From 6e78687d35c973e5fdcb5754bd42fcb1e3de6f7c Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Thu, 31 Oct 2024 19:57:25 +0100 Subject: [PATCH 13/35] add validation tests for `any_http_url1 and `port` --- tests/e2e_tests/test_end_to_end.py | 50 ++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/tests/e2e_tests/test_end_to_end.py b/tests/e2e_tests/test_end_to_end.py index 7dcabfa25..5e9e4240a 100644 --- a/tests/e2e_tests/test_end_to_end.py +++ b/tests/e2e_tests/test_end_to_end.py @@ -9,8 +9,8 @@ # for mysql: docker compose --project-directory . -f tests/e2e_tests/docker-compose-e2e-mysql.yml up -d # for postgres: docker compose --project-directory . -f tests/e2e_tests/docker-compose-e2e-postgres.yml up -d # 2. Run the tests using pytest. -# e.g. poetry run coverage run --branch -m pytest -s tests/e2e_tests/ -# NOTE: to clean the database, run +# e.g. poetry run coverage run --branch -m pytest -s tests/e2e_tests/ +# NOTE: to clean the database, run # docker compose stop # docker compose --project-directory . -f tests/e2e_tests/docker-compose-e2e-mysql.yml down --volumes # docker compose --project-directory . -f tests/e2e_tests/docker-compose-e2e-postgres.yml down --volumes @@ -33,10 +33,11 @@ # - Spin up the environment using docker-compose. # - Run "playwright codegen localhost:3000" # - Copy the generated code to a new test function. -import re import string import sys +from playwright.sync_api import expect + # Running the tests in GitHub Actions: # - Look at the test-pr-e2e.yml file in the .github/workflows directory. @@ -150,3 +151,46 @@ def test_providers_page_is_accessible(browser): with open(current_test_name + ".html", "w") as f: f.write(browser.content()) raise + + +def test_provider_validation(browser): + """ + Test field validation for provider fields. + """ + browser.goto( + "http://localhost:3000/signin?callbackUrl=http%3A%2F%2Flocalhost%3A3000%2Fproviders" + ) + # using Kibana Provider + browser.goto("http://localhost:3000/providers") + browser.locator("button:has-text('Kibana'):has-text('alert')").click() + # test required fields + connect_btn = browser.get_by_role("button", name="Connect", exact=True) + connect_btn.click() + expect(browser.get_by_text("This field is required")).to_have_count(3) + # test `any_http_url` field validation + browser.get_by_placeholder("Enter provider name").fill("random name") + browser.get_by_placeholder("Enter api_key").fill("random api key") + browser.get_by_placeholder("Enter kibana_host").fill("invalid url") + connect_btn.click() + expect(browser.locator("p.tremor-TextInput-errorMessage")).to_have_count(1) + browser.get_by_placeholder("Enter kibana_host").fill("http://localhost") + connect_btn.click() + expect(browser.locator("p.tremor-TextInput-errorMessage")).to_be_hidden() + browser.get_by_placeholder("Enter kibana_host").fill( + "https://keep.kb.us-central1.gcp.cloud.es.io" + ) + connect_btn.click() + expect(browser.locator("p.tremor-TextInput-errorMessage")).to_be_hidden() + # test `port` field validation + browser.get_by_placeholder("Enter kibana_port").fill("invalid port") + connect_btn.click() + expect(browser.locator("p.tremor-TextInput-errorMessage")).to_have_count(1) + browser.get_by_placeholder("Enter kibana_port").fill("0") + connect_btn.click() + expect(browser.locator("p.tremor-TextInput-errorMessage")).to_have_count(1) + browser.get_by_placeholder("Enter kibana_port").fill("65_536") + connect_btn.click() + expect(browser.locator("p.tremor-TextInput-errorMessage")).to_have_count(1) + browser.get_by_placeholder("Enter kibana_port").fill("9243") + connect_btn.click() + expect(browser.locator("p.tremor-TextInput-errorMessage")).to_be_hidden() From a18e2fa852c095cdce81a7f4f6fcb392aa11d91f Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Fri, 1 Nov 2024 03:54:35 +0100 Subject: [PATCH 14/35] add validation tests for https_url, any_url & tld --- tests/e2e_tests/test_end_to_end.py | 68 ++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 8 deletions(-) diff --git a/tests/e2e_tests/test_end_to_end.py b/tests/e2e_tests/test_end_to_end.py index 5e9e4240a..f74108df9 100644 --- a/tests/e2e_tests/test_end_to_end.py +++ b/tests/e2e_tests/test_end_to_end.py @@ -165,32 +165,84 @@ def test_provider_validation(browser): browser.locator("button:has-text('Kibana'):has-text('alert')").click() # test required fields connect_btn = browser.get_by_role("button", name="Connect", exact=True) + error_msg = browser.locator("p.tremor-TextInput-errorMessage") connect_btn.click() - expect(browser.get_by_text("This field is required")).to_have_count(3) + expect(error_msg).to_have_count(3) # test `any_http_url` field validation browser.get_by_placeholder("Enter provider name").fill("random name") browser.get_by_placeholder("Enter api_key").fill("random api key") browser.get_by_placeholder("Enter kibana_host").fill("invalid url") connect_btn.click() - expect(browser.locator("p.tremor-TextInput-errorMessage")).to_have_count(1) + expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter kibana_host").fill("http://localhost") connect_btn.click() - expect(browser.locator("p.tremor-TextInput-errorMessage")).to_be_hidden() + expect(error_msg).to_be_hidden() browser.get_by_placeholder("Enter kibana_host").fill( "https://keep.kb.us-central1.gcp.cloud.es.io" ) connect_btn.click() - expect(browser.locator("p.tremor-TextInput-errorMessage")).to_be_hidden() + expect(error_msg).to_be_hidden() # test `port` field validation browser.get_by_placeholder("Enter kibana_port").fill("invalid port") connect_btn.click() - expect(browser.locator("p.tremor-TextInput-errorMessage")).to_have_count(1) + expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter kibana_port").fill("0") connect_btn.click() - expect(browser.locator("p.tremor-TextInput-errorMessage")).to_have_count(1) + expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter kibana_port").fill("65_536") connect_btn.click() - expect(browser.locator("p.tremor-TextInput-errorMessage")).to_have_count(1) + expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter kibana_port").fill("9243") connect_btn.click() - expect(browser.locator("p.tremor-TextInput-errorMessage")).to_be_hidden() + expect(error_msg).to_be_hidden() + + # using Teams Provider + browser.goto("http://localhost:3000/providers") + browser.locator("button:has-text('Teams'):has-text('messaging')").click() + # test `https_url` field validation + browser.get_by_placeholder("Enter provider name").fill("random name") + browser.get_by_placeholder("Enter webhook_url").fill("random url") + connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter webhook_url").fill("http://localhost") + connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter webhook_url").fill("http://example.com") + connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter webhook_url").fill("https://example.com") + connect_btn.click() + expect(error_msg).to_be_hidden() + + # using Site24x7 Provider + browser.goto("http://localhost:3000/providers") + browser.locator("button:has-text('Site24x7'):has-text('alert')").click() + # test `tld` field validation + browser.get_by_placeholder("Enter provider name").fill("random name") + browser.get_by_placeholder("Enter zohoRefreshToken").fill("random") + browser.get_by_placeholder("Enter zohoClientId").fill("random") + browser.get_by_placeholder("Enter zohoClientSecret").fill("random") + browser.get_by_placeholder("Enter zohoAccountTLD").fill("") + connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter zohoAccountTLD").fill("random") + connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter zohoAccountTLD").fill(".com") + connect_btn.click() + expect(error_msg).to_be_hidden() + + # using MongoDB Provider + browser.goto("http://localhost:3000/providers") + browser.locator("button:has-text('MongoDB'):has-text('data')").click() + # test `any_url` field validation + browser.get_by_placeholder("Enter provider name").fill("random name") + browser.get_by_placeholder("Enter host").fill("random") + connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter host").fill("host.com:5000") + connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter host").fill("mongodb://host.com:3000") + connect_btn.click() + expect(error_msg).to_be_hidden() From c7b939ffafb1976166c4eed0e742b145d66b5c4e Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sat, 2 Nov 2024 22:19:52 +0100 Subject: [PATCH 15/35] add provider form validation file --- keep-ui/app/providers/form-validation.ts | 232 +++++++++++++++++++++++ keep-ui/app/providers/provider-form.tsx | 133 +------------ 2 files changed, 233 insertions(+), 132 deletions(-) create mode 100644 keep-ui/app/providers/form-validation.ts diff --git a/keep-ui/app/providers/form-validation.ts b/keep-ui/app/providers/form-validation.ts new file mode 100644 index 000000000..9b29921ed --- /dev/null +++ b/keep-ui/app/providers/form-validation.ts @@ -0,0 +1,232 @@ +import { z } from "zod"; +import { Provider } from "./providers"; + +type UrlOptions = { + protocols: string[]; + requireTld: boolean; + requireProtocol: boolean; + requirePort: boolean; + validateLength: boolean; + maxLength: number; +}; + +type ValidatorRes = { success: true } | { success: false; msg: string }; + +const defaultUrlOptions: UrlOptions = { + protocols: [], + requireTld: false, + requireProtocol: true, + requirePort: false, + validateLength: true, + maxLength: 2 ** 16, +}; + +function mergeOptions>( + defaults: T, + opts?: T +) { + if (!opts) return defaults; + for (const key in defaults) { + if (typeof opts[key] === "undefined") { + opts[key] = defaults[key]; + } + } + return opts; +} + +const error = (msg: string) => ({ success: false, msg }); +const urlError = { success: false, msg: "Please provide a valid URL." }; +const protocolError = { + success: false, + msg: "A valid URL protocol is required.", +}; +const relProtocolError = { + success: false, + msg: "A protocol-relavie URL is not allowed.", +}; + +function getProtocolError(opts: UrlOptions["protocols"]) { + if (opts.length === 0) return protocolError; + if (opts.length === 1) return error(`A URL with \`${opts[0]}\` is required.`); + if (opts.length === 2) + return error(`A URL with \`${opts[0]}\` or \`${opts[1]}\` is required.`); + const lst = opts.length - 1; + const wrap = (x: string, y: string) => `\`${x} + ${y}\``; + const optsStr = opts.reduce( + (acc, p, i) => + i === 0 + ? wrap(acc, p) + : i === lst + ? wrap(acc, `or ${p}`) + : wrap(acc, `, ${p}`), + "" + ); + return error(`A URL with one of ${optsStr} is required.`); +} + +function isUrl(url: string, options?: UrlOptions): ValidatorRes { + const opts = mergeOptions(defaultUrlOptions, options); + + if (url.length === 0 || /[\s<>]/.test(url)) return urlError; + if (opts.validateLength && url.length > opts.maxLength) { + return { + success: false, + msg: `Invalid url length, max of ${opts.maxLength} expected.`, + }; + } + + let _url = url; + let protocol: string; + let host: string; + let hostname: string; + let port: number; + let portStr: string | null; + let split: string[]; + let ipv6: string | null; + + split = url.split("#"); + _url = split.shift() ?? ""; + + split = url.split("?"); + _url = split.shift() ?? ""; + + if (_url.slice(0, 2) === "//") return relProtocolError; + + split = url.split("://"); + protocol = split?.shift()?.toLowerCase() ?? ""; + if (opts.requireProtocol && opts.protocols.indexOf(protocol) === -1) + return getProtocolError(opts.protocols); + + return { success: true }; +} + +export function getZodSchema(fields: Provider["config"]) { + const required_error = "This field is required"; + const portError = "Invalid port number"; + const emptyStringToNull = z + .string() + .optional() + .transform((val) => (val?.length === 0 ? null : val)); + const kvPairs = Object.entries(fields).map(([field, config]) => { + if (config.type === "form") { + const baseFormSchema = z.record(z.string(), z.string()).array(); + const formSchema = config.required + ? baseFormSchema.nonempty({ + message: "At least one key-value entry should be provided.", + }) + : baseFormSchema.optional(); + return [field, formSchema]; + } + + if (config.type === "file") { + const baseFileSchema = z + .instanceof(File, { message: "Please upload a file here." }) + .refine( + (file) => { + if (config.file_type == undefined) return true; + if (config.file_type.length <= 1) return true; + return config.file_type.includes(file.type); + }, + { + message: + config.file_type && config.file_type?.split(",").length > 1 + ? `File type should be one of ${config.file_type}.` + : `File should be of type ${config.file_type}.`, + } + ); + const fileSchema = config.required + ? baseFileSchema + : baseFileSchema.optional(); + return [field, fileSchema]; + } + + if (config.type === "switch") { + const switchSchema = config.required + ? z.boolean() + : z.boolean().optional(); + return [field, switchSchema]; + } + + const urlSchema = z.string({ required_error }).url({ + message: + "Please provide a valid url, with a scheme & hostname as required.", + }); + const urlTldSchema = z.string().regex(new RegExp(/\.[a-z]{2,63}$/), { + message: "Url must contain a valid TLD e.g .com, .io, .dev, .net", + }); + const baseAnyHttpSchema = urlSchema.refine( + (url) => url.startsWith("http://") || url.startsWith("https://"), + { message: "A url with `http` or `https` protocol is reuquired." } + ); + const baseHttpSchema = baseAnyHttpSchema.and(urlTldSchema); + const baseHttpsSchema = urlSchema + .refine((url) => url.startsWith("https://"), { + message: "A url with `https` protocol is required.", + }) + .and(urlTldSchema); + + if (config.validation === "any_url") { + const anyUrlSchema = config.required + ? urlSchema + : emptyStringToNull.pipe(urlSchema.nullish()); + return [field, anyUrlSchema]; + } + + if (config.validation === "any_http_url") { + const anyHttpSchema = config.required + ? baseAnyHttpSchema + : emptyStringToNull.pipe(baseAnyHttpSchema.nullish()); + return [field, anyHttpSchema]; + } + + if (config.validation === "http_url") { + const httpSchema = config.required + ? baseHttpSchema + : emptyStringToNull.pipe(baseHttpSchema.nullish()); + return [field, httpSchema]; + } + if (config.validation === "https_url") { + const httpsSchema = config.required + ? baseHttpsSchema + : emptyStringToNull.pipe(baseHttpsSchema.nullish()); + return [field, httpsSchema]; + } + if (config.validation === "tld") { + const baseTldSchema = z + .string({ required_error }) + .regex(new RegExp(/\.[a-z]{2,63}$/), { + message: "Please provide a valid TLD e.g .com, .io, .dev, .net", + }); + const tldSchema = config.required + ? baseTldSchema + : baseTldSchema.optional(); + return [field, tldSchema]; + } + if (config.validation === "port") { + const basePortSchema = z.coerce + .number({ required_error, invalid_type_error: portError }) + .min(1, { message: portError }) + .max(65_535, { message: portError }); + const portSchema = config.required + ? basePortSchema + : emptyStringToNull.pipe(basePortSchema.nullish()); + return [field, portSchema]; + } + return [ + field, + config.required + ? z + .string({ required_error }) + .trim() + .min(1, { message: required_error }) + : z.string().optional(), + ]; + }); + return z.object({ + provider_name: z + .string({ required_error }) + .trim() + .min(1, { message: required_error }), + ...Object.fromEntries(kvPairs), + }); +} diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index 15e913709..8e3085647 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -52,7 +52,7 @@ import { useSearchParams } from "next/navigation"; import "./provider-form.css"; import { toast } from "react-toastify"; import { useProviders } from "@/utils/hooks/useProviders"; -import { z } from "zod"; +import { getZodSchema } from "./form-validation"; type ProviderFormProps = { provider: Provider; @@ -157,137 +157,6 @@ function getInitialFormValues(provider: Provider) { return initialValues; } -function getZodSchema(fields: Provider["config"]) { - const required_error = "This field is required"; - const portError = "Invalid port number"; - const emptyStringToNull = z - .string() - .optional() - .transform((val) => (val?.length === 0 ? null : val)); - const kvPairs = Object.entries(fields).map(([field, config]) => { - if (config.type === "form") { - const baseFormSchema = z.record(z.string(), z.string()).array(); - const formSchema = config.required - ? baseFormSchema.nonempty({ - message: "At least one key-value entry should be provided.", - }) - : baseFormSchema.optional(); - return [field, formSchema]; - } - - if (config.type === "file") { - const baseFileSchema = z - .instanceof(File, { message: "Please upload a file here." }) - .refine( - (file) => { - if (config.file_type == undefined) return true; - if (config.file_type.length <= 1) return true; - return config.file_type.includes(file.type); - }, - { - message: - config.file_type && config.file_type?.split(",").length > 1 - ? `File type should be one of ${config.file_type}.` - : `File should be of type ${config.file_type}.`, - } - ); - const fileSchema = config.required - ? baseFileSchema - : baseFileSchema.optional(); - return [field, fileSchema]; - } - - if (config.type === "switch") { - const switchSchema = config.required - ? z.boolean() - : z.boolean().optional(); - return [field, switchSchema]; - } - - const urlSchema = z.string({ required_error }).url({ - message: - "Please provide a valid url, with a scheme & hostname as required.", - }); - const urlTldSchema = z.string().regex(new RegExp(/\.[a-z]{2,63}$/), { - message: "Url must contain a valid TLD e.g .com, .io, .dev, .net", - }); - const baseAnyHttpSchema = urlSchema.refine( - (url) => url.startsWith("http://") || url.startsWith("https://"), - { message: "A url with `http` or `https` protocol is reuquired." } - ); - const baseHttpSchema = baseAnyHttpSchema.and(urlTldSchema); - const baseHttpsSchema = urlSchema - .refine((url) => url.startsWith("https://"), { - message: "A url with `https` protocol is required.", - }) - .and(urlTldSchema); - - if (config.validation === "any_url") { - const anyUrlSchema = config.required - ? urlSchema - : emptyStringToNull.pipe(urlSchema.nullish()); - return [field, anyUrlSchema]; - } - - if (config.validation === "any_http_url") { - const anyHttpSchema = config.required - ? baseAnyHttpSchema - : emptyStringToNull.pipe(baseAnyHttpSchema.nullish()); - return [field, anyHttpSchema]; - } - - if (config.validation === "http_url") { - const httpSchema = config.required - ? baseHttpSchema - : emptyStringToNull.pipe(baseHttpSchema.nullish()); - return [field, httpSchema]; - } - if (config.validation === "https_url") { - const httpsSchema = config.required - ? baseHttpsSchema - : emptyStringToNull.pipe(baseHttpsSchema.nullish()); - return [field, httpsSchema]; - } - if (config.validation === "tld") { - const baseTldSchema = z - .string({ required_error }) - .regex(new RegExp(/\.[a-z]{2,63}$/), { - message: "Please provide a valid TLD e.g .com, .io, .dev, .net", - }); - const tldSchema = config.required - ? baseTldSchema - : baseTldSchema.optional(); - return [field, tldSchema]; - } - if (config.validation === "port") { - const basePortSchema = z.coerce - .number({ required_error, invalid_type_error: portError }) - .min(1, { message: portError }) - .max(65_535, { message: portError }); - const portSchema = config.required - ? basePortSchema - : emptyStringToNull.pipe(basePortSchema.nullish()); - return [field, portSchema]; - } - return [ - field, - config.required - ? z - .string({ required_error }) - .trim() - .min(1, { message: required_error }) - : z.string().optional(), - ]; - }); - return z.object({ - provider_name: z - .string({ required_error }) - .trim() - .min(1, { message: required_error }), - ...Object.fromEntries(kvPairs), - }); -} - const providerNameFieldConfig: ProviderAuthConfig = { required: true, description: "Provider Name", From 4e77c5c54f9d2c8e13bf13a1506d3734e22bf220 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Mon, 4 Nov 2024 23:48:23 +0100 Subject: [PATCH 16/35] add new validation logic for url --- keep-ui/app/providers/form-validation.ts | 205 +++++++++++++++-------- 1 file changed, 136 insertions(+), 69 deletions(-) diff --git a/keep-ui/app/providers/form-validation.ts b/keep-ui/app/providers/form-validation.ts index 9b29921ed..d7dff4325 100644 --- a/keep-ui/app/providers/form-validation.ts +++ b/keep-ui/app/providers/form-validation.ts @@ -1,7 +1,7 @@ import { z } from "zod"; import { Provider } from "./providers"; -type UrlOptions = { +type URLOptions = { protocols: string[]; requireTld: boolean; requireProtocol: boolean; @@ -12,7 +12,7 @@ type UrlOptions = { type ValidatorRes = { success: true } | { success: false; msg: string }; -const defaultUrlOptions: UrlOptions = { +const defaultURLOptions: URLOptions = { protocols: [], requireTld: false, requireProtocol: true, @@ -23,81 +23,159 @@ const defaultUrlOptions: UrlOptions = { function mergeOptions>( defaults: T, - opts?: T -) { + opts?: Partial +): T { if (!opts) return defaults; - for (const key in defaults) { - if (typeof opts[key] === "undefined") { - opts[key] = defaults[key]; - } - } - return opts; + return { ...defaults, ...opts }; } const error = (msg: string) => ({ success: false, msg }); -const urlError = { success: false, msg: "Please provide a valid URL." }; -const protocolError = { - success: false, - msg: "A valid URL protocol is required.", -}; -const relProtocolError = { - success: false, - msg: "A protocol-relavie URL is not allowed.", -}; +const urlError = error("Please provide a valid URL."); +const protocolError = error("A valid URL protocol is required."); +const relProtocolError = error("A protocol-relavie URL is not allowed."); +const missingPortError = error("A URL with a port number is required."); +const portError = error("Invalid port number."); +const hostError = error("Invalid URL host."); +const hostWildcardError = error("Wildcard in URL host is not allowed"); +const tldError = error( + "URL must contain a valid TLD e.g .com, .io, .dev, .net" +); -function getProtocolError(opts: UrlOptions["protocols"]) { +function getProtocolError(opts: URLOptions["protocols"]) { if (opts.length === 0) return protocolError; - if (opts.length === 1) return error(`A URL with \`${opts[0]}\` is required.`); + if (opts.length === 1) + return error(`A URL with \`${opts[0]}\` protocol is required.`); if (opts.length === 2) - return error(`A URL with \`${opts[0]}\` or \`${opts[1]}\` is required.`); + return error( + `A URL with \`${opts[0]}\` or \`${opts[1]}\` protocol is required.` + ); const lst = opts.length - 1; - const wrap = (x: string, y: string) => `\`${x} + ${y}\``; + const wrap = (acc: string, p: string) => acc + `\`${p}\``; const optsStr = opts.reduce( (acc, p, i) => - i === 0 + i === lst ? wrap(acc, p) - : i === lst - ? wrap(acc, `or ${p}`) - : wrap(acc, `, ${p}`), + : i === lst - 1 + ? wrap(acc, p) + " or " + : wrap(acc, p) + ", ", "" ); - return error(`A URL with one of ${optsStr} is required.`); + return error(`A URL with one of ${optsStr} protocols is required.`); +} + +function isFQDN(str: string, options?: Partial): ValidatorRes { + const opts = mergeOptions(defaultURLOptions, options); + + if (str[str.length - 1] === ".") return hostError; // trailing dot not allowed + if (str.indexOf("*.") === 0) return hostWildcardError; // wildcard not allowed + + const parts = str.split("."); + const tld = parts[parts.length - 1]; + const tldRegex = + /^([a-z\u00A1-\u00A8\u00AA-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF]{2,}|xn[a-z0-9-]{2,})$/i; + + if ( + opts.requireTld && + (parts.length < 2 || !tldRegex.test(tld) || /\s/.test(tld)) + ) + return tldError; + + const partsValid = parts.every((part) => { + if (!/^[a-z_\u00a1-\uffff0-9-]+$/i.test(part)) { + return false; + } + + // disallow full-width chars + if (/[\uff01-\uff5e]/.test(part)) { + return false; + } + + // disallow parts starting or ending with hyphen + if (/^-|-$/.test(part)) { + return false; + } + + return true; + }); + + return partsValid ? { success: true } : hostError; +} + +function isIP(str: string) { + const validation = z.string().ip().safeParse(str); + return validation.success; } -function isUrl(url: string, options?: UrlOptions): ValidatorRes { - const opts = mergeOptions(defaultUrlOptions, options); +function isURL(str: string, options?: Partial): ValidatorRes { + const opts = mergeOptions(defaultURLOptions, options); - if (url.length === 0 || /[\s<>]/.test(url)) return urlError; - if (opts.validateLength && url.length > opts.maxLength) { - return { - success: false, - msg: `Invalid url length, max of ${opts.maxLength} expected.`, - }; + if (str.length === 0 || /[\s<>]/.test(str)) return urlError; + if (opts.validateLength && str.length > opts.maxLength) { + return error(`Invalid url length, max of ${opts.maxLength} expected.`); } - let _url = url; - let protocol: string; + let url = str; let host: string; - let hostname: string; let port: number; - let portStr: string | null; + let portStr: string = ""; let split: string[]; - let ipv6: string | null; split = url.split("#"); - _url = split.shift() ?? ""; + url = split.shift() ?? ""; split = url.split("?"); - _url = split.shift() ?? ""; + url = split.shift() ?? ""; - if (_url.slice(0, 2) === "//") return relProtocolError; + if (url.slice(0, 2) === "//") return relProtocolError; split = url.split("://"); - protocol = split?.shift()?.toLowerCase() ?? ""; + const protocol = split?.shift()?.toLowerCase() ?? ""; if (opts.requireProtocol && opts.protocols.indexOf(protocol) === -1) return getProtocolError(opts.protocols); + url = split.join("://"); + + split = url.split("/"); + url = split.shift() ?? ""; + if (!url.length) return urlError; + + split = url.split("@"); + if (split.length > 1 && !split[0]) return urlError; + if (split.length > 1) { + const auth = split.shift() ?? ""; + if (auth.split(":").length > 2) return urlError; + const [user, pass] = auth.split(":"); + if (!user && !pass) return urlError; + } + + const hostname = split.join("@"); + const wrapped_ipv6 = /^\[([^\]]+)\](?::([0-9]+))?$/; + const ipv6Match = hostname.match(wrapped_ipv6); + if (ipv6Match) { + host = ipv6Match[1]; + portStr = ipv6Match[2]; + } else { + split = hostname.split(":"); + host = split.shift() ?? ""; + if (split.length) portStr = split.join(":"); + } + + if (portStr.length) { + port = parseInt(portStr, 10); + if (Number.isNaN(port)) return missingPortError; + if (port <= 0 || port > 65_535) return portError; + } else if (opts.requirePort) return missingPortError; - return { success: true }; + if (!host) return hostError; + if (isIP(host)) return { success: true }; + return isFQDN(host); +} + +function addZodErr(valdn: ValidatorRes, ctx: z.RefinementCtx) { + if (valdn.success) return; + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: valdn.msg, + }); } export function getZodSchema(fields: Provider["config"]) { @@ -107,6 +185,7 @@ export function getZodSchema(fields: Provider["config"]) { .string() .optional() .transform((val) => (val?.length === 0 ? null : val)); + const kvPairs = Object.entries(fields).map(([field, config]) => { if (config.type === "form") { const baseFormSchema = z.record(z.string(), z.string()).array(); @@ -147,23 +226,19 @@ export function getZodSchema(fields: Provider["config"]) { return [field, switchSchema]; } - const urlSchema = z.string({ required_error }).url({ - message: - "Please provide a valid url, with a scheme & hostname as required.", + const urlStr = z.string({ required_error }); + const urlSchema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url); + addZodErr(valdn, ctx); }); - const urlTldSchema = z.string().regex(new RegExp(/\.[a-z]{2,63}$/), { - message: "Url must contain a valid TLD e.g .com, .io, .dev, .net", + const baseAnyHttpSchema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url, { protocols: ["http", "https"] }); + addZodErr(valdn, ctx); + }); + const baseHttpsSchema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url, { requireTld: true, protocols: ["https"] }); + addZodErr(valdn, ctx); }); - const baseAnyHttpSchema = urlSchema.refine( - (url) => url.startsWith("http://") || url.startsWith("https://"), - { message: "A url with `http` or `https` protocol is reuquired." } - ); - const baseHttpSchema = baseAnyHttpSchema.and(urlTldSchema); - const baseHttpsSchema = urlSchema - .refine((url) => url.startsWith("https://"), { - message: "A url with `https` protocol is required.", - }) - .and(urlTldSchema); if (config.validation === "any_url") { const anyUrlSchema = config.required @@ -171,20 +246,12 @@ export function getZodSchema(fields: Provider["config"]) { : emptyStringToNull.pipe(urlSchema.nullish()); return [field, anyUrlSchema]; } - if (config.validation === "any_http_url") { const anyHttpSchema = config.required ? baseAnyHttpSchema : emptyStringToNull.pipe(baseAnyHttpSchema.nullish()); return [field, anyHttpSchema]; } - - if (config.validation === "http_url") { - const httpSchema = config.required - ? baseHttpSchema - : emptyStringToNull.pipe(baseHttpSchema.nullish()); - return [field, httpSchema]; - } if (config.validation === "https_url") { const httpsSchema = config.required ? baseHttpsSchema From 0369bd6f9c5e1d2dd97ccbbe1c5a130930acd1b3 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Wed, 6 Nov 2024 00:32:25 +0100 Subject: [PATCH 17/35] add validation for urls without scheme --- keep-ui/app/providers/form-validation.ts | 41 ++++++++----- keep-ui/app/providers/providers.tsx | 2 +- .../postgres_provider/postgres_provider.py | 10 +++- keep/validation/fields.py | 58 ++++++++++++++++++- 4 files changed, 91 insertions(+), 20 deletions(-) diff --git a/keep-ui/app/providers/form-validation.ts b/keep-ui/app/providers/form-validation.ts index d7dff4325..aa4771a32 100644 --- a/keep-ui/app/providers/form-validation.ts +++ b/keep-ui/app/providers/form-validation.ts @@ -167,7 +167,7 @@ function isURL(str: string, options?: Partial): ValidatorRes { if (!host) return hostError; if (isIP(host)) return { success: true }; - return isFQDN(host); + return isFQDN(host, opts); } function addZodErr(valdn: ValidatorRes, ctx: z.RefinementCtx) { @@ -227,37 +227,51 @@ export function getZodSchema(fields: Provider["config"]) { } const urlStr = z.string({ required_error }); - const urlSchema = urlStr.superRefine((url, ctx) => { - const valdn = isURL(url); - addZodErr(valdn, ctx); - }); - const baseAnyHttpSchema = urlStr.superRefine((url, ctx) => { - const valdn = isURL(url, { protocols: ["http", "https"] }); - addZodErr(valdn, ctx); - }); - const baseHttpsSchema = urlStr.superRefine((url, ctx) => { - const valdn = isURL(url, { requireTld: true, protocols: ["https"] }); - addZodErr(valdn, ctx); - }); if (config.validation === "any_url") { + const urlSchema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url); + addZodErr(valdn, ctx); + }); const anyUrlSchema = config.required ? urlSchema : emptyStringToNull.pipe(urlSchema.nullish()); return [field, anyUrlSchema]; } + if (config.validation === "any_http_url") { + const baseAnyHttpSchema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url, { protocols: ["http", "https"] }); + addZodErr(valdn, ctx); + }); const anyHttpSchema = config.required ? baseAnyHttpSchema : emptyStringToNull.pipe(baseAnyHttpSchema.nullish()); return [field, anyHttpSchema]; } + if (config.validation === "https_url") { + const baseHttpsSchema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url, { requireTld: true, protocols: ["https"] }); + addZodErr(valdn, ctx); + }); const httpsSchema = config.required ? baseHttpsSchema : emptyStringToNull.pipe(baseHttpsSchema.nullish()); return [field, httpsSchema]; } + + if (config.validation === "no_scheme_url") { + const baseNoSchemeSchema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url, { requireProtocol: false }); + addZodErr(valdn, ctx); + }); + const noSchemeSchema = config.required + ? baseNoSchemeSchema + : emptyStringToNull.pipe(baseNoSchemeSchema.nullish()); + return [field, noSchemeSchema]; + } + if (config.validation === "tld") { const baseTldSchema = z .string({ required_error }) @@ -269,6 +283,7 @@ export function getZodSchema(fields: Provider["config"]) { : baseTldSchema.optional(); return [field, tldSchema]; } + if (config.validation === "port") { const basePortSchema = z.coerce .number({ required_error, invalid_type_error: portError }) diff --git a/keep-ui/app/providers/providers.tsx b/keep-ui/app/providers/providers.tsx index 1e03bf640..faa74aa10 100644 --- a/keep-ui/app/providers/providers.tsx +++ b/keep-ui/app/providers/providers.tsx @@ -5,8 +5,8 @@ export interface ProviderAuthConfig { validation?: | "any_url" | "any_http_url" - | "http_url" | "https_url" + | "no_scheme_url" | "port" | "tld"; required?: boolean; diff --git a/keep/providers/postgres_provider/postgres_provider.py b/keep/providers/postgres_provider/postgres_provider.py index 1b188fbe7..4c032d1b5 100644 --- a/keep/providers/postgres_provider/postgres_provider.py +++ b/keep/providers/postgres_provider/postgres_provider.py @@ -11,7 +11,7 @@ from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope -from keep.validation.fields import UrlPort +from keep.validation.fields import NoSchemeUrl, UrlPort @pydantic.dataclasses.dataclass @@ -26,8 +26,12 @@ class PostgresProviderAuthConfig: "sensitive": True, } ) - host: str = dataclasses.field( - metadata={"required": True, "description": "Postgres hostname"} + host: NoSchemeUrl = dataclasses.field( + metadata={ + "required": True, + "description": "Postgres hostname", + "validation": "no_scheme_url", + } ) database: str | None = dataclasses.field( metadata={"required": False, "description": "Postgres database name"}, diff --git a/keep/validation/fields.py b/keep/validation/fields.py index 3ac977d84..aae4bc8a2 100644 --- a/keep/validation/fields.py +++ b/keep/validation/fields.py @@ -1,11 +1,63 @@ -from pydantic import HttpUrl, conint +from typing import Optional + +from pydantic import AnyUrl, HttpUrl, conint, errors +from pydantic.networks import Parts class HttpsUrl(HttpUrl): - scheme = {'https'} + scheme = {"https"} @staticmethod def get_default_parts(parts): - return {'port': '443'} + return {"port": "443"} + UrlPort = conint(ge=1, le=65_535) + + +class NoSchemeUrl(AnyUrl): + """Override to allow url without a scheme.""" + + @classmethod + def build( + cls, + *, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: str, + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + **_kwargs: str, + ) -> str: + url = super().build( + scheme=scheme, + user=user, + password=password, + host=host, + port=port, + path=path, + query=query, + fragment=fragment, + **_kwargs, + ) + return url.split("://")[1] + + @classmethod + def validate_parts(cls, parts: Parts, validate_port: bool = True) -> Parts: + """ + In this override, we removed validation for url scheme. + """ + + parts["scheme"] = "foo" + + if validate_port: + cls._validate_port(parts["port"]) + + user = parts["user"] + if cls.user_required and user is None: + raise errors.UrlUserInfoError() + + return parts From ab9908f558bda795568a2b195412d034ec8ac806 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Wed, 6 Nov 2024 23:51:56 +0100 Subject: [PATCH 18/35] complete validation for 8 providers --- keep-ui/app/providers/form-validation.ts | 144 ++++++++---------- keep-ui/app/providers/provider-form.tsx | 4 +- .../gitlab_provider/gitlab_provider.py | 3 +- .../jiraonprem_provider.py | 3 +- .../kafka_provider/kafka_provider.py | 3 +- .../redmine_provider/redmine_provider.py | 58 +++++-- keep/providers/smtp_provider/smtp_provider.py | 47 +++--- .../splunk_provider/splunk_provider.py | 17 ++- keep/providers/ssh_provider/ssh_provider.py | 18 ++- .../victoriametrics_provider.py | 7 +- keep/validation/fields.py | 8 +- 11 files changed, 170 insertions(+), 142 deletions(-) diff --git a/keep-ui/app/providers/form-validation.ts b/keep-ui/app/providers/form-validation.ts index aa4771a32..1ef73de28 100644 --- a/keep-ui/app/providers/form-validation.ts +++ b/keep-ui/app/providers/form-validation.ts @@ -30,24 +30,23 @@ function mergeOptions>( } const error = (msg: string) => ({ success: false, msg }); -const urlError = error("Please provide a valid URL."); -const protocolError = error("A valid URL protocol is required."); -const relProtocolError = error("A protocol-relavie URL is not allowed."); -const missingPortError = error("A URL with a port number is required."); -const portError = error("Invalid port number."); -const hostError = error("Invalid URL host."); +const urlError = error("Please provide a valid URL"); +const protocolError = error("A valid URL protocol is required"); +const relProtocolError = error("A protocol-relavie URL is not allowed"); +const missingPortError = error("A URL with a port number is required"); +const portError = error("Invalid port number"); +const hostError = error("Invalid URL host"); const hostWildcardError = error("Wildcard in URL host is not allowed"); const tldError = error( "URL must contain a valid TLD e.g .com, .io, .dev, .net" ); function getProtocolError(opts: URLOptions["protocols"]) { - if (opts.length === 0) return protocolError; if (opts.length === 1) - return error(`A URL with \`${opts[0]}\` protocol is required.`); + return error(`A URL with \`${opts[0]}\` protocol is required`); if (opts.length === 2) return error( - `A URL with \`${opts[0]}\` or \`${opts[1]}\` protocol is required.` + `A URL with \`${opts[0]}\` or \`${opts[1]}\` protocol is required` ); const lst = opts.length - 1; const wrap = (acc: string, p: string) => acc + `\`${p}\``; @@ -60,7 +59,7 @@ function getProtocolError(opts: URLOptions["protocols"]) { : wrap(acc, p) + ", ", "" ); - return error(`A URL with one of ${optsStr} protocols is required.`); + return error(`A URL with one of ${optsStr} protocols is required`); } function isFQDN(str: string, options?: Partial): ValidatorRes { @@ -128,16 +127,20 @@ function isURL(str: string, options?: Partial): ValidatorRes { if (url.slice(0, 2) === "//") return relProtocolError; + // extract protocol & validate split = url.split("://"); - const protocol = split?.shift()?.toLowerCase() ?? ""; - if (opts.requireProtocol && opts.protocols.indexOf(protocol) === -1) - return getProtocolError(opts.protocols); + if (split.length > 1) { + const protocol = split?.shift()?.toLowerCase() ?? ""; + if (opts.protocols.length && opts.protocols.indexOf(protocol) === -1) + return getProtocolError(opts.protocols); + } else if (split.length > 2 || opts.requireProtocol) return protocolError; url = split.join("://"); split = url.split("/"); url = split.shift() ?? ""; if (!url.length) return urlError; + // extract auth details & validate split = url.split("@"); if (split.length > 1 && !split[0]) return urlError; if (split.length > 1) { @@ -146,8 +149,9 @@ function isURL(str: string, options?: Partial): ValidatorRes { const [user, pass] = auth.split(":"); if (!user && !pass) return urlError; } - const hostname = split.join("@"); + + // extract ipv6 & port const wrapped_ipv6 = /^\[([^\]]+)\](?::([0-9]+))?$/; const ipv6Match = hostname.match(wrapped_ipv6); if (ipv6Match) { @@ -161,7 +165,7 @@ function isURL(str: string, options?: Partial): ValidatorRes { if (portStr.length) { port = parseInt(portStr, 10); - if (Number.isNaN(port)) return missingPortError; + if (Number.isNaN(port)) return urlError; if (port <= 0 || port > 65_535) return portError; } else if (opts.requirePort) return missingPortError; @@ -170,35 +174,37 @@ function isURL(str: string, options?: Partial): ValidatorRes { return isFQDN(host, opts); } -function addZodErr(valdn: ValidatorRes, ctx: z.RefinementCtx) { - if (valdn.success) return; - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: valdn.msg, +const required_error = "This field is required"; + +function getBaseUrlSchema(options?: Partial) { + const urlStr = z.string({ required_error }); + const schema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url, options); + if (valdn.success) return; + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: valdn.msg, + }); }); + return schema; } export function getZodSchema(fields: Provider["config"]) { - const required_error = "This field is required"; const portError = "Invalid port number"; - const emptyStringToNull = z - .string() - .optional() - .transform((val) => (val?.length === 0 ? null : val)); const kvPairs = Object.entries(fields).map(([field, config]) => { if (config.type === "form") { - const baseFormSchema = z.record(z.string(), z.string()).array(); - const formSchema = config.required - ? baseFormSchema.nonempty({ + const baseSchema = z.record(z.string(), z.string()).array(); + const schema = config.required + ? baseSchema.nonempty({ message: "At least one key-value entry should be provided.", }) - : baseFormSchema.optional(); - return [field, formSchema]; + : baseSchema.optional(); + return [field, schema]; } if (config.type === "file") { - const baseFileSchema = z + const baseSchema = z .instanceof(File, { message: "Please upload a file here." }) .refine( (file) => { @@ -213,86 +219,60 @@ export function getZodSchema(fields: Provider["config"]) { : `File should be of type ${config.file_type}.`, } ); - const fileSchema = config.required - ? baseFileSchema - : baseFileSchema.optional(); - return [field, fileSchema]; + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; } if (config.type === "switch") { - const switchSchema = config.required - ? z.boolean() - : z.boolean().optional(); - return [field, switchSchema]; + const schema = config.required ? z.boolean() : z.boolean().optional(); + return [field, schema]; } - const urlStr = z.string({ required_error }); - if (config.validation === "any_url") { - const urlSchema = urlStr.superRefine((url, ctx) => { - const valdn = isURL(url); - addZodErr(valdn, ctx); - }); - const anyUrlSchema = config.required - ? urlSchema - : emptyStringToNull.pipe(urlSchema.nullish()); - return [field, anyUrlSchema]; + const baseSchema = getBaseUrlSchema(); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; } if (config.validation === "any_http_url") { - const baseAnyHttpSchema = urlStr.superRefine((url, ctx) => { - const valdn = isURL(url, { protocols: ["http", "https"] }); - addZodErr(valdn, ctx); - }); - const anyHttpSchema = config.required - ? baseAnyHttpSchema - : emptyStringToNull.pipe(baseAnyHttpSchema.nullish()); - return [field, anyHttpSchema]; + const baseSchema = getBaseUrlSchema({ protocols: ["http", "https"] }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; } if (config.validation === "https_url") { - const baseHttpsSchema = urlStr.superRefine((url, ctx) => { - const valdn = isURL(url, { requireTld: true, protocols: ["https"] }); - addZodErr(valdn, ctx); + const baseSchema = getBaseUrlSchema({ + protocols: ["https"], + requireTld: true, + maxLength: 2083, }); - const httpsSchema = config.required - ? baseHttpsSchema - : emptyStringToNull.pipe(baseHttpsSchema.nullish()); - return [field, httpsSchema]; + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; } if (config.validation === "no_scheme_url") { - const baseNoSchemeSchema = urlStr.superRefine((url, ctx) => { - const valdn = isURL(url, { requireProtocol: false }); - addZodErr(valdn, ctx); - }); - const noSchemeSchema = config.required - ? baseNoSchemeSchema - : emptyStringToNull.pipe(baseNoSchemeSchema.nullish()); - return [field, noSchemeSchema]; + const baseSchema = getBaseUrlSchema({ requireProtocol: false }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; } if (config.validation === "tld") { - const baseTldSchema = z + const baseSchema = z .string({ required_error }) .regex(new RegExp(/\.[a-z]{2,63}$/), { message: "Please provide a valid TLD e.g .com, .io, .dev, .net", }); - const tldSchema = config.required - ? baseTldSchema - : baseTldSchema.optional(); - return [field, tldSchema]; + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; } if (config.validation === "port") { - const basePortSchema = z.coerce + const baseSchema = z.coerce .number({ required_error, invalid_type_error: portError }) .min(1, { message: portError }) .max(65_535, { message: portError }); - const portSchema = config.required - ? basePortSchema - : emptyStringToNull.pipe(basePortSchema.nullish()); - return [field, portSchema]; + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; } return [ field, diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index 8e3085647..52c348848 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -356,11 +356,11 @@ const ProviderForm = ({ const error = "detail" in data ? data.detail : "message" in data ? data.message : data; if (status === 400) setFormErrors(error); - if (response.status === 409) + if (status === 409) setFormErrors( `Provider with name ${formValues.provider_name} already exists` ); - if (response.status === 412) setProviderValidatedScopes(error); + if (status === 412) setProviderValidatedScopes(error); } async function handleUpdateClick() { diff --git a/keep/providers/gitlab_provider/gitlab_provider.py b/keep/providers/gitlab_provider/gitlab_provider.py index 90563a15e..0674aed2a 100644 --- a/keep/providers/gitlab_provider/gitlab_provider.py +++ b/keep/providers/gitlab_provider/gitlab_provider.py @@ -18,12 +18,13 @@ class GitlabProviderAuthConfig: """GitLab authentication configuration.""" - host: str = dataclasses.field( + host: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "GitLab Host", "sensitive": False, "hint": "example.gitlab.com", + "validation": "any_http_url" } ) diff --git a/keep/providers/jiraonprem_provider/jiraonprem_provider.py b/keep/providers/jiraonprem_provider/jiraonprem_provider.py index f6f48379e..2e623424c 100644 --- a/keep/providers/jiraonprem_provider/jiraonprem_provider.py +++ b/keep/providers/jiraonprem_provider/jiraonprem_provider.py @@ -20,12 +20,13 @@ class JiraonpremProviderAuthConfig: """Jira On Prem authentication configuration.""" - host: str = dataclasses.field( + host: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Jira Host", "sensitive": False, "hint": "jira.onprem.com", + "validation": "any_http_url" } ) diff --git a/keep/providers/kafka_provider/kafka_provider.py b/keep/providers/kafka_provider/kafka_provider.py index 3277b6aaa..af2a9bcac 100644 --- a/keep/providers/kafka_provider/kafka_provider.py +++ b/keep/providers/kafka_provider/kafka_provider.py @@ -22,11 +22,12 @@ class KafkaProviderAuthConfig: Kafka authentication configuration. """ - host: str = dataclasses.field( + host: pydantic.AnyUrl = dataclasses.field( metadata={ "required": True, "description": "Kafka host", "hint": "e.g. https://kafka:9092", + "validation": "any_url" }, ) topic: str = dataclasses.field( diff --git a/keep/providers/redmine_provider/redmine_provider.py b/keep/providers/redmine_provider/redmine_provider.py index 27150f3e4..e7e40d953 100644 --- a/keep/providers/redmine_provider/redmine_provider.py +++ b/keep/providers/redmine_provider/redmine_provider.py @@ -17,12 +17,13 @@ class RedmineProviderAuthConfig: """Redmine authentication configuration.""" - host: str = dataclasses.field( + host: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Redmine Host", "sensitive": False, "hint": "http://localhost:8080", + "validation": "any_http_url", } ) @@ -51,7 +52,7 @@ class RedmineProvider(BaseProvider): PROVIDER_TAGS = ["ticketing"] def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): self._host = None super().__init__(context_manager, provider_id, config) @@ -69,16 +70,23 @@ def validate_scopes(self): try: resp.raise_for_status() if resp.status_code == 200: - scopes = { - "authenticated": True - } + scopes = {"authenticated": True} else: - self.logger.error(f"Failed to validate scope for {self.provider_id}", extra=resp.json()) + self.logger.error( + f"Failed to validate scope for {self.provider_id}", + extra=resp.json(), + ) scopes = { - "authenticated": {"status_code": resp.status_code, "error": resp.json()} + "authenticated": { + "status_code": resp.status_code, + "error": resp.json(), + } } except HTTPError as e: - self.logger.error(f"HTTPError while validating scope for {self.provider_id}", extra={"error": str(e)}) + self.logger.error( + f"HTTPError while validating scope for {self.provider_id}", + extra={"error": str(e)}, + ) scopes = { "authenticated": {"status_code": resp.status_code, "error": str(e)} } @@ -98,7 +106,7 @@ def __redmine_url(self): # if the user explicitly supplied a host with http/https, use it if self.authentication_config.host.startswith( - "http://" + "http://" ) or self.authentication_config.host.startswith("https://"): self._host = self.authentication_config.host return self.authentication_config.host.rstrip("/") @@ -144,18 +152,36 @@ def __build_payload_from_kwargs(self, kwargs: dict): params[param] = kwargs[param] return params - def _notify(self, project_id: str, subject: str, priority_id: str, description: str = "", - **kwargs: dict): + def _notify( + self, + project_id: str, + subject: str, + priority_id: str, + description: str = "", + **kwargs: dict, + ): self.logger.info("Creating an issue in redmine") payload = self.__build_payload_from_kwargs( - kwargs={**kwargs, 'subject': subject, 'description': description, "project_id": project_id, - "priority_id": priority_id}) - resp = requests.post(f"{self.__redmine_url}/issues.json", headers=self.__get_headers(), - json={'issue': payload}) + kwargs={ + **kwargs, + "subject": subject, + "description": description, + "project_id": project_id, + "priority_id": priority_id, + } + ) + resp = requests.post( + f"{self.__redmine_url}/issues.json", + headers=self.__get_headers(), + json={"issue": payload}, + ) try: resp.raise_for_status() except HTTPError as e: self.logger.error("Error While creating Redmine Issue") raise Exception(f"Failed to create issue: {str(e)}") - self.logger.info("Successfully created a Redmine Issue", extra={"status_code": resp.status_code}) + self.logger.info( + "Successfully created a Redmine Issue", + extra={"status_code": resp.status_code}, + ) return resp.json() diff --git a/keep/providers/smtp_provider/smtp_provider.py b/keep/providers/smtp_provider/smtp_provider.py index 9b61a608f..e0893b9d4 100644 --- a/keep/providers/smtp_provider/smtp_provider.py +++ b/keep/providers/smtp_provider/smtp_provider.py @@ -13,7 +13,7 @@ from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope -from keep.validation.fields import UrlPort +from keep.validation.fields import NoSchemeUrl, UrlPort @pydantic.dataclasses.dataclass @@ -35,11 +35,12 @@ class SmtpProviderAuthConfig: } ) - smtp_server: str = dataclasses.field( + smtp_server: NoSchemeUrl = dataclasses.field( metadata={ "required": True, "description": "SMTP Server Address", "config_main_group": "authentication", + "validation": "no_scheme_url", } ) @@ -48,7 +49,7 @@ class SmtpProviderAuthConfig: "required": True, "description": "SMTP port", "config_main_group": "authentication", - "validation": "port" + "validation": "port", } ) @@ -77,18 +78,18 @@ class SmtpProvider(BaseProvider): PROVIDER_DISPLAY_NAME = "SMTP" def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): super().__init__(context_manager, provider_id, config) def dispose(self): pass - + def validate_config(self): self.authentication_config = SmtpProviderAuthConfig( **self.config.authentication ) - + def validate_scopes(self): """ Validate that the scopes provided are correct. @@ -99,7 +100,7 @@ def validate_scopes(self): return {"send_email": True} except Exception as e: return {"send_email": str(e)} - + def generate_smtp_client(self): """ Generate an SMTP client. @@ -110,18 +111,20 @@ def generate_smtp_client(self): smtp_port = self.authentication_config.smtp_port encryption = self.authentication_config.encryption - if (encryption == "SSL"): + if encryption == "SSL": smtp = SMTP_SSL(smtp_server, smtp_port) smtp.login(smtp_username, smtp_password) return smtp - - elif (encryption == "TLS"): + + elif encryption == "TLS": smtp = SMTP(smtp_server, smtp_port) smtp.starttls() smtp.login(smtp_username, smtp_password) return smtp - - def send_email(self, from_email: str, from_name: str, to_email: str, subject: str, body: str): + + def send_email( + self, from_email: str, from_name: str, to_email: str, subject: str, body: str + ): """ Send an email using SMTP protocol. """ @@ -129,9 +132,9 @@ def send_email(self, from_email: str, from_name: str, to_email: str, subject: st if from_name == "": msg["From"] = from_email msg["From"] = f"{from_name} <{from_email}>" - msg['To'] = to_email - msg['Subject'] = subject - msg.attach(MIMEText(body, 'plain')) + msg["To"] = to_email + msg["Subject"] = subject + msg.attach(MIMEText(body, "plain")) try: smtp = self.generate_smtp_client() @@ -140,18 +143,16 @@ def send_email(self, from_email: str, from_name: str, to_email: str, subject: st except Exception as e: raise Exception(f"Failed to send email: {str(e)}") - - def _notify(self, from_email: str, from_name: str, to_email: str, subject: str, body: str): + + def _notify( + self, from_email: str, from_name: str, to_email: str, subject: str, body: str + ): """ Send an email using SMTP protocol. """ self.send_email(from_email, from_name, to_email, subject, body) - return { - "from": from_email, - "to": to_email, - "subject": subject, - "body": body - } + return {"from": from_email, "to": to_email, "subject": subject, "body": body} + if __name__ == "__main__": import logging diff --git a/keep/providers/splunk_provider/splunk_provider.py b/keep/providers/splunk_provider/splunk_provider.py index d3d9d8943..ab637bfdb 100644 --- a/keep/providers/splunk_provider/splunk_provider.py +++ b/keep/providers/splunk_provider/splunk_provider.py @@ -13,7 +13,7 @@ from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.providers_factory import ProvidersFactory -from keep.validation.fields import UrlPort +from keep.validation.fields import NoSchemeUrl, UrlPort @pydantic.dataclasses.dataclass @@ -26,9 +26,10 @@ class SplunkProviderAuthConfig: } ) - host: str = dataclasses.field( + host: NoSchemeUrl = dataclasses.field( metadata={ "description": "Splunk Host (default is localhost)", + "validation": "no_scheme_url" }, default="localhost", ) @@ -39,6 +40,14 @@ class SplunkProviderAuthConfig: }, default=8089, ) + verify: bool = dataclasses.field( + metadata={ + "description": "Enable SSL verification", + "hint": "An `https` protocol will be used if enabled.", + "type": "switch" + }, + default=True, + ) class SplunkProvider(BaseProvider): @@ -106,6 +115,8 @@ def validate_scopes(self) -> dict[str, bool | str]: token=self.authentication_config.api_key, host=self.authentication_config.host, port=self.authentication_config.port, + scheme='https' if self.authentication_config.verify else 'http', + verify=self.authentication_config.verify ) self.logger.debug("Connected to Splunk", extra={"service": service}) @@ -216,6 +227,8 @@ def setup_webhook( token=self.authentication_config.api_key, host=self.authentication_config.host, port=self.authentication_config.port, + scheme='https' if self.authentication_config.verify else 'http', + verify=self.authentication_config.verify ) for saved_search in service.saved_searches: existing_webhook_url = saved_search["_state"]["content"].get( diff --git a/keep/providers/ssh_provider/ssh_provider.py b/keep/providers/ssh_provider/ssh_provider.py index 14dc56356..19e91f23a 100644 --- a/keep/providers/ssh_provider/ssh_provider.py +++ b/keep/providers/ssh_provider/ssh_provider.py @@ -13,22 +13,26 @@ from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig from keep.providers.providers_factory import ProvidersFactory -from keep.validation.fields import UrlPort +from keep.validation.fields import NoSchemeUrl, UrlPort @pydantic.dataclasses.dataclass class SshProviderAuthConfig: """SSH authentication configuration.""" - # TODO: validate hostname because it seems pydantic doesn't have a validator for it - host: str = dataclasses.field( - metadata={"required": True, "description": "SSH hostname"} + host: NoSchemeUrl = dataclasses.field( + metadata={ + "required": True, + "description": "SSH hostname", + "validation": "no_scheme_url", + } ) user: str = dataclasses.field( metadata={"required": True, "description": "SSH user"} ) port: UrlPort = dataclasses.field( - default=22, metadata={"required": False, "description": "SSH port", "validation": "port"} + default=22, + metadata={"required": False, "description": "SSH port", "validation": "port"}, ) pkey: typing.Optional[str] = dataclasses.field( default=None, @@ -37,8 +41,8 @@ class SshProviderAuthConfig: "sensitive": True, "type": "file", "name": "pkey", - "file_type": "text/plain, application/x-pem-file, application/x-putty-private-key, "+ - "application/x-ed25519-key, application/pkcs8, application/octet-stream", + "file_type": "text/plain, application/x-pem-file, application/x-putty-private-key, " + + "application/x-ed25519-key, application/pkcs8, application/octet-stream", "config_sub_group": "private_key", "config_main_group": "authentication", }, diff --git a/keep/providers/victoriametrics_provider/victoriametrics_provider.py b/keep/providers/victoriametrics_provider/victoriametrics_provider.py index 6f93e171e..bb67741a4 100644 --- a/keep/providers/victoriametrics_provider/victoriametrics_provider.py +++ b/keep/providers/victoriametrics_provider/victoriametrics_provider.py @@ -26,11 +26,12 @@ class VictoriametricsProviderAuthConfig: vmalert authentication configuration. """ - VMAlertHost: str = dataclasses.field( + VMAlertHost: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "The hostname or IP address where VMAlert is running. This can be a local or remote server address.", - "hint": "Example: 'localhost', '192.168.1.100', or 'vmalert.mydomain.com'", + "hint": "Example: 'http://localhost', 'http://192.168.1.100', or 'https://vmalert.mydomain.com'", + "validation": "any_http_url", }, ) @@ -39,7 +40,7 @@ class VictoriametricsProviderAuthConfig: "required": True, "description": "The port number on which VMAlert is listening. This should match the port configured in your VMAlert setup.", "hint": "Example: 8880 (if VMAlert is set to listen on port 8880)", - "validation": "port" + "validation": "port", }, ) diff --git a/keep/validation/fields.py b/keep/validation/fields.py index aae4bc8a2..1f43283e5 100644 --- a/keep/validation/fields.py +++ b/keep/validation/fields.py @@ -3,6 +3,8 @@ from pydantic import AnyUrl, HttpUrl, conint, errors from pydantic.networks import Parts +UrlPort = conint(ge=1, le=65_535) + class HttpsUrl(HttpUrl): scheme = {"https"} @@ -12,9 +14,6 @@ def get_default_parts(parts): return {"port": "443"} -UrlPort = conint(ge=1, le=65_535) - - class NoSchemeUrl(AnyUrl): """Override to allow url without a scheme.""" @@ -51,7 +50,8 @@ def validate_parts(cls, parts: Parts, validate_port: bool = True) -> Parts: In this override, we removed validation for url scheme. """ - parts["scheme"] = "foo" + scheme = parts["scheme"] + parts["scheme"] = "foo" if scheme is None else scheme if validate_port: cls._validate_port(parts["port"]) From ac0a24629ab2039882cbf27d3752a2ca5de87c96 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sun, 10 Nov 2024 12:59:54 +0100 Subject: [PATCH 19/35] add more provider tests and clean up form --- keep-ui/app/providers/provider-form.tsx | 41 ++++++++++++------------- tests/e2e_tests/test_end_to_end.py | 19 ++++++++++++ 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index 499d30d77..72ba71dcc 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -1,5 +1,3 @@ -// TODO: refactor this file and separate in to smaller components -// There's also a lot of s**t in here, but it works for now 🤷‍♂️ import React, { useState, useRef, useMemo } from "react"; import { useSession } from "next-auth/react"; import { Provider, ProviderAuthConfig } from "./providers"; @@ -276,6 +274,7 @@ const ProviderForm = ({ function handleFormChange(key: string, value: ProviderFormValue) { if (typeof value === "string" && value.trim().length === 0) { + // remove fields with empty string value setFormValues((prev) => { const updated = structuredClone(prev); delete updated[key]; @@ -334,9 +333,8 @@ const ProviderForm = ({ }; async function submit(requestUrl: string, method: string = "POST") { - let headers = { + const headers = { Authorization: `Bearer ${accessToken}`, - "Content-Type": "application/json", }; let body; @@ -351,9 +349,9 @@ const ProviderForm = ({ : formData.append(key, value.toString()); } body = formData; - headers["Content-Type"] = "multipart/form-data"; } else { // Standard JSON for non-file submissions + Object.assign(headers, { "Content-Type": "application/json" }); body = JSON.stringify(formValues); } @@ -369,12 +367,15 @@ const ProviderForm = ({ const data = await response.json(); const error = "detail" in data ? data.detail : "message" in data ? data.message : data; - if (status === 400) setFormErrors(error); - if (status === 409) + if (status === 409) { setFormErrors( `Provider with name ${formValues.provider_name} already exists` ); - if (status === 412) setProviderValidatedScopes(error); + } else if (status === 412) { + setProviderValidatedScopes(error); + } else { + setFormErrors(error); + } } async function handleUpdateClick() { @@ -704,6 +705,7 @@ const ProviderForm = ({ - {Array.isArray(value) && ( - - )} + {Array.isArray(value) && } {error && error?.length > 0 && (

{error}

)} @@ -1135,15 +1138,11 @@ function KVForm({ } const KVInput = ({ - name, data, onChange, - error, }: { - name: string; data: KVFormData; onChange: (entries: KVFormData) => void; - error?: string; }) => { const handleEntryChange = (index: number, name: string, value: string) => { const newEntries = data.map((entry, i) => diff --git a/tests/e2e_tests/test_end_to_end.py b/tests/e2e_tests/test_end_to_end.py index f74108df9..9bf1616a4 100644 --- a/tests/e2e_tests/test_end_to_end.py +++ b/tests/e2e_tests/test_end_to_end.py @@ -246,3 +246,22 @@ def test_provider_validation(browser): browser.get_by_placeholder("Enter host").fill("mongodb://host.com:3000") connect_btn.click() expect(error_msg).to_be_hidden() + + # using Postgres provider + browser.goto("http://localhost:3000/providers") + browser.locator("button:has-text('PostgreSQL'):has-text('data')").click() + # test `no_scheme_url` field validation + # - on the frontend: url with/without scheme validates. + # - on the backend: scheme is removed during validation. + browser.get_by_placeholder("Enter provider name").fill("random name") + browser.get_by_placeholder("Enter username").fill("username") + browser.get_by_placeholder("Enter password").fill("password") + browser.get_by_placeholder("Enter host").fill("*.") + connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter host").fill("localhost:5000") + connect_btn.click() + expect(error_msg).to_be_hidden() + browser.get_by_placeholder("Enter host").fill("https://host.com:3000") + connect_btn.click() + expect(error_msg).to_be_hidden() From 9cdf0efd0df9e54d0fcc0d60872a7ab6bd9b4816 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sun, 10 Nov 2024 15:45:32 +0100 Subject: [PATCH 20/35] fix file type validation bug on provider config --- keep-ui/app/providers/form-validation.ts | 8 ++++++-- keep-ui/app/providers/provider-form.tsx | 8 +++++++- .../gcpmonitoring_provider/gcpmonitoring_provider.py | 2 +- keep/providers/newrelic_provider/newrelic_provider.py | 2 +- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/keep-ui/app/providers/form-validation.ts b/keep-ui/app/providers/form-validation.ts index 1ef73de28..6a58467be 100644 --- a/keep-ui/app/providers/form-validation.ts +++ b/keep-ui/app/providers/form-validation.ts @@ -189,7 +189,7 @@ function getBaseUrlSchema(options?: Partial) { return schema; } -export function getZodSchema(fields: Provider["config"]) { +export function getZodSchema(fields: Provider["config"], installed: boolean) { const portError = "Invalid port number"; const kvPairs = Object.entries(fields).map(([field, config]) => { @@ -206,11 +206,15 @@ export function getZodSchema(fields: Provider["config"]) { if (config.type === "file") { const baseSchema = z .instanceof(File, { message: "Please upload a file here." }) + .or(z.string()) .refine( (file) => { if (config.file_type == undefined) return true; if (config.file_type.length <= 1) return true; - return config.file_type.includes(file.type); + if (typeof file === "string" && installed) return true; + return ( + typeof file !== "string" && config.file_type.includes(file.type) + ); }, { message: diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index 72ba71dcc..f5089ec9c 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -199,7 +199,10 @@ const ProviderForm = ({ () => getConfigByMainGroup(provider.config), [provider] ); - const zodSchema = useMemo(() => getZodSchema(provider.config), [provider]); + const zodSchema = useMemo( + () => getZodSchema(provider.config, provider.installed), + [provider] + ); const apiUrl = useApiUrl(); @@ -373,6 +376,9 @@ const ProviderForm = ({ ); } else if (status === 412) { setProviderValidatedScopes(error); + setFormErrors( + `Provider scopes validation failed: ${JSON.stringify(error, null, 4)}` + ); } else { setFormErrors(error); } diff --git a/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py b/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py index dce283442..9027e93a4 100644 --- a/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py +++ b/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py @@ -39,7 +39,7 @@ class GcpmonitoringProviderAuthConfig: "sensitive": True, "type": "file", "name": "service_account_json", - "file_type": ".json", # this is used to filter the file type in the UI + "file_type": "application/json", # this is used to filter the file type in the UI } ) diff --git a/keep/providers/newrelic_provider/newrelic_provider.py b/keep/providers/newrelic_provider/newrelic_provider.py index e05e40855..fbf1926b3 100644 --- a/keep/providers/newrelic_provider/newrelic_provider.py +++ b/keep/providers/newrelic_provider/newrelic_provider.py @@ -162,7 +162,7 @@ def __make_delete_webhook_destination_query(self, destination_id: str): } def validate_scopes(self) -> dict[str, bool | str]: - scopes = {scope.name: False for scope in self.PROVIDER_SCOPES} + scopes = {scope.name: "Invalid" for scope in self.PROVIDER_SCOPES} read_scopes = [key for key in scopes.keys() if "read" in key] try: From 605c18805ed7b23c6495140eaf7840797ed80a02 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sun, 10 Nov 2024 22:13:09 +0100 Subject: [PATCH 21/35] add backend provider config validation --- keep-ui/app/providers/form-validation.ts | 14 ++++++++++---- keep-ui/app/providers/providers-tiles.tsx | 4 +++- keep/providers/cilium_provider/cilium_provider.py | 6 ++++-- .../clickhouse_provider/clickhouse_provider.py | 10 +++++++--- .../providers/graylog_provider/graylog_provider.py | 8 +++++--- 5 files changed, 29 insertions(+), 13 deletions(-) diff --git a/keep-ui/app/providers/form-validation.ts b/keep-ui/app/providers/form-validation.ts index 6a58467be..163028a40 100644 --- a/keep-ui/app/providers/form-validation.ts +++ b/keep-ui/app/providers/form-validation.ts @@ -133,6 +133,8 @@ function isURL(str: string, options?: Partial): ValidatorRes { const protocol = split?.shift()?.toLowerCase() ?? ""; if (opts.protocols.length && opts.protocols.indexOf(protocol) === -1) return getProtocolError(opts.protocols); + } else if (opts.requireProtocol && opts.protocols.length) { + return getProtocolError(opts.protocols); } else if (split.length > 2 || opts.requireProtocol) return protocolError; url = split.join("://"); @@ -271,10 +273,14 @@ export function getZodSchema(fields: Provider["config"], installed: boolean) { } if (config.validation === "port") { - const baseSchema = z.coerce - .number({ required_error, invalid_type_error: portError }) - .min(1, { message: portError }) - .max(65_535, { message: portError }); + const baseSchema = z + .string({ required_error }) + .pipe( + z.coerce + .number({ invalid_type_error: portError }) + .min(1, { message: portError }) + .max(65_535, { message: portError }) + ); const schema = config.required ? baseSchema : baseSchema.optional(); return [field, schema]; } diff --git a/keep-ui/app/providers/providers-tiles.tsx b/keep-ui/app/providers/providers-tiles.tsx index 078bc289b..e3050706e 100644 --- a/keep-ui/app/providers/providers-tiles.tsx +++ b/keep-ui/app/providers/providers-tiles.tsx @@ -115,7 +115,9 @@ const ProvidersTiles = ({ diff --git a/keep/providers/cilium_provider/cilium_provider.py b/keep/providers/cilium_provider/cilium_provider.py index 522bc9218..92cd0464e 100644 --- a/keep/providers/cilium_provider/cilium_provider.py +++ b/keep/providers/cilium_provider/cilium_provider.py @@ -7,7 +7,8 @@ from keep.api.models.db.topology import TopologyServiceInDto from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseTopologyProvider -from keep.providers.cilium_provider.grpc.observer_pb2 import FlowFilter, GetFlowsRequest +from keep.providers.cilium_provider.grpc.observer_pb2 import (FlowFilter, + GetFlowsRequest) from keep.providers.cilium_provider.grpc.observer_pb2_grpc import ObserverStub from keep.providers.models.provider_config import ProviderConfig @@ -21,7 +22,8 @@ class CiliumProviderAuthConfig: "required": True, "description": "The base endpoint of the cilium hubble relay", "sensitive": False, - "hint": "localhost:4245", + "hint": "http://localhost:4245", + "validation": "any_http_url" } ) diff --git a/keep/providers/clickhouse_provider/clickhouse_provider.py b/keep/providers/clickhouse_provider/clickhouse_provider.py index 307e0be5a..b0c65966f 100644 --- a/keep/providers/clickhouse_provider/clickhouse_provider.py +++ b/keep/providers/clickhouse_provider/clickhouse_provider.py @@ -12,7 +12,7 @@ from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope -from keep.validation.fields import UrlPort +from keep.validation.fields import NoSchemeUrl, UrlPort @pydantic.dataclasses.dataclass @@ -27,8 +27,12 @@ class ClickhouseProviderAuthConfig: "sensitive": True, } ) - host: str = dataclasses.field( - metadata={"required": True, "description": "Clickhouse hostname"} + host: NoSchemeUrl = dataclasses.field( + metadata={ + "required": True, + "description": "Clickhouse hostname", + "validation": "no_scheme_url", + } ) port: UrlPort = dataclasses.field( metadata={ diff --git a/keep/providers/graylog_provider/graylog_provider.py b/keep/providers/graylog_provider/graylog_provider.py index 20ad94d4c..ff97a5b08 100644 --- a/keep/providers/graylog_provider/graylog_provider.py +++ b/keep/providers/graylog_provider/graylog_provider.py @@ -5,7 +5,7 @@ import dataclasses import math import uuid -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from typing import List from urllib.parse import urlencode, urljoin, urlparse @@ -44,11 +44,12 @@ class GraylogProviderAuthConfig: "sensitive": True, }, ) - deployment_url: str = dataclasses.field( + deployment_url: pydantic.AnyHttpUrl = dataclasses.field( metadata={ "required": True, "description": "Deployment Url", "hint": "Example: http://127.0.0.1:9000", + "validation": "any_http_url" }, ) @@ -531,10 +532,11 @@ def _format_alert(event: dict, provider_instance: BaseProvider) -> AlertDto: @classmethod def simulate_alert(cls) -> dict: - from keep.providers.graylog_provider.alerts_mock import ALERTS import random import string + from keep.providers.graylog_provider.alerts_mock import ALERTS + # Use the provided ALERTS structure alert_data = ALERTS.copy() From 9cb18e34caaa05b2c3cbb3fcee076e3a650285a3 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:50:33 +0100 Subject: [PATCH 22/35] update provider config --- keep_clickhouse_e53b19ac051549aeb333e9c07cb25db8 | 1 + keep_prometheus_06e9495272b34630b4170634336ba8e1 | 1 + keep_victoriametrics_c77ff2c5e09d4ec38335f752cd4f8a39 | 1 + 3 files changed, 3 insertions(+) create mode 100644 keep_clickhouse_e53b19ac051549aeb333e9c07cb25db8 create mode 100644 keep_prometheus_06e9495272b34630b4170634336ba8e1 create mode 100644 keep_victoriametrics_c77ff2c5e09d4ec38335f752cd4f8a39 diff --git a/keep_clickhouse_e53b19ac051549aeb333e9c07cb25db8 b/keep_clickhouse_e53b19ac051549aeb333e9c07cb25db8 new file mode 100644 index 000000000..4a45d4459 --- /dev/null +++ b/keep_clickhouse_e53b19ac051549aeb333e9c07cb25db8 @@ -0,0 +1 @@ +{"authentication": {"host": "http://localhost", "port": 1234, "username": "keep", "password": "keep", "database": "keep-db"}, "name": "keepClickhouse1"} \ No newline at end of file diff --git a/keep_prometheus_06e9495272b34630b4170634336ba8e1 b/keep_prometheus_06e9495272b34630b4170634336ba8e1 new file mode 100644 index 000000000..04a81c6aa --- /dev/null +++ b/keep_prometheus_06e9495272b34630b4170634336ba8e1 @@ -0,0 +1 @@ +{"authentication": {"url": "http://localhost", "port": 9090}, "name": "keepPrometheus"} \ No newline at end of file diff --git a/keep_victoriametrics_c77ff2c5e09d4ec38335f752cd4f8a39 b/keep_victoriametrics_c77ff2c5e09d4ec38335f752cd4f8a39 new file mode 100644 index 000000000..f8995beb5 --- /dev/null +++ b/keep_victoriametrics_c77ff2c5e09d4ec38335f752cd4f8a39 @@ -0,0 +1 @@ +{"authentication": {"VMAlertHost": "http://localhost", "VMAlertPort": 1234}, "name": "keepVictoriaMetrics"} \ No newline at end of file From b9d0f4b21f29c5101cba4d444d49cd91c0edf20b Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:07:03 +0100 Subject: [PATCH 23/35] add zod --- keep-ui/package-lock.json | 1 + keep-ui/package.json | 1 + 2 files changed, 2 insertions(+) diff --git a/keep-ui/package-lock.json b/keep-ui/package-lock.json index d60805ac1..fcfa2fae9 100644 --- a/keep-ui/package-lock.json +++ b/keep-ui/package-lock.json @@ -98,6 +98,7 @@ "tailwindcss": "^3.4.1", "uuid": "^8.3.2", "yaml": "^2.2.2", + "zod": "^3.23.8", "zustand": "^5.0.1" }, "devDependencies": { diff --git a/keep-ui/package.json b/keep-ui/package.json index a1ace7294..08b8d9ca9 100644 --- a/keep-ui/package.json +++ b/keep-ui/package.json @@ -99,6 +99,7 @@ "tailwindcss": "^3.4.1", "uuid": "^8.3.2", "yaml": "^2.2.2", + "zod": "^3.23.8", "zustand": "^5.0.1" }, "devDependencies": { From aa13f255df89efcb5ea8185f1a24eeda8f597be4 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sat, 23 Nov 2024 03:33:55 +0100 Subject: [PATCH 24/35] HttpsUrl should handle validation & transformation --- keep/providers/auth0_provider/auth0_provider.py | 13 ++++--------- keep/providers/jira_provider/jira_provider.py | 5 ----- keep/validation/fields.py | 2 +- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/keep/providers/auth0_provider/auth0_provider.py b/keep/providers/auth0_provider/auth0_provider.py index 9b9eff405..cbe258e79 100644 --- a/keep/providers/auth0_provider/auth0_provider.py +++ b/keep/providers/auth0_provider/auth0_provider.py @@ -30,9 +30,9 @@ class Auth0ProviderAuthConfig: ) token: str = dataclasses.field( - default=None, metadata={ "required": True, + "sensitive": True, "description": "Auth0 API Token", "hint": "https://manage.auth0.com/dashboard/us/YOUR_ACCOUNT/apis/management/explorer", }, @@ -56,11 +56,6 @@ def validate_config(self): """ Validates required configuration for Auth0 provider. """ - if self.is_installed or self.is_provisioned: - host = self.config.authentication["domain"] - host = "https://" + host if not host.startswith("https://") else host - self.config.authentication["domain"] = host - self.authentication_config = Auth0ProviderAuthConfig( **self.config.authentication ) @@ -91,9 +86,9 @@ def _query(self, log_type: str, from_: str = None, **kwargs: dict): "per_page": 100, # specify the number of entries per page } if from_: - params[ - "q" - ] = f"({params['q']}) AND (date:[{from_} TO {datetime.datetime.now().isoformat()}])" + params["q"] = ( + f"({params['q']}) AND (date:[{from_} TO {datetime.datetime.now().isoformat()}])" + ) response = requests.get(url, headers=headers, params=params) response.raise_for_status() logs = response.json() diff --git a/keep/providers/jira_provider/jira_provider.py b/keep/providers/jira_provider/jira_provider.py index cef86fd4f..5f4661b40 100644 --- a/keep/providers/jira_provider/jira_provider.py +++ b/keep/providers/jira_provider/jira_provider.py @@ -153,11 +153,6 @@ def validate_scopes(self): return scopes def validate_config(self): - if self.is_installed or self.is_provisioned: - host = self.config.authentication['host'] - host = "https://" + host if not host.startswith("https://") else host - self.config.authentication['host'] = host - self.authentication_config = JiraProviderAuthConfig( **self.config.authentication ) diff --git a/keep/validation/fields.py b/keep/validation/fields.py index 1f43283e5..b82ac101c 100644 --- a/keep/validation/fields.py +++ b/keep/validation/fields.py @@ -11,7 +11,7 @@ class HttpsUrl(HttpUrl): @staticmethod def get_default_parts(parts): - return {"port": "443"} + return {"scheme": "https", "port": "443"} class NoSchemeUrl(AnyUrl): From b5126acc5c003bfc335cbe091bebb7629909f65e Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sat, 23 Nov 2024 10:09:19 +0100 Subject: [PATCH 25/35] cleanup provider validation logic --- .../app/(keep)/providers/provider-form.tsx | 24 ++++++++++++------- keep/parser/parser.py | 2 +- .../centreon_provider/centreon_provider.py | 4 ++-- .../cilium_provider/cilium_provider.py | 7 +++--- .../clickhouse_provider.py | 1 + .../gitlab_provider/gitlab_provider.py | 2 +- .../ilert_provider/ilert_provider.py | 9 +++++-- .../kibana_provider/kibana_provider.py | 5 ++-- .../mongodb_provider/mongodb_provider.py | 9 ++++--- .../openobserve_provider.py | 5 ++-- .../slack_provider/slack_provider.py | 1 + 11 files changed, 42 insertions(+), 27 deletions(-) diff --git a/keep-ui/app/(keep)/providers/provider-form.tsx b/keep-ui/app/(keep)/providers/provider-form.tsx index 93de36b4b..233ade385 100644 --- a/keep-ui/app/(keep)/providers/provider-form.tsx +++ b/keep-ui/app/(keep)/providers/provider-form.tsx @@ -100,19 +100,25 @@ function base64urlencode(a: ArrayBuffer) { return btoa(str).replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, ""); } +function getConfigsFromArr(arr: [string, ProviderAuthConfig][]) { + const configs: Provider["config"] = {}; + arr.forEach(([key, value]) => (configs[key] = value)); + return configs; +} + function getRequiredConfigs(config: Provider["config"]): Provider["config"] { - return Object.entries(config) - .filter(([_, config]) => config.required && !config.config_main_group) - .reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {}); + const configs = Object.entries(config).filter( + ([_, config]) => config.required && !config.config_main_group + ); + return getConfigsFromArr(configs); } function getOptionalConfigs(config: Provider["config"]): Provider["config"] { - return Object.entries(config) - .filter( - ([_, config]) => - !config.required && !config.hidden && !config.config_main_group - ) - .reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {}); + const configs = Object.entries(config).filter( + ([_, config]) => + config.required && !config.hidden && !config.config_main_group + ); + return getConfigsFromArr(configs); } function getConfigGroup(type: "config_main_group" | "config_sub_group") { diff --git a/keep/parser/parser.py b/keep/parser/parser.py index 98de0e2fa..2eb891aa0 100644 --- a/keep/parser/parser.py +++ b/keep/parser/parser.py @@ -313,7 +313,7 @@ def _inject_env_variables(self, config): def _parse_providers_from_workflow( self, context_manager: ContextManager, workflow: dict - ): + ) -> None: context_manager.providers_context.update(workflow.get("providers")) self.logger.debug("Workflow providers parsed successfully") diff --git a/keep/providers/centreon_provider/centreon_provider.py b/keep/providers/centreon_provider/centreon_provider.py index 039c59b67..b11242624 100644 --- a/keep/providers/centreon_provider/centreon_provider.py +++ b/keep/providers/centreon_provider/centreon_provider.py @@ -148,7 +148,7 @@ def __get_host_status(self) -> list[AlertDto]: except Exception as e: self.logger.error("Error getting host status from Centreon: %s", e) - raise ProviderException(f"Error getting host status from Centreon: {e}") + raise ProviderException(f"Error getting host status from Centreon: {e}") from e def __get_service_status(self) -> list[AlertDto]: try: @@ -181,7 +181,7 @@ def __get_service_status(self) -> list[AlertDto]: except Exception as e: self.logger.error("Error getting service status from Centreon: %s", e) - raise ProviderException(f"Error getting service status from Centreon: {e}") + raise ProviderException(f"Error getting service status from Centreon: {e}") from e def _get_alerts(self) -> list[AlertDto]: alerts = [] diff --git a/keep/providers/cilium_provider/cilium_provider.py b/keep/providers/cilium_provider/cilium_provider.py index 92cd0464e..fba1c19a5 100644 --- a/keep/providers/cilium_provider/cilium_provider.py +++ b/keep/providers/cilium_provider/cilium_provider.py @@ -11,19 +11,20 @@ GetFlowsRequest) from keep.providers.cilium_provider.grpc.observer_pb2_grpc import ObserverStub from keep.providers.models.provider_config import ProviderConfig +from keep.validation.fields import NoSchemeUrl @pydantic.dataclasses.dataclass class CiliumProviderAuthConfig: """Cilium authentication configuration.""" - cilium_base_endpoint: str = dataclasses.field( + cilium_base_endpoint: NoSchemeUrl = dataclasses.field( metadata={ "required": True, "description": "The base endpoint of the cilium hubble relay", "sensitive": False, - "hint": "http://localhost:4245", - "validation": "any_http_url" + "hint": "localhost:4245", + "validation": "no_scheme_url" } ) diff --git a/keep/providers/clickhouse_provider/clickhouse_provider.py b/keep/providers/clickhouse_provider/clickhouse_provider.py index b0c65966f..60c3817ab 100644 --- a/keep/providers/clickhouse_provider/clickhouse_provider.py +++ b/keep/providers/clickhouse_provider/clickhouse_provider.py @@ -164,6 +164,7 @@ def _notify(self, query="", single_row=False, **kwargs: dict) -> list | tuple: "password": os.environ.get("CLICKHOUSE_PASSWORD"), "host": os.environ.get("CLICKHOUSE_HOST"), "database": os.environ.get("CLICKHOUSE_DATABASE"), + "port": os.environ.get("CLICKHOUSE_PORT") } ) context_manager = ContextManager( diff --git a/keep/providers/gitlab_provider/gitlab_provider.py b/keep/providers/gitlab_provider/gitlab_provider.py index 0674aed2a..3df1852d1 100644 --- a/keep/providers/gitlab_provider/gitlab_provider.py +++ b/keep/providers/gitlab_provider/gitlab_provider.py @@ -23,7 +23,7 @@ class GitlabProviderAuthConfig: "required": True, "description": "GitLab Host", "sensitive": False, - "hint": "example.gitlab.com", + "hint": "http://example.gitlab.com", "validation": "any_http_url" } ) diff --git a/keep/providers/ilert_provider/ilert_provider.py b/keep/providers/ilert_provider/ilert_provider.py index f5a5071cd..acb9723a9 100644 --- a/keep/providers/ilert_provider/ilert_provider.py +++ b/keep/providers/ilert_provider/ilert_provider.py @@ -119,12 +119,17 @@ def validate_scopes(self): headers={ "Authorization": self.authentication_config.ilert_token }, + timeout=10 ) res.raise_for_status() data = res.json() if data['role'] not in ["USER", "ADMIN"]: - scopes[scope.name] = "User role & permisisions may be limited." - scopes[scope.name] = True + warning_msg = f"User role '{data['role']}' has limited permissions" + self.logger.warning(warning_msg) + scopes[scope.name] = warning_msg + else: + self.logger.debug(f"Write permission validated successfully for role: {data['role']}") + scopes[scope.name] = True except Exception as e: self.logger.warning( "Failed to validate scope", diff --git a/keep/providers/kibana_provider/kibana_provider.py b/keep/providers/kibana_provider/kibana_provider.py index 578dba332..251c491d5 100644 --- a/keep/providers/kibana_provider/kibana_provider.py +++ b/keep/providers/kibana_provider/kibana_provider.py @@ -443,8 +443,9 @@ def setup_webhook( def validate_config(self): if self.is_installed or self.is_provisioned: host = self.config.authentication['kibana_host'] - host = "https://" + host if not (host.startswith("http://") or host.startswith("https://")) else host - self.config.authentication['kibana_host'] = host + if not (host.startswith("http://") or host.startswith("https://")): + scheme = "http://" if "localhost" in host else "https://" + self.config.authentication['kibana_host'] = scheme + host self.authentication_config = KibanaProviderAuthConfig( **self.config.authentication diff --git a/keep/providers/mongodb_provider/mongodb_provider.py b/keep/providers/mongodb_provider/mongodb_provider.py index 76e7fc14e..6cfc71abe 100644 --- a/keep/providers/mongodb_provider/mongodb_provider.py +++ b/keep/providers/mongodb_provider/mongodb_provider.py @@ -140,11 +140,10 @@ def validate_config(self): host = self.config.authentication["host"] if host is None: raise ProviderConfigException("Please provide a value for `host`") - host = ( - "mongodb://" + host - if not (host.startswith("mongodb://") or host.startwith("mongodb+srv://")) - else host - ) + if not host.strip(): + raise ProviderConfigException("Host cannot be empty") + if not (host.startswith("mongodb://") or host.startswith("mongodb+srv://")): + host = f"mongodb://{host}" self.authentication_config = MongodbProviderAuthConfig( **self.config.authentication diff --git a/keep/providers/openobserve_provider/openobserve_provider.py b/keep/providers/openobserve_provider/openobserve_provider.py index 9ceb97166..46c150a19 100644 --- a/keep/providers/openobserve_provider/openobserve_provider.py +++ b/keep/providers/openobserve_provider/openobserve_provider.py @@ -110,8 +110,9 @@ def validate_config(self): """ if self.is_installed or self.is_provisioned: host = self.config.authentication['openObserveHost'] - host = "https://" + host if not (host.startswith("http://") or host.startswith("https://")) else host - self.config.authentication['openObserveHost'] = host + if not (host.startswith("http://") or host.startswith("https://")): + scheme = "http://" if "localhost" in host else "https://" + self.config.authentication['openObserveHost'] = scheme + host self.authentication_config = OpenobserveProviderAuthConfig( **self.config.authentication diff --git a/keep/providers/slack_provider/slack_provider.py b/keep/providers/slack_provider/slack_provider.py index 911a7558b..a82febcb3 100644 --- a/keep/providers/slack_provider/slack_provider.py +++ b/keep/providers/slack_provider/slack_provider.py @@ -25,6 +25,7 @@ class SlackProviderAuthConfig: "required": True, "description": "Slack Webhook Url", "validation": "https_url", + "sensitive": True }, ) access_token: str = dataclasses.field( From 32da06714925e27f2283c42809bb7e0b3b828caf Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sat, 23 Nov 2024 17:25:27 +0100 Subject: [PATCH 26/35] fix provider form render bug --- keep-ui/app/(keep)/providers/provider-form.tsx | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/keep-ui/app/(keep)/providers/provider-form.tsx b/keep-ui/app/(keep)/providers/provider-form.tsx index 233ade385..cb687fb4f 100644 --- a/keep-ui/app/(keep)/providers/provider-form.tsx +++ b/keep-ui/app/(keep)/providers/provider-form.tsx @@ -116,7 +116,7 @@ function getRequiredConfigs(config: Provider["config"]): Provider["config"] { function getOptionalConfigs(config: Provider["config"]): Provider["config"] { const configs = Object.entries(config).filter( ([_, config]) => - config.required && !config.hidden && !config.config_main_group + !config.required && !config.hidden && !config.config_main_group ); return getConfigsFromArr(configs); } @@ -145,11 +145,10 @@ function getInitialFormValues(provider: Provider) { install_webhook: provider.can_setup_webhook ?? false, pulling_enabled: provider.pulling_enabled, }; - if (!provider.installed) return initialValues; Object.assign(initialValues, { - provider_name: provider.details.name, - ...provider.details.authentication, + provider_name: provider.details?.name, + ...provider.details?.authentication, }); // Set default values for select & switch inputs From 860873b68be0e7aaea4c49becffa49eb1e7cdedb Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sun, 24 Nov 2024 10:45:58 +0100 Subject: [PATCH 27/35] add tests for custom validation fields --- keep/validation/fields.py | 42 ++++-------- tests/test_provider_validation_fields.py | 86 ++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 29 deletions(-) create mode 100644 tests/test_provider_validation_fields.py diff --git a/keep/validation/fields.py b/keep/validation/fields.py index b82ac101c..aa3e5b237 100644 --- a/keep/validation/fields.py +++ b/keep/validation/fields.py @@ -7,42 +7,26 @@ class HttpsUrl(HttpUrl): - scheme = {"https"} + """Validate https url, coerce if no scheme, throw if wrong scheme.""" + + allowed_schemes = {"https"} + + def __new__(cls, url: Optional[str], **kwargs) -> object: + _url = url if url is not None and url.startswith("https://") else None + return super().__new__(cls, _url, **kwargs) @staticmethod - def get_default_parts(parts): + def get_default_parts(parts: Parts) -> Parts: return {"scheme": "https", "port": "443"} class NoSchemeUrl(AnyUrl): - """Override to allow url without a scheme.""" + """Validate url with any scheme, remove scheme in output.""" - @classmethod - def build( - cls, - *, - scheme: str, - user: Optional[str] = None, - password: Optional[str] = None, - host: str, - port: Optional[str] = None, - path: Optional[str] = None, - query: Optional[str] = None, - fragment: Optional[str] = None, - **_kwargs: str, - ) -> str: - url = super().build( - scheme=scheme, - user=user, - password=password, - host=host, - port=port, - path=path, - query=query, - fragment=fragment, - **_kwargs, - ) - return url.split("://")[1] + def __new__(cls, url: Optional[str], **kwargs) -> object: + _url = cls.build(**kwargs) if url is None else url + _url = _url.split("://")[1] if "://" in _url else _url + return super().__new__(cls, _url, **kwargs) @classmethod def validate_parts(cls, parts: Parts, validate_port: bool = True) -> Parts: diff --git a/tests/test_provider_validation_fields.py b/tests/test_provider_validation_fields.py new file mode 100644 index 000000000..6f314ef64 --- /dev/null +++ b/tests/test_provider_validation_fields.py @@ -0,0 +1,86 @@ +import pytest +from pydantic import BaseModel, ValidationError + +from keep.validation.fields import HttpsUrl, NoSchemeUrl + + +@pytest.mark.parametrize( + "value,expected", + [ + ("example.org", "https://example.org"), + ("https://example.org", "https://example.org"), + ("https://example.org?a=1&b=2", "https://example.org?a=1&b=2"), + ("example.org#a=3;b=3", "https://example.org#a=3;b=3"), + ("https://foo_bar.example.com/", "https://foo_bar.example.com/"), + ("https://example.xn--p1ai", "https://example.xn--p1ai"), + ], +) +def test_https_url_valid(value, expected): + class Model(BaseModel): + v: HttpsUrl + + assert str(Model(v=value).v) == expected + + +@pytest.mark.parametrize( + "value", + [ + "ftp://example.com/", + "http://example.com/", + "x" * 2084, + ], +) +def test_https_url_invalid(value): + class Model(BaseModel): + v: HttpsUrl + + with pytest.raises(ValidationError) as exc_info: + Model(v=value) + assert len(exc_info.value.errors()) == 1, exc_info.value.errors() + + +@pytest.mark.parametrize( + "value,expected", + [ + ("example.org", "example.org"), + ("https://example.org", "example.org"), + ("localhost:8000", "localhost:8000"), + ("http://localhost:8000", "localhost:8000"), + ("postgres://user:pass@localhost:5432/app", "user:pass@localhost:5432/app"), + ( + "postgresql+psycopg2://postgres:postgres@localhost:5432/hatch", + "postgres:postgres@localhost:5432/hatch", + ), + ("http://123.45.67.8:8329/", "123.45.67.8:8329/"), + ("http://[2001:db8::ff00:42]:8329", "[2001:db8::ff00:42]:8329"), + ("http://example.org/path?query#fragment", "example.org/path?query#fragment"), + ], +) +def test_no_scheme_url_valid(value, expected): + class Model(BaseModel): + v: NoSchemeUrl + + assert str(Model(v=value).v) == expected + + +@pytest.mark.parametrize( + "value", + [ + "http://??", + "https://example.org more", + "$https://example.org", + "../icons/logo.gif", + "http://2001:db8::ff00:42:8329", + "http://[192.168.1.1]:8329", + "..", + "/rando/", + "http://example.com:99999", + ], +) +def test_no_scheme_url_invalid(value): + class Model(BaseModel): + v: NoSchemeUrl + + with pytest.raises(ValidationError) as exc_info: + Model(v=value) + assert len(exc_info.value.errors()) == 1, exc_info.value.errors() From e2bd55bae1f4162ef524de41ca40d8ae3a7cec94 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sun, 24 Nov 2024 14:21:27 +0100 Subject: [PATCH 28/35] add multihost url validation --- .../app/(keep)/providers/form-validation.ts | 80 ++++++++---- keep-ui/app/(keep)/providers/providers.tsx | 2 + keep/validation/fields.py | 118 +++++++++++++++++- tests/test_provider_validation_fields.py | 87 ++++++++++++- 4 files changed, 262 insertions(+), 25 deletions(-) diff --git a/keep-ui/app/(keep)/providers/form-validation.ts b/keep-ui/app/(keep)/providers/form-validation.ts index 163028a40..4c5da0e9d 100644 --- a/keep-ui/app/(keep)/providers/form-validation.ts +++ b/keep-ui/app/(keep)/providers/form-validation.ts @@ -6,6 +6,7 @@ type URLOptions = { requireTld: boolean; requireProtocol: boolean; requirePort: boolean; + alllowMultihost: boolean; validateLength: boolean; maxLength: number; }; @@ -17,6 +18,7 @@ const defaultURLOptions: URLOptions = { requireTld: false, requireProtocol: true, requirePort: false, + alllowMultihost: false, validateLength: true, maxLength: 2 ** 16, }; @@ -37,6 +39,7 @@ const missingPortError = error("A URL with a port number is required"); const portError = error("Invalid port number"); const hostError = error("Invalid URL host"); const hostWildcardError = error("Wildcard in URL host is not allowed"); +const multihostError = error("Multiple hosts are not allowed."); const tldError = error( "URL must contain a valid TLD e.g .com, .io, .dev, .net" ); @@ -105,6 +108,35 @@ function isIP(str: string) { return validation.success; } +function validateHost(hostname: string, opts: URLOptions): ValidatorRes { + let host: string; + let port: number; + let portStr: string = ""; + let split: string[]; + + // extract ipv6 & port + const wrapped_ipv6 = /^\[([^\]]+)\](?::([0-9]+))?$/; + const ipv6Match = hostname.match(wrapped_ipv6); + if (ipv6Match) { + host = ipv6Match[1]; + portStr = ipv6Match[2]; + } else { + split = hostname.split(":"); + host = split.shift() ?? ""; + if (split.length) portStr = split.join(":"); + } + + if (portStr.length) { + port = parseInt(portStr, 10); + if (Number.isNaN(port)) return urlError; + if (port <= 0 || port > 65_535) return portError; + } else if (opts.requirePort) return missingPortError; + + if (!host) return hostError; + if (isIP(host)) return { success: true }; + return isFQDN(host, opts); +} + function isURL(str: string, options?: Partial): ValidatorRes { const opts = mergeOptions(defaultURLOptions, options); @@ -114,9 +146,6 @@ function isURL(str: string, options?: Partial): ValidatorRes { } let url = str; - let host: string; - let port: number; - let portStr: string = ""; let split: string[]; split = url.split("#"); @@ -153,27 +182,17 @@ function isURL(str: string, options?: Partial): ValidatorRes { } const hostname = split.join("@"); - // extract ipv6 & port - const wrapped_ipv6 = /^\[([^\]]+)\](?::([0-9]+))?$/; - const ipv6Match = hostname.match(wrapped_ipv6); - if (ipv6Match) { - host = ipv6Match[1]; - portStr = ipv6Match[2]; - } else { - split = hostname.split(":"); - host = split.shift() ?? ""; - if (split.length) portStr = split.join(":"); + // validate multihost + split = hostname.split(","); + if (split.length > 1 && !opts.alllowMultihost) return multihostError; + if (split.length > 1) { + for (const host of split) { + const res = validateHost(host, opts); + if (!res.success) return res; + } + return { success: true }; } - - if (portStr.length) { - port = parseInt(portStr, 10); - if (Number.isNaN(port)) return urlError; - if (port <= 0 || port > 65_535) return portError; - } else if (opts.requirePort) return missingPortError; - - if (!host) return hostError; - if (isIP(host)) return { success: true }; - return isFQDN(host, opts); + return validateHost(hostname, opts); } const required_error = "This field is required"; @@ -262,6 +281,21 @@ export function getZodSchema(fields: Provider["config"], installed: boolean) { return [field, schema]; } + if (config.validation === "multihost_url") { + const baseSchema = getBaseUrlSchema({ alllowMultihost: true }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.validation === "no_scheme_multihost_url") { + const baseSchema = getBaseUrlSchema({ + alllowMultihost: true, + requireProtocol: false, + }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + if (config.validation === "tld") { const baseSchema = z .string({ required_error }) diff --git a/keep-ui/app/(keep)/providers/providers.tsx b/keep-ui/app/(keep)/providers/providers.tsx index a0edd24cb..defd0db33 100644 --- a/keep-ui/app/(keep)/providers/providers.tsx +++ b/keep-ui/app/(keep)/providers/providers.tsx @@ -7,6 +7,8 @@ export interface ProviderAuthConfig { | "any_http_url" | "https_url" | "no_scheme_url" + | "multihost_url" + | "no_scheme_multihost_url" | "port" | "tld"; required?: boolean; diff --git a/keep/validation/fields.py b/keep/validation/fields.py index aa3e5b237..596152a30 100644 --- a/keep/validation/fields.py +++ b/keep/validation/fields.py @@ -1,7 +1,7 @@ from typing import Optional from pydantic import AnyUrl, HttpUrl, conint, errors -from pydantic.networks import Parts +from pydantic.networks import MultiHostDsn, Parts UrlPort = conint(ge=1, le=65_535) @@ -45,3 +45,119 @@ def validate_parts(cls, parts: Parts, validate_port: bool = True) -> Parts: raise errors.UrlUserInfoError() return parts + + +class MultiHostUrl(MultiHostDsn): + @classmethod + def build( + cls, + *, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + **_kwargs: str, + ) -> str: + hosts = _kwargs.get("hosts") + if host is not None and hosts is None: + return super().build( + scheme=scheme, + user=user, + password=password, + host=host, + port=port, + path=path, + query=query, + fragment=fragment, + **_kwargs, + ) + urls = [ + cls._build_single_url( + position=-1 if len(hosts) - idx == 1 else idx, + scheme=scheme, + user=user, + password=password, + host=hp["host"] + (hp["tld"] if hp["host_type"] == "domain" else ""), + port=hp["port"], + path=path, + query=query, + fragment=fragment, + **_kwargs, + ) + for (idx, hp) in enumerate(hosts) + ] + return ",".join(urls) + + @classmethod + def _build_single_url( + cls, + *, + position: int, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: str, + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + **_kwargs: str, + ) -> str: + parts = Parts( + scheme=scheme, + user=user, + password=password, + host=host, + port=port, + path=path, + query=query, + fragment=fragment, + **_kwargs, # type: ignore[misc] + ) + + url = "" + if position == 0: + url = scheme + "://" + if user: + url += user + if password: + url += ":" + password + if user or password: + url += "@" + + url += host + if port and ( + "port" not in cls.hidden_parts + or cls.get_default_parts(parts).get("port") != port + ): + url += ":" + port + + if position == -1: + if path: + url += path + if query: + url += "?" + query + if fragment: + url += "#" + fragment + return url + + +class NoSchemeMultiHostUrl(MultiHostUrl): + def __new__(cls, url: Optional[str], **kwargs) -> object: + _url = cls.build(**kwargs) if url is None else url + _url = _url.split("://")[1] if "://" in _url else _url + return super().__new__(cls, _url, **kwargs) + + @classmethod + def validate_parts(cls, parts: Parts, validate_port: bool = True) -> Parts: + """ + Remove validation for url scheme, port & user. + """ + scheme = parts["scheme"] + parts["scheme"] = "" if scheme is None else scheme + + return parts diff --git a/tests/test_provider_validation_fields.py b/tests/test_provider_validation_fields.py index 6f314ef64..9846b5e97 100644 --- a/tests/test_provider_validation_fields.py +++ b/tests/test_provider_validation_fields.py @@ -1,7 +1,12 @@ import pytest from pydantic import BaseModel, ValidationError -from keep.validation.fields import HttpsUrl, NoSchemeUrl +from keep.validation.fields import ( + HttpsUrl, + MultiHostUrl, + NoSchemeMultiHostUrl, + NoSchemeUrl, +) @pytest.mark.parametrize( @@ -84,3 +89,83 @@ class Model(BaseModel): with pytest.raises(ValidationError) as exc_info: Model(v=value) assert len(exc_info.value.errors()) == 1, exc_info.value.errors() + + +@pytest.mark.parametrize( + "value", + [ + "http://localhost:5000", + "http://localhost:5000,localhost:2222", + "https://user:pass@localhost:4321,localhost:3000/app", + "http://123.45.67.8:8329/,113.45.67.8:9309/", + "http://[2001:db8::ff00:42]:8329,[2001:db8::ff00:42]:5000", + "ampq://broker.com,en.broker.com/app", + "postgres://user:pass@host1.db.net:4321,host2.db.net:6432/app", + "mongodb://user:pass@host1.db.net:4321,host2.db.net:6432/app?query#fragment", + ], +) +def test_multihost_url_valid(value): + class Model(BaseModel): + v: MultiHostUrl + + assert str(Model(v=value).v) == value + + +@pytest.mark.parametrize( + "value", + [ + "localhost:5000,localhost:2222", + "broker.com,en.broker.com/app", + "http://[192.168.1.1]:8329,[192.168.1.2]:8421", + "user:pass@host1.db.net:4321,host2.db.net:6432/app?query#fragment", + ], +) +def test_multihost_url_invalid(value): + class Model(BaseModel): + v: MultiHostUrl + + with pytest.raises(ValidationError) as exc_info: + Model(v=value) + assert len(exc_info.value.errors()) == 1, exc_info.value.errors() + + +@pytest.mark.parametrize( + "value,expected", + [ + ("http://localhost:5000,localhost:2222", "localhost:5000,localhost:2222"), + ("localhost:5000,localhost:2222", "localhost:5000,localhost:2222"), + ( + "https://user:pass@localhost:4321,localhost:3000/app", + "user:pass@localhost:4321,localhost:3000/app", + ), + ( + "postgres://user:pass@host1.db.net:4321,host2.db.net:6432/app?query#fragment", + "user:pass@host1.db.net:4321,host2.db.net:6432/app?query#fragment", + ), + ], +) +def test_no_scheme_multihost_url_valid(value, expected): + class Model(BaseModel): + v: NoSchemeMultiHostUrl + + assert str(Model(v=value).v) == expected + + +@pytest.mark.parametrize( + "value", + [ + "http://??, localhost:5000", + "../icons/logo.gif", + "http://[192.168.1.1]:8329", + "..", + "/rando/", + "http://example.com:99999", + ], +) +def test_no_scheme_multihost_url_invalid(value): + class Model(BaseModel): + v: NoSchemeMultiHostUrl + + with pytest.raises(ValidationError) as exc_info: + Model(v=value) + assert len(exc_info.value.errors()) == 1, exc_info.value.errors() From adc20b393a13b84a78082ead0528915c58b420b0 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Sun, 24 Nov 2024 16:35:43 +0100 Subject: [PATCH 29/35] add multihost validation tp kafka & mongo --- keep/providers/kafka_provider/kafka_provider.py | 7 ++++--- keep/providers/mongodb_provider/mongodb_provider.py | 7 ++++--- tests/e2e_tests/test_end_to_end.py | 1 - 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/keep/providers/kafka_provider/kafka_provider.py b/keep/providers/kafka_provider/kafka_provider.py index af2a9bcac..f95fd35ed 100644 --- a/keep/providers/kafka_provider/kafka_provider.py +++ b/keep/providers/kafka_provider/kafka_provider.py @@ -14,6 +14,7 @@ from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.providers_factory import ProvidersFactory +from keep.validation.fields import NoSchemeMultiHostUrl @pydantic.dataclasses.dataclass @@ -22,12 +23,12 @@ class KafkaProviderAuthConfig: Kafka authentication configuration. """ - host: pydantic.AnyUrl = dataclasses.field( + host: NoSchemeMultiHostUrl = dataclasses.field( metadata={ "required": True, "description": "Kafka host", - "hint": "e.g. https://kafka:9092", - "validation": "any_url" + "hint": "e.g. localhost:9092 or localhost:9092,localhost:8093", + "validation": "no_scheme_multihost_url" }, ) topic: str = dataclasses.field( diff --git a/keep/providers/mongodb_provider/mongodb_provider.py b/keep/providers/mongodb_provider/mongodb_provider.py index 6cfc71abe..50252e101 100644 --- a/keep/providers/mongodb_provider/mongodb_provider.py +++ b/keep/providers/mongodb_provider/mongodb_provider.py @@ -13,16 +13,17 @@ from keep.exceptions.provider_config_exception import ProviderConfigException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import MultiHostUrl @pydantic.dataclasses.dataclass class MongodbProviderAuthConfig: - host: pydantic.AnyUrl = dataclasses.field( + host: MultiHostUrl = dataclasses.field( metadata={ "required": True, "description": "Mongo host_uri", - "hint": "any valid mongo host_uri like mongodb://host:port, user:paassword@host:port?authSource", - "validation": "any_url", + "hint": "mongodb+srv://host:port, mongodb://host1:port1,host2:port2?authSource", + "validation": "multihost_url", } ) username: str = dataclasses.field( diff --git a/tests/e2e_tests/test_end_to_end.py b/tests/e2e_tests/test_end_to_end.py index 903f1a6c5..6b922d011 100644 --- a/tests/e2e_tests/test_end_to_end.py +++ b/tests/e2e_tests/test_end_to_end.py @@ -24,7 +24,6 @@ import os import random - # Adding a new test: # 1. Manually: # - Create a new test function. From 5d5462b18e89a98deba148b0d05889b7a178ef3a Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Tue, 26 Nov 2024 11:08:45 +0100 Subject: [PATCH 30/35] update mysql config validation --- keep/providers/mysql_provider/mysql_provider.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keep/providers/mysql_provider/mysql_provider.py b/keep/providers/mysql_provider/mysql_provider.py index d07831901..e0d3f9d1c 100644 --- a/keep/providers/mysql_provider/mysql_provider.py +++ b/keep/providers/mysql_provider/mysql_provider.py @@ -11,6 +11,7 @@ from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope +from keep.validation.fields import NoSchemeUrl @pydantic.dataclasses.dataclass @@ -21,11 +22,11 @@ class MysqlProviderAuthConfig: password: str = dataclasses.field( metadata={"required": True, "description": "MySQL password", "sensitive": True} ) - host: pydantic.AnyUrl = dataclasses.field( + host: NoSchemeUrl = dataclasses.field( metadata={ "required": True, "description": "MySQL hostname", - "validation": "any_url", + "validation": "no_scheme_url", } ) database: str | None = dataclasses.field( From 5f9be5a12dcca03155de813ceb8c7b2ac9de8bb7 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Wed, 27 Nov 2024 08:45:35 +0100 Subject: [PATCH 31/35] change validation logic from on submit to on input --- .../app/(keep)/providers/provider-form.tsx | 19 +++++- tests/e2e_tests/test_end_to_end.py | 58 +++++++++---------- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/keep-ui/app/(keep)/providers/provider-form.tsx b/keep-ui/app/(keep)/providers/provider-form.tsx index cb687fb4f..c12bbf68f 100644 --- a/keep-ui/app/(keep)/providers/provider-form.tsx +++ b/keep-ui/app/(keep)/providers/provider-form.tsx @@ -319,7 +319,14 @@ const ProviderForm = ({ }); } - if (Object.keys(inputErrors).includes(key) && value !== "") { + if ( + value == undefined || + typeof value === "boolean" || + (typeof value === "object" && value instanceof File === false) + ) + return; + + if (validate({ [key]: value })) { const updatedInputErrors = { ...inputErrors }; delete updatedInputErrors[key]; setInputErrors(updatedInputErrors); @@ -334,8 +341,14 @@ const ProviderForm = ({ })); }; - function validate() { - const validation = zodSchema.safeParse(formValues); + function validate(data?: ProviderFormData) { + let schema = zodSchema; + if (data) { + schema = zodSchema.pick( + Object.fromEntries(Object.keys(data).map((field) => [field, true])) + ); + } + const validation = schema.safeParse(data ?? formValues); if (validation.success) return true; const errors: InputErrors = {}; Object.entries(validation.error.format()).forEach(([field, err]) => { diff --git a/tests/e2e_tests/test_end_to_end.py b/tests/e2e_tests/test_end_to_end.py index 6b922d011..784e96315 100644 --- a/tests/e2e_tests/test_end_to_end.py +++ b/tests/e2e_tests/test_end_to_end.py @@ -36,6 +36,7 @@ import sys from datetime import datetime +import pytest from playwright.sync_api import expect # Running the tests in GitHub Actions: @@ -45,6 +46,20 @@ # os.environ["PLAYWRIGHT_HEADLESS"] = "false" +@pytest.fixture(scope="session") +def browserx(): + from playwright.sync_api import sync_playwright + + headless = False + with sync_playwright() as p: + browser = p.chromium.launch(headless=headless) + context = browser.new_context() + page = context.new_page() + page.set_default_timeout(5000) + yield page + context.close() + browser.close() + def setup_console_listener(page, log_entries): """Set up console listener to capture logs.""" page.on( @@ -129,7 +144,6 @@ def test_insert_new_alert(browser): # browser is actually a page object save_failure_artifacts(browser, log_entries) raise - def test_providers_page_is_accessible(browser): """ Test to check if the providers page is accessible @@ -178,14 +192,15 @@ def test_provider_validation(browser): """ Test field validation for provider fields. """ - browser.goto( - "http://localhost:3000/signin?callbackUrl=http%3A%2F%2Flocalhost%3A3000%2Fproviders" - ) + # browser = browserx + browser.goto("http://localhost:3000/signin") + # browser.goto("http://localhost:3000/providers") # using Kibana Provider - browser.goto("http://localhost:3000/providers") + browser.get_by_role("link", name="Providers").click() browser.locator("button:has-text('Kibana'):has-text('alert')").click() # test required fields connect_btn = browser.get_by_role("button", name="Connect", exact=True) + cancel_btn = browser.get_by_role("button", name="Cancel", exact=True) error_msg = browser.locator("p.tremor-TextInput-errorMessage") connect_btn.click() expect(error_msg).to_have_count(3) @@ -193,83 +208,67 @@ def test_provider_validation(browser): browser.get_by_placeholder("Enter provider name").fill("random name") browser.get_by_placeholder("Enter api_key").fill("random api key") browser.get_by_placeholder("Enter kibana_host").fill("invalid url") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter kibana_host").fill("http://localhost") - connect_btn.click() expect(error_msg).to_be_hidden() browser.get_by_placeholder("Enter kibana_host").fill( "https://keep.kb.us-central1.gcp.cloud.es.io" ) - connect_btn.click() expect(error_msg).to_be_hidden() # test `port` field validation browser.get_by_placeholder("Enter kibana_port").fill("invalid port") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter kibana_port").fill("0") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter kibana_port").fill("65_536") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter kibana_port").fill("9243") - connect_btn.click() expect(error_msg).to_be_hidden() + cancel_btn.click() # using Teams Provider - browser.goto("http://localhost:3000/providers") browser.locator("button:has-text('Teams'):has-text('messaging')").click() # test `https_url` field validation browser.get_by_placeholder("Enter provider name").fill("random name") browser.get_by_placeholder("Enter webhook_url").fill("random url") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter webhook_url").fill("http://localhost") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter webhook_url").fill("http://example.com") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter webhook_url").fill("https://example.com") - connect_btn.click() expect(error_msg).to_be_hidden() + cancel_btn.click() # using Site24x7 Provider - browser.goto("http://localhost:3000/providers") browser.locator("button:has-text('Site24x7'):has-text('alert')").click() # test `tld` field validation browser.get_by_placeholder("Enter provider name").fill("random name") browser.get_by_placeholder("Enter zohoRefreshToken").fill("random") browser.get_by_placeholder("Enter zohoClientId").fill("random") browser.get_by_placeholder("Enter zohoClientSecret").fill("random") - browser.get_by_placeholder("Enter zohoAccountTLD").fill("") - connect_btn.click() - expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter zohoAccountTLD").fill("random") - connect_btn.click() + expect(error_msg).to_have_count(1) + browser.get_by_placeholder("Enter zohoAccountTLD").fill("") expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter zohoAccountTLD").fill(".com") - connect_btn.click() expect(error_msg).to_be_hidden() + cancel_btn.click() # using MongoDB Provider - browser.goto("http://localhost:3000/providers") browser.locator("button:has-text('MongoDB'):has-text('data')").click() # test `any_url` field validation browser.get_by_placeholder("Enter provider name").fill("random name") browser.get_by_placeholder("Enter host").fill("random") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter host").fill("host.com:5000") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter host").fill("mongodb://host.com:3000") - connect_btn.click() expect(error_msg).to_be_hidden() + cancel_btn.click() # using Postgres provider - browser.goto("http://localhost:3000/providers") + browser.get_by_role("link", name="Providers").click() browser.locator("button:has-text('PostgreSQL'):has-text('data')").click() # test `no_scheme_url` field validation # - on the frontend: url with/without scheme validates. @@ -278,11 +277,8 @@ def test_provider_validation(browser): browser.get_by_placeholder("Enter username").fill("username") browser.get_by_placeholder("Enter password").fill("password") browser.get_by_placeholder("Enter host").fill("*.") - connect_btn.click() expect(error_msg).to_have_count(1) browser.get_by_placeholder("Enter host").fill("localhost:5000") - connect_btn.click() expect(error_msg).to_be_hidden() browser.get_by_placeholder("Enter host").fill("https://host.com:3000") - connect_btn.click() expect(error_msg).to_be_hidden() From 71472ca5910e9d41281aec1bfd969c364b913d6f Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Wed, 27 Nov 2024 18:09:32 +0100 Subject: [PATCH 32/35] cleanup validation logic & tests --- .../app/(keep)/providers/form-validation.ts | 29 +++-- tests/e2e_tests/test_end_to_end.py | 106 +++++++++--------- 2 files changed, 71 insertions(+), 64 deletions(-) diff --git a/keep-ui/app/(keep)/providers/form-validation.ts b/keep-ui/app/(keep)/providers/form-validation.ts index 4c5da0e9d..cf1bf5e1b 100644 --- a/keep-ui/app/(keep)/providers/form-validation.ts +++ b/keep-ui/app/(keep)/providers/form-validation.ts @@ -35,25 +35,28 @@ const error = (msg: string) => ({ success: false, msg }); const urlError = error("Please provide a valid URL"); const protocolError = error("A valid URL protocol is required"); const relProtocolError = error("A protocol-relavie URL is not allowed"); +const multiProtocolError = error("URL cannot have more than one protocol"); const missingPortError = error("A URL with a port number is required"); const portError = error("Invalid port number"); const hostError = error("Invalid URL host"); const hostWildcardError = error("Wildcard in URL host is not allowed"); -const multihostError = error("Multiple hosts are not allowed."); +const multihostError = error("Multiple hosts are not allowed"); +const multihostProtocolError = error("Invalid multihost protocol"); const tldError = error( "URL must contain a valid TLD e.g .com, .io, .dev, .net" ); -function getProtocolError(opts: URLOptions["protocols"]) { - if (opts.length === 1) - return error(`A URL with \`${opts[0]}\` protocol is required`); - if (opts.length === 2) +function getProtocolError(protocols: URLOptions["protocols"]) { + if (protocols.length === 0) return protocolError; + if (protocols.length === 1) + return error(`A URL with \`${protocols[0]}\` protocol is required`); + if (protocols.length === 2) return error( - `A URL with \`${opts[0]}\` or \`${opts[1]}\` protocol is required` + `A URL with \`${protocols[0]}\` or \`${protocols[1]}\` protocol is required` ); - const lst = opts.length - 1; + const lst = protocols.length - 1; const wrap = (acc: string, p: string) => acc + `\`${p}\``; - const optsStr = opts.reduce( + const optsStr = protocols.reduce( (acc, p, i) => i === lst ? wrap(acc, p) @@ -158,14 +161,16 @@ function isURL(str: string, options?: Partial): ValidatorRes { // extract protocol & validate split = url.split("://"); + if (split.length > 2) return multiProtocolError; if (split.length > 1) { - const protocol = split?.shift()?.toLowerCase() ?? ""; + const protocol = split.shift()?.toLowerCase() ?? ""; if (opts.protocols.length && opts.protocols.indexOf(protocol) === -1) return getProtocolError(opts.protocols); - } else if (opts.requireProtocol && opts.protocols.length) { + if (protocol.includes(",")) return multihostProtocolError; + url = split.join("://"); + } else if (opts.requireProtocol) { return getProtocolError(opts.protocols); - } else if (split.length > 2 || opts.requireProtocol) return protocolError; - url = split.join("://"); + } split = url.split("/"); url = split.shift() ?? ""; diff --git a/tests/e2e_tests/test_end_to_end.py b/tests/e2e_tests/test_end_to_end.py index 784e96315..c27b77fcf 100644 --- a/tests/e2e_tests/test_end_to_end.py +++ b/tests/e2e_tests/test_end_to_end.py @@ -36,7 +36,6 @@ import sys from datetime import datetime -import pytest from playwright.sync_api import expect # Running the tests in GitHub Actions: @@ -46,20 +45,6 @@ # os.environ["PLAYWRIGHT_HEADLESS"] = "false" -@pytest.fixture(scope="session") -def browserx(): - from playwright.sync_api import sync_playwright - - headless = False - with sync_playwright() as p: - browser = p.chromium.launch(headless=headless) - context = browser.new_context() - page = context.new_page() - page.set_default_timeout(5000) - yield page - context.close() - browser.close() - def setup_console_listener(page, log_entries): """Set up console listener to capture logs.""" page.on( @@ -194,7 +179,6 @@ def test_provider_validation(browser): """ # browser = browserx browser.goto("http://localhost:3000/signin") - # browser.goto("http://localhost:3000/providers") # using Kibana Provider browser.get_by_role("link", name="Providers").click() browser.locator("button:has-text('Kibana'):has-text('alert')").click() @@ -205,80 +189,98 @@ def test_provider_validation(browser): connect_btn.click() expect(error_msg).to_have_count(3) # test `any_http_url` field validation - browser.get_by_placeholder("Enter provider name").fill("random name") - browser.get_by_placeholder("Enter api_key").fill("random api key") - browser.get_by_placeholder("Enter kibana_host").fill("invalid url") + host_input = browser.get_by_placeholder("Enter kibana_host") + host_input.fill("invalid url") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter kibana_host").fill("http://localhost") + host_input.fill("http://localhost") expect(error_msg).to_be_hidden() - browser.get_by_placeholder("Enter kibana_host").fill( - "https://keep.kb.us-central1.gcp.cloud.es.io" - ) + host_input.fill( "https://keep.kb.us-central1.gcp.cloud.es.io") expect(error_msg).to_be_hidden() # test `port` field validation - browser.get_by_placeholder("Enter kibana_port").fill("invalid port") + port_input = browser.get_by_placeholder("Enter kibana_port") + port_input.fill("invalid port") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter kibana_port").fill("0") + port_input.fill("0") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter kibana_port").fill("65_536") + port_input.fill("65_536") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter kibana_port").fill("9243") + port_input.fill("9243") expect(error_msg).to_be_hidden() cancel_btn.click() # using Teams Provider browser.locator("button:has-text('Teams'):has-text('messaging')").click() # test `https_url` field validation - browser.get_by_placeholder("Enter provider name").fill("random name") - browser.get_by_placeholder("Enter webhook_url").fill("random url") + url_input = browser.get_by_placeholder("Enter webhook_url") + url_input.fill("random url") + expect(error_msg).to_have_count(1) + url_input.fill("http://localhost") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter webhook_url").fill("http://localhost") + url_input.fill("http://example.com") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter webhook_url").fill("http://example.com") + url_input.fill("https://example.c") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter webhook_url").fill("https://example.com") + url_input.fill("https://example.com") expect(error_msg).to_be_hidden() cancel_btn.click() # using Site24x7 Provider browser.locator("button:has-text('Site24x7'):has-text('alert')").click() # test `tld` field validation - browser.get_by_placeholder("Enter provider name").fill("random name") - browser.get_by_placeholder("Enter zohoRefreshToken").fill("random") - browser.get_by_placeholder("Enter zohoClientId").fill("random") - browser.get_by_placeholder("Enter zohoClientSecret").fill("random") - browser.get_by_placeholder("Enter zohoAccountTLD").fill("random") + tld_input = browser.get_by_placeholder("Enter zohoAccountTLD") + tld_input.fill("random") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter zohoAccountTLD").fill("") + tld_input.fill("") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter zohoAccountTLD").fill(".com") + tld_input.fill(".com") expect(error_msg).to_be_hidden() cancel_btn.click() # using MongoDB Provider browser.locator("button:has-text('MongoDB'):has-text('data')").click() - # test `any_url` field validation - browser.get_by_placeholder("Enter provider name").fill("random name") - browser.get_by_placeholder("Enter host").fill("random") + # test `multihost_url` field validation + host_input = browser.get_by_placeholder("Enter host") + host_input.fill("random") + expect(error_msg).to_have_count(1) + host_input.fill("host.com:5000") + expect(error_msg).to_have_count(1) + host_input.fill("host1.com:5000,host2.com:3000") + expect(error_msg).to_have_count(1) + host_input.fill("mongodb://host1.com:5000,mongodb+srv://host2.com:3000") + expect(error_msg).to_have_count(1) + host_input.fill("mongodb://host.com:3000") + expect(error_msg).to_be_hidden() + host_input.fill("mongodb://localhost:3000,localhost:5000") + expect(error_msg).to_be_hidden() + cancel_btn.click() + + # using Kafka Provider + browser.locator("button:has-text('Kafka'):has-text('queue')").click() + # test `no_scheme_multihost_url` field validation + host_input = browser.get_by_placeholder("Enter host") + host_input.fill("*.") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter host").fill("host.com:5000") + host_input.fill("host.com:5000") + expect(error_msg).to_be_hidden() + host_input.fill("host1.com:5000,host2.com:3000") + expect(error_msg).to_be_hidden() + host_input.fill("http://host1.com:5000,https://host2.com:3000") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter host").fill("mongodb://host.com:3000") + host_input.fill("http://host.com:3000") + expect(error_msg).to_be_hidden() + host_input.fill("mongodb://localhost:3000,localhost:5000") expect(error_msg).to_be_hidden() cancel_btn.click() + # using Postgres provider browser.get_by_role("link", name="Providers").click() browser.locator("button:has-text('PostgreSQL'):has-text('data')").click() # test `no_scheme_url` field validation - # - on the frontend: url with/without scheme validates. - # - on the backend: scheme is removed during validation. - browser.get_by_placeholder("Enter provider name").fill("random name") - browser.get_by_placeholder("Enter username").fill("username") - browser.get_by_placeholder("Enter password").fill("password") - browser.get_by_placeholder("Enter host").fill("*.") + host_input = browser.get_by_placeholder("Enter host") + host_input.fill("*.") expect(error_msg).to_have_count(1) - browser.get_by_placeholder("Enter host").fill("localhost:5000") + host_input.fill("localhost:5000") expect(error_msg).to_be_hidden() - browser.get_by_placeholder("Enter host").fill("https://host.com:3000") + host_input.fill("https://host.com:3000") expect(error_msg).to_be_hidden() From f4fe709dbe1330487f6179780bf9df2919902ec9 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Thu, 28 Nov 2024 12:05:13 +0100 Subject: [PATCH 33/35] update client-side validation logic --- keep-ui/app/(keep)/providers/provider-form.tsx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/keep-ui/app/(keep)/providers/provider-form.tsx b/keep-ui/app/(keep)/providers/provider-form.tsx index a234f7709..a93c249b3 100644 --- a/keep-ui/app/(keep)/providers/provider-form.tsx +++ b/keep-ui/app/(keep)/providers/provider-form.tsx @@ -318,13 +318,15 @@ const ProviderForm = ({ } if ( - value == undefined || typeof value === "boolean" || (typeof value === "object" && value instanceof File === false) ) return; - const isValid = validate({ [key]: value }); + const isValid = validate({ + [key]: + typeof value === "string" && value.length === 0 ? undefined : value, + }); if (isValid) { const updatedInputErrors = { ...inputErrors }; delete updatedInputErrors[key]; @@ -355,7 +357,7 @@ const ProviderForm = ({ ? (errors[field] = err._errors[0]) : null; }); - setInputErrors(errors); + setInputErrors((prev) => ({ ...prev, ...errors })); return false; } From 12ed2c7d72cd4ba5af0b9345db16b718f84ee308 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Thu, 28 Nov 2024 13:51:57 +0100 Subject: [PATCH 34/35] fix validation & tests bugs --- keep/providers/prometheus_provider/prometheus_provider.py | 1 + tests/e2e_tests/test_end_to_end.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/keep/providers/prometheus_provider/prometheus_provider.py b/keep/providers/prometheus_provider/prometheus_provider.py index f165f3a59..cb3c392a7 100644 --- a/keep/providers/prometheus_provider/prometheus_provider.py +++ b/keep/providers/prometheus_provider/prometheus_provider.py @@ -23,6 +23,7 @@ class PrometheusProviderAuthConfig: "required": True, "description": "Prometheus server URL", "hint": "https://prometheus-us-central1.grafana.net/api/prom", + "validation": "any_http_url" } ) username: str = dataclasses.field( diff --git a/tests/e2e_tests/test_end_to_end.py b/tests/e2e_tests/test_end_to_end.py index c27b77fcf..91fad01bc 100644 --- a/tests/e2e_tests/test_end_to_end.py +++ b/tests/e2e_tests/test_end_to_end.py @@ -188,7 +188,9 @@ def test_provider_validation(browser): error_msg = browser.locator("p.tremor-TextInput-errorMessage") connect_btn.click() expect(error_msg).to_have_count(3) + cancel_btn.click() # test `any_http_url` field validation + browser.locator("button:has-text('Kibana'):has-text('alert')").click() host_input = browser.get_by_placeholder("Enter kibana_host") host_input.fill("invalid url") expect(error_msg).to_have_count(1) From 303a9608cc3274cffebda19e2a0e338ef3b330b1 Mon Sep 17 00:00:00 2001 From: theedigerati <39467790+theedigerati@users.noreply.github.com> Date: Fri, 29 Nov 2024 01:49:26 +0100 Subject: [PATCH 35/35] code review updates --- keep/providers/jira_provider/jira_provider.py | 10 +- .../kibana_provider/kibana_provider.py | 2 +- .../openobserve_provider.py | 2 +- .../splunk_provider/splunk_provider.py | 6 + .../squadcast_provider/squadcast_provider.py | 10 +- tests/e2e_tests/test_end_to_end.py | 228 +++++++++--------- 6 files changed, 140 insertions(+), 118 deletions(-) diff --git a/keep/providers/jira_provider/jira_provider.py b/keep/providers/jira_provider/jira_provider.py index 7b008461f..60695066e 100644 --- a/keep/providers/jira_provider/jira_provider.py +++ b/keep/providers/jira_provider/jira_provider.py @@ -162,12 +162,10 @@ def validate_config(self): @property def jira_host(self): - host = ( - self.authentication_config.host - if self.authentication_config.host.startswith("https://") - else f"https://{self.authentication_config.host}" - ) - return host + if self._host: + return self._host + self._host = str(self.authentication_config.host) + return self._host def dispose(self): """ diff --git a/keep/providers/kibana_provider/kibana_provider.py b/keep/providers/kibana_provider/kibana_provider.py index ff061bc89..f522e4cf9 100644 --- a/keep/providers/kibana_provider/kibana_provider.py +++ b/keep/providers/kibana_provider/kibana_provider.py @@ -445,7 +445,7 @@ def validate_config(self): if self.is_installed or self.is_provisioned: host = self.config.authentication['kibana_host'] if not (host.startswith("http://") or host.startswith("https://")): - scheme = "http://" if "localhost" in host else "https://" + scheme = "http://" if ("localhost" in host or "127.0.0.1" in host) else "https://" self.config.authentication['kibana_host'] = scheme + host self.authentication_config = KibanaProviderAuthConfig( diff --git a/keep/providers/openobserve_provider/openobserve_provider.py b/keep/providers/openobserve_provider/openobserve_provider.py index fc37cf7a5..04c166a88 100644 --- a/keep/providers/openobserve_provider/openobserve_provider.py +++ b/keep/providers/openobserve_provider/openobserve_provider.py @@ -111,7 +111,7 @@ def validate_config(self): if self.is_installed or self.is_provisioned: host = self.config.authentication['openObserveHost'] if not (host.startswith("http://") or host.startswith("https://")): - scheme = "http://" if "localhost" in host else "https://" + scheme = "http://" if ("localhost" in host or "127.0.0.1" in host) else "https://" self.config.authentication['openObserveHost'] = scheme + host self.authentication_config = OpenobserveProviderAuthConfig( diff --git a/keep/providers/splunk_provider/splunk_provider.py b/keep/providers/splunk_provider/splunk_provider.py index 8f13e09c0..29bee7936 100644 --- a/keep/providers/splunk_provider/splunk_provider.py +++ b/keep/providers/splunk_provider/splunk_provider.py @@ -120,6 +120,12 @@ def validate_scopes(self) -> dict[str, bool | str]: ) self.logger.debug("Connected to Splunk", extra={"service": service}) + if not self.authentication_config.verify: + self.logger.warning( + "SSL verification is disabled - connection is not secure", + extra={"host": self.authentication_config.host} + ) + if len(service.users) > 1: self.logger.warning( "Splunk provider has more than one user", diff --git a/keep/providers/squadcast_provider/squadcast_provider.py b/keep/providers/squadcast_provider/squadcast_provider.py index ab0945504..eac756a8d 100644 --- a/keep/providers/squadcast_provider/squadcast_provider.py +++ b/keep/providers/squadcast_provider/squadcast_provider.py @@ -137,7 +137,15 @@ def _create_incidents( ) # append body to additional_json we are doing this way because we don't want to override the core body fields - body = json.dumps({**json.loads(additional_json), **json.loads(body)}) + try: + additional_fields = json.loads(additional_json) if additional_json else {} + core_fields = json.loads(body) + body = json.dumps({**additional_fields, **core_fields}) + except json.JSONDecodeError as e: + raise ProviderConfigException( + f"Invalid additional_json format: {str(e)}", + provider_id=self.provider_id + ) return requests.post( self.authentication_config.webhook_url, data=body, headers=headers diff --git a/tests/e2e_tests/test_end_to_end.py b/tests/e2e_tests/test_end_to_end.py index 91fad01bc..86a509902 100644 --- a/tests/e2e_tests/test_end_to_end.py +++ b/tests/e2e_tests/test_end_to_end.py @@ -177,112 +177,122 @@ def test_provider_validation(browser): """ Test field validation for provider fields. """ - # browser = browserx - browser.goto("http://localhost:3000/signin") - # using Kibana Provider - browser.get_by_role("link", name="Providers").click() - browser.locator("button:has-text('Kibana'):has-text('alert')").click() - # test required fields - connect_btn = browser.get_by_role("button", name="Connect", exact=True) - cancel_btn = browser.get_by_role("button", name="Cancel", exact=True) - error_msg = browser.locator("p.tremor-TextInput-errorMessage") - connect_btn.click() - expect(error_msg).to_have_count(3) - cancel_btn.click() - # test `any_http_url` field validation - browser.locator("button:has-text('Kibana'):has-text('alert')").click() - host_input = browser.get_by_placeholder("Enter kibana_host") - host_input.fill("invalid url") - expect(error_msg).to_have_count(1) - host_input.fill("http://localhost") - expect(error_msg).to_be_hidden() - host_input.fill( "https://keep.kb.us-central1.gcp.cloud.es.io") - expect(error_msg).to_be_hidden() - # test `port` field validation - port_input = browser.get_by_placeholder("Enter kibana_port") - port_input.fill("invalid port") - expect(error_msg).to_have_count(1) - port_input.fill("0") - expect(error_msg).to_have_count(1) - port_input.fill("65_536") - expect(error_msg).to_have_count(1) - port_input.fill("9243") - expect(error_msg).to_be_hidden() - cancel_btn.click() - - # using Teams Provider - browser.locator("button:has-text('Teams'):has-text('messaging')").click() - # test `https_url` field validation - url_input = browser.get_by_placeholder("Enter webhook_url") - url_input.fill("random url") - expect(error_msg).to_have_count(1) - url_input.fill("http://localhost") - expect(error_msg).to_have_count(1) - url_input.fill("http://example.com") - expect(error_msg).to_have_count(1) - url_input.fill("https://example.c") - expect(error_msg).to_have_count(1) - url_input.fill("https://example.com") - expect(error_msg).to_be_hidden() - cancel_btn.click() - - # using Site24x7 Provider - browser.locator("button:has-text('Site24x7'):has-text('alert')").click() - # test `tld` field validation - tld_input = browser.get_by_placeholder("Enter zohoAccountTLD") - tld_input.fill("random") - expect(error_msg).to_have_count(1) - tld_input.fill("") - expect(error_msg).to_have_count(1) - tld_input.fill(".com") - expect(error_msg).to_be_hidden() - cancel_btn.click() - - # using MongoDB Provider - browser.locator("button:has-text('MongoDB'):has-text('data')").click() - # test `multihost_url` field validation - host_input = browser.get_by_placeholder("Enter host") - host_input.fill("random") - expect(error_msg).to_have_count(1) - host_input.fill("host.com:5000") - expect(error_msg).to_have_count(1) - host_input.fill("host1.com:5000,host2.com:3000") - expect(error_msg).to_have_count(1) - host_input.fill("mongodb://host1.com:5000,mongodb+srv://host2.com:3000") - expect(error_msg).to_have_count(1) - host_input.fill("mongodb://host.com:3000") - expect(error_msg).to_be_hidden() - host_input.fill("mongodb://localhost:3000,localhost:5000") - expect(error_msg).to_be_hidden() - cancel_btn.click() - - # using Kafka Provider - browser.locator("button:has-text('Kafka'):has-text('queue')").click() - # test `no_scheme_multihost_url` field validation - host_input = browser.get_by_placeholder("Enter host") - host_input.fill("*.") - expect(error_msg).to_have_count(1) - host_input.fill("host.com:5000") - expect(error_msg).to_be_hidden() - host_input.fill("host1.com:5000,host2.com:3000") - expect(error_msg).to_be_hidden() - host_input.fill("http://host1.com:5000,https://host2.com:3000") - expect(error_msg).to_have_count(1) - host_input.fill("http://host.com:3000") - expect(error_msg).to_be_hidden() - host_input.fill("mongodb://localhost:3000,localhost:5000") - expect(error_msg).to_be_hidden() - cancel_btn.click() - - - # using Postgres provider - browser.get_by_role("link", name="Providers").click() - browser.locator("button:has-text('PostgreSQL'):has-text('data')").click() - # test `no_scheme_url` field validation - host_input = browser.get_by_placeholder("Enter host") - host_input.fill("*.") - expect(error_msg).to_have_count(1) - host_input.fill("localhost:5000") - expect(error_msg).to_be_hidden() - host_input.fill("https://host.com:3000") - expect(error_msg).to_be_hidden() + try: + browser.goto("http://localhost:3000/signin") + # using Kibana Provider + browser.get_by_role("link", name="Providers").click() + browser.locator("button:has-text('Kibana'):has-text('alert')").click() + # test required fields + connect_btn = browser.get_by_role("button", name="Connect", exact=True) + cancel_btn = browser.get_by_role("button", name="Cancel", exact=True) + error_msg = browser.locator("p.tremor-TextInput-errorMessage") + connect_btn.click() + expect(error_msg).to_have_count(3) + cancel_btn.click() + # test `any_http_url` field validation + browser.locator("button:has-text('Kibana'):has-text('alert')").click() + host_input = browser.get_by_placeholder("Enter kibana_host") + host_input.fill("invalid url") + expect(error_msg).to_have_count(1) + host_input.fill("http://localhost") + expect(error_msg).to_be_hidden() + host_input.fill( "https://keep.kb.us-central1.gcp.cloud.es.io") + expect(error_msg).to_be_hidden() + # test `port` field validation + port_input = browser.get_by_placeholder("Enter kibana_port") + port_input.fill("invalid port") + expect(error_msg).to_have_count(1) + port_input.fill("0") + expect(error_msg).to_have_count(1) + port_input.fill("65_536") + expect(error_msg).to_have_count(1) + port_input.fill("9243") + expect(error_msg).to_be_hidden() + cancel_btn.click() + + # using Teams Provider + browser.locator("button:has-text('Teams'):has-text('messaging')").click() + # test `https_url` field validation + url_input = browser.get_by_placeholder("Enter webhook_url") + url_input.fill("random url") + expect(error_msg).to_have_count(1) + url_input.fill("http://localhost") + expect(error_msg).to_have_count(1) + url_input.fill("http://example.com") + expect(error_msg).to_have_count(1) + url_input.fill("https://example.c") + expect(error_msg).to_have_count(1) + url_input.fill("https://example.com") + expect(error_msg).to_be_hidden() + cancel_btn.click() + + # using Site24x7 Provider + browser.locator("button:has-text('Site24x7'):has-text('alert')").click() + # test `tld` field validation + tld_input = browser.get_by_placeholder("Enter zohoAccountTLD") + tld_input.fill("random") + expect(error_msg).to_have_count(1) + tld_input.fill("") + expect(error_msg).to_have_count(1) + tld_input.fill(".com") + expect(error_msg).to_be_hidden() + cancel_btn.click() + + # using MongoDB Provider + browser.locator("button:has-text('MongoDB'):has-text('data')").click() + # test `multihost_url` field validation + host_input = browser.get_by_placeholder("Enter host") + host_input.fill("random") + expect(error_msg).to_have_count(1) + host_input.fill("host.com:5000") + expect(error_msg).to_have_count(1) + host_input.fill("host1.com:5000,host2.com:3000") + expect(error_msg).to_have_count(1) + host_input.fill("mongodb://host1.com:5000,mongodb+srv://host2.com:3000") + expect(error_msg).to_have_count(1) + host_input.fill("mongodb://host.com:3000") + expect(error_msg).to_be_hidden() + host_input.fill("mongodb://localhost:3000,localhost:5000") + expect(error_msg).to_be_hidden() + cancel_btn.click() + + # using Kafka Provider + browser.locator("button:has-text('Kafka'):has-text('queue')").click() + # test `no_scheme_multihost_url` field validation + host_input = browser.get_by_placeholder("Enter host") + host_input.fill("*.") + expect(error_msg).to_have_count(1) + host_input.fill("host.com:5000") + expect(error_msg).to_be_hidden() + host_input.fill("host1.com:5000,host2.com:3000") + expect(error_msg).to_be_hidden() + host_input.fill("http://host1.com:5000,https://host2.com:3000") + expect(error_msg).to_have_count(1) + host_input.fill("http://host.com:3000") + expect(error_msg).to_be_hidden() + host_input.fill("mongodb://localhost:3000,localhost:5000") + expect(error_msg).to_be_hidden() + cancel_btn.click() + + # using Postgres provider + browser.get_by_role("link", name="Providers").click() + browser.locator("button:has-text('PostgreSQL'):has-text('data')").click() + # test `no_scheme_url` field validation + host_input = browser.get_by_placeholder("Enter host") + host_input.fill("*.") + expect(error_msg).to_have_count(1) + host_input.fill("localhost:5000") + expect(error_msg).to_be_hidden() + host_input.fill("https://host.com:3000") + expect(error_msg).to_be_hidden() + except Exception: + current_test_name = ( + "playwright_dump_" + + os.path.basename(__file__)[:-3] + + "_" + + sys._getframe().f_code.co_name + ) + browser.screenshot(path=current_test_name + ".png") + with open(current_test_name + ".html", "w") as f: + f.write(browser.content()) + raise