Skip to content

Commit

Permalink
Auth fix + Registration Clarity (#3590)
Browse files Browse the repository at this point in the history
* clarify auth flow

* k

* nit

* k

* fix typing
  • Loading branch information
pablonyx authored Jan 6, 2025
1 parent e100a5e commit c8090ab
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 10 deletions.
1 change: 1 addition & 0 deletions backend/ee/onyx/server/tenants/billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
"""
Send a request to the control service to register the number of users for a tenant.
"""

if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")

Expand Down
11 changes: 7 additions & 4 deletions backend/onyx/auth/email_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,24 @@ def send_email(


def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Workspace"
subject = "Invitation to Join Onyx Organization"
body = dedent(
f"""\
Hello,
You have been invited to join a workspace on Onyx.
You have been invited to join an organization on Onyx.
To join the workspace, please visit the following link:
To join the organization, please visit the following link:
{WEB_DOMAIN}/auth/login
{WEB_DOMAIN}/auth/signup?email={user_email}
You'll be asked to set a password or login with Google to complete your registration.
Best regards,
The Onyx Team
"""
)

send_email(user_email, subject, body, current_user.email)


Expand Down
4 changes: 0 additions & 4 deletions backend/onyx/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession

from onyx.auth.api_key import get_hashed_api_key_from_request
Expand Down Expand Up @@ -396,11 +395,9 @@ async def oauth_callback(

# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))

# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)

await self.on_after_register(user, request)

else:
Expand All @@ -419,7 +416,6 @@ async def oauth_callback(

# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled

if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(
Expand Down
16 changes: 15 additions & 1 deletion backend/onyx/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,23 @@ async def get_async_session_with_tenant(
bind=engine, expire_on_commit=False, class_=AsyncSession
) # type: ignore

async def _set_search_path(session: AsyncSession, tenant_id: str) -> None:
await session.execute(text(f'SET search_path = "{tenant_id}"'))

async with async_session_factory() as session:
# Register an event listener that is called whenever a new transaction starts
@event.listens_for(session.sync_session, "after_begin")
def after_begin(session_: Any, transaction: Any, connection: Any) -> None:
# Because the event is sync, we can't directly await here.
# Instead we queue up an asyncio task to ensures
# the next statement sets the search_path
session_.do_orm_execute = lambda state: connection.exec_driver_sql(
f'SET search_path = "{tenant_id}"'
)

try:
await session.execute(text(f'SET search_path = "{tenant_id}"'))
await _set_search_path(session, tenant_id)

if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
text(
Expand Down
4 changes: 3 additions & 1 deletion web/src/app/auth/login/EmailPasswordForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ export function EmailPasswordForm({
shouldVerify,
referralSource,
nextUrl,
defaultEmail,
}: {
isSignup?: boolean;
shouldVerify?: boolean;
referralSource?: string;
nextUrl?: string | null;
defaultEmail?: string | null;
}) {
const { user } = useUser();
const { popup, setPopup } = usePopup();
Expand All @@ -34,7 +36,7 @@ export function EmailPasswordForm({
{popup}
<Formik
initialValues={{
email: "",
email: defaultEmail || "",
password: "",
}}
validationSchema={Yup.object().shape({
Expand Down
5 changes: 5 additions & 0 deletions web/src/app/auth/signup/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ const Page = async (props: {
? searchParams?.next[0]
: searchParams?.next || null;

const defaultEmail = Array.isArray(searchParams?.email)
? searchParams?.email[0]
: searchParams?.email || null;

// catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page
// will not render
Expand Down Expand Up @@ -93,6 +97,7 @@ const Page = async (props: {
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
nextUrl={nextUrl}
defaultEmail={defaultEmail}
/>
</div>
</>
Expand Down

0 comments on commit c8090ab

Please sign in to comment.