Skip to content

Commit

Permalink
Bitwarden Security Upgrade (#900)
Browse files Browse the repository at this point in the history
  • Loading branch information
ykeremy authored Oct 2, 2024
1 parent 4f6feae commit 36135a6
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Add bitwarden details to organizations
Revision ID: 6c90d565076b
Revises: c5848cc524b1
Create Date: 2024-10-02 22:12:34.959165+00:00
"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "6c90d565076b"
down_revision: Union[str, None] = "c5848cc524b1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("organizations", sa.Column("bw_organization_id", sa.String(), nullable=True))
op.add_column("organizations", sa.Column("bw_collection_ids", sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("organizations", "bw_collection_ids")
op.drop_column("organizations", "bw_organization_id")
# ### end Alembic commands ###
9 changes: 9 additions & 0 deletions skyvern/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ def __init__(self, message: str) -> None:
super().__init__(f"Error syncing Bitwarden: {message}")


class BitwardenAccessDeniedError(BitwardenBaseError):
def __init__(self) -> None:
super().__init__(
"Current organization does not have access to the specified Bitwarden collection. \
Contact Skyvern support to enable access. This is a security layer on top of Bitwarden, \
Skyvern team needs to let your Skyvern account access the Bitwarden collection."
)


class UnknownElementTreeFormat(SkyvernException):
def __init__(self, fmt: str) -> None:
super().__init__(f"Unknown element tree format {fmt}")
Expand Down
2 changes: 2 additions & 0 deletions skyvern/forge/sdk/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class OrganizationModel(Base):
max_steps_per_run = Column(Integer, nullable=True)
max_retries_per_step = Column(Integer, nullable=True)
domain = Column(String, nullable=True, index=True)
bw_organization_id = Column(String, nullable=True, default=None)
bw_collection_ids = Column(JSON, nullable=True, default=None)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
DateTime,
Expand Down
2 changes: 2 additions & 0 deletions skyvern/forge/sdk/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
max_steps_per_run=org_model.max_steps_per_run,
max_retries_per_step=org_model.max_retries_per_step,
domain=org_model.domain,
bw_organization_id=org_model.bw_organization_id,
bw_collection_ids=org_model.bw_collection_ids,
created_at=org_model.created_at,
modified_at=org_model.modified_at,
)
Expand Down
6 changes: 6 additions & 0 deletions skyvern/forge/sdk/executor/async_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,14 @@ async def execute_workflow(
"Executing workflow using background task executor",
workflow_run_id=workflow_run_id,
)

organization = await app.DATABASE.get_organization(organization_id)
if organization is None:
raise OrganizationNotFound(organization_id)

background_tasks.add_task(
app.WORKFLOW_SERVICE.execute_workflow,
workflow_run_id=workflow_run_id,
api_key=api_key,
organization=organization,
)
2 changes: 2 additions & 0 deletions skyvern/forge/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class Organization(BaseModel):
max_steps_per_run: int | None = None
max_retries_per_step: int | None = None
domain: str | None = None
bw_organization_id: str | None = None
bw_collection_ids: list[str] | None = None

created_at: datetime
modified_at: datetime
Expand Down
49 changes: 48 additions & 1 deletion skyvern/forge/sdk/services/bitwarden.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from skyvern.config import settings
from skyvern.exceptions import (
BitwardenAccessDeniedError,
BitwardenListItemsError,
BitwardenLoginError,
BitwardenLogoutError,
Expand All @@ -29,6 +30,9 @@ def is_valid_email(email: str | None) -> bool:


class BitwardenConstants(StrEnum):
BW_ORGANIZATION_ID = "BW_ORGANIZATION_ID"
BW_COLLECTION_IDS = "BW_COLLECTION_IDS"

CLIENT_ID = "BW_CLIENT_ID"
CLIENT_SECRET = "BW_CLIENT_SECRET"
MASTER_PASSWORD = "BW_MASTER_PASSWORD"
Expand Down Expand Up @@ -79,6 +83,8 @@ async def get_secret_value_from_url(
client_id: str,
client_secret: str,
master_password: str,
bw_organization_id: str | None,
bw_collection_ids: list[str] | None,
url: str,
collection_id: str | None = None,
remaining_retries: int = settings.BITWARDEN_MAX_RETRIES,
Expand All @@ -94,6 +100,8 @@ async def get_secret_value_from_url(
client_id=client_id,
client_secret=client_secret,
master_password=master_password,
bw_organization_id=bw_organization_id,
bw_collection_ids=bw_collection_ids,
url=url,
collection_id=collection_id,
)
Expand All @@ -109,6 +117,8 @@ async def get_secret_value_from_url(
client_id=client_id,
client_secret=client_secret,
master_password=master_password,
bw_organization_id=bw_organization_id,
bw_collection_ids=bw_collection_ids,
url=url,
collection_id=collection_id,
remaining_retries=remaining_retries,
Expand All @@ -122,12 +132,16 @@ async def _get_secret_value_from_url(
client_id: str,
client_secret: str,
master_password: str,
bw_organization_id: str | None,
bw_collection_ids: list[str] | None,
url: str,
collection_id: str | None = None,
) -> dict[str, str]:
"""
Get the secret value from the Bitwarden CLI.
"""
if not bw_organization_id and bw_collection_ids and collection_id not in bw_collection_ids:
raise BitwardenAccessDeniedError()
try:
BitwardenService.login(client_id, client_secret)
BitwardenService.sync()
Expand All @@ -144,7 +158,13 @@ async def _get_secret_value_from_url(
"--session",
session_key,
]
if collection_id:
if bw_organization_id:
LOG.info(
"Organization ID is provided, filtering items by organization ID",
bw_organization_id=bw_organization_id,
)
list_command.extend(["--organizationid", bw_organization_id])
elif collection_id:
LOG.info("Collection ID is provided, filtering items by collection ID", collection_id=collection_id)
list_command.extend(["--collectionid", collection_id])
items_result = BitwardenService.run_command(list_command)
Expand All @@ -158,11 +178,26 @@ async def _get_secret_value_from_url(
except json.JSONDecodeError:
raise BitwardenListItemsError("Failed to parse items JSON. Output: " + items_result.stdout)

# Since Bitwarden can't AND multiple filters, we only use organization id in the list command
# but we still need to filter the items by collection id here
if bw_organization_id and collection_id:
filtered_items = []
for item in items:
if "collectionIds" in item and collection_id in item["collectionIds"]:
filtered_items.append(item)
items = filtered_items

if not items:
collection_id_str = f" in collection with ID: {collection_id}" if collection_id else ""
raise BitwardenListItemsError(f"No items found in Bitwarden for URL: {url}{collection_id_str}")

# TODO (kerem): To make this more robust, we need to store the item id of the totp login item
# and use it here to get the TOTP code for that specific item
totp_command = ["bw", "get", "totp", url, "--session", session_key]
if bw_organization_id:
# We need to add this filter because the TOTP command fails if there are multiple results
# For now, we require that the bitwarden organization id has only one totp login item for the domain
totp_command.extend(["--organizationid", bw_organization_id])
totp_result = BitwardenService.run_command(totp_command)

if totp_result.stderr and "Event post failed" not in totp_result.stderr:
Expand Down Expand Up @@ -208,6 +243,8 @@ async def get_sensitive_information_from_identity(
client_id: str,
client_secret: str,
master_password: str,
bw_organization_id: str | None,
bw_collection_ids: list[str] | None,
collection_id: str,
identity_key: str,
identity_fields: list[str],
Expand All @@ -224,6 +261,8 @@ async def get_sensitive_information_from_identity(
client_id=client_id,
client_secret=client_secret,
master_password=master_password,
bw_organization_id=bw_organization_id,
bw_collection_ids=bw_collection_ids,
collection_id=collection_id,
identity_key=identity_key,
identity_fields=identity_fields,
Expand All @@ -240,6 +279,8 @@ async def get_sensitive_information_from_identity(
client_id=client_id,
client_secret=client_secret,
master_password=master_password,
bw_organization_id=bw_organization_id,
bw_collection_ids=bw_collection_ids,
collection_id=collection_id,
identity_key=identity_key,
identity_fields=identity_fields,
Expand All @@ -257,10 +298,14 @@ async def _get_sensitive_information_from_identity(
collection_id: str,
identity_key: str,
identity_fields: list[str],
bw_organization_id: str | None,
bw_collection_ids: list[str] | None,
) -> dict[str, str]:
"""
Get the sensitive information from the Bitwarden CLI.
"""
if not bw_organization_id and bw_collection_ids and collection_id not in bw_collection_ids:
raise BitwardenAccessDeniedError()
try:
BitwardenService.login(client_id, client_secret)
BitwardenService.sync()
Expand All @@ -278,6 +323,8 @@ async def _get_sensitive_information_from_identity(
"--collectionid",
collection_id,
]
if bw_organization_id:
list_command.extend(["--organizationid", bw_organization_id])
items_result = BitwardenService.run_command(list_command)

# Parse the items and extract sensitive information
Expand Down
20 changes: 18 additions & 2 deletions skyvern/forge/sdk/workflow/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from skyvern.exceptions import BitwardenBaseError, WorkflowRunContextNotInitialized
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.models import Organization
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants, BitwardenService
from skyvern.forge.sdk.workflow.exceptions import OutputParameterKeyCollisionError
from skyvern.forge.sdk.workflow.models.parameter import (
Expand Down Expand Up @@ -106,6 +107,8 @@ async def get_secrets_from_password_manager(self) -> dict[str, Any]:
client_secret=self.secrets[BitwardenConstants.CLIENT_SECRET],
client_id=self.secrets[BitwardenConstants.CLIENT_ID],
master_password=self.secrets[BitwardenConstants.MASTER_PASSWORD],
bw_organization_id=self.secrets[BitwardenConstants.BW_ORGANIZATION_ID],
bw_collection_ids=self.secrets[BitwardenConstants.BW_COLLECTION_IDS],
)
return secret_credentials

Expand All @@ -117,6 +120,7 @@ async def register_parameter_value(
self,
aws_client: AsyncAWSClient,
parameter: PARAMETER_TYPE,
organization: Organization,
) -> None:
if parameter.parameter_type == ParameterType.WORKFLOW:
LOG.error(f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}")
Expand Down Expand Up @@ -165,10 +169,14 @@ async def register_parameter_value(
client_id,
client_secret,
master_password,
organization.bw_organization_id,
organization.bw_collection_ids,
url,
collection_id=collection_id,
)
if secret_credentials:
self.secrets[BitwardenConstants.BW_ORGANIZATION_ID] = organization.bw_organization_id
self.secrets[BitwardenConstants.BW_COLLECTION_IDS] = organization.bw_collection_ids
self.secrets[BitwardenConstants.URL] = url
self.secrets[BitwardenConstants.CLIENT_SECRET] = client_secret
self.secrets[BitwardenConstants.CLIENT_ID] = client_id
Expand Down Expand Up @@ -223,11 +231,15 @@ async def register_parameter_value(
client_id,
client_secret,
master_password,
organization.bw_organization_id,
organization.bw_collection_ids,
collection_id,
bitwarden_identity_key,
parameter.bitwarden_identity_fields,
)
if sensitive_values:
self.secrets[BitwardenConstants.BW_ORGANIZATION_ID] = organization.bw_organization_id
self.secrets[BitwardenConstants.BW_COLLECTION_IDS] = organization.bw_collection_ids
self.secrets[BitwardenConstants.IDENTITY_KEY] = bitwarden_identity_key
self.secrets[BitwardenConstants.CLIENT_SECRET] = client_secret
self.secrets[BitwardenConstants.CLIENT_ID] = client_id
Expand Down Expand Up @@ -333,6 +345,7 @@ async def register_block_parameters(
self,
aws_client: AsyncAWSClient,
parameters: list[PARAMETER_TYPE],
organization: Organization,
) -> None:
# Sort the parameters so that ContextParameter and BitwardenLoginCredentialParameter are processed last
# ContextParameter should be processed at the end since it requires the source parameter to be set
Expand Down Expand Up @@ -369,7 +382,7 @@ async def register_block_parameters(
)

self.parameters[parameter.key] = parameter
await self.register_parameter_value(aws_client, parameter)
await self.register_parameter_value(aws_client, parameter, organization)


class WorkflowContextManager:
Expand Down Expand Up @@ -410,6 +423,9 @@ async def register_block_parameters_for_workflow_run(
self,
workflow_run_id: str,
parameters: list[PARAMETER_TYPE],
organization: Organization,
) -> None:
self._validate_workflow_run_context(workflow_run_id)
await self.workflow_run_contexts[workflow_run_id].register_block_parameters(self.aws_client, parameters)
await self.workflow_run_contexts[workflow_run_id].register_block_parameters(
self.aws_client, parameters, organization
)
7 changes: 4 additions & 3 deletions skyvern/forge/sdk/workflow/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import (
ContextParameterSourceNotDefined,
Expand Down Expand Up @@ -150,9 +150,10 @@ async def execute_workflow(
self,
workflow_run_id: str,
api_key: str,
organization_id: str | None = None,
organization: Organization,
) -> WorkflowRun:
"""Execute a workflow."""
organization_id = organization.organization_id
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id)

Expand Down Expand Up @@ -181,7 +182,7 @@ async def execute_workflow(
try:
parameters = block.get_all_parameters(workflow_run_id)
await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run(
workflow_run_id, parameters
workflow_run_id, parameters, organization
)
LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run_id}",
Expand Down

0 comments on commit 36135a6

Please sign in to comment.