diff --git a/.github/workflows/docker-build-push-backend-container-on-tag.yml b/.github/workflows/docker-build-push-backend-container-on-tag.yml index ef33750c271..162c0221308 100644 --- a/.github/workflows/docker-build-push-backend-container-on-tag.yml +++ b/.github/workflows/docker-build-push-backend-container-on-tag.yml @@ -6,7 +6,7 @@ on: - "*" env: - REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }} + REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }} LATEST_TAG: ${{ contains(github.ref_name, 'latest') }} jobs: @@ -44,7 +44,7 @@ jobs: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} ${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }} build-args: | - DANSWER_VERSION=${{ github.ref_name }} + ONYX_VERSION=${{ github.ref_name }} # trivy has their own rate limiting issues causing this action to flake # we worked around it by hardcoding to different db repos in env @@ -57,7 +57,7 @@ jobs: TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2" TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1" with: - # To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend + # To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} severity: "CRITICAL,HIGH" trivyignores: ./backend/.trivyignore diff --git a/.github/workflows/docker-build-push-cloud-web-container-on-tag.yml b/.github/workflows/docker-build-push-cloud-web-container-on-tag.yml index 45cd5093a0c..99caf6392a0 100644 --- a/.github/workflows/docker-build-push-cloud-web-container-on-tag.yml +++ b/.github/workflows/docker-build-push-cloud-web-container-on-tag.yml @@ -7,7 +7,7 @@ on: - "*" env: - REGISTRY_IMAGE: danswer/danswer-web-server-cloud + REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud LATEST_TAG: ${{ contains(github.ref_name, 'latest') }} jobs: @@ -60,7 +60,7 @@ jobs: platforms: ${{ matrix.platform }} push: true build-args: | - DANSWER_VERSION=${{ github.ref_name }} + ONYX_VERSION=${{ github.ref_name }} NEXT_PUBLIC_CLOUD_ENABLED=true NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }} NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }} diff --git a/.github/workflows/docker-build-push-model-server-container-on-tag.yml b/.github/workflows/docker-build-push-model-server-container-on-tag.yml index 3e0445ab04a..7df47c416ce 100644 --- a/.github/workflows/docker-build-push-model-server-container-on-tag.yml +++ b/.github/workflows/docker-build-push-model-server-container-on-tag.yml @@ -6,20 +6,70 @@ on: - "*" env: - REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }} + REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }} LATEST_TAG: ${{ contains(github.ref_name, 'latest') }} + DOCKER_BUILDKIT: 1 + BUILDKIT_PROGRESS: plain jobs: - build-and-push: - # See https://runs-on.com/runners/linux/ - runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"] + build-amd64: + runs-on: + [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: System Info + run: | + df -h + free -h + docker system prune -af --volumes + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + driver-opts: | + image=moby/buildkit:latest + network=host + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: Build and Push AMD64 + uses: docker/build-push-action@v5 + with: + context: ./backend + file: ./backend/Dockerfile.model_server + platforms: linux/amd64 + push: true + tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 + build-args: | + DANSWER_VERSION=${{ github.ref_name }} + outputs: type=registry + provenance: false + build-arm64: + runs-on: + [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"] steps: - name: Checkout code uses: actions/checkout@v4 + - name: System Info + run: | + df -h + free -h + docker system prune -af --volumes + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + with: + driver-opts: | + image=moby/buildkit:latest + network=host - name: Login to Docker Hub uses: docker/login-action@v3 @@ -27,29 +77,47 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - - name: Model Server Image Docker Build and Push + - name: Build and Push ARM64 uses: docker/build-push-action@v5 with: context: ./backend file: ./backend/Dockerfile.model_server - platforms: linux/amd64,linux/arm64 + platforms: linux/arm64 push: true - tags: | - ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} - ${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }} + tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64 build-args: | DANSWER_VERSION=${{ github.ref_name }} + outputs: type=registry + provenance: false + + merge-and-scan: + needs: [build-amd64, build-arm64] + runs-on: ubuntu-latest + steps: + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: Create and Push Multi-arch Manifest + run: | + docker buildx create --use + docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \ + ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \ + ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64 + if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then + docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \ + ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \ + ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64 + fi - # trivy has their own rate limiting issues causing this action to flake - # we worked around it by hardcoding to different db repos in env - # can re-enable when they figure it out - # https://github.com/aquasecurity/trivy/discussions/7538 - # https://github.com/aquasecurity/trivy-action/issues/389 - name: Run Trivy vulnerability scanner uses: aquasecurity/trivy-action@master env: TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2" TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1" with: - image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }} + image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }} severity: "CRITICAL,HIGH" + timeout: "10m" diff --git a/.github/workflows/docker-build-push-web-container-on-tag.yml b/.github/workflows/docker-build-push-web-container-on-tag.yml index 4f1fd804969..b7f4a5dbc68 100644 --- a/.github/workflows/docker-build-push-web-container-on-tag.yml +++ b/.github/workflows/docker-build-push-web-container-on-tag.yml @@ -3,12 +3,12 @@ name: Build and Push Web Image on Tag on: push: tags: - - '*' + - "*" env: - REGISTRY_IMAGE: danswer/danswer-web-server + REGISTRY_IMAGE: onyxdotapp/onyx-web-server LATEST_TAG: ${{ contains(github.ref_name, 'latest') }} - + jobs: build: runs-on: @@ -27,11 +27,11 @@ jobs: - name: Prepare run: | platform=${{ matrix.platform }} - echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - + echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV + - name: Checkout uses: actions/checkout@v4 - + - name: Docker meta id: meta uses: docker/metadata-action@v5 @@ -40,16 +40,16 @@ jobs: tags: | type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }} - + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - + - name: Login to Docker Hub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - + - name: Build and push by digest id: build uses: docker/build-push-action@v5 @@ -59,18 +59,18 @@ jobs: platforms: ${{ matrix.platform }} push: true build-args: | - DANSWER_VERSION=${{ github.ref_name }} - # needed due to weird interactions with the builds for different platforms + ONYX_VERSION=${{ github.ref_name }} + # needed due to weird interactions with the builds for different platforms no-cache: true labels: ${{ steps.meta.outputs.labels }} outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true - + - name: Export digest run: | mkdir -p /tmp/digests digest="${{ steps.build.outputs.digest }}" - touch "/tmp/digests/${digest#sha256:}" - + touch "/tmp/digests/${digest#sha256:}" + - name: Upload digest uses: actions/upload-artifact@v4 with: @@ -90,42 +90,42 @@ jobs: path: /tmp/digests pattern: digests-* merge-multiple: true - + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - + - name: Docker meta id: meta uses: docker/metadata-action@v5 with: images: ${{ env.REGISTRY_IMAGE }} - + - name: Login to Docker Hub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - + - name: Create manifest list and push working-directory: /tmp/digests run: | docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - $(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *) - + $(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *) + - name: Inspect image run: | docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }} - # trivy has their own rate limiting issues causing this action to flake - # we worked around it by hardcoding to different db repos in env - # can re-enable when they figure it out - # https://github.com/aquasecurity/trivy/discussions/7538 - # https://github.com/aquasecurity/trivy-action/issues/389 + # trivy has their own rate limiting issues causing this action to flake + # we worked around it by hardcoding to different db repos in env + # can re-enable when they figure it out + # https://github.com/aquasecurity/trivy/discussions/7538 + # https://github.com/aquasecurity/trivy-action/issues/389 - name: Run Trivy vulnerability scanner uses: aquasecurity/trivy-action@master env: - TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2' - TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1' + TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2" + TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1" with: image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} - severity: 'CRITICAL,HIGH' + severity: "CRITICAL,HIGH" diff --git a/.github/workflows/docker-tag-latest.yml b/.github/workflows/docker-tag-latest.yml index e2c7c30f31e..fd0c07e5ba7 100644 --- a/.github/workflows/docker-tag-latest.yml +++ b/.github/workflows/docker-tag-latest.yml @@ -7,31 +7,31 @@ on: workflow_dispatch: inputs: version: - description: 'The version (ie v0.0.1) to tag as latest' + description: "The version (ie v0.0.1) to tag as latest" required: true jobs: tag: # See https://runs-on.com/runners/linux/ # use a lower powered instance since this just does i/o to docker hub - runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"] + runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"] steps: - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 - - name: Login to Docker Hub - uses: docker/login-action@v1 - with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_TOKEN }} + - name: Login to Docker Hub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} - - name: Enable Docker CLI experimental features - run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV + - name: Enable Docker CLI experimental features + run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV - - name: Pull, Tag and Push Web Server Image - run: | - docker buildx imagetools create -t danswer/danswer-web-server:latest danswer/danswer-web-server:${{ github.event.inputs.version }} + - name: Pull, Tag and Push Web Server Image + run: | + docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${{ github.event.inputs.version }} - - name: Pull, Tag and Push API Server Image - run: | - docker buildx imagetools create -t danswer/danswer-backend:latest danswer/danswer-backend:${{ github.event.inputs.version }} + - name: Pull, Tag and Push API Server Image + run: | + docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }} diff --git a/.github/workflows/hotfix-release-branches.yml b/.github/workflows/hotfix-release-branches.yml index 0e921f8d694..6e14fa8269e 100644 --- a/.github/workflows/hotfix-release-branches.yml +++ b/.github/workflows/hotfix-release-branches.yml @@ -8,43 +8,42 @@ on: workflow_dispatch: inputs: hotfix_commit: - description: 'Hotfix commit hash' + description: "Hotfix commit hash" required: true hotfix_suffix: - description: 'Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})' + description: "Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})" required: true release_branch_pattern: - description: 'Release branch pattern (regex)' + description: "Release branch pattern (regex)" required: true - default: 'release/.*' + default: "release/.*" auto_merge: - description: 'Automatically merge the hotfix PRs' + description: "Automatically merge the hotfix PRs" required: true type: choice - default: 'true' + default: "true" options: - true - false - + jobs: hotfix_release_branches: permissions: write-all # See https://runs-on.com/runners/linux/ # use a lower powered instance since this just does i/o to docker hub - runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"] + runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"] steps: - # needs RKUO_DEPLOY_KEY for write access to merge PR's - name: Checkout Repository uses: actions/checkout@v4 with: ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}" fetch-depth: 0 - + - name: Set up Git user run: | git config user.name "Richard Kuo [bot]" - git config user.email "rkuo[bot]@danswer.ai" + git config user.email "rkuo[bot]@onyx.app" - name: Fetch All Branches run: | @@ -62,10 +61,10 @@ jobs: echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'." exit 1 fi - + echo "Found release branches:" echo "$BRANCHES" - + # Join the branches into a single line separated by commas BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//') @@ -169,4 +168,4 @@ jobs: echo "Failed to merge pull request #$PR_NUMBER." fi fi - done \ No newline at end of file + done diff --git a/.github/workflows/pr-backport-autotrigger.yml b/.github/workflows/pr-backport-autotrigger.yml index 273f00a5c5a..8b0906916e3 100644 --- a/.github/workflows/pr-backport-autotrigger.yml +++ b/.github/workflows/pr-backport-autotrigger.yml @@ -4,7 +4,7 @@ name: Backport on Merge on: pull_request: - types: [closed] # Later we check for merge so only PRs that go in can get backported + types: [closed] # Later we check for merge so only PRs that go in can get backported permissions: contents: write @@ -26,9 +26,9 @@ jobs: - name: Set up Git user run: | git config user.name "Richard Kuo [bot]" - git config user.email "rkuo[bot]@danswer.ai" + git config user.email "rkuo[bot]@onyx.app" git fetch --prune - + - name: Check for Backport Checkbox id: checkbox-check run: | @@ -51,14 +51,14 @@ jobs: # Fetch latest tags for beta and stable LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1) LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1) - + # Handle case where no beta tags exist if [[ -z "$LATEST_BETA_TAG" ]]; then NEW_BETA_TAG="v1.0.0-beta.1" else NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}') fi - + # Increment latest stable tag NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}') echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT @@ -80,10 +80,10 @@ jobs: run: | set -e echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}" - + # Echo the merge commit SHA echo "Merge commit SHA: ${{ github.event.pull_request.merge_commit_sha }}" - + # Fetch all history for all branches and tags git fetch --prune @@ -98,7 +98,7 @@ jobs: echo "Cherry-pick to beta failed due to conflicts." exit 1 } - + # Create new beta branch/tag git tag ${{ steps.list-branches.outputs.new_beta_tag }} # Push the changes and tag to the beta branch using PAT @@ -110,13 +110,13 @@ jobs: echo "Last 5 commits on stable branch:" git log -n 5 --pretty=format:"%H" echo "" # Newline for formatting - + # Cherry-pick the merge commit from the merged PR git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || { echo "Cherry-pick to stable failed due to conflicts." exit 1 } - + # Create new stable branch/tag git tag ${{ steps.list-branches.outputs.new_stable_tag }} # Push the changes and tag to the stable branch using PAT diff --git a/.github/workflows/pr-chromatic-tests.yml b/.github/workflows/pr-chromatic-tests.yml index 5d8b29ed572..1ebb7598116 100644 --- a/.github/workflows/pr-chromatic-tests.yml +++ b/.github/workflows/pr-chromatic-tests.yml @@ -14,18 +14,24 @@ jobs: name: Playwright Tests # See https://runs-on.com/runners/linux/ - runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"] + runs-on: + [ + runs-on, + runner=32cpu-linux-x64, + disk=large, + "run-id=${{ github.run_id }}", + ] steps: - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 0 - + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.11' - cache: 'pip' + python-version: "3.11" + cache: "pip" cache-dependency-path: | backend/requirements/default.txt backend/requirements/dev.txt @@ -35,7 +41,7 @@ jobs: pip install --retries 5 --timeout 30 -r backend/requirements/default.txt pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt - + - name: Setup node uses: actions/setup-node@v4 with: @@ -48,7 +54,7 @@ jobs: - name: Install playwright browsers working-directory: ./web run: npx playwright install --with-deps - + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -60,13 +66,13 @@ jobs: # tag every docker image with "test" so that we can spin up the correct set # of images during testing - + # we use the runs-on cache for docker builds # in conjunction with runs-on runners, it has better speed and unlimited caching # https://runs-on.com/caching/s3-cache-for-github-actions/ # https://runs-on.com/caching/docker/ # https://github.com/moby/buildkit#s3-cache-experimental - + # images are built and run locally for testing purposes. Not pushed. - name: Build Web Docker image @@ -75,7 +81,7 @@ jobs: context: ./web file: ./web/Dockerfile platforms: linux/amd64 - tags: danswer/danswer-web-server:test + tags: onyxdotapp/onyx-web-server:test push: false load: true cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }} @@ -87,7 +93,7 @@ jobs: context: ./backend file: ./backend/Dockerfile platforms: linux/amd64 - tags: danswer/danswer-backend:test + tags: onyxdotapp/onyx-backend:test push: false load: true cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }} @@ -99,7 +105,7 @@ jobs: context: ./backend file: ./backend/Dockerfile.model_server platforms: linux/amd64 - tags: danswer/danswer-model-server:test + tags: onyxdotapp/onyx-model-server:test push: false load: true cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }} @@ -110,6 +116,7 @@ jobs: cd deployment/docker_compose ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \ AUTH_TYPE=basic \ + GEN_AI_API_KEY=${{ secrets.OPENAI_API_KEY }} \ REQUIRE_EMAIL_VERIFICATION=false \ DISABLE_TELEMETRY=true \ IMAGE_TAG=test \ @@ -119,12 +126,12 @@ jobs: - name: Wait for service to be ready run: | echo "Starting wait-for-service script..." - + docker logs -f danswer-stack-api_server-1 & start_time=$(date +%s) timeout=300 # 5 minutes in seconds - + while true; do current_time=$(date +%s) elapsed_time=$((current_time - start_time)) @@ -152,7 +159,7 @@ jobs: - name: Run pytest playwright test init working-directory: ./backend - env: + env: PYTEST_IGNORE_SKIP: true run: pytest -s tests/integration/tests/playwright/test_playwright.py @@ -168,7 +175,7 @@ jobs: name: test-results path: ./web/test-results retention-days: 30 - + # save before stopping the containers so the logs can be captured - name: Save Docker logs if: success() || failure() @@ -176,7 +183,7 @@ jobs: cd deployment/docker_compose docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log mv docker-compose.log ${{ github.workspace }}/docker-compose.log - + - name: Upload logs if: success() || failure() uses: actions/upload-artifact@v4 @@ -191,35 +198,41 @@ jobs: chromatic-tests: name: Chromatic Tests - + needs: playwright-tests - runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"] + runs-on: + [ + runs-on, + runner=32cpu-linux-x64, + disk=large, + "run-id=${{ github.run_id }}", + ] steps: - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 0 - + - name: Setup node uses: actions/setup-node@v4 with: node-version: 22 - + - name: Install node dependencies working-directory: ./web run: npm ci - + - name: Download Playwright test results uses: actions/download-artifact@v4 with: name: test-results path: ./web/test-results - + - name: Run Chromatic uses: chromaui/action@latest with: playwright: true projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }} workingDir: ./web - env: + env: CHROMATIC_ARCHIVE_LOCATION: ./test-results diff --git a/.github/workflows/pr-integration-tests.yml b/.github/workflows/pr-integration-tests.yml index f2dc97e75da..f0004c4e256 100644 --- a/.github/workflows/pr-integration-tests.yml +++ b/.github/workflows/pr-integration-tests.yml @@ -8,7 +8,7 @@ on: pull_request: branches: - main - - 'release/**' + - "release/**" env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -16,11 +16,11 @@ env: CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }} CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }} CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }} - + jobs: integration-tests: # See https://runs-on.com/runners/linux/ - runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"] + runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"] steps: - name: Checkout code uses: actions/checkout@v4 @@ -36,21 +36,21 @@ jobs: # tag every docker image with "test" so that we can spin up the correct set # of images during testing - + # We don't need to build the Web Docker image since it's not yet used - # in the integration tests. We have a separate action to verify that it builds + # in the integration tests. We have a separate action to verify that it builds # successfully. - name: Pull Web Docker image run: | - docker pull danswer/danswer-web-server:latest - docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test + docker pull onyxdotapp/onyx-web-server:latest + docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test # we use the runs-on cache for docker builds # in conjunction with runs-on runners, it has better speed and unlimited caching # https://runs-on.com/caching/s3-cache-for-github-actions/ # https://runs-on.com/caching/docker/ # https://github.com/moby/buildkit#s3-cache-experimental - + # images are built and run locally for testing purposes. Not pushed. - name: Build Backend Docker image uses: ./.github/actions/custom-build-and-push @@ -58,7 +58,7 @@ jobs: context: ./backend file: ./backend/Dockerfile platforms: linux/amd64 - tags: danswer/danswer-backend:test + tags: onyxdotapp/onyx-backend:test push: false load: true cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }} @@ -70,19 +70,19 @@ jobs: context: ./backend file: ./backend/Dockerfile.model_server platforms: linux/amd64 - tags: danswer/danswer-model-server:test + tags: onyxdotapp/onyx-model-server:test push: false load: true cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }} cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max - + - name: Build integration test Docker image uses: ./.github/actions/custom-build-and-push with: context: ./backend file: ./backend/tests/integration/Dockerfile platforms: linux/amd64 - tags: danswer/danswer-integration:test + tags: onyxdotapp/onyx-integration:test push: false load: true cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }} @@ -119,7 +119,7 @@ jobs: -e TEST_WEB_HOSTNAME=test-runner \ -e AUTH_TYPE=cloud \ -e MULTI_TENANT=true \ - danswer/danswer-integration:test \ + onyxdotapp/onyx-integration:test \ /app/tests/integration/multitenant_tests continue-on-error: true id: run_multitenant_tests @@ -131,15 +131,14 @@ jobs: exit 1 else echo "All integration tests passed successfully." - fi + fi - name: Stop multi-tenant Docker containers run: | cd deployment/docker_compose docker compose -f docker-compose.dev.yml -p danswer-stack down -v - - - name: Start Docker containers + - name: Start Docker containers run: | cd deployment/docker_compose ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \ @@ -153,12 +152,12 @@ jobs: - name: Wait for service to be ready run: | echo "Starting wait-for-service script..." - + docker logs -f danswer-stack-api_server-1 & start_time=$(date +%s) timeout=300 # 5 minutes in seconds - + while true; do current_time=$(date +%s) elapsed_time=$((current_time - start_time)) @@ -202,7 +201,7 @@ jobs: -e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \ -e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \ -e TEST_WEB_HOSTNAME=test-runner \ - danswer/danswer-integration:test \ + onyxdotapp/onyx-integration:test \ /app/tests/integration/tests \ /app/tests/integration/connector_job_tests continue-on-error: true @@ -229,7 +228,7 @@ jobs: run: | cd deployment/docker_compose docker compose -f docker-compose.dev.yml -p danswer-stack down -v - + - name: Upload logs if: success() || failure() uses: actions/upload-artifact@v4 diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index 6e122860ee9..e8720adaf2f 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -24,6 +24,8 @@ env: GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }} GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }} GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }} + # Slab + SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }} jobs: connectors-check: diff --git a/.github/workflows/tag-nightly.yml b/.github/workflows/tag-nightly.yml index 50bb20808a3..61d1818f4d4 100644 --- a/.github/workflows/tag-nightly.yml +++ b/.github/workflows/tag-nightly.yml @@ -2,53 +2,52 @@ name: Nightly Tag Push on: schedule: - - cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC + - cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC permissions: - contents: write # Allows pushing tags to the repository + contents: write # Allows pushing tags to the repository jobs: create-and-push-tag: - runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"] + runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"] steps: - # actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes - # see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we - # implement here which needs an actual user's deploy key - - name: Checkout code - uses: actions/checkout@v4 - with: - ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}" - - - name: Set up Git user - run: | - git config user.name "Richard Kuo [bot]" - git config user.email "rkuo[bot]@danswer.ai" - - - name: Check for existing nightly tag - id: check_tag - run: | - if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then - echo "A tag starting with 'nightly-latest' already exists on HEAD." - echo "tag_exists=true" >> $GITHUB_OUTPUT - else - echo "No tag starting with 'nightly-latest' exists on HEAD." - echo "tag_exists=false" >> $GITHUB_OUTPUT - fi - - # don't tag again if HEAD already has a nightly-latest tag on it - - name: Create Nightly Tag - if: steps.check_tag.outputs.tag_exists == 'false' - env: - DATE: ${{ github.run_id }} - run: | - TAG_NAME="nightly-latest-$(date +'%Y%m%d')" - echo "Creating tag: $TAG_NAME" - git tag $TAG_NAME - - - name: Push Tag - if: steps.check_tag.outputs.tag_exists == 'false' - run: | - TAG_NAME="nightly-latest-$(date +'%Y%m%d')" - git push origin $TAG_NAME - \ No newline at end of file + # actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes + # see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we + # implement here which needs an actual user's deploy key + - name: Checkout code + uses: actions/checkout@v4 + with: + ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}" + + - name: Set up Git user + run: | + git config user.name "Richard Kuo [bot]" + git config user.email "rkuo[bot]@onyx.app" + + - name: Check for existing nightly tag + id: check_tag + run: | + if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then + echo "A tag starting with 'nightly-latest' already exists on HEAD." + echo "tag_exists=true" >> $GITHUB_OUTPUT + else + echo "No tag starting with 'nightly-latest' exists on HEAD." + echo "tag_exists=false" >> $GITHUB_OUTPUT + fi + + # don't tag again if HEAD already has a nightly-latest tag on it + - name: Create Nightly Tag + if: steps.check_tag.outputs.tag_exists == 'false' + env: + DATE: ${{ github.run_id }} + run: | + TAG_NAME="nightly-latest-$(date +'%Y%m%d')" + echo "Creating tag: $TAG_NAME" + git tag $TAG_NAME + + - name: Push Tag + if: steps.check_tag.outputs.tag_exists == 'false' + run: | + TAG_NAME="nightly-latest-$(date +'%Y%m%d')" + git push origin $TAG_NAME diff --git a/README.md b/README.md index 1f9fbef5b2f..0b7f87ceaa4 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,48 @@ - +

- +

-

Open Source Gen-AI Chat + Unified Search.

+

Open Source Gen-AI + Enterprise Search.

- + Documentation - + Slack Discord - + License

-[Danswer](https://www.danswer.ai/) is the AI Assistant connected to your company's docs, apps, and people. -Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any +[Onyx](https://www.onyx.app/) (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people. +Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your -own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready +own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for -configuring Personas (AI Assistants) and their Prompts. +configuring AI Assistants. -Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc. -By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if +Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc. +By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already supported?" or "Where's the pull request for feature Y?"

Usage

-Danswer Web App: +Onyx Web App: https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410 -Or, plug Danswer into your existing Slack workflows (more integrations to come 😁): +Or, plug Onyx into your existing Slack workflows (more integrations to come 😁): https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b @@ -52,16 +52,16 @@ For more details on the Admin UI to manage connectors and users, check out our ## Deployment -Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single -`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more. +Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single +`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more. -We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes). +We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes). ## 💃 Main Features * Chat UI with the ability to select documents to chat with. * Create custom AI Assistants with different prompts and backing knowledge sets. -* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution). +* Connect Onyx with LLM of your choice (self-host for a fully airgapped solution). * Document Search + AI Answers for natural language queries. * Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc. * Slack integration to get answers and search results directly in Slack. @@ -75,12 +75,12 @@ We also have built-in support for deployment on Kubernetes. Files for that can b * Organizational understanding and ability to locate and suggest experts from your team. -## Other Notable Benefits of Danswer +## Other Notable Benefits of Onyx * User Authentication with document level access management. * Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models). * Admin Dashboard to configure connectors, document-sets, access, etc. * Custom deep learning models + learn from user feedback. -* Easy deployment and ability to host Danswer anywhere of your choosing. +* Easy deployment and ability to host Onyx anywhere of your choosing. ## 🔌 Connectors @@ -108,10 +108,10 @@ Efficiently pulls the latest changes from: ## 📚 Editions -There are two editions of Danswer: +There are two editions of Onyx: - * Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above. - * Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes: + * Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above. + * Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes: * Single Sign-On (SSO), with support for both SAML and OIDC * Role-based access control * Document permission inheritance from connected sources @@ -119,24 +119,24 @@ There are two editions of Danswer: * Whitelabeling * API key authentication * Encryption of secrets - * Any many more! Checkout [our website](https://www.danswer.ai/) for the latest. + * Any many more! Checkout [our website](https://www.onyx.app/) for the latest. -To try the Danswer Enterprise Edition: +To try the Onyx Enterprise Edition: - 1. Checkout our [Cloud product](https://app.danswer.ai/signup). - 2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders). + 1. Checkout our [Cloud product](https://cloud.onyx.app/signup). + 2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders). ## 💡 Contributing Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details. ## ⭐Star History -[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date) +[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date) ## ✨Contributors - - contributors + + contributors

diff --git a/backend/Dockerfile b/backend/Dockerfile index 2f8de6e7996..d77b4e8737e 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -73,6 +73,7 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* && \ rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key + # Pre-downloading models for setups with limited egress RUN python -c "from tokenizers import Tokenizer; \ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')" diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 019ea94b836..6f9ecdbfced 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -1,5 +1,5 @@ from sqlalchemy.engine.base import Connection -from typing import Any +from typing import Literal import asyncio from logging.config import fileConfig import logging @@ -8,6 +8,7 @@ from sqlalchemy import pool from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.sql import text +from sqlalchemy.sql.schema import SchemaItem from shared_configs.configs import MULTI_TENANT from danswer.db.engine import build_connection_string @@ -35,7 +36,18 @@ def include_object( - object: Any, name: str, type_: str, reflected: bool, compare_to: Any + object: SchemaItem, + name: str | None, + type_: Literal[ + "schema", + "table", + "column", + "index", + "unique_constraint", + "foreign_key_constraint", + ], + reflected: bool, + compare_to: SchemaItem | None, ) -> bool: """ Determines whether a database object should be included in migrations. diff --git a/backend/alembic/versions/9f696734098f_combine_search_and_chat.py b/backend/alembic/versions/9f696734098f_combine_search_and_chat.py new file mode 100644 index 00000000000..65dbdece086 --- /dev/null +++ b/backend/alembic/versions/9f696734098f_combine_search_and_chat.py @@ -0,0 +1,36 @@ +"""Combine Search and Chat + +Revision ID: 9f696734098f +Revises: a8c2065484e6 +Create Date: 2024-11-27 15:32:19.694972 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "9f696734098f" +down_revision = "a8c2065484e6" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column("chat_session", "description", nullable=True) + op.drop_column("chat_session", "one_shot") + op.drop_column("slack_channel_config", "response_type") + + +def downgrade() -> None: + op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL") + op.alter_column("chat_session", "description", nullable=False) + op.add_column( + "chat_session", + sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()), + ) + op.add_column( + "slack_channel_config", + sa.Column( + "response_type", sa.String(), nullable=False, server_default="citations" + ), + ) diff --git a/backend/alembic/versions/a8c2065484e6_add_auto_scroll_to_user_model.py b/backend/alembic/versions/a8c2065484e6_add_auto_scroll_to_user_model.py new file mode 100644 index 00000000000..6cce89f3fb1 --- /dev/null +++ b/backend/alembic/versions/a8c2065484e6_add_auto_scroll_to_user_model.py @@ -0,0 +1,27 @@ +"""add auto scroll to user model + +Revision ID: a8c2065484e6 +Revises: abe7378b8217 +Create Date: 2024-11-22 17:34:09.690295 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "a8c2065484e6" +down_revision = "abe7378b8217" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "user", + sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None), + ) + + +def downgrade() -> None: + op.drop_column("user", "auto_scroll") diff --git a/backend/alembic/versions/bf7a81109301_delete_input_prompts.py b/backend/alembic/versions/bf7a81109301_delete_input_prompts.py new file mode 100644 index 00000000000..7aa3faf3277 --- /dev/null +++ b/backend/alembic/versions/bf7a81109301_delete_input_prompts.py @@ -0,0 +1,57 @@ +"""delete_input_prompts + +Revision ID: bf7a81109301 +Revises: f7a894b06d02 +Create Date: 2024-12-09 12:00:49.884228 + +""" +from alembic import op +import sqlalchemy as sa +import fastapi_users_db_sqlalchemy + + +# revision identifiers, used by Alembic. +revision = "bf7a81109301" +down_revision = "f7a894b06d02" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_table("inputprompt__user") + op.drop_table("inputprompt") + + +def downgrade() -> None: + op.create_table( + "inputprompt", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("prompt", sa.String(), nullable=False), + sa.Column("content", sa.String(), nullable=False), + sa.Column("active", sa.Boolean(), nullable=False), + sa.Column("is_public", sa.Boolean(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "inputprompt__user", + sa.Column("input_prompt_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["input_prompt_id"], + ["inputprompt.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["inputprompt.id"], + ), + sa.PrimaryKeyConstraint("input_prompt_id", "user_id"), + ) diff --git a/backend/alembic/versions/f7a894b06d02_non_nullbale_slack_bot_id_in_channel_.py b/backend/alembic/versions/f7a894b06d02_non_nullbale_slack_bot_id_in_channel_.py new file mode 100644 index 00000000000..370ab1f0d84 --- /dev/null +++ b/backend/alembic/versions/f7a894b06d02_non_nullbale_slack_bot_id_in_channel_.py @@ -0,0 +1,40 @@ +"""non-nullbale slack bot id in channel config + +Revision ID: f7a894b06d02 +Revises: 9f696734098f +Create Date: 2024-12-06 12:55:42.845723 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "f7a894b06d02" +down_revision = "9f696734098f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Delete all rows with null slack_bot_id + op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL") + + # Make slack_bot_id non-nullable + op.alter_column( + "slack_channel_config", + "slack_bot_id", + existing_type=sa.Integer(), + nullable=False, + ) + + +def downgrade() -> None: + # Make slack_bot_id nullable again + op.alter_column( + "slack_channel_config", + "slack_bot_id", + existing_type=sa.Integer(), + nullable=True, + ) diff --git a/backend/alembic_tenants/env.py b/backend/alembic_tenants/env.py index f0f1178ce09..506dbda0313 100644 --- a/backend/alembic_tenants/env.py +++ b/backend/alembic_tenants/env.py @@ -1,5 +1,6 @@ import asyncio from logging.config import fileConfig +from typing import Literal from sqlalchemy import pool from sqlalchemy.engine import Connection @@ -37,8 +38,15 @@ def include_object( object: SchemaItem, - name: str, - type_: str, + name: str | None, + type_: Literal[ + "schema", + "table", + "column", + "index", + "unique_constraint", + "foreign_key_constraint", + ], reflected: bool, compare_to: SchemaItem | None, ) -> bool: diff --git a/backend/danswer/access/models.py b/backend/danswer/access/models.py index 126648eb41e..11cf57638f2 100644 --- a/backend/danswer/access/models.py +++ b/backend/danswer/access/models.py @@ -18,6 +18,11 @@ class ExternalAccess: @dataclass(frozen=True) class DocExternalAccess: + """ + This is just a class to wrap the external access and the document ID + together. It's used for syncing document permissions to Redis. + """ + external_access: ExternalAccess # The document ID doc_id: str diff --git a/backend/danswer/auth/api_key.py b/backend/danswer/auth/api_key.py index aef557960f6..4931a9037ca 100644 --- a/backend/danswer/auth/api_key.py +++ b/backend/danswer/auth/api_key.py @@ -1,3 +1,4 @@ +import hashlib import secrets import uuid from urllib.parse import quote @@ -18,7 +19,8 @@ # organizations like the Internet Engineering Task Force (IETF). _API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization" _BEARER_PREFIX = "Bearer " -_API_KEY_PREFIX = "dn_" +_API_KEY_PREFIX = "on_" +_DEPRECATED_API_KEY_PREFIX = "dn_" _API_KEY_LEN = 192 @@ -52,7 +54,9 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None: api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip() - if not api_key.startswith(_API_KEY_PREFIX): + if not api_key.startswith(_API_KEY_PREFIX) and not api_key.startswith( + _DEPRECATED_API_KEY_PREFIX + ): return None parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1) @@ -63,10 +67,19 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None: return unquote(tenant_id) if tenant_id else None +def _deprecated_hash_api_key(api_key: str) -> str: + return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS) + + def hash_api_key(api_key: str) -> str: # NOTE: no salt is needed, as the API key is randomly generated # and overlaps are impossible - return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS) + if api_key.startswith(_API_KEY_PREFIX): + return hashlib.sha256(api_key.encode("utf-8")).hexdigest() + elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX): + return _deprecated_hash_api_key(api_key) + else: + raise ValueError(f"Invalid API key prefix: {api_key[:3]}") def build_displayable_api_key(api_key: str) -> str: diff --git a/backend/danswer/auth/invited_users.py b/backend/danswer/auth/invited_users.py index fb30332afd9..ff3a8cce95e 100644 --- a/backend/danswer/auth/invited_users.py +++ b/backend/danswer/auth/invited_users.py @@ -9,7 +9,6 @@ def get_invited_users() -> list[str]: try: store = get_kv_store() - return cast(list, store.load(KV_USER_STORE_KEY)) except KvKeyNotFoundError: return list() diff --git a/backend/danswer/auth/noauth_user.py b/backend/danswer/auth/noauth_user.py index 9eb589dbb25..c7e11cd452a 100644 --- a/backend/danswer/auth/noauth_user.py +++ b/backend/danswer/auth/noauth_user.py @@ -23,7 +23,9 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences: ) return UserPreferences(**preferences_data) except KvKeyNotFoundError: - return UserPreferences(chosen_assistants=None, default_model=None) + return UserPreferences( + chosen_assistants=None, default_model=None, auto_scroll=True + ) def fetch_no_auth_user(store: KeyValueStore) -> UserInfo: diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index ca7d69b24be..b205dee2b11 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -58,7 +58,6 @@ from danswer.auth.schemas import UserUpdate from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_AUTH -from danswer.configs.app_configs import DISABLE_VERIFICATION from danswer.configs.app_configs import EMAIL_FROM from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS @@ -87,6 +86,7 @@ from danswer.db.models import OAuthAccount from danswer.db.models import User from danswer.db.users import get_user_by_email +from danswer.server.utils import BasicAuthenticationError from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType @@ -100,11 +100,6 @@ logger = setup_logger() -class BasicAuthenticationError(HTTPException): - def __init__(self, detail: str): - super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail) - - def is_user_admin(user: User | None) -> bool: if AUTH_TYPE == AuthType.DISABLED: return True @@ -137,11 +132,12 @@ def get_display_email(email: str | None, space_less: bool = False) -> str: def user_needs_to_be_verified() -> bool: - # all other auth types besides basic should require users to be - # verified - return not DISABLE_VERIFICATION and ( - AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION - ) + if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD: + return REQUIRE_EMAIL_VERIFICATION + + # For other auth types, if the user is authenticated it's assumed that + # the user is already verified via the external IDP + return False def verify_email_is_invited(email: str) -> None: diff --git a/backend/danswer/background/celery/apps/app_base.py b/backend/danswer/background/celery/apps/app_base.py index d041ce0d2bc..a92f0c742bd 100644 --- a/backend/danswer/background/celery/apps/app_base.py +++ b/backend/danswer/background/celery/apps/app_base.py @@ -11,6 +11,7 @@ from celery.states import READY_STATES from celery.utils.log import get_task_logger from celery.worker import strategy # type: ignore +from redis.lock import Lock as RedisLock from sentry_sdk.integrations.celery import CeleryIntegration from sqlalchemy import text from sqlalchemy.orm import Session @@ -332,16 +333,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: return logger.info("Releasing primary worker lock.") - lock = sender.primary_worker_lock + lock: RedisLock = sender.primary_worker_lock try: if lock.owned(): try: lock.release() sender.primary_worker_lock = None - except Exception as e: - logger.error(f"Failed to release primary worker lock: {e}") - except Exception as e: - logger.error(f"Failed to check if primary worker lock is owned: {e}") + except Exception: + logger.exception("Failed to release primary worker lock") + except Exception: + logger.exception("Failed to check if primary worker lock is owned") def on_setup_logging( diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index 69c398b0394..1c54e7b0903 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -11,6 +11,7 @@ from celery.signals import worker_init from celery.signals import worker_ready from celery.signals import worker_shutdown +from redis.lock import Lock as RedisLock import danswer.background.celery.apps.app_base as app_base from danswer.background.celery.apps.app_base import task_logger @@ -39,7 +40,6 @@ from danswer.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT - logger = setup_logger() celery_app = Celery(__name__) @@ -117,9 +117,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: # it is planned to use this lock to enforce singleton behavior on the primary # worker, since the primary worker does redis cleanup on startup, but this isn't # implemented yet. - lock = r.lock( + + # set thread_local=False since we don't control what thread the periodic task might + # reacquire the lock with + lock: RedisLock = r.lock( DanswerRedisLocks.PRIMARY_WORKER, timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + thread_local=False, ) logger.info("Primary worker lock: Acquire starting.") @@ -228,7 +232,7 @@ def run_periodic_task(self, worker: Any) -> None: if not hasattr(worker, "primary_worker_lock"): return - lock = worker.primary_worker_lock + lock: RedisLock = worker.primary_worker_lock r = get_redis_client(tenant_id=None) diff --git a/backend/danswer/background/celery/tasks/beat_schedule.py b/backend/danswer/background/celery/tasks/beat_schedule.py index 3b18f8931e4..6d65bb01654 100644 --- a/backend/danswer/background/celery/tasks/beat_schedule.py +++ b/backend/danswer/background/celery/tasks/beat_schedule.py @@ -2,54 +2,55 @@ from typing import Any from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryTask tasks_to_schedule = [ { "name": "check-for-vespa-sync", - "task": "check_for_vespa_sync_task", + "task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK, "schedule": timedelta(seconds=20), "options": {"priority": DanswerCeleryPriority.HIGH}, }, { "name": "check-for-connector-deletion", - "task": "check_for_connector_deletion_task", + "task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION, "schedule": timedelta(seconds=20), "options": {"priority": DanswerCeleryPriority.HIGH}, }, { "name": "check-for-indexing", - "task": "check_for_indexing", + "task": DanswerCeleryTask.CHECK_FOR_INDEXING, "schedule": timedelta(seconds=15), "options": {"priority": DanswerCeleryPriority.HIGH}, }, { "name": "check-for-prune", - "task": "check_for_pruning", + "task": DanswerCeleryTask.CHECK_FOR_PRUNING, "schedule": timedelta(seconds=15), "options": {"priority": DanswerCeleryPriority.HIGH}, }, { "name": "kombu-message-cleanup", - "task": "kombu_message_cleanup_task", + "task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK, "schedule": timedelta(seconds=3600), "options": {"priority": DanswerCeleryPriority.LOWEST}, }, { "name": "monitor-vespa-sync", - "task": "monitor_vespa_sync", + "task": DanswerCeleryTask.MONITOR_VESPA_SYNC, "schedule": timedelta(seconds=5), "options": {"priority": DanswerCeleryPriority.HIGH}, }, { "name": "check-for-doc-permissions-sync", - "task": "check_for_doc_permissions_sync", + "task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC, "schedule": timedelta(seconds=30), "options": {"priority": DanswerCeleryPriority.HIGH}, }, { "name": "check-for-external-group-sync", - "task": "check_for_external_group_sync", + "task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC, "schedule": timedelta(seconds=20), "options": {"priority": DanswerCeleryPriority.HIGH}, }, diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index 9413dd97854..d0298f2dd6a 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -5,13 +5,13 @@ from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded -from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import DanswerRedisLocks from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.connector_credential_pair import get_connector_credential_pairs @@ -29,7 +29,7 @@ class TaskDependencyError(RuntimeError): @shared_task( - name="check_for_connector_deletion_task", + name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION, soft_time_limit=JOB_TIMEOUT, trail=False, bind=True, @@ -37,7 +37,7 @@ class TaskDependencyError(RuntimeError): def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None: r = get_redis_client(tenant_id=tenant_id) - lock_beat = r.lock( + lock_beat: RedisLock = r.lock( DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) @@ -60,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N redis_connector = RedisConnector(tenant_id, cc_pair_id) try: try_generate_document_cc_pair_cleanup_tasks( - self.app, cc_pair_id, db_session, r, lock_beat, tenant_id + self.app, cc_pair_id, db_session, lock_beat, tenant_id ) except TaskDependencyError as e: # this means we wanted to start deleting but dependent tasks were running @@ -86,7 +86,6 @@ def try_generate_document_cc_pair_cleanup_tasks( app: Celery, cc_pair_id: int, db_session: Session, - r: Redis, lock_beat: RedisLock, tenant_id: str | None, ) -> int | None: diff --git a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py index 6a5761a7428..e95fbac05d6 100644 --- a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py @@ -8,6 +8,7 @@ from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.lock import Lock as RedisLock from danswer.access.models import DocExternalAccess from danswer.background.celery.apps.app_base import task_logger @@ -17,9 +18,11 @@ from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import DocumentSource from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.document import upsert_document_by_connector_credential_pair from danswer.db.engine import get_session_with_tenant from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus @@ -27,7 +30,7 @@ from danswer.db.users import batch_add_ext_perm_user_if_not_exists from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_connector_doc_perm_sync import ( - RedisConnectorPermissionSyncData, + RedisConnectorPermissionSyncPayload, ) from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import doc_permission_sync_ctx @@ -81,7 +84,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b @shared_task( - name="check_for_doc_permissions_sync", + name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC, soft_time_limit=JOB_TIMEOUT, bind=True, ) @@ -138,7 +141,7 @@ def try_creating_permissions_sync_task( LOCK_TIMEOUT = 30 - lock = r.lock( + lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks", timeout=LOCK_TIMEOUT, ) @@ -162,8 +165,8 @@ def try_creating_permissions_sync_task( custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}" - app.send_task( - "connector_permission_sync_generator_task", + result = app.send_task( + DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK, kwargs=dict( cc_pair_id=cc_pair_id, tenant_id=tenant_id, @@ -174,8 +177,8 @@ def try_creating_permissions_sync_task( ) # set a basic fence to start - payload = RedisConnectorPermissionSyncData( - started=None, + payload = RedisConnectorPermissionSyncPayload( + started=None, celery_task_id=result.id ) redis_connector.permissions.set_fence(payload) @@ -190,7 +193,7 @@ def try_creating_permissions_sync_task( @shared_task( - name="connector_permission_sync_generator_task", + name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK, acks_late=False, soft_time_limit=JOB_TIMEOUT, track_started=True, @@ -216,7 +219,7 @@ def connector_permission_sync_generator_task( r = get_redis_client(tenant_id=tenant_id) - lock = r.lock( + lock: RedisLock = r.lock( DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX + f"_{redis_connector.id}", timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT, @@ -241,13 +244,17 @@ def connector_permission_sync_generator_task( doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) if doc_sync_func is None: - raise ValueError(f"No doc sync func found for {source_type}") + raise ValueError( + f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}" + ) - logger.info(f"Syncing docs for {source_type}") + logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}") - payload = RedisConnectorPermissionSyncData( - started=datetime.now(timezone.utc), - ) + payload = redis_connector.permissions.payload + if not payload: + raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}") + + payload.started = datetime.now(timezone.utc) redis_connector.permissions.set_fence(payload) document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair) @@ -256,7 +263,12 @@ def connector_permission_sync_generator_task( f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}" ) tasks_generated = redis_connector.permissions.generate_tasks( - self.app, lock, document_external_accesses, source_type + celery_app=self.app, + lock=lock, + new_permissions=document_external_accesses, + source_string=source_type, + connector_id=cc_pair.connector.id, + credential_id=cc_pair.credential.id, ) if tasks_generated is None: return None @@ -281,7 +293,7 @@ def connector_permission_sync_generator_task( @shared_task( - name="update_external_document_permissions_task", + name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK, soft_time_limit=LIGHT_SOFT_TIME_LIMIT, time_limit=LIGHT_TIME_LIMIT, max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES, @@ -292,6 +304,8 @@ def update_external_document_permissions_task( tenant_id: str | None, serialized_doc_external_access: dict, source_string: str, + connector_id: int, + credential_id: int, ) -> bool: document_external_access = DocExternalAccess.from_dict( serialized_doc_external_access @@ -300,18 +314,28 @@ def update_external_document_permissions_task( external_access = document_external_access.external_access try: with get_session_with_tenant(tenant_id) as db_session: - # Then we build the update requests to update vespa + # Add the users to the DB if they don't exist batch_add_ext_perm_user_if_not_exists( db_session=db_session, emails=list(external_access.external_user_emails), ) - upsert_document_external_perms( + # Then we upsert the document's external permissions in postgres + created_new_doc = upsert_document_external_perms( db_session=db_session, doc_id=doc_id, external_access=external_access, source_type=DocumentSource(source_string), ) + if created_new_doc: + # If a new document was created, we associate it with the cc_pair + upsert_document_by_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + document_ids=[doc_id], + ) + logger.debug( f"Successfully synced postgres document permissions for {doc_id}" ) diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py index 61ceae4e463..c8ac5f870d3 100644 --- a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py @@ -8,6 +8,7 @@ from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.lock import Lock as RedisLock from danswer.background.celery.apps.app_base import task_logger from danswer.configs.app_configs import JOB_TIMEOUT @@ -16,6 +17,7 @@ from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import DanswerRedisLocks from danswer.db.connector import mark_cc_pair_as_external_group_synced from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id @@ -24,13 +26,20 @@ from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.redis.redis_connector import RedisConnector +from danswer.redis.redis_connector_ext_group_sync import ( + RedisConnectorExternalGroupSyncPayload, +) from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs +from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source from ee.danswer.db.external_perm import ExternalUserGroup from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP +from ee.danswer.external_permissions.sync_params import ( + GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC, +) logger = setup_logger() @@ -49,7 +58,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool: if cc_pair.access_type != AccessType.SYNC: return False - # skip pruning if not active + # skip external group sync if not active if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: return False @@ -81,7 +90,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool: @shared_task( - name="check_for_external_group_sync", + name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC, soft_time_limit=JOB_TIMEOUT, bind=True, ) @@ -102,12 +111,28 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None: with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) + # We only want to sync one cc_pair per source type in + # GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC + for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: + # These are ordered by cc_pair id so the first one is the one we want + cc_pairs_to_dedupe = get_cc_pairs_by_source( + db_session, source, only_sync=True + ) + # We only want to sync one cc_pair per source type + # in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here + for cc_pair_to_remove in cc_pairs_to_dedupe[1:]: + cc_pairs = [ + cc_pair + for cc_pair in cc_pairs + if cc_pair.id != cc_pair_to_remove.id + ] + for cc_pair in cc_pairs: if _is_external_group_sync_due(cc_pair): cc_pair_ids_to_sync.append(cc_pair.id) for cc_pair_id in cc_pair_ids_to_sync: - tasks_created = try_creating_permissions_sync_task( + tasks_created = try_creating_external_group_sync_task( self.app, cc_pair_id, r, tenant_id ) if not tasks_created: @@ -125,7 +150,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None: lock_beat.release() -def try_creating_permissions_sync_task( +def try_creating_external_group_sync_task( app: Celery, cc_pair_id: int, r: Redis, @@ -156,8 +181,8 @@ def try_creating_permissions_sync_task( custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}" - _ = app.send_task( - "connector_external_group_sync_generator_task", + result = app.send_task( + DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK, kwargs=dict( cc_pair_id=cc_pair_id, tenant_id=tenant_id, @@ -166,8 +191,13 @@ def try_creating_permissions_sync_task( task_id=custom_task_id, priority=DanswerCeleryPriority.HIGH, ) - # set a basic fence to start - redis_connector.external_group_sync.set_fence(True) + + payload = RedisConnectorExternalGroupSyncPayload( + started=datetime.now(timezone.utc), + celery_task_id=result.id, + ) + + redis_connector.external_group_sync.set_fence(payload) except Exception: task_logger.exception( @@ -182,7 +212,7 @@ def try_creating_permissions_sync_task( @shared_task( - name="connector_external_group_sync_generator_task", + name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK, acks_late=False, soft_time_limit=JOB_TIMEOUT, track_started=True, @@ -195,7 +225,7 @@ def connector_external_group_sync_generator_task( tenant_id: str | None, ) -> None: """ - Permission sync task that handles document permission syncing for a given connector credential pair + Permission sync task that handles external group syncing for a given connector credential pair This task assumes that the task has already been properly fenced """ @@ -203,7 +233,7 @@ def connector_external_group_sync_generator_task( r = get_redis_client(tenant_id=tenant_id) - lock = r.lock( + lock: RedisLock = r.lock( DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX + f"_{redis_connector.id}", timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT, @@ -228,9 +258,13 @@ def connector_external_group_sync_generator_task( ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type) if ext_group_sync_func is None: - raise ValueError(f"No external group sync func found for {source_type}") + raise ValueError( + f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}" + ) - logger.info(f"Syncing docs for {source_type}") + logger.info( + f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}" + ) external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair) @@ -249,7 +283,6 @@ def connector_external_group_sync_generator_task( ) mark_cc_pair_as_external_group_synced(db_session, cc_pair.id) - except Exception as e: task_logger.exception( f"Failed to run external group sync: cc_pair={cc_pair_id}" @@ -260,6 +293,6 @@ def connector_external_group_sync_generator_task( raise e finally: # we always want to clear the fence after the task is done or failed so it doesn't get stuck - redis_connector.external_group_sync.set_fence(False) + redis_connector.external_group_sync.set_fence(None) if lock.owned(): lock.release() diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index dc4f89f5a31..e4b05f9479d 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -22,6 +22,7 @@ from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import DocumentSource from danswer.db.connector import mark_ccpair_with_indexing_trigger @@ -38,12 +39,13 @@ from danswer.db.index_attempt import get_all_index_attempts_by_status from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_last_attempt_for_cc_pair +from danswer.db.index_attempt import mark_attempt_canceled from danswer.db.index_attempt import mark_attempt_failed from danswer.db.models import ConnectorCredentialPair from danswer.db.models import IndexAttempt from danswer.db.models import SearchSettings +from danswer.db.search_settings import get_active_search_settings from danswer.db.search_settings import get_current_search_settings -from danswer.db.search_settings import get_secondary_search_settings from danswer.db.swap_index import check_index_swap from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.natural_language_processing.search_nlp_models import EmbeddingModel @@ -154,7 +156,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: Redis) -> list[int]: @shared_task( - name="check_for_indexing", + name=DanswerCeleryTask.CHECK_FOR_INDEXING, soft_time_limit=300, bind=True, ) @@ -208,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: redis_connector = RedisConnector(tenant_id, cc_pair_id) with get_session_with_tenant(tenant_id) as db_session: - # Get the primary search settings - primary_search_settings = get_current_search_settings(db_session) - search_settings = [primary_search_settings] - - # Check for secondary search settings - secondary_search_settings = get_secondary_search_settings(db_session) - if secondary_search_settings is not None: - # If secondary settings exist, add them to the list - search_settings.append(secondary_search_settings) - - for search_settings_instance in search_settings: + search_settings_list: list[SearchSettings] = get_active_search_settings( + db_session + ) + for search_settings_instance in search_settings_list: redis_connector_index = redis_connector.new_index( search_settings_instance.id ) @@ -236,7 +231,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: ) search_settings_primary = False - if search_settings_instance.id == primary_search_settings.id: + if search_settings_instance.id == search_settings_list[0].id: search_settings_primary = True if not _should_index( @@ -244,13 +239,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: last_index=last_attempt, search_settings_instance=search_settings_instance, search_settings_primary=search_settings_primary, - secondary_index_building=len(search_settings) > 1, + secondary_index_building=len(search_settings_list) > 1, db_session=db_session, ): continue reindex = False - if search_settings_instance.id == primary_search_settings.id: + if search_settings_instance.id == search_settings_list[0].id: # the indexing trigger is only checked and cleared with the primary search settings if cc_pair.indexing_trigger is not None: if cc_pair.indexing_trigger == IndexingMode.REINDEX: @@ -283,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: f"Connector indexing queued: " f"index_attempt={attempt_id} " f"cc_pair={cc_pair.id} " - f"search_settings={search_settings_instance.id} " + f"search_settings={search_settings_instance.id}" ) tasks_created += 1 @@ -491,7 +486,7 @@ def try_creating_indexing_task( # when the task is sent, we have yet to finish setting up the fence # therefore, the task must contain code that blocks until the fence is ready result = celery_app.send_task( - "connector_indexing_proxy_task", + DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK, kwargs=dict( index_attempt_id=index_attempt_id, cc_pair_id=cc_pair.id, @@ -528,8 +523,14 @@ def try_creating_indexing_task( return index_attempt_id -@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True) +@shared_task( + name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK, + bind=True, + acks_late=False, + track_started=True, +) def connector_indexing_proxy_task( + self: Task, index_attempt_id: int, cc_pair_id: int, search_settings_id: int, @@ -542,6 +543,10 @@ def connector_indexing_proxy_task( f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" ) + + if not self.request.id: + task_logger.error("self.request.id is None!") + client = SimpleJobClient() job = client.submit( @@ -570,29 +575,80 @@ def connector_indexing_proxy_task( f"search_settings={search_settings_id}" ) + redis_connector = RedisConnector(tenant_id, cc_pair_id) + redis_connector_index = redis_connector.new_index(search_settings_id) + while True: - sleep(10) + sleep(5) - # do nothing for ongoing jobs that haven't been stopped - if not job.done(): - with get_session_with_tenant(tenant_id) as db_session: - index_attempt = get_index_attempt( - db_session=db_session, index_attempt_id=index_attempt_id + if self.request.id and redis_connector_index.terminating(self.request.id): + task_logger.warning( + "Indexing watchdog - termination signal detected: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + + try: + with get_session_with_tenant(tenant_id) as db_session: + mark_attempt_canceled( + index_attempt_id, + db_session, + "Connector termination signal detected", + ) + except Exception: + # if the DB exceptions, we'll just get an unfriendly failure message + # in the UI instead of the cancellation message + logger.exception( + "Indexing watchdog - transient exception marking index attempt as canceled: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" ) - if not index_attempt: - continue + job.cancel() - if not index_attempt.is_finished(): - continue + break + + if not job.done(): + # if the spawned task is still running, restart the check once again + # if the index attempt is not in a finished status + try: + with get_session_with_tenant(tenant_id) as db_session: + index_attempt = get_index_attempt( + db_session=db_session, index_attempt_id=index_attempt_id + ) + + if not index_attempt: + continue + + if not index_attempt.is_finished(): + continue + except Exception: + # if the DB exceptioned, just restart the check. + # polling the index attempt status doesn't need to be strongly consistent + logger.exception( + "Indexing watchdog - transient exception looking up index attempt: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + continue if job.status == "error": + exit_code: int | None = None + if job.process: + exit_code = job.process.exitcode task_logger.error( - f"Indexing watchdog - spawned task exceptioned: " + "Indexing watchdog - spawned task exceptioned: " f"attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " + f"exit_code={exit_code} " f"error={job.exception()}" ) @@ -736,9 +792,12 @@ def connector_indexing_task( ) break + # set thread_local=False since we don't control what thread the indexing/pruning + # might run our callback with lock: RedisLock = r.lock( redis_connector_index.generator_lock_key, timeout=CELERY_INDEXING_LOCK_TIMEOUT, + thread_local=False, ) acquired = lock.acquire(blocking=False) diff --git a/backend/danswer/background/celery/tasks/periodic/tasks.py b/backend/danswer/background/celery/tasks/periodic/tasks.py index 20baa7c52fa..efef013f5e4 100644 --- a/backend/danswer/background/celery/tasks/periodic/tasks.py +++ b/backend/danswer/background/celery/tasks/periodic/tasks.py @@ -13,12 +13,13 @@ from danswer.background.celery.apps.app_base import task_logger from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import PostgresAdvisoryLocks from danswer.db.engine import get_session_with_tenant @shared_task( - name="kombu_message_cleanup_task", + name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK, soft_time_limit=JOB_TIMEOUT, bind=True, base=AbortableTask, diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 67b781f228f..5497f1211a3 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -8,6 +8,7 @@ from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger @@ -20,6 +21,7 @@ from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import DanswerRedisLocks from danswer.connectors.factory import instantiate_connector from danswer.connectors.models import InputType @@ -75,7 +77,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool: @shared_task( - name="check_for_pruning", + name=DanswerCeleryTask.CHECK_FOR_PRUNING, soft_time_limit=JOB_TIMEOUT, bind=True, ) @@ -184,7 +186,7 @@ def try_creating_prune_generator_task( custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}" celery_app.send_task( - "connector_pruning_generator_task", + DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK, kwargs=dict( cc_pair_id=cc_pair.id, connector_id=cc_pair.connector_id, @@ -209,7 +211,7 @@ def try_creating_prune_generator_task( @shared_task( - name="connector_pruning_generator_task", + name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK, acks_late=False, soft_time_limit=JOB_TIMEOUT, track_started=True, @@ -238,9 +240,12 @@ def connector_pruning_generator_task( r = get_redis_client(tenant_id=tenant_id) - lock = r.lock( + # set thread_local=False since we don't control what thread the indexing/pruning + # might run our callback with + lock: RedisLock = r.lock( DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}", timeout=CELERY_PRUNING_LOCK_TIMEOUT, + thread_local=False, ) acquired = lock.acquire(blocking=False) diff --git a/backend/danswer/background/celery/tasks/shared/tasks.py b/backend/danswer/background/celery/tasks/shared/tasks.py index 2719a4d0665..2212046c3e9 100644 --- a/backend/danswer/background/celery/tasks/shared/tasks.py +++ b/backend/danswer/background/celery/tasks/shared/tasks.py @@ -9,6 +9,7 @@ from danswer.access.access import get_access_for_document from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex +from danswer.configs.constants import DanswerCeleryTask from danswer.db.document import delete_document_by_connector_credential_pair__no_commit from danswer.db.document import delete_documents_complete__no_commit from danswer.db.document import get_document @@ -31,7 +32,7 @@ @shared_task( - name="document_by_cc_pair_cleanup_task", + name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK, soft_time_limit=LIGHT_SOFT_TIME_LIMIT, time_limit=LIGHT_TIME_LIMIT, max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES, diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 08c1b6b8bc5..f43b4b56620 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -25,6 +25,7 @@ from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import DanswerRedisLocks from danswer.db.connector import fetch_connector_by_id from danswer.db.connector import mark_cc_pair_as_permissions_synced @@ -46,6 +47,7 @@ from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced from danswer.db.engine import get_session_with_tenant +from danswer.db.enums import IndexingStatus from danswer.db.index_attempt import delete_index_attempts from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import mark_attempt_failed @@ -58,7 +60,7 @@ from danswer.redis.redis_connector_delete import RedisConnectorDelete from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from danswer.redis.redis_connector_doc_perm_sync import ( - RedisConnectorPermissionSyncData, + RedisConnectorPermissionSyncPayload, ) from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune @@ -79,7 +81,7 @@ # celery auto associates tasks created inside another task, # which bloats the result metadata considerably. trail=False prevents this. @shared_task( - name="check_for_vespa_sync_task", + name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK, soft_time_limit=JOB_TIMEOUT, trail=False, bind=True, @@ -588,7 +590,7 @@ def monitor_ccpair_permissions_taskset( if remaining > 0: return - payload: RedisConnectorPermissionSyncData | None = ( + payload: RedisConnectorPermissionSyncPayload | None = ( redis_connector.permissions.payload ) start_time: datetime | None = payload.started if payload else None @@ -596,9 +598,7 @@ def monitor_ccpair_permissions_taskset( mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time) task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}") - redis_connector.permissions.taskset_clear() - redis_connector.permissions.generator_clear() - redis_connector.permissions.set_fence(None) + redis_connector.permissions.reset() def monitor_ccpair_indexing_taskset( @@ -655,33 +655,52 @@ def monitor_ccpair_indexing_taskset( # outer = result.state in READY state status_int = redis_connector_index.get_completion() if status_int is None: # inner signal not set ... possible error - result_state = result.state + task_state = result.state if ( - result_state in READY_STATES + task_state in READY_STATES ): # outer signal in terminal state ... possible error # Now double check! if redis_connector_index.get_completion() is None: # inner signal still not set (and cannot change when outer result_state is READY) # Task is finished but generator complete isn't set. # We have a problem! Worker may have crashed. + task_result = str(result.result) + task_traceback = str(result.traceback) msg = ( f"Connector indexing aborted or exceptioned: " f"attempt={payload.index_attempt_id} " f"celery_task={payload.celery_task_id} " - f"result_state={result_state} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " - f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}" + f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} " + f"result.state={task_state} " + f"result.result={task_result} " + f"result.traceback={task_traceback}" ) task_logger.warning(msg) - index_attempt = get_index_attempt(db_session, payload.index_attempt_id) - if index_attempt: - mark_attempt_failed( - index_attempt_id=payload.index_attempt_id, - db_session=db_session, - failure_reason=msg, + try: + index_attempt = get_index_attempt( + db_session, payload.index_attempt_id + ) + if index_attempt: + if ( + index_attempt.status != IndexingStatus.CANCELED + and index_attempt.status != IndexingStatus.FAILED + ): + mark_attempt_failed( + index_attempt_id=payload.index_attempt_id, + db_session=db_session, + failure_reason=msg, + ) + except Exception: + task_logger.exception( + "monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: " + f"attempt={payload.index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" ) redis_connector_index.reset() @@ -692,6 +711,7 @@ def monitor_ccpair_indexing_taskset( task_logger.info( f"Connector indexing finished: cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " + f"progress={progress} " f"status={status_enum.name} " f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}" ) @@ -699,38 +719,8 @@ def monitor_ccpair_indexing_taskset( redis_connector_index.reset() -# def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]: -# """Gets a list of unfenced index attempts. Should not be possible, so we'd typically -# want to clean them up. - -# Unfenced = attempt not in terminal state and fence does not exist. -# """ -# unfenced_attempts: list[int] = [] - -# # do some cleanup before clearing fences -# # check the db for any outstanding index attempts -# attempts: list[IndexAttempt] = [] -# attempts.extend( -# get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session) -# ) -# attempts.extend( -# get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session) -# ) - -# for attempt in attempts: -# # if attempts exist in the db but we don't detect them in redis, mark them as failed -# fence_key = RedisConnectorIndex.fence_key_with_ids( -# attempt.connector_credential_pair_id, attempt.search_settings_id -# ) -# if r.exists(fence_key): -# continue - -# unfenced_attempts.append(attempt.id) - -# return unfenced_attempts - -@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True) +@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True) def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: """This is a celery beat task that monitors and finalizes metadata sync tasksets. It scans for fence values and then gets the counts of any associated tasksets. @@ -755,7 +745,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: # print current queue lengths r_celery = self.app.broker_connection().channel().client # type: ignore - n_celery = celery_get_queue_length("celery", r) + n_celery = celery_get_queue_length("celery", r_celery) n_indexing = celery_get_queue_length( DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery ) @@ -841,7 +831,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: @shared_task( - name="vespa_metadata_sync_task", + name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK, bind=True, soft_time_limit=LIGHT_SOFT_TIME_LIMIT, time_limit=LIGHT_TIME_LIMIT, diff --git a/backend/danswer/background/indexing/job_client.py b/backend/danswer/background/indexing/job_client.py index 602ec4294c0..a31b7b3119e 100644 --- a/backend/danswer/background/indexing/job_client.py +++ b/backend/danswer/background/indexing/job_client.py @@ -82,7 +82,7 @@ def status(self) -> JobStatusType: return "running" elif self.process.exitcode is None: return "cancelled" - elif self.process.exitcode > 0: + elif self.process.exitcode != 0: return "error" else: return "finished" @@ -123,7 +123,8 @@ def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | N self._cleanup_completed_jobs() if len(self.jobs) >= self.n_workers: logger.debug( - f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'." + f"No available workers to run job. " + f"Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'." ) return None diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 699e4682caa..40ed778f033 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pair from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.index_attempt import mark_attempt_canceled from danswer.db.index_attempt import mark_attempt_failed from danswer.db.index_attempt import mark_attempt_partially_succeeded from danswer.db.index_attempt import mark_attempt_succeeded @@ -87,6 +88,10 @@ def _get_connector_runner( ) +class ConnectorStopSignal(Exception): + """A custom exception used to signal a stop in processing.""" + + def _run_indexing( db_session: Session, index_attempt: IndexAttempt, @@ -208,9 +213,7 @@ def _run_indexing( # contents still need to be initially pulled. if callback: if callback.should_stop(): - raise RuntimeError( - "_run_indexing: Connector stop signal detected" - ) + raise ConnectorStopSignal("Connector stop signal detected") # TODO: should we move this into the above callback instead? db_session.refresh(db_cc_pair) @@ -304,26 +307,16 @@ def _run_indexing( ) except Exception as e: logger.exception( - f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds" + f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds" ) - # Only mark the attempt as a complete failure if this is the first indexing window. - # Otherwise, some progress was made - the next run will not start from the beginning. - # In this case, it is not accurate to mark it as a failure. When the next run begins, - # if that fails immediately, it will be marked as a failure. - # - # NOTE: if the connector is manually disabled, we should mark it as a failure regardless - # to give better clarity in the UI, as the next run will never happen. - if ( - ind == 0 - or not db_cc_pair.status.is_active() - or index_attempt.status != IndexingStatus.IN_PROGRESS - ): - mark_attempt_failed( + + if isinstance(e, ConnectorStopSignal): + mark_attempt_canceled( index_attempt.id, db_session, - failure_reason=str(e), - full_exception_trace=traceback.format_exc(), + reason=str(e), ) + if is_primary: update_connector_credential_pair( db_session=db_session, @@ -335,6 +328,37 @@ def _run_indexing( if INDEXING_TRACER_INTERVAL > 0: tracer.stop() raise e + else: + # Only mark the attempt as a complete failure if this is the first indexing window. + # Otherwise, some progress was made - the next run will not start from the beginning. + # In this case, it is not accurate to mark it as a failure. When the next run begins, + # if that fails immediately, it will be marked as a failure. + # + # NOTE: if the connector is manually disabled, we should mark it as a failure regardless + # to give better clarity in the UI, as the next run will never happen. + if ( + ind == 0 + or not db_cc_pair.status.is_active() + or index_attempt.status != IndexingStatus.IN_PROGRESS + ): + mark_attempt_failed( + index_attempt.id, + db_session, + failure_reason=str(e), + full_exception_trace=traceback.format_exc(), + ) + + if is_primary: + update_connector_credential_pair( + db_session=db_session, + connector_id=db_connector.id, + credential_id=db_credential.id, + net_docs=net_doc_change, + ) + + if INDEXING_TRACER_INTERVAL > 0: + tracer.stop() + raise e # break => similar to success case. As mentioned above, if the next run fails for the same # reason it will then be marked as a failure diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/chat/answer.py similarity index 86% rename from backend/danswer/llm/answering/answer.py rename to backend/danswer/chat/answer.py index 466a953ee2b..50aec821bd8 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/chat/answer.py @@ -6,33 +6,27 @@ from langchain_core.messages import AIMessageChunk from langchain_core.messages import ToolCall +from danswer.chat.llm_response_handler import LLMResponseHandlerManager from danswer.chat.models import AnswerQuestionPossibleReturn +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece -from danswer.file_store.utils import InMemoryChatFile -from danswer.llm.answering.llm_response_handler import LLMCall -from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder -from danswer.llm.answering.prompts.build import default_build_system_message -from danswer.llm.answering.prompts.build import default_build_user_message -from danswer.llm.answering.stream_processing.answer_response_handler import ( - AnswerResponseHandler, -) -from danswer.llm.answering.stream_processing.answer_response_handler import ( +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.build import AnswerPromptBuilder +from danswer.chat.prompt_builder.build import default_build_system_message +from danswer.chat.prompt_builder.build import default_build_user_message +from danswer.chat.prompt_builder.build import LLMCall +from danswer.chat.stream_processing.answer_response_handler import ( CitationResponseHandler, ) -from danswer.llm.answering.stream_processing.answer_response_handler import ( +from danswer.chat.stream_processing.answer_response_handler import ( DummyAnswerResponseHandler, ) -from danswer.llm.answering.stream_processing.answer_response_handler import ( - QuotesResponseHandler, -) -from danswer.llm.answering.stream_processing.utils import map_document_id_order -from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler +from danswer.chat.stream_processing.utils import map_document_id_order +from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler +from danswer.file_store.utils import InMemoryChatFile from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.natural_language_processing.utils import get_tokenizer from danswer.tools.force import ForceUseTool from danswer.tools.models import ToolResponse @@ -213,20 +207,28 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: # + figure out what the next LLM call should be tool_call_handler = ToolResponseHandler(current_llm_call.tools) - search_result = SearchTool.get_search_result(current_llm_call) or [] - - answer_handler: AnswerResponseHandler - if self.answer_style_config.citation_config: - answer_handler = CitationResponseHandler( - context_docs=search_result, - doc_id_to_rank_map=map_document_id_order(search_result), - ) - elif self.answer_style_config.quotes_config: - answer_handler = QuotesResponseHandler( - context_docs=search_result, - ) - else: - raise ValueError("No answer style config provided") + search_result, displayed_search_results_map = SearchTool.get_search_result( + current_llm_call + ) or ([], {}) + + # Quotes are no longer supported + # answer_handler: AnswerResponseHandler + # if self.answer_style_config.citation_config: + # answer_handler = CitationResponseHandler( + # context_docs=search_result, + # doc_id_to_rank_map=map_document_id_order(search_result), + # ) + # elif self.answer_style_config.quotes_config: + # answer_handler = QuotesResponseHandler( + # context_docs=search_result, + # ) + # else: + # raise ValueError("No answer style config provided") + answer_handler = CitationResponseHandler( + context_docs=search_result, + doc_id_to_rank_map=map_document_id_order(search_result), + display_doc_order_dict=displayed_search_results_map, + ) response_handler_manager = LLMResponseHandlerManager( tool_call_handler, answer_handler, self.is_cancelled diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index 5e42ae23f5a..eb63c68754d 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -2,20 +2,79 @@ from typing import cast from uuid import UUID +from fastapi import HTTPException from fastapi.datastructures import Headers from sqlalchemy.orm import Session +from danswer.auth.users import is_user_admin from danswer.chat.models import CitationInfo from danswer.chat.models import LlmDoc +from danswer.chat.models import PersonaOverrideConfig +from danswer.chat.models import ThreadMessage +from danswer.configs.constants import DEFAULT_PERSONA_ID +from danswer.configs.constants import MessageType from danswer.context.search.models import InferenceSection +from danswer.context.search.models import RerankingDetails +from danswer.context.search.models import RetrievalDetails +from danswer.db.chat import create_chat_session from danswer.db.chat import get_chat_messages_by_session +from danswer.db.llm import fetch_existing_doc_sets +from danswer.db.llm import fetch_existing_tools from danswer.db.models import ChatMessage -from danswer.llm.answering.models import PreviousMessage +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.db.models import Tool +from danswer.db.models import User +from danswer.db.persona import get_prompts_by_ids +from danswer.llm.models import PreviousMessage +from danswer.natural_language_processing.utils import BaseTokenizer +from danswer.server.query_and_chat.models import CreateChatMessageRequest +from danswer.tools.tool_implementations.custom.custom_tool import ( + build_custom_tools_from_openapi_schema_and_headers, +) from danswer.utils.logger import setup_logger logger = setup_logger() +def prepare_chat_message_request( + message_text: str, + user: User | None, + persona_id: int | None, + # Does the question need to have a persona override + persona_override_config: PersonaOverrideConfig | None, + prompt: Prompt | None, + message_ts_to_respond_to: str | None, + retrieval_details: RetrievalDetails | None, + rerank_settings: RerankingDetails | None, + db_session: Session, +) -> CreateChatMessageRequest: + # Typically used for one shot flows like SlackBot or non-chat API endpoint use cases + new_chat_session = create_chat_session( + db_session=db_session, + description=None, + user_id=user.id if user else None, + # If using an override, this id will be ignored later on + persona_id=persona_id or DEFAULT_PERSONA_ID, + danswerbot_flow=True, + slack_thread_id=message_ts_to_respond_to, + ) + + return CreateChatMessageRequest( + chat_session_id=new_chat_session.id, + parent_message_id=None, # It's a standalone chat session each time + message=message_text, + file_descriptors=[], # Currently SlackBot/answer api do not support files in the context + prompt_id=prompt.id if prompt else None, + # Can always override the persona for the single query, if it's a normal persona + # then it will be treated the same + persona_override_config=persona_override_config, + search_doc_ids=None, + retrieval_options=retrieval_details, + rerank_settings=rerank_settings, + ) + + def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc: return LlmDoc( document_id=inference_section.center_chunk.document_id, @@ -31,9 +90,49 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo if inference_section.center_chunk.source_links else None, source_links=inference_section.center_chunk.source_links, + match_highlights=inference_section.center_chunk.match_highlights, ) +def combine_message_thread( + messages: list[ThreadMessage], + max_tokens: int | None, + llm_tokenizer: BaseTokenizer, +) -> str: + """Used to create a single combined message context from threads""" + if not messages: + return "" + + message_strs: list[str] = [] + total_token_count = 0 + + for message in reversed(messages): + if message.role == MessageType.USER: + role_str = message.role.value.upper() + if message.sender: + role_str += " " + message.sender + else: + # Since other messages might have the user identifying information + # better to use Unknown for symmetry + role_str += " Unknown" + else: + role_str = message.role.value.upper() + + msg_str = f"{role_str}:\n{message.message}" + message_token_count = len(llm_tokenizer.encode(msg_str)) + + if ( + max_tokens is not None + and total_token_count + message_token_count > max_tokens + ): + break + + message_strs.insert(0, msg_str) + total_token_count += message_token_count + + return "\n\n".join(message_strs) + + def create_chat_chain( chat_session_id: UUID, db_session: Session, @@ -196,3 +295,71 @@ def extract_headers( if lowercase_key in headers: extracted_headers[lowercase_key] = headers[lowercase_key] return extracted_headers + + +def create_temporary_persona( + persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None +) -> Persona: + if not is_user_admin(user): + raise HTTPException( + status_code=403, + detail="User is not authorized to create a persona in one shot queries", + ) + + """Create a temporary Persona object from the provided configuration.""" + persona = Persona( + name=persona_config.name, + description=persona_config.description, + num_chunks=persona_config.num_chunks, + llm_relevance_filter=persona_config.llm_relevance_filter, + llm_filter_extraction=persona_config.llm_filter_extraction, + recency_bias=persona_config.recency_bias, + llm_model_provider_override=persona_config.llm_model_provider_override, + llm_model_version_override=persona_config.llm_model_version_override, + ) + + if persona_config.prompts: + persona.prompts = [ + Prompt( + name=p.name, + description=p.description, + system_prompt=p.system_prompt, + task_prompt=p.task_prompt, + include_citations=p.include_citations, + datetime_aware=p.datetime_aware, + ) + for p in persona_config.prompts + ] + elif persona_config.prompt_ids: + persona.prompts = get_prompts_by_ids( + db_session=db_session, prompt_ids=persona_config.prompt_ids + ) + + persona.tools = [] + if persona_config.custom_tools_openapi: + for schema in persona_config.custom_tools_openapi: + tools = cast( + list[Tool], + build_custom_tools_from_openapi_schema_and_headers(schema), + ) + persona.tools.extend(tools) + + if persona_config.tools: + tool_ids = [tool.id for tool in persona_config.tools] + persona.tools.extend( + fetch_existing_tools(db_session=db_session, tool_ids=tool_ids) + ) + + if persona_config.tool_ids: + persona.tools.extend( + fetch_existing_tools( + db_session=db_session, tool_ids=persona_config.tool_ids + ) + ) + + fetched_docs = fetch_existing_doc_sets( + db_session=db_session, doc_ids=persona_config.document_set_ids + ) + persona.document_sets = fetched_docs + + return persona diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/chat/llm_response_handler.py similarity index 52% rename from backend/danswer/llm/answering/llm_response_handler.py rename to backend/danswer/chat/llm_response_handler.py index f8426844244..ee3d3f930bb 100644 --- a/backend/danswer/llm/answering/llm_response_handler.py +++ b/backend/danswer/chat/llm_response_handler.py @@ -1,60 +1,22 @@ from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterator -from typing import TYPE_CHECKING from langchain_core.messages import BaseMessage -from pydantic.v1 import BaseModel as BaseModel__v1 -from danswer.chat.models import CitationInfo -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuotes +from danswer.chat.models import ResponsePart from danswer.chat.models import StreamStopInfo from danswer.chat.models import StreamStopReason -from danswer.file_store.models import InMemoryChatFile -from danswer.llm.answering.prompts.build import AnswerPromptBuilder -from danswer.tools.force import ForceUseTool -from danswer.tools.models import ToolCallFinalResult -from danswer.tools.models import ToolCallKickoff -from danswer.tools.models import ToolResponse -from danswer.tools.tool import Tool - - -if TYPE_CHECKING: - from danswer.llm.answering.stream_processing.answer_response_handler import ( - AnswerResponseHandler, - ) - from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler - - -ResponsePart = ( - DanswerAnswerPiece - | CitationInfo - | DanswerQuotes - | ToolCallKickoff - | ToolResponse - | ToolCallFinalResult - | StreamStopInfo -) - - -class LLMCall(BaseModel__v1): - prompt_builder: AnswerPromptBuilder - tools: list[Tool] - force_use_tool: ForceUseTool - files: list[InMemoryChatFile] - tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult] - using_tool_calling_llm: bool - - class Config: - arbitrary_types_allowed = True +from danswer.chat.prompt_builder.build import LLMCall +from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler +from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler class LLMResponseHandlerManager: def __init__( self, - tool_handler: "ToolResponseHandler", - answer_handler: "AnswerResponseHandler", + tool_handler: ToolResponseHandler, + answer_handler: AnswerResponseHandler, is_cancelled: Callable[[], bool], ): self.tool_handler = tool_handler diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 3852029c47b..213a5ed74a5 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -1,17 +1,30 @@ +from collections.abc import Callable from collections.abc import Iterator from datetime import datetime from enum import Enum from typing import Any +from typing import TYPE_CHECKING from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import model_validator from danswer.configs.constants import DocumentSource +from danswer.configs.constants import MessageType from danswer.context.search.enums import QueryFlow +from danswer.context.search.enums import RecencyBiasSetting from danswer.context.search.enums import SearchType from danswer.context.search.models import RetrievalDocs -from danswer.context.search.models import SearchResponse +from danswer.llm.override_models import PromptOverride +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType +if TYPE_CHECKING: + from danswer.db.models import Prompt + class LlmDoc(BaseModel): """This contains the minimal set information for the LLM portion including citations""" @@ -25,6 +38,7 @@ class LlmDoc(BaseModel): updated_at: datetime | None link: str | None source_links: dict[int, str] | None + match_highlights: list[str] | None # First chunk of info for streaming QA @@ -117,20 +131,6 @@ class StreamingError(BaseModel): stack_trace: str | None = None -class DanswerQuote(BaseModel): - # This is during inference so everything is a string by this point - quote: str - document_id: str - link: str | None - source_type: str - semantic_identifier: str - blurb: str - - -class DanswerQuotes(BaseModel): - quotes: list[DanswerQuote] - - class DanswerContext(BaseModel): content: str document_id: str @@ -146,14 +146,20 @@ class DanswerAnswer(BaseModel): answer: str | None -class QAResponse(SearchResponse, DanswerAnswer): - quotes: list[DanswerQuote] | None - contexts: list[DanswerContexts] | None - predicted_flow: QueryFlow - predicted_search: SearchType - eval_res_valid: bool | None = None +class ThreadMessage(BaseModel): + message: str + sender: str | None = None + role: MessageType = MessageType.USER + + +class ChatDanswerBotResponse(BaseModel): + answer: str | None = None + citations: list[CitationInfo] | None = None + docs: QADocsResponse | None = None llm_selected_doc_indices: list[int] | None = None error_msg: str | None = None + chat_message_id: int | None = None + answer_valid: bool = True # Reflexion result, default True if Reflexion not run class FileChatDisplay(BaseModel): @@ -165,9 +171,41 @@ class CustomToolResponse(BaseModel): tool_name: str +class ToolConfig(BaseModel): + id: int + + +class PromptOverrideConfig(BaseModel): + name: str + description: str = "" + system_prompt: str + task_prompt: str = "" + include_citations: bool = True + datetime_aware: bool = True + + +class PersonaOverrideConfig(BaseModel): + name: str + description: str + search_type: SearchType = SearchType.SEMANTIC + num_chunks: float | None = None + llm_relevance_filter: bool = False + llm_filter_extraction: bool = False + recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO + llm_model_provider_override: str | None = None + llm_model_version_override: str | None = None + + prompts: list[PromptOverrideConfig] = Field(default_factory=list) + prompt_ids: list[int] = Field(default_factory=list) + + document_set_ids: list[int] = Field(default_factory=list) + tools: list[ToolConfig] = Field(default_factory=list) + tool_ids: list[int] = Field(default_factory=list) + custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list) + + AnswerQuestionPossibleReturn = ( DanswerAnswerPiece - | DanswerQuotes | CitationInfo | DanswerContexts | FileChatDisplay @@ -183,3 +221,109 @@ class CustomToolResponse(BaseModel): class LLMMetricsContainer(BaseModel): prompt_tokens: int response_tokens: int + + +StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] + + +class DocumentPruningConfig(BaseModel): + max_chunks: int | None = None + max_window_percentage: float | None = None + max_tokens: int | None = None + # different pruning behavior is expected when the + # user manually selects documents they want to chat with + # e.g. we don't want to truncate each document to be no more + # than one chunk long + is_manually_selected_docs: bool = False + # If user specifies to include additional context Chunks for each match, then different pruning + # is used. As many Sections as possible are included, and the last Section is truncated + # If this is false, all of the Sections are truncated if they are longer than the expected Chunk size. + # Sections are often expected to be longer than the maximum Chunk size but Chunks should not be. + use_sections: bool = True + # If using tools, then we need to consider the tool length + tool_num_tokens: int = 0 + # If using a tool message to represent the docs, then we have to JSON serialize + # the document content, which adds to the token count. + using_tool_message: bool = False + + +class ContextualPruningConfig(DocumentPruningConfig): + num_chunk_multiple: int + + @classmethod + def from_doc_pruning_config( + cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig + ) -> "ContextualPruningConfig": + return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict()) + + +class CitationConfig(BaseModel): + all_docs_useful: bool = False + + +class QuotesConfig(BaseModel): + pass + + +class AnswerStyleConfig(BaseModel): + citation_config: CitationConfig | None = None + quotes_config: QuotesConfig | None = None + document_pruning_config: DocumentPruningConfig = Field( + default_factory=DocumentPruningConfig + ) + # forces the LLM to return a structured response, see + # https://platform.openai.com/docs/guides/structured-outputs/introduction + # right now, only used by the simple chat API + structured_response_format: dict | None = None + + @model_validator(mode="after") + def check_quotes_and_citation(self) -> "AnswerStyleConfig": + if self.citation_config is None and self.quotes_config is None: + raise ValueError( + "One of `citation_config` or `quotes_config` must be provided" + ) + + if self.citation_config is not None and self.quotes_config is not None: + raise ValueError( + "Only one of `citation_config` or `quotes_config` must be provided" + ) + + return self + + +class PromptConfig(BaseModel): + """Final representation of the Prompt configuration passed + into the `Answer` object.""" + + system_prompt: str + task_prompt: str + datetime_aware: bool + include_citations: bool + + @classmethod + def from_model( + cls, model: "Prompt", prompt_override: PromptOverride | None = None + ) -> "PromptConfig": + override_system_prompt = ( + prompt_override.system_prompt if prompt_override else None + ) + override_task_prompt = prompt_override.task_prompt if prompt_override else None + + return cls( + system_prompt=override_system_prompt or model.system_prompt, + task_prompt=override_task_prompt or model.task_prompt, + datetime_aware=model.datetime_aware, + include_citations=model.include_citations, + ) + + model_config = ConfigDict(frozen=True) + + +ResponsePart = ( + DanswerAnswerPiece + | CitationInfo + | ToolCallKickoff + | ToolResponse + | ToolCallFinalResult + | StreamStopInfo +) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 9048e21d610..e7eab659830 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -6,16 +6,24 @@ from sqlalchemy.orm import Session +from danswer.chat.answer import Answer from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.chat_utils import create_temporary_persona from danswer.chat.models import AllCitations +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import ChatDanswerBotResponse +from danswer.chat.models import CitationConfig from danswer.chat.models import CitationInfo from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerContexts +from danswer.chat.models import DocumentPruningConfig from danswer.chat.models import FileChatDisplay from danswer.chat.models import FinalUsedContextDocsResponse from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import MessageResponseIDInfo from danswer.chat.models import MessageSpecificCitations +from danswer.chat.models import PromptConfig from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.chat.models import StreamStopInfo @@ -54,16 +62,11 @@ from danswer.file_store.models import ChatFileType from danswer.file_store.models import FileDescriptor from danswer.file_store.utils import load_all_chat_files -from danswer.file_store.utils import save_files_from_urls -from danswer.llm.answering.answer import Answer -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import CitationConfig -from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig +from danswer.file_store.utils import save_files from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple +from danswer.llm.models import PreviousMessage from danswer.llm.utils import litellm_exception_to_error_msg from danswer.natural_language_processing.utils import get_tokenizer from danswer.server.query_and_chat.models import ChatMessageDetail @@ -102,6 +105,7 @@ from danswer.tools.tool_implementations.search.search_tool import ( FINAL_CONTEXT_DOCUMENTS_ID, ) +from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID from danswer.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, ) @@ -113,7 +117,10 @@ from danswer.tools.tool_runner import ToolCallFinalResult from danswer.utils.logger import setup_logger from danswer.utils.long_term_log import LongTermLogger +from danswer.utils.timing import log_function_time from danswer.utils.timing import log_generator_function_time +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR + logger = setup_logger() @@ -256,6 +263,7 @@ def _get_force_search_settings( ChatPacket = ( StreamingError | QADocsResponse + | DanswerContexts | LLMRelevanceFilterResponse | FinalUsedContextDocsResponse | ChatMessageDetail @@ -286,6 +294,8 @@ def stream_chat_message_objects( custom_tool_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, enforce_chat_session_id_for_search_docs: bool = True, + bypass_acl: bool = False, + include_contexts: bool = False, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -293,6 +303,7 @@ def stream_chat_message_objects( 3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails 4. [always] Details on the final AI response message that is created """ + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() use_existing_user_message = new_msg_req.use_existing_user_message existing_assistant_message_id = new_msg_req.existing_assistant_message_id @@ -322,17 +333,31 @@ def stream_chat_message_objects( metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)} ) - # use alternate persona if alternative assistant id is passed in if alternate_assistant_id is not None: + # Allows users to specify a temporary persona (assistant) in the chat session + # this takes highest priority since it's user specified persona = get_persona_by_id( alternate_assistant_id, user=user, db_session=db_session, is_for_edit=False, ) + elif new_msg_req.persona_override_config: + # Certain endpoints allow users to specify arbitrary persona settings + # this should never conflict with the alternate_assistant_id + persona = persona = create_temporary_persona( + db_session=db_session, + persona_config=new_msg_req.persona_override_config, + user=user, + ) else: persona = chat_session.persona + if not persona: + raise RuntimeError("No persona specified or found for chat session") + + # If a prompt override is specified via the API, use that with highest priority + # but for saving it, we are just mapping it to an existing prompt prompt_id = new_msg_req.prompt_id if prompt_id is None and persona.prompts: prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id @@ -555,19 +580,34 @@ def stream_chat_message_objects( reserved_message_id=reserved_message_id, ) - if not final_msg.prompt: - raise RuntimeError("No Prompt found") - - prompt_config = ( - PromptConfig.from_model( + prompt_override = new_msg_req.prompt_override or chat_session.prompt_override + if new_msg_req.persona_override_config: + prompt_config = PromptConfig( + system_prompt=new_msg_req.persona_override_config.prompts[ + 0 + ].system_prompt, + task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt, + datetime_aware=new_msg_req.persona_override_config.prompts[ + 0 + ].datetime_aware, + include_citations=new_msg_req.persona_override_config.prompts[ + 0 + ].include_citations, + ) + elif prompt_override: + if not final_msg.prompt: + raise ValueError( + "Prompt override cannot be applied, no base prompt found." + ) + prompt_config = PromptConfig.from_model( final_msg.prompt, - prompt_override=( - new_msg_req.prompt_override or chat_session.prompt_override - ), + prompt_override=prompt_override, ) - if not persona - else PromptConfig.from_model(persona.prompts[0]) - ) + elif final_msg.prompt: + prompt_config = PromptConfig.from_model(final_msg.prompt) + else: + prompt_config = PromptConfig.from_model(persona.prompts[0]) + answer_style_config = AnswerStyleConfig( citation_config=CitationConfig( all_docs_useful=selected_db_search_docs is not None @@ -587,11 +627,13 @@ def stream_chat_message_objects( answer_style_config=answer_style_config, document_pruning_config=document_pruning_config, retrieval_options=retrieval_options or RetrievalDetails(), + rerank_settings=new_msg_req.rerank_settings, selected_sections=selected_sections, chunks_above=new_msg_req.chunks_above, chunks_below=new_msg_req.chunks_below, full_doc=new_msg_req.full_doc, latest_query_files=latest_query_files, + bypass_acl=bypass_acl, ), internet_search_tool_config=InternetSearchToolConfig( answer_style_config=answer_style_config, @@ -605,6 +647,7 @@ def stream_chat_message_objects( additional_headers=custom_tool_additional_headers, ), ) + tools: list[Tool] = [] for tool_list in tool_dict.values(): tools.extend(tool_list) @@ -637,7 +680,8 @@ def stream_chat_message_objects( reference_db_search_docs = None qa_docs_response = None - ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images + # any files to associate with the AI message e.g. dall-e generated images + ai_message_files = [] dropped_indices = None tool_result = None @@ -692,8 +736,14 @@ def stream_chat_message_objects( list[ImageGenerationResponse], packet.response ) - file_ids = save_files_from_urls( - [img.url for img in img_generation_response] + file_ids = save_files( + urls=[img.url for img in img_generation_response if img.url], + base64_files=[ + img.image_data + for img in img_generation_response + if img.image_data + ], + tenant_id=tenant_id, ) ai_message_files = [ FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) @@ -719,15 +769,19 @@ def stream_chat_message_objects( or custom_tool_response.response_type == "csv" ): file_ids = custom_tool_response.tool_result.file_ids - ai_message_files = [ - FileDescriptor( - id=str(file_id), - type=ChatFileType.IMAGE - if custom_tool_response.response_type == "image" - else ChatFileType.CSV, - ) - for file_id in file_ids - ] + ai_message_files.extend( + [ + FileDescriptor( + id=str(file_id), + type=( + ChatFileType.IMAGE + if custom_tool_response.response_type == "image" + else ChatFileType.CSV + ), + ) + for file_id in file_ids + ] + ) yield FileChatDisplay( file_ids=[str(file_id) for file_id in file_ids] ) @@ -736,6 +790,8 @@ def stream_chat_message_objects( response=custom_tool_response.tool_result, tool_name=custom_tool_response.tool_name, ) + elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts: + yield cast(DanswerContexts, packet.response) elif isinstance(packet, StreamStopInfo): pass @@ -775,7 +831,8 @@ def stream_chat_message_objects( citations_list=answer.citations, db_docs=reference_db_search_docs, ) - yield AllCitations(citations=answer.citations) + if not answer.is_cancelled(): + yield AllCitations(citations=answer.citations) # Saving Gen AI answer and responding with message info tool_name_to_tool_id: dict[str, int] = {} @@ -844,3 +901,30 @@ def stream_chat_message( ) for obj in objects: yield get_json_line(obj.model_dump()) + + +@log_function_time() +def gather_stream_for_slack( + packets: ChatPacketStream, +) -> ChatDanswerBotResponse: + response = ChatDanswerBotResponse() + + answer = "" + for packet in packets: + if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: + answer += packet.answer_piece + elif isinstance(packet, QADocsResponse): + response.docs = packet + elif isinstance(packet, StreamingError): + response.error_msg = packet.error + elif isinstance(packet, ChatMessageDetail): + response.chat_message_id = packet.message_id + elif isinstance(packet, LLMRelevanceFilterResponse): + response.llm_selected_doc_indices = packet.llm_selected_doc_indices + elif isinstance(packet, AllCitations): + response.citations = packet.citations + + if answer: + response.answer = answer + + return response diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/chat/prompt_builder/build.py similarity index 85% rename from backend/danswer/llm/answering/prompts/build.py rename to backend/danswer/chat/prompt_builder/build.py index 29b5100735d..9f0d0ab0ec6 100644 --- a/backend/danswer/llm/answering/prompts/build.py +++ b/backend/danswer/chat/prompt_builder/build.py @@ -4,20 +4,26 @@ from langchain_core.messages import BaseMessage from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage +from pydantic.v1 import BaseModel as BaseModel__v1 +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens +from danswer.chat.prompt_builder.utils import translate_history_to_basemessages from danswer.file_store.models import InMemoryChatFile -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens from danswer.llm.interfaces import LLMConfig +from danswer.llm.models import PreviousMessage from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_message_tokens from danswer.llm.utils import message_to_prompt_and_imgs -from danswer.llm.utils import translate_history_to_basemessages from danswer.natural_language_processing.utils import get_tokenizer from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.prompt_utils import add_date_time_to_prompt from danswer.prompts.prompt_utils import drop_messages_history_overflow +from danswer.tools.force import ForceUseTool +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse +from danswer.tools.tool import Tool def default_build_system_message( @@ -141,3 +147,15 @@ def build(self) -> list[BaseMessage]: return drop_messages_history_overflow( final_messages_with_tokens, self.max_tokens ) + + +class LLMCall(BaseModel__v1): + prompt_builder: AnswerPromptBuilder + tools: list[Tool] + force_use_tool: ForceUseTool + files: list[InMemoryChatFile] + tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult] + using_tool_calling_llm: bool + + class Config: + arbitrary_types_allowed = True diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/chat/prompt_builder/citations_prompt.py similarity index 99% rename from backend/danswer/llm/answering/prompts/citations_prompt.py rename to backend/danswer/chat/prompt_builder/citations_prompt.py index 1ff48432b86..a49dd25ae92 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/chat/prompt_builder/citations_prompt.py @@ -2,12 +2,12 @@ from langchain.schema.messages import SystemMessage from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.context.search.models import InferenceChunk from danswer.db.models import Persona from danswer.db.persona import get_default_prompt__read_only from danswer.db.search_settings import get_multilingual_expansion -from danswer.llm.answering.models import PromptConfig from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.interfaces import LLMConfig diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/chat/prompt_builder/quotes_prompt.py similarity index 97% rename from backend/danswer/llm/answering/prompts/quotes_prompt.py rename to backend/danswer/chat/prompt_builder/quotes_prompt.py index 00f22f9e7df..fa51b571e4d 100644 --- a/backend/danswer/llm/answering/prompts/quotes_prompt.py +++ b/backend/danswer/chat/prompt_builder/quotes_prompt.py @@ -1,10 +1,10 @@ from langchain.schema.messages import HumanMessage from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.context.search.models import InferenceChunk from danswer.db.search_settings import get_multilingual_expansion -from danswer.llm.answering.models import PromptConfig from danswer.llm.utils import message_to_prompt_and_imgs from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK diff --git a/backend/danswer/chat/prompt_builder/utils.py b/backend/danswer/chat/prompt_builder/utils.py new file mode 100644 index 00000000000..6383be5345e --- /dev/null +++ b/backend/danswer/chat/prompt_builder/utils.py @@ -0,0 +1,62 @@ +from langchain.schema.messages import AIMessage +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage + +from danswer.configs.constants import MessageType +from danswer.db.models import ChatMessage +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.models import PreviousMessage +from danswer.llm.utils import build_content_with_imgs +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT + + +def build_dummy_prompt( + system_prompt: str, task_prompt: str, retrieval_disabled: bool +) -> str: + if retrieval_disabled: + return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() + + return PARAMATERIZED_PROMPT.format( + context_docs_str="", + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() + + +def translate_danswer_msg_to_langchain( + msg: ChatMessage | PreviousMessage, +) -> BaseMessage: + files: list[InMemoryChatFile] = [] + + # If the message is a `ChatMessage`, it doesn't have the downloaded files + # attached. Just ignore them for now. + if not isinstance(msg, ChatMessage): + files = msg.files + content = build_content_with_imgs(msg.message, files, message_type=msg.message_type) + + if msg.message_type == MessageType.SYSTEM: + raise ValueError("System messages are not currently part of history") + if msg.message_type == MessageType.ASSISTANT: + return AIMessage(content=content) + if msg.message_type == MessageType.USER: + return HumanMessage(content=content) + + raise ValueError(f"New message type {msg.message_type} not handled") + + +def translate_history_to_basemessages( + history: list[ChatMessage] | list["PreviousMessage"], +) -> tuple[list[BaseMessage], list[int]]: + history_basemessages = [ + translate_danswer_msg_to_langchain(msg) + for msg in history + if msg.token_count != 0 + ] + history_token_counts = [msg.token_count for msg in history if msg.token_count != 0] + return history_basemessages, history_token_counts diff --git a/backend/danswer/llm/answering/prune_and_merge.py b/backend/danswer/chat/prune_and_merge.py similarity index 98% rename from backend/danswer/llm/answering/prune_and_merge.py rename to backend/danswer/chat/prune_and_merge.py index 21ea2226d97..0085793f88c 100644 --- a/backend/danswer/llm/answering/prune_and_merge.py +++ b/backend/danswer/chat/prune_and_merge.py @@ -5,16 +5,16 @@ from pydantic import BaseModel +from danswer.chat.models import ContextualPruningConfig from danswer.chat.models import ( LlmDoc, ) +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.context.search.models import InferenceChunk from danswer.context.search.models import InferenceSection -from danswer.llm.answering.models import ContextualPruningConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.llm.interfaces import LLMConfig from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content diff --git a/backend/danswer/llm/answering/stream_processing/answer_response_handler.py b/backend/danswer/chat/stream_processing/answer_response_handler.py similarity index 60% rename from backend/danswer/llm/answering/stream_processing/answer_response_handler.py rename to backend/danswer/chat/stream_processing/answer_response_handler.py index edb0c500a28..a10f46be5f5 100644 --- a/backend/danswer/llm/answering/stream_processing/answer_response_handler.py +++ b/backend/danswer/chat/stream_processing/answer_response_handler.py @@ -3,16 +3,11 @@ from langchain_core.messages import BaseMessage +from danswer.chat.llm_response_handler import ResponsePart from danswer.chat.models import CitationInfo from danswer.chat.models import LlmDoc -from danswer.llm.answering.llm_response_handler import ResponsePart -from danswer.llm.answering.stream_processing.citation_processing import ( - CitationProcessor, -) -from danswer.llm.answering.stream_processing.quotes_processing import ( - QuotesProcessor, -) -from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping +from danswer.chat.stream_processing.citation_processing import CitationProcessor +from danswer.chat.stream_processing.utils import DocumentIdOrderMapping from danswer.utils.logger import setup_logger logger = setup_logger() @@ -40,13 +35,18 @@ def handle_response_part( class CitationResponseHandler(AnswerResponseHandler): def __init__( - self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping + self, + context_docs: list[LlmDoc], + doc_id_to_rank_map: DocumentIdOrderMapping, + display_doc_order_dict: dict[str, int], ): self.context_docs = context_docs self.doc_id_to_rank_map = doc_id_to_rank_map + self.display_doc_order_dict = display_doc_order_dict self.citation_processor = CitationProcessor( context_docs=self.context_docs, doc_id_to_rank_map=self.doc_id_to_rank_map, + display_doc_order_dict=self.display_doc_order_dict, ) self.processed_text = "" self.citations: list[CitationInfo] = [] @@ -70,28 +70,29 @@ def handle_response_part( yield from self.citation_processor.process_token(content) -class QuotesResponseHandler(AnswerResponseHandler): - def __init__( - self, - context_docs: list[LlmDoc], - is_json_prompt: bool = True, - ): - self.quotes_processor = QuotesProcessor( - context_docs=context_docs, - is_json_prompt=is_json_prompt, - ) - - def handle_response_part( - self, - response_item: BaseMessage | None, - previous_response_items: list[BaseMessage], - ) -> Generator[ResponsePart, None, None]: - if response_item is None: - yield from self.quotes_processor.process_token(None) - return - - content = ( - response_item.content if isinstance(response_item.content, str) else "" - ) - - yield from self.quotes_processor.process_token(content) +# No longer in use, remove later +# class QuotesResponseHandler(AnswerResponseHandler): +# def __init__( +# self, +# context_docs: list[LlmDoc], +# is_json_prompt: bool = True, +# ): +# self.quotes_processor = QuotesProcessor( +# context_docs=context_docs, +# is_json_prompt=is_json_prompt, +# ) + +# def handle_response_part( +# self, +# response_item: BaseMessage | None, +# previous_response_items: list[BaseMessage], +# ) -> Generator[ResponsePart, None, None]: +# if response_item is None: +# yield from self.quotes_processor.process_token(None) +# return + +# content = ( +# response_item.content if isinstance(response_item.content, str) else "" +# ) + +# yield from self.quotes_processor.process_token(content) diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/chat/stream_processing/citation_processing.py similarity index 79% rename from backend/danswer/llm/answering/stream_processing/citation_processing.py rename to backend/danswer/chat/stream_processing/citation_processing.py index 950ad207878..8966303faff 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/chat/stream_processing/citation_processing.py @@ -4,8 +4,8 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc +from danswer.chat.stream_processing.utils import DocumentIdOrderMapping from danswer.configs.chat_configs import STOP_STREAM_PAT -from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.utils.logger import setup_logger @@ -22,12 +22,16 @@ def __init__( self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping, + display_doc_order_dict: dict[str, int], stop_stream: str | None = STOP_STREAM_PAT, ): self.context_docs = context_docs self.doc_id_to_rank_map = doc_id_to_rank_map self.stop_stream = stop_stream self.order_mapping = doc_id_to_rank_map.order_mapping + self.display_doc_order_dict = ( + display_doc_order_dict # original order of docs to displayed to user + ) self.llm_out = "" self.max_citation_num = len(context_docs) self.citation_order: list[int] = [] @@ -67,9 +71,9 @@ def process_token( if piece_that_comes_after == "\n" and in_code_block(self.llm_out): self.curr_segment = self.curr_segment.replace("```", "```plaintext") - citation_pattern = r"\[(\d+)\]" + citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc. citations_found = list(re.finditer(citation_pattern, self.curr_segment)) - possible_citation_pattern = r"(\[\d*$)" # [1, [, etc + possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc. possible_citation_found = re.search( possible_citation_pattern, self.curr_segment ) @@ -77,13 +81,15 @@ def process_token( if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5: self.current_citations = [] - result = "" # Initialize result here + result = "" if citations_found and not in_code_block(self.llm_out): last_citation_end = 0 length_to_add = 0 while len(citations_found) > 0: citation = citations_found.pop(0) - numerical_value = int(citation.group(1)) + numerical_value = int( + next(group for group in citation.groups() if group is not None) + ) if 1 <= numerical_value <= self.max_citation_num: context_llm_doc = self.context_docs[numerical_value - 1] @@ -96,6 +102,18 @@ def process_token( self.citation_order.index(real_citation_num) + 1 ) + # get the value that was displayed to user, should always + # be in the display_doc_order_dict. But check anyways + if context_llm_doc.document_id in self.display_doc_order_dict: + displayed_citation_num = self.display_doc_order_dict[ + context_llm_doc.document_id + ] + else: + displayed_citation_num = real_citation_num + logger.warning( + f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." + ) + # Skip consecutive citations of the same work if target_citation_num in self.current_citations: start, end = citation.span() @@ -116,6 +134,7 @@ def process_token( doc_id = int(match.group(1)) context_llm_doc = self.context_docs[doc_id - 1] yield CitationInfo( + # stay with the original for now (order of LLM cites) citation_num=target_citation_num, document_id=context_llm_doc.document_id, ) @@ -131,29 +150,24 @@ def process_token( link = context_llm_doc.link - # Replace the citation in the current segment - start, end = citation.span() - self.curr_segment = ( - self.curr_segment[: start + length_to_add] - + f"[{target_citation_num}]" - + self.curr_segment[end + length_to_add :] - ) - self.past_cite_count = len(self.llm_out) self.current_citations.append(target_citation_num) if target_citation_num not in self.cited_inds: self.cited_inds.add(target_citation_num) yield CitationInfo( + # stay with the original for now (order of LLM cites) citation_num=target_citation_num, document_id=context_llm_doc.document_id, ) + start, end = citation.span() if link: prev_length = len(self.curr_segment) self.curr_segment = ( self.curr_segment[: start + length_to_add] - + f"[[{target_citation_num}]]({link})" + + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user + # + f"[[{target_citation_num}]]({link})" + self.curr_segment[end + length_to_add :] ) length_to_add += len(self.curr_segment) - prev_length @@ -161,7 +175,8 @@ def process_token( prev_length = len(self.curr_segment) self.curr_segment = ( self.curr_segment[: start + length_to_add] - + f"[[{target_citation_num}]]()" + + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user + # + f"[[{target_citation_num}]]()" + self.curr_segment[end + length_to_add :] ) length_to_add += len(self.curr_segment) - prev_length diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/chat/stream_processing/quotes_processing.py similarity index 97% rename from backend/danswer/llm/answering/stream_processing/quotes_processing.py rename to backend/danswer/chat/stream_processing/quotes_processing.py index 1f1afc1aaba..306901ca396 100644 --- a/backend/danswer/llm/answering/stream_processing/quotes_processing.py +++ b/backend/danswer/chat/stream_processing/quotes_processing.py @@ -1,3 +1,4 @@ +# THIS IS NO LONGER IN USE import math import re from collections.abc import Generator @@ -5,11 +6,10 @@ from typing import Optional import regex +from pydantic import BaseModel from danswer.chat.models import DanswerAnswer from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuote -from danswer.chat.models import DanswerQuotes from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.context.search.models import InferenceChunk @@ -26,6 +26,20 @@ answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE) +class DanswerQuote(BaseModel): + # This is during inference so everything is a string by this point + quote: str + document_id: str + link: str | None + source_type: str + semantic_identifier: str + blurb: str + + +class DanswerQuotes(BaseModel): + quotes: list[DanswerQuote] + + def _extract_answer_quotes_freeform( answer_raw: str, ) -> tuple[Optional[str], Optional[list[str]]]: diff --git a/backend/danswer/llm/answering/stream_processing/utils.py b/backend/danswer/chat/stream_processing/utils.py similarity index 100% rename from backend/danswer/llm/answering/stream_processing/utils.py rename to backend/danswer/chat/stream_processing/utils.py diff --git a/backend/danswer/llm/answering/tool/tool_response_handler.py b/backend/danswer/chat/tool_handling/tool_response_handler.py similarity index 98% rename from backend/danswer/llm/answering/tool/tool_response_handler.py rename to backend/danswer/chat/tool_handling/tool_response_handler.py index db35663c487..5438aa2255e 100644 --- a/backend/danswer/llm/answering/tool/tool_response_handler.py +++ b/backend/danswer/chat/tool_handling/tool_response_handler.py @@ -4,8 +4,8 @@ from langchain_core.messages import BaseMessage from langchain_core.messages import ToolCall -from danswer.llm.answering.llm_response_handler import LLMCall -from danswer.llm.answering.llm_response_handler import ResponsePart +from danswer.chat.models import ResponsePart +from danswer.chat.prompt_builder.build import LLMCall from danswer.llm.interfaces import LLM from danswer.tools.force import ForceUseTool from danswer.tools.message import build_tool_message diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index b96598ba7db..7f745669eca 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -43,9 +43,6 @@ AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower()) DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED -# Necessary for cloud integration tests -DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true" - # Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive # information. This provides an extra layer of security on top of Postgres access controls # and is available in Danswer EE @@ -84,7 +81,14 @@ or "" ) +# for future OAuth connector support +# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "") +# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "") +# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "") +# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "") + USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "") + # for basic auth REQUIRE_EMAIL_VERIFICATION = ( os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true" @@ -118,6 +122,8 @@ VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST VESPA_PORT = os.environ.get("VESPA_PORT") or "8081" VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071" +# the number of times to try and connect to vespa on startup before giving up +VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10) VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "") @@ -308,6 +314,22 @@ os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000) ) +# Due to breakages in the confluence API, the timezone offset must be specified client side +# to match the user's specified timezone. + +# The current state of affairs: +# CQL queries are parsed in the user's timezone and cannot be specified in UTC +# no API retrieves the user's timezone +# All data is returned in UTC, so we can't derive the user's timezone from that + +# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16 +# https://jira.atlassian.com/browse/CONFCLOUD-69670 + +# enter as a floating point offset from UTC in hours (-24 < val < 24) +# this will be applied globally, so it probably makes sense to transition this to per +# connector as some point. +CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0)) + JIRA_CONNECTOR_LABELS_TO_SKIP = [ ignored_tag for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") @@ -326,6 +348,12 @@ os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true" ) +# Egnyte specific configs +EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE") +EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN") +EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID") +EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET") + DASK_JOB_CLIENT_ENABLED = ( os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true" ) @@ -389,21 +417,28 @@ # We don't want the metadata to overwhelm the actual contents of the chunk SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true" # Timeout to wait for job's last update before killing it, in hours -CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3)) +CLEANUP_INDEXING_JOBS_TIMEOUT = int( + os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT") or 3 +) # The indexer will warn in the logs whenver a document exceeds this threshold (in bytes) INDEXING_SIZE_WARNING_THRESHOLD = int( - os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024) + os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024 ) # during indexing, will log verbose memory diff stats every x batches and at the end. # 0 disables this behavior and is the default. -INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0)) +INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0) # During an indexing attempt, specifies the number of batches which are allowed to # exception without aborting the attempt. -INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0)) +INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0) +# Maximum file size in a document to be indexed +MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000) +MAX_FILE_SIZE_BYTES = int( + os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024 +) # 2GB in bytes ##### # Miscellaneous @@ -493,10 +528,6 @@ # JWT configuration JWT_ALGORITHM = "HS256" -# Super Users -SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]')) -SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") - ##### # API Key Configs @@ -510,3 +541,6 @@ POD_NAME = os.environ.get("POD_NAME") POD_NAMESPACE = os.environ.get("POD_NAMESPACE") + + +DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true" diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index 2d72bed0f5a..88ff301a99e 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -3,7 +3,6 @@ PROMPTS_YAML = "./danswer/seeding/prompts.yaml" PERSONAS_YAML = "./danswer/seeding/personas.yaml" -INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml" NUM_RETURNED_HITS = 50 # Used for LLM filtering and reranking diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index e6facc587d7..b9b5f7deb26 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -31,6 +31,8 @@ "You can still use Danswer as a search engine." ) +DEFAULT_PERSONA_ID = 0 + # Postgres connection constants for application_name POSTGRES_WEB_APP_NAME = "web" POSTGRES_INDEXER_APP_NAME = "indexer" @@ -130,6 +132,7 @@ class DocumentSource(str, Enum): NOT_APPLICABLE = "not_applicable" FRESHDESK = "freshdesk" FIREFLIES = "fireflies" + EGNYTE = "egnyte" DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE] @@ -259,6 +262,32 @@ class DanswerCeleryPriority(int, Enum): LOWEST = auto() +class DanswerCeleryTask: + CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task" + CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task" + CHECK_FOR_INDEXING = "check_for_indexing" + CHECK_FOR_PRUNING = "check_for_pruning" + CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync" + CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync" + MONITOR_VESPA_SYNC = "monitor_vespa_sync" + KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task" + CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = ( + "connector_permission_sync_generator_task" + ) + UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = ( + "update_external_document_permissions_task" + ) + CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = ( + "connector_external_group_sync_generator_task" + ) + CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task" + CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task" + DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task" + VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task" + CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task" + AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task" + + REDIS_SOCKET_KEEPALIVE_OPTIONS = {} REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15 REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3 diff --git a/backend/danswer/configs/danswerbot_configs.py b/backend/danswer/configs/danswerbot_configs.py index 3fca9bc78b3..7a7e5f41377 100644 --- a/backend/danswer/configs/danswerbot_configs.py +++ b/backend/danswer/configs/danswerbot_configs.py @@ -4,11 +4,8 @@ # Danswer Slack Bot Configs ##### DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5")) -DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int( - os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90") -) # How much of the available input context can be used for thread context -DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072 +MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072 # Number of docs to display in "Reference Documents" DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int( os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5") @@ -47,17 +44,6 @@ DANSWER_BOT_RESPOND_EVERY_CHANNEL = ( os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true" ) -# Add a second LLM call post Answer to verify if the Answer is valid -# Throws out answers that don't directly or fully answer the user query -# This is the default for all DanswerBot channels unless the channel is configured individually -# Set/unset by "Hide Non Answers" -ENABLE_DANSWERBOT_REFLEXION = ( - os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true" -) -# Currently not support chain of thought, probably will add back later -DANSWER_BOT_DISABLE_COT = True -# if set, will default DanswerBot to use quotes and reference documents -DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true" # Maximum Questions Per Minute, Default Uncapped DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 0618bf5f684..b71762a4c88 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -70,7 +70,9 @@ ) # Typically, GenAI models nowadays are at least 4K tokens -GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096 +GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int( + os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096 +) # Number of tokens from chat history to include at maximum # 3000 should be enough context regardless of use, no need to include as much as possible diff --git a/backend/danswer/configs/tool_configs.py b/backend/danswer/configs/tool_configs.py index 3170cb31ff9..9e143301494 100644 --- a/backend/danswer/configs/tool_configs.py +++ b/backend/danswer/configs/tool_configs.py @@ -2,6 +2,8 @@ import os +IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url") + # if specified, will pass through request headers to the call to API calls made by custom tools CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get( diff --git a/backend/danswer/connectors/README.md b/backend/danswer/connectors/README.md index bb7f5a5fe4f..5a0fb1b2aef 100644 --- a/backend/danswer/connectors/README.md +++ b/backend/danswer/connectors/README.md @@ -11,11 +11,16 @@ Connectors come in 3 different flows: - Load Connector: - Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all documents via a connector's API or loads the documents from some sort of a dump file. -- Poll connector: +- Poll Connector: - Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest changes and additions since the last round of polling. This connector helps keep the document index up to date without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of documents. +- Slim Connector: + - This connector should be a lighter weight method of checking all documents in the source to see if they still exist. + - This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves. + - This is used by our pruning job which removes old documents from the index. + - The optional start and end datetimes can be ignored. - Event Based connectors: - Connectors that listen to events and update documents accordingly. - Currently not used by the background job, this exists for future design purposes. @@ -26,8 +31,14 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown): [Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139) +For implementing a Slim Connector, refer to the comments in this PR: +[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files) + +All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector. + + #### Implementing the new Connector -The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector. +The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector. The `__init__` should take arguments for configuring what documents the connector will and where it finds those documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 1f76404891c..ef23dece68f 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -1,4 +1,5 @@ from datetime import datetime +from datetime import timedelta from datetime import timezone from typing import Any from urllib.parse import quote @@ -6,6 +7,7 @@ from atlassian import Confluence # type: ignore from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP +from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource @@ -15,6 +17,7 @@ from danswer.connectors.confluence.utils import build_confluence_document_id from danswer.connectors.confluence.utils import datetime_from_string from danswer.connectors.confluence.utils import extract_text_from_confluence_html +from danswer.connectors.confluence.utils import validate_attachment_filetype from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector @@ -53,7 +56,7 @@ "restrictions.read.restrictions.group", ] -_SLIM_DOC_BATCH_SIZE = 1000 +_SLIM_DOC_BATCH_SIZE = 5000 class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): @@ -71,6 +74,7 @@ def __init__( # skip it. This is generally used to avoid indexing extra sensitive # pages. labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP, + timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET, ) -> None: self.batch_size = batch_size self.continue_on_failure = continue_on_failure @@ -106,6 +110,8 @@ def __init__( ) self.cql_label_filter = f" and label not in ({comma_separated_labels})" + self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset)) + @property def confluence_client(self) -> OnyxConfluence: if self._confluence_client is None: @@ -221,12 +227,14 @@ def _fetch_document_batches(self) -> GenerateDocumentsOutput: confluence_page_ids: list[str] = [] page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter + logger.debug(f"page_query: {page_query}") # Fetch pages as Documents for page in self.confluence_client.paginated_cql_retrieval( cql=page_query, expand=",".join(_PAGE_EXPANSION_FIELDS), limit=self.batch_size, ): + logger.debug(f"_fetch_document_batches: {page['id']}") confluence_page_ids.append(page["id"]) doc = self._convert_object_to_document(page) if doc is not None: @@ -259,10 +267,10 @@ def load_from_state(self) -> GenerateDocumentsOutput: def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput: # Add time filters - formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime( + formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime( "%Y-%m-%d %H:%M" ) - formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime( + formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime( "%Y-%m-%d %H:%M" ) self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'" @@ -286,9 +294,11 @@ def retrieve_all_slim_documents( ): # If the page has restrictions, add them to the perm_sync_data # These will be used by doc_sync.py to sync permissions - perm_sync_data = { - "restrictions": page.get("restrictions", {}), - "space_key": page.get("space", {}).get("key"), + page_restrictions = page.get("restrictions") + page_space_key = page.get("space", {}).get("key") + page_perm_sync_data = { + "restrictions": page_restrictions or {}, + "space_key": page_space_key, } doc_metadata_list.append( @@ -298,7 +308,7 @@ def retrieve_all_slim_documents( page["_links"]["webui"], self.is_cloud, ), - perm_sync_data=perm_sync_data, + perm_sync_data=page_perm_sync_data, ) ) attachment_cql = f"type=attachment and container='{page['id']}'" @@ -308,6 +318,21 @@ def retrieve_all_slim_documents( expand=restrictions_expand, limit=_SLIM_DOC_BATCH_SIZE, ): + if not validate_attachment_filetype(attachment): + continue + attachment_restrictions = attachment.get("restrictions") + if not attachment_restrictions: + attachment_restrictions = page_restrictions + + attachment_space_key = attachment.get("space", {}).get("key") + if not attachment_space_key: + attachment_space_key = page_space_key + + attachment_perm_sync_data = { + "restrictions": attachment_restrictions or {}, + "space_key": attachment_space_key, + } + doc_metadata_list.append( SlimDocument( id=build_confluence_document_id( @@ -315,8 +340,11 @@ def retrieve_all_slim_documents( attachment["_links"]["webui"], self.is_cloud, ), - perm_sync_data=perm_sync_data, + perm_sync_data=attachment_perm_sync_data, ) ) - yield doc_metadata_list - doc_metadata_list = [] + if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE: + yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE] + doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:] + + yield doc_metadata_list diff --git a/backend/danswer/connectors/confluence/onyx_confluence.py b/backend/danswer/connectors/confluence/onyx_confluence.py index b70ae65c4e8..64130a42641 100644 --- a/backend/danswer/connectors/confluence/onyx_confluence.py +++ b/backend/danswer/connectors/confluence/onyx_confluence.py @@ -118,7 +118,7 @@ def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: return cast(F, wrapped_call) -_DEFAULT_PAGINATION_LIMIT = 100 +_DEFAULT_PAGINATION_LIMIT = 1000 class OnyxConfluence(Confluence): @@ -132,6 +132,32 @@ def __init__(self, url: str, *args: Any, **kwargs: Any) -> None: super(OnyxConfluence, self).__init__(url, *args, **kwargs) self._wrap_methods() + def get_current_user(self, expand: str | None = None) -> Any: + """ + Implements a method that isn't in the third party client. + + Get information about the current user + :param expand: OPTIONAL expand for get status of user. + Possible param is "status". Results are "Active, Deactivated" + :return: Returns the user details + """ + + from atlassian.errors import ApiPermissionError # type:ignore + + url = "rest/api/user/current" + params = {} + if expand: + params["expand"] = expand + try: + response = self.get(url, params=params) + except HTTPError as e: + if e.response.status_code == 403: + raise ApiPermissionError( + "The calling user does not have permission", reason=e + ) + raise + return response + def _wrap_methods(self) -> None: """ For each attribute that is callable (i.e., a method) and doesn't start with an underscore, @@ -305,6 +331,13 @@ def _validate_connector_configuration( ) spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1) + # uncomment the following for testing + # the following is an attempt to retrieve the user's timezone + # Unfornately, all data is returned in UTC regardless of the user's time zone + # even tho CQL parses incoming times based on the user's time zone + # space_key = spaces["results"][0]["key"] + # space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space") + if not spaces: raise RuntimeError( f"No spaces found at {wiki_base}! " diff --git a/backend/danswer/connectors/confluence/utils.py b/backend/danswer/connectors/confluence/utils.py index e6ac0308a3a..991d03e6571 100644 --- a/backend/danswer/connectors/confluence/utils.py +++ b/backend/danswer/connectors/confluence/utils.py @@ -32,7 +32,11 @@ def get_user_email_from_username__server( response = confluence_client.get_mobile_parameters(user_name) email = response.get("email") except Exception: - email = None + # For now, we'll just return a string that indicates failure + # We may want to revert to returning None in the future + # email = None + email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}" + logger.warning(f"failed to get confluence email for {user_name}") _USER_EMAIL_CACHE[user_name] = email return _USER_EMAIL_CACHE[user_name] @@ -173,19 +177,23 @@ def extract_text_from_confluence_html( return format_document_soup(soup) -def attachment_to_content( - confluence_client: OnyxConfluence, - attachment: dict[str, Any], -) -> str | None: - """If it returns None, assume that we should skip this attachment.""" - if attachment["metadata"]["mediaType"] in [ +def validate_attachment_filetype(attachment: dict[str, Any]) -> bool: + return attachment["metadata"]["mediaType"] not in [ "image/jpeg", "image/png", "image/gif", "image/svg+xml", "video/mp4", "video/quicktime", - ]: + ] + + +def attachment_to_content( + confluence_client: OnyxConfluence, + attachment: dict[str, Any], +) -> str | None: + """If it returns None, assume that we should skip this attachment.""" + if not validate_attachment_filetype(attachment): return None download_link = confluence_client.url + attachment["_links"]["download"] @@ -241,7 +249,7 @@ def build_confluence_document_id( return f"{base_url}{content_url}" -def extract_referenced_attachment_names(page_text: str) -> list[str]: +def _extract_referenced_attachment_names(page_text: str) -> list[str]: """Parse a Confluence html page to generate a list of current attachments in use diff --git a/backend/danswer/connectors/egnyte/connector.py b/backend/danswer/connectors/egnyte/connector.py new file mode 100644 index 00000000000..73285644c16 --- /dev/null +++ b/backend/danswer/connectors/egnyte/connector.py @@ -0,0 +1,384 @@ +import io +import os +from collections.abc import Generator +from datetime import datetime +from datetime import timezone +from logging import Logger +from typing import Any +from typing import cast +from typing import IO + +import requests +from retry import retry + +from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN +from danswer.configs.app_configs import EGNYTE_CLIENT_ID +from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET +from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import DocumentSource +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import OAuthConnector +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import BasicExpertInfo +from danswer.connectors.models import ConnectorMissingCredentialError +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.file_processing.extract_file_text import detect_encoding +from danswer.file_processing.extract_file_text import extract_file_text +from danswer.file_processing.extract_file_text import get_file_ext +from danswer.file_processing.extract_file_text import is_text_file_extension +from danswer.file_processing.extract_file_text import is_valid_file_ext +from danswer.file_processing.extract_file_text import read_text_file +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1" +_EGNYTE_APP_BASE = "https://{domain}.egnyte.com" +_TIMEOUT = 60 + + +def _request_with_retries( + method: str, + url: str, + data: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + timeout: int = _TIMEOUT, + stream: bool = False, + tries: int = 8, + delay: float = 1, + backoff: float = 2, +) -> requests.Response: + @retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger)) + def _make_request() -> requests.Response: + response = requests.request( + method, + url, + data=data, + headers=headers, + params=params, + timeout=timeout, + stream=stream, + ) + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + if e.response.status_code != 403: + logger.exception( + f"Failed to call Egnyte API.\n" + f"URL: {url}\n" + f"Headers: {headers}\n" + f"Data: {data}\n" + f"Params: {params}" + ) + raise e + return response + + return _make_request() + + +def _parse_last_modified(last_modified: str) -> datetime: + return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace( + tzinfo=timezone.utc + ) + + +def _process_egnyte_file( + file_metadata: dict[str, Any], + file_content: IO, + base_url: str, + folder_path: str | None = None, +) -> Document | None: + """Process an Egnyte file into a Document object + + Args: + file_data: The file data from Egnyte API + file_content: The raw content of the file in bytes + base_url: The base URL for the Egnyte instance + folder_path: Optional folder path to filter results + """ + # Skip if file path doesn't match folder path filter + if folder_path and not file_metadata["path"].startswith(folder_path): + raise ValueError( + f"File path {file_metadata['path']} does not match folder path {folder_path}" + ) + + file_name = file_metadata["name"] + extension = get_file_ext(file_name) + if not is_valid_file_ext(extension): + logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") + return None + + # Extract text content based on file type + if is_text_file_extension(file_name): + encoding = detect_encoding(file_content) + file_content_raw, file_metadata = read_text_file( + file_content, encoding=encoding, ignore_danswer_metadata=False + ) + else: + file_content_raw = extract_file_text( + file=file_content, + file_name=file_name, + break_on_unprocessable=True, + ) + + # Build the web URL for the file + web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}" + + # Create document metadata + metadata: dict[str, str | list[str]] = { + "file_path": file_metadata["path"], + "last_modified": file_metadata.get("last_modified", ""), + } + + # Add lock info if present + if lock_info := file_metadata.get("lock_info"): + metadata[ + "lock_owner" + ] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}" + + # Create the document owners + primary_owner = None + if uploaded_by := file_metadata.get("uploaded_by"): + primary_owner = BasicExpertInfo( + email=uploaded_by, # Using username as email since that's what we have + ) + + # Create the document + return Document( + id=f"egnyte-{file_metadata['entry_id']}", + sections=[Section(text=file_content_raw.strip(), link=web_url)], + source=DocumentSource.EGNYTE, + semantic_identifier=file_name, + metadata=metadata, + doc_updated_at=( + _parse_last_modified(file_metadata["last_modified"]) + if "last_modified" in file_metadata + else None + ), + primary_owners=[primary_owner] if primary_owner else None, + ) + + +class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): + def __init__( + self, + folder_path: str | None = None, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.domain = "" # will always be set in `load_credentials` + self.folder_path = folder_path or "" # Root folder if not specified + self.batch_size = batch_size + self.access_token: str | None = None + + @classmethod + def oauth_id(cls) -> DocumentSource: + return DocumentSource.EGNYTE + + @classmethod + def oauth_authorization_url(cls, base_domain: str, state: str) -> str: + if not EGNYTE_CLIENT_ID: + raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") + if not EGNYTE_BASE_DOMAIN: + raise ValueError("EGNYTE_DOMAIN environment variable must be set") + + if EGNYTE_LOCALHOST_OVERRIDE: + base_domain = EGNYTE_LOCALHOST_OVERRIDE + + callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte" + return ( + f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" + f"?client_id={EGNYTE_CLIENT_ID}" + f"&redirect_uri={callback_uri}" + f"&scope=Egnyte.filesystem" + f"&state={state}" + f"&response_type=code" + ) + + @classmethod + def oauth_code_to_token(cls, code: str) -> dict[str, Any]: + if not EGNYTE_CLIENT_ID: + raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") + if not EGNYTE_CLIENT_SECRET: + raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set") + if not EGNYTE_BASE_DOMAIN: + raise ValueError("EGNYTE_DOMAIN environment variable must be set") + + # Exchange code for token + url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" + data = { + "client_id": EGNYTE_CLIENT_ID, + "client_secret": EGNYTE_CLIENT_SECRET, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte", + "scope": "Egnyte.filesystem", + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + response = _request_with_retries( + method="POST", + url=url, + data=data, + headers=headers, + # try a lot faster since this is a realtime flow + backoff=0, + delay=0.1, + ) + if not response.ok: + raise RuntimeError(f"Failed to exchange code for token: {response.text}") + + token_data = response.json() + return { + "domain": EGNYTE_BASE_DOMAIN, + "access_token": token_data["access_token"], + } + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.domain = credentials["domain"] + self.access_token = credentials["access_token"] + return None + + def _get_files_list( + self, + path: str, + ) -> list[dict[str, Any]]: + if not self.access_token or not self.domain: + raise ConnectorMissingCredentialError("Egnyte") + + headers = { + "Authorization": f"Bearer {self.access_token}", + } + + params: dict[str, Any] = { + "list_content": True, + } + + url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}" + response = _request_with_retries( + method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT + ) + if not response.ok: + raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}") + + data = response.json() + all_files: list[dict[str, Any]] = [] + + # Add files from current directory + all_files.extend(data.get("files", [])) + + # Recursively traverse folders + for item in data.get("folders", []): + all_files.extend(self._get_files_list(item["path"])) + + return all_files + + def _filter_files( + self, + files: list[dict[str, Any]], + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> list[dict[str, Any]]: + filtered_files = [] + for file in files: + if file["is_folder"]: + continue + + file_modified = _parse_last_modified(file["last_modified"]) + if start_time and file_modified < start_time: + continue + if end_time and file_modified > end_time: + continue + + filtered_files.append(file) + + return filtered_files + + def _process_files( + self, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> Generator[list[Document], None, None]: + files = self._get_files_list(self.folder_path) + files = self._filter_files(files, start_time, end_time) + + current_batch: list[Document] = [] + for file in files: + try: + # Set up request with streaming enabled + headers = { + "Authorization": f"Bearer {self.access_token}", + } + url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}" + response = _request_with_retries( + method="GET", + url=url, + headers=headers, + timeout=_TIMEOUT, + stream=True, + ) + + if not response.ok: + logger.error( + f"Failed to fetch file content: {file['path']} (status code: {response.status_code})" + ) + continue + + # Stream the response content into a BytesIO buffer + buffer = io.BytesIO() + for chunk in response.iter_content(chunk_size=8192): + if chunk: + buffer.write(chunk) + + # Reset buffer's position to the start + buffer.seek(0) + + # Process the streamed file content + doc = _process_egnyte_file( + file_metadata=file, + file_content=buffer, + base_url=_EGNYTE_APP_BASE.format(domain=self.domain), + folder_path=self.folder_path, + ) + + if doc is not None: + current_batch.append(doc) + + if len(current_batch) >= self.batch_size: + yield current_batch + current_batch = [] + + except Exception: + logger.exception(f"Failed to process file {file['path']}") + continue + + if current_batch: + yield current_batch + + def load_from_state(self) -> GenerateDocumentsOutput: + yield from self._process_files() + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + start_time = datetime.fromtimestamp(start, tz=timezone.utc) + end_time = datetime.fromtimestamp(end, tz=timezone.utc) + + yield from self._process_files(start_time=start_time, end_time=end_time) + + +if __name__ == "__main__": + connector = EgnyteConnector() + connector.load_credentials( + { + "domain": os.environ["EGNYTE_DOMAIN"], + "access_token": os.environ["EGNYTE_ACCESS_TOKEN"], + } + ) + document_batches = connector.load_from_state() + print(next(document_batches)) diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 40f926b31d1..241d5ed81ce 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -15,6 +15,7 @@ from danswer.connectors.discourse.connector import DiscourseConnector from danswer.connectors.document360.connector import Document360Connector from danswer.connectors.dropbox.connector import DropboxConnector +from danswer.connectors.egnyte.connector import EgnyteConnector from danswer.connectors.file.connector import LocalFileConnector from danswer.connectors.fireflies.connector import FirefliesConnector from danswer.connectors.freshdesk.connector import FreshdeskConnector @@ -40,7 +41,6 @@ from danswer.connectors.sharepoint.connector import SharepointConnector from danswer.connectors.slab.connector import SlabConnector from danswer.connectors.slack.connector import SlackPollConnector -from danswer.connectors.slack.load_connector import SlackLoadConnector from danswer.connectors.teams.connector import TeamsConnector from danswer.connectors.web.connector import WebConnector from danswer.connectors.wikipedia.connector import WikipediaConnector @@ -63,7 +63,6 @@ def identify_connector_class( DocumentSource.WEB: WebConnector, DocumentSource.FILE: LocalFileConnector, DocumentSource.SLACK: { - InputType.LOAD_STATE: SlackLoadConnector, InputType.POLL: SlackPollConnector, InputType.SLIM_RETRIEVAL: SlackPollConnector, }, @@ -103,6 +102,7 @@ def identify_connector_class( DocumentSource.XENFORO: XenforoConnector, DocumentSource.FRESHDESK: FreshdeskConnector, DocumentSource.FIREFLIES: FirefliesConnector, + DocumentSource.EGNYTE: EgnyteConnector, } connector_by_source = connector_map.get(source, {}) diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index b263354822f..70b7219f65a 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -17,11 +17,11 @@ from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.db.engine import get_session_with_tenant -from danswer.file_processing.extract_file_text import check_file_ext_is_valid from danswer.file_processing.extract_file_text import detect_encoding from danswer.file_processing.extract_file_text import extract_file_text from danswer.file_processing.extract_file_text import get_file_ext from danswer.file_processing.extract_file_text import is_text_file_extension +from danswer.file_processing.extract_file_text import is_valid_file_ext from danswer.file_processing.extract_file_text import load_files_from_zip from danswer.file_processing.extract_file_text import read_pdf_file from danswer.file_processing.extract_file_text import read_text_file @@ -50,7 +50,7 @@ def _read_files_and_metadata( file_content, ignore_dirs=True ): yield os.path.join(directory_path, file_info.filename), file, metadata - elif check_file_ext_is_valid(extension): + elif is_valid_file_ext(extension): yield file_name, file_content, metadata else: logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") @@ -63,7 +63,7 @@ def _process_file( pdf_pass: str | None = None, ) -> list[Document]: extension = get_file_ext(file_name) - if not check_file_ext_is_valid(extension): + if not is_valid_file_ext(extension): logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") return [] diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index ad929eb0905..1b03c703dbb 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -4,11 +4,13 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Any +from typing import cast from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.app_configs import MAX_FILE_SIZE_BYTES from danswer.configs.constants import DocumentSource from danswer.connectors.google_drive.doc_conversion import build_slim_document from danswer.connectors.google_drive.doc_conversion import ( @@ -452,12 +454,14 @@ def _fetch_drive_items( if isinstance(self.creds, ServiceAccountCredentials) else self._manage_oauth_retrieval ) - return retrieval_method( + drive_files = retrieval_method( is_slim=is_slim, start=start, end=end, ) + return drive_files + def _extract_docs_from_google_drive( self, start: SecondsSinceUnixEpoch | None = None, @@ -473,6 +477,15 @@ def _extract_docs_from_google_drive( files_to_process = [] # Gather the files into batches to be processed in parallel for file in self._fetch_drive_items(is_slim=False, start=start, end=end): + if ( + file.get("size") + and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES + ): + logger.warning( + f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes" + ) + continue + files_to_process.append(file) if len(files_to_process) >= LARGE_BATCH_SIZE: yield from _process_files_batch( diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py index 962d531b076..9b9b17a8c27 100644 --- a/backend/danswer/connectors/google_drive/file_retrieval.py +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -16,7 +16,7 @@ FILE_FIELDS = ( "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, " - "shortcutDetails, owners(emailAddress))" + "shortcutDetails, owners(emailAddress), size)" ) SLIM_FILE_FIELDS = ( "nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), " diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index c53b3de5f2f..3ab447a7a88 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -2,6 +2,7 @@ from collections.abc import Iterator from typing import Any +from danswer.configs.constants import DocumentSource from danswer.connectors.models import Document from danswer.connectors.models import SlimDocument @@ -64,6 +65,23 @@ def retrieve_all_slim_documents( raise NotImplementedError +class OAuthConnector(BaseConnector): + @classmethod + @abc.abstractmethod + def oauth_id(cls) -> DocumentSource: + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def oauth_authorization_url(cls, base_domain: str, state: str) -> str: + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def oauth_code_to_token(cls, code: str) -> dict[str, Any]: + raise NotImplementedError + + # Event driven class EventConnector(BaseConnector): @abc.abstractmethod diff --git a/backend/danswer/connectors/linear/connector.py b/backend/danswer/connectors/linear/connector.py index 22b769562d1..c6da61555bd 100644 --- a/backend/danswer/connectors/linear/connector.py +++ b/backend/danswer/connectors/linear/connector.py @@ -132,7 +132,6 @@ def _process_issues( branchName customerTicketCount description - descriptionData comments { nodes { url @@ -215,5 +214,6 @@ def poll_source( if __name__ == "__main__": connector = LinearConnector() connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]}) + document_batches = connector.load_from_state() print(next(document_batches)) diff --git a/backend/danswer/connectors/slab/connector.py b/backend/danswer/connectors/slab/connector.py index ae76332838b..f60fb8cb6ed 100644 --- a/backend/danswer/connectors/slab/connector.py +++ b/backend/danswer/connectors/slab/connector.py @@ -12,12 +12,15 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.interfaces import SlimConnector from danswer.connectors.models import ConnectorMissingCredentialError from danswer.connectors.models import Document from danswer.connectors.models import Section +from danswer.connectors.models import SlimDocument from danswer.utils.logger import setup_logger @@ -28,6 +31,8 @@ SLAB_GRAPHQL_MAX_TRIES = 10 SLAB_API_URL = "https://api.slab.com/v1/graphql" +_SLIM_BATCH_SIZE = 1000 + def run_graphql_request( graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES @@ -158,21 +163,26 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str: return urljoin(urljoin(base_url, "posts/"), url_id) -class SlabConnector(LoadConnector, PollConnector): +class SlabConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, base_url: str, batch_size: int = INDEX_BATCH_SIZE, - slab_bot_token: str | None = None, ) -> None: self.base_url = base_url self.batch_size = batch_size - self.slab_bot_token = slab_bot_token + self._slab_bot_token: str | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: - self.slab_bot_token = credentials["slab_bot_token"] + self._slab_bot_token = credentials["slab_bot_token"] return None + @property + def slab_bot_token(self) -> str: + if self._slab_bot_token is None: + raise ConnectorMissingCredentialError("Slab") + return self._slab_bot_token + def _iterate_posts( self, time_filter: Callable[[datetime], bool] | None = None ) -> GenerateDocumentsOutput: @@ -227,3 +237,21 @@ def poll_source( yield from self._iterate_posts( time_filter=lambda t: start_time <= t <= end_time ) + + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: + slim_doc_batch: list[SlimDocument] = [] + for post_id in get_all_post_ids(self.slab_bot_token): + slim_doc_batch.append( + SlimDocument( + id=post_id, + ) + ) + if len(slim_doc_batch) >= _SLIM_BATCH_SIZE: + yield slim_doc_batch + slim_doc_batch = [] + if slim_doc_batch: + yield slim_doc_batch diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index 22ace603bd4..9135be77758 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -134,7 +134,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime: def thread_to_doc( - workspace: str, channel: ChannelType, thread: ThreadType, slack_cleaner: SlackTextCleaner, @@ -171,15 +170,15 @@ def thread_to_doc( else first_message ) - doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}" + doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace( + "\n", " " + ) return Document( id=f"{channel_id}__{thread[0]['ts']}", sections=[ Section( - link=get_message_link( - event=m, workspace=workspace, channel_id=channel_id - ), + link=get_message_link(event=m, client=client, channel_id=channel_id), text=slack_cleaner.index_clean(cast(str, m["text"])), ) for m in thread @@ -263,7 +262,6 @@ def filter_channels( def _get_all_docs( client: WebClient, - workspace: str, channels: list[str] | None = None, channel_name_regex_enabled: bool = False, oldest: str | None = None, @@ -310,7 +308,6 @@ def _get_all_docs( if filtered_thread: channel_docs += 1 yield thread_to_doc( - workspace=workspace, channel=channel, thread=filtered_thread, slack_cleaner=slack_cleaner, @@ -373,14 +370,12 @@ def _get_all_doc_ids( class SlackPollConnector(PollConnector, SlimConnector): def __init__( self, - workspace: str, channels: list[str] | None = None, # if specified, will treat the specified channel strings as # regexes, and will only index channels that fully match the regexes channel_regex_enabled: bool = False, batch_size: int = INDEX_BATCH_SIZE, ) -> None: - self.workspace = workspace self.channels = channels self.channel_regex_enabled = channel_regex_enabled self.batch_size = batch_size @@ -414,7 +409,6 @@ def poll_source( documents: list[Document] = [] for document in _get_all_docs( client=self.client, - workspace=self.workspace, channels=self.channels, channel_name_regex_enabled=self.channel_regex_enabled, # NOTE: need to impute to `None` instead of using 0.0, since Slack will @@ -438,7 +432,6 @@ def poll_source( slack_channel = os.environ.get("SLACK_CHANNEL") connector = SlackPollConnector( - workspace=os.environ["SLACK_WORKSPACE"], channels=[slack_channel] if slack_channel else None, ) connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]}) diff --git a/backend/danswer/connectors/slack/load_connector.py b/backend/danswer/connectors/slack/load_connector.py deleted file mode 100644 index 7350ac6284d..00000000000 --- a/backend/danswer/connectors/slack/load_connector.py +++ /dev/null @@ -1,140 +0,0 @@ -import json -import os -from datetime import datetime -from datetime import timezone -from pathlib import Path -from typing import Any -from typing import cast - -from danswer.configs.app_configs import INDEX_BATCH_SIZE -from danswer.configs.constants import DocumentSource -from danswer.connectors.interfaces import GenerateDocumentsOutput -from danswer.connectors.interfaces import LoadConnector -from danswer.connectors.models import Document -from danswer.connectors.models import Section -from danswer.connectors.slack.connector import filter_channels -from danswer.connectors.slack.utils import get_message_link -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -def get_event_time(event: dict[str, Any]) -> datetime | None: - ts = event.get("ts") - if not ts: - return None - return datetime.fromtimestamp(float(ts), tz=timezone.utc) - - -class SlackLoadConnector(LoadConnector): - # WARNING: DEPRECATED, DO NOT USE - def __init__( - self, - workspace: str, - export_path_str: str, - channels: list[str] | None = None, - # if specified, will treat the specified channel strings as - # regexes, and will only index channels that fully match the regexes - channel_regex_enabled: bool = False, - batch_size: int = INDEX_BATCH_SIZE, - ) -> None: - self.workspace = workspace - self.channels = channels - self.channel_regex_enabled = channel_regex_enabled - self.export_path_str = export_path_str - self.batch_size = batch_size - - def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: - if credentials: - logger.warning("Unexpected credentials provided for Slack Load Connector") - return None - - @staticmethod - def _process_batch_event( - slack_event: dict[str, Any], - channel: dict[str, Any], - matching_doc: Document | None, - workspace: str, - ) -> Document | None: - if ( - slack_event["type"] == "message" - and slack_event.get("subtype") != "channel_join" - ): - if matching_doc: - return Document( - id=matching_doc.id, - sections=matching_doc.sections - + [ - Section( - link=get_message_link( - event=slack_event, - workspace=workspace, - channel_id=channel["id"], - ), - text=slack_event["text"], - ) - ], - source=matching_doc.source, - semantic_identifier=matching_doc.semantic_identifier, - title="", # slack docs don't really have a "title" - doc_updated_at=get_event_time(slack_event), - metadata=matching_doc.metadata, - ) - - return Document( - id=slack_event["ts"], - sections=[ - Section( - link=get_message_link( - event=slack_event, - workspace=workspace, - channel_id=channel["id"], - ), - text=slack_event["text"], - ) - ], - source=DocumentSource.SLACK, - semantic_identifier=channel["name"], - title="", # slack docs don't really have a "title" - doc_updated_at=get_event_time(slack_event), - metadata={}, - ) - - return None - - def load_from_state(self) -> GenerateDocumentsOutput: - export_path = Path(self.export_path_str) - - with open(export_path / "channels.json") as f: - all_channels = json.load(f) - - filtered_channels = filter_channels( - all_channels, self.channels, self.channel_regex_enabled - ) - - document_batch: dict[str, Document] = {} - for channel_info in filtered_channels: - channel_dir_path = export_path / cast(str, channel_info["name"]) - channel_file_paths = [ - channel_dir_path / file_name - for file_name in os.listdir(channel_dir_path) - ] - for path in channel_file_paths: - with open(path) as f: - events = cast(list[dict[str, Any]], json.load(f)) - for slack_event in events: - doc = self._process_batch_event( - slack_event=slack_event, - channel=channel_info, - matching_doc=document_batch.get( - slack_event.get("thread_ts", "") - ), - workspace=self.workspace, - ) - if doc: - document_batch[doc.id] = doc - if len(document_batch) >= self.batch_size: - yield list(document_batch.values()) - - yield list(document_batch.values()) diff --git a/backend/danswer/connectors/slack/utils.py b/backend/danswer/connectors/slack/utils.py index 78bc42a0926..62ac749c166 100644 --- a/backend/danswer/connectors/slack/utils.py +++ b/backend/danswer/connectors/slack/utils.py @@ -2,6 +2,7 @@ import time from collections.abc import Callable from collections.abc import Generator +from functools import lru_cache from functools import wraps from typing import Any from typing import cast @@ -21,19 +22,21 @@ _SLACK_LIMIT = 900 +@lru_cache() +def get_base_url(token: str) -> str: + """Retrieve and cache the base URL of the Slack workspace based on the client token.""" + client = WebClient(token=token) + return client.auth_test()["url"] + + def get_message_link( - event: dict[str, Any], workspace: str, channel_id: str | None = None + event: dict[str, Any], client: WebClient, channel_id: str | None = None ) -> str: - channel_id = channel_id or cast( - str, event["channel"] - ) # channel must either be present in the event or passed in - message_ts = cast(str, event["ts"]) - message_ts_without_dot = message_ts.replace(".", "") - thread_ts = cast(str | None, event.get("thread_ts")) - return ( - f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}" - + (f"?thread_ts={thread_ts}" if thread_ts else "") - ) + channel_id = channel_id or event["channel"] + message_ts = event["ts"] + response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts) + permalink = response["permalink"] + return permalink def _make_slack_api_call_logged( diff --git a/backend/danswer/connectors/teams/connector.py b/backend/danswer/connectors/teams/connector.py index 3b9340878ff..847eb059205 100644 --- a/backend/danswer/connectors/teams/connector.py +++ b/backend/danswer/connectors/teams/connector.py @@ -33,7 +33,7 @@ def get_created_datetime(chat_message: ChatMessage) -> datetime: def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]: channel_members_list: list[BasicExpertInfo] = [] - members = channel.members.get().execute_query() + members = channel.members.get().execute_query_retry() for member in members: channel_members_list.append(BasicExpertInfo(display_name=member.display_name)) return channel_members_list @@ -51,7 +51,7 @@ def _get_threads_from_channel( end = end.replace(tzinfo=timezone.utc) query = channel.messages.get() - base_messages: list[ChatMessage] = query.execute_query() + base_messages: list[ChatMessage] = query.execute_query_retry() threads: list[list[ChatMessage]] = [] for base_message in base_messages: @@ -65,7 +65,7 @@ def _get_threads_from_channel( continue reply_query = base_message.replies.get_all() - replies = reply_query.execute_query() + replies = reply_query.execute_query_retry() # start a list containing the base message and its replies thread: list[ChatMessage] = [base_message] @@ -82,7 +82,7 @@ def _get_channels_from_teams( channels_list: list[Channel] = [] for team in teams: query = team.channels.get() - channels = query.execute_query() + channels = query.execute_query_retry() channels_list.extend(channels) return channels_list @@ -210,7 +210,7 @@ def _get_all_teams(self) -> list[Team]: teams_list: list[Team] = [] - teams = self.graph_client.teams.get().execute_query() + teams = self.graph_client.teams.get().execute_query_retry() if len(self.requested_team_list) > 0: adjusted_request_strings = [ @@ -234,14 +234,25 @@ def _fetch_from_teams( raise ConnectorMissingCredentialError("Teams") teams = self._get_all_teams() + logger.debug(f"Found available teams: {[str(t) for t in teams]}") + if not teams: + msg = "No teams found." + logger.error(msg) + raise ValueError(msg) channels = _get_channels_from_teams( teams=teams, ) + logger.debug(f"Found available channels: {[c.id for c in channels]}") + if not channels: + msg = "No channels found." + logger.error(msg) + raise ValueError(msg) # goes over channels, converts them into Document objects and then yields them in batches doc_batch: list[Document] = [] for channel in channels: + logger.debug(f"Fetching threads from channel: {channel.id}") thread_list = _get_threads_from_channel(channel, start=start, end=end) for thread in thread_list: converted_doc = _convert_thread_to_document(channel, thread) @@ -259,8 +270,8 @@ def load_from_state(self) -> GenerateDocumentsOutput: def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: - start_datetime = datetime.utcfromtimestamp(start) - end_datetime = datetime.utcfromtimestamp(end) + start_datetime = datetime.fromtimestamp(start, timezone.utc) + end_datetime = datetime.fromtimestamp(end, timezone.utc) return self._fetch_from_teams(start=start_datetime, end=end_datetime) diff --git a/backend/danswer/context/search/pipeline.py b/backend/danswer/context/search/pipeline.py index 21c518348e7..52748514003 100644 --- a/backend/danswer/context/search/pipeline.py +++ b/backend/danswer/context/search/pipeline.py @@ -5,7 +5,11 @@ from sqlalchemy.orm import Session +from danswer.chat.models import PromptConfig from danswer.chat.models import SectionRelevancePiece +from danswer.chat.prune_and_merge import _merge_sections +from danswer.chat.prune_and_merge import ChunkRange +from danswer.chat.prune_and_merge import merge_chunk_intervals from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE from danswer.context.search.enums import LLMEvaluationType from danswer.context.search.enums import QueryFlow @@ -27,10 +31,6 @@ from danswer.db.search_settings import get_current_search_settings from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaChunkRequest -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prune_and_merge import _merge_sections -from danswer.llm.answering.prune_and_merge import ChunkRange -from danswer.llm.answering.prune_and_merge import merge_chunk_intervals from danswer.llm.interfaces import LLM from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section from danswer.utils.logger import setup_logger diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index 1f689157452..34ec92e7daa 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -16,24 +16,31 @@ from slack_sdk.models.blocks.basic_components import MarkdownTextObject from slack_sdk.models.blocks.block_elements import ImageElement -from danswer.chat.models import DanswerQuote +from danswer.chat.models import ChatDanswerBotResponse from danswer.configs.app_configs import DISABLE_GENERATIVE_AI +from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import DocumentSource from danswer.configs.constants import SearchFeedbackType from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY from danswer.context.search.models import SavedSearchDoc +from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID +from danswer.danswerbot.slack.formatting import format_slack_message from danswer.danswerbot.slack.icons import source_to_github_img_link +from danswer.danswerbot.slack.models import SlackMessageInfo +from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id from danswer.danswerbot.slack.utils import build_feedback_id from danswer.danswerbot.slack.utils import remove_slack_text_interactions from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack +from danswer.db.chat import get_chat_session_by_message_id +from danswer.db.engine import get_session_with_tenant +from danswer.db.models import ChannelConfig from danswer.utils.text_processing import decode_escapes -from danswer.utils.text_processing import replace_whitespaces_w_space _MAX_BLURB_LEN = 45 @@ -101,12 +108,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]: return chunks -def clean_markdown_link_text(text: str) -> str: +def _clean_markdown_link_text(text: str) -> str: # Remove any newlines within the text return text.replace("\n", " ").strip() -def build_qa_feedback_block( +def _build_qa_feedback_block( message_id: int, feedback_reminder_id: str | None = None ) -> Block: return ActionsBlock( @@ -115,7 +122,6 @@ def build_qa_feedback_block( ButtonElement( action_id=LIKE_BLOCK_ACTION_ID, text="👍 Helpful", - style="primary", value=feedback_reminder_id, ), ButtonElement( @@ -155,7 +161,7 @@ def get_document_feedback_blocks() -> Block: ) -def build_doc_feedback_block( +def _build_doc_feedback_block( message_id: int, document_id: str, document_rank: int, @@ -182,7 +188,7 @@ def get_restate_blocks( ] -def build_documents_blocks( +def _build_documents_blocks( documents: list[SavedSearchDoc], message_id: int | None, num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY, @@ -198,7 +204,8 @@ def build_documents_blocks( continue seen_docs_identifiers.add(d.document_id) - doc_sem_id = d.semantic_identifier + # Strip newlines from the semantic identifier for Slackbot formatting + doc_sem_id = d.semantic_identifier.replace("\n", " ") if d.source_type == DocumentSource.SLACK.value: doc_sem_id = "#" + doc_sem_id @@ -223,7 +230,7 @@ def build_documents_blocks( feedback: ButtonElement | dict = {} if message_id is not None: - feedback = build_doc_feedback_block( + feedback = _build_doc_feedback_block( message_id=message_id, document_id=d.document_id, document_rank=rank, @@ -241,7 +248,7 @@ def build_documents_blocks( return section_blocks -def build_sources_blocks( +def _build_sources_blocks( cited_documents: list[tuple[int, SavedSearchDoc]], num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY, ) -> list[Block]: @@ -286,7 +293,7 @@ def build_sources_blocks( + ([days_ago_str] if days_ago_str else []) ) - document_title = clean_markdown_link_text(doc_sem_id) + document_title = _clean_markdown_link_text(doc_sem_id) img_link = source_to_github_img_link(d.source_type) section_blocks.append( @@ -317,106 +324,105 @@ def build_sources_blocks( return section_blocks -def build_quotes_block( - quotes: list[DanswerQuote], +def _priority_ordered_documents_blocks( + answer: ChatDanswerBotResponse, ) -> list[Block]: - quote_lines: list[str] = [] - doc_to_quotes: dict[str, list[str]] = {} - doc_to_link: dict[str, str] = {} - doc_to_sem_id: dict[str, str] = {} - for q in quotes: - quote = q.quote - doc_id = q.document_id - doc_link = q.link - doc_name = q.semantic_identifier - if doc_link and doc_name and doc_id and quote: - if doc_id not in doc_to_quotes: - doc_to_quotes[doc_id] = [quote] - doc_to_link[doc_id] = doc_link - doc_to_sem_id[doc_id] = ( - doc_name - if q.source_type != DocumentSource.SLACK.value - else "#" + doc_name - ) - else: - doc_to_quotes[doc_id].append(quote) + docs_response = answer.docs if answer.docs else None + top_docs = docs_response.top_documents if docs_response else [] + llm_doc_inds = answer.llm_selected_doc_indices or [] + llm_docs = [top_docs[i] for i in llm_doc_inds] + remaining_docs = [ + doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds + ] + priority_ordered_docs = llm_docs + remaining_docs + if not priority_ordered_docs: + return [] - for doc_id, quote_strs in doc_to_quotes.items(): - quotes_str_clean = [ - replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs - ] - longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5] - single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes]) - link = doc_to_link[doc_id] - sem_id = doc_to_sem_id[doc_id] - quote_lines.append( - f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}" - ) + document_blocks = _build_documents_blocks( + documents=priority_ordered_docs, + message_id=answer.chat_message_id, + ) + if document_blocks: + document_blocks = [DividerBlock()] + document_blocks + return document_blocks - if not doc_to_quotes: - return [] - return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))] +def _build_citations_blocks( + answer: ChatDanswerBotResponse, +) -> list[Block]: + docs_response = answer.docs if answer.docs else None + top_docs = docs_response.top_documents if docs_response else [] + citations = answer.citations or [] + cited_docs = [] + for citation in citations: + matching_doc = next( + (d for d in top_docs if d.document_id == citation.document_id), + None, + ) + if matching_doc: + cited_docs.append((citation.citation_num, matching_doc)) + + cited_docs.sort() + citations_block = _build_sources_blocks(cited_documents=cited_docs) + return citations_block -def build_qa_response_blocks( - message_id: int | None, - answer: str | None, - quotes: list[DanswerQuote] | None, - source_filters: list[DocumentSource] | None, - time_cutoff: datetime | None, - favor_recent: bool, - skip_quotes: bool = False, +def _build_qa_response_blocks( + answer: ChatDanswerBotResponse, process_message_for_citations: bool = False, - skip_ai_feedback: bool = False, - feedback_reminder_id: str | None = None, ) -> list[Block]: + retrieval_info = answer.docs + if not retrieval_info: + # This should not happen, even with no docs retrieved, there is still info returned + raise RuntimeError("Failed to retrieve docs, cannot answer question.") + + formatted_answer = format_slack_message(answer.answer) if answer.answer else None + if DISABLE_GENERATIVE_AI: return [] - quotes_blocks: list[Block] = [] - filter_block: Block | None = None - if time_cutoff or favor_recent or source_filters: + if ( + retrieval_info.applied_time_cutoff + or retrieval_info.recency_bias_multiplier > 1 + or retrieval_info.applied_source_filters + ): filter_text = "Filters: " - if source_filters: - sources_str = ", ".join([s.value for s in source_filters]) + if retrieval_info.applied_source_filters: + sources_str = ", ".join( + [s.value for s in retrieval_info.applied_source_filters] + ) filter_text += f"`Sources in [{sources_str}]`" - if time_cutoff or favor_recent: + if ( + retrieval_info.applied_time_cutoff + or retrieval_info.recency_bias_multiplier > 1 + ): filter_text += " and " - if time_cutoff is not None: - time_str = time_cutoff.strftime("%b %d, %Y") + if retrieval_info.applied_time_cutoff is not None: + time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y") filter_text += f"`Docs Updated >= {time_str}` " - if favor_recent: - if time_cutoff is not None: + if retrieval_info.recency_bias_multiplier > 1: + if retrieval_info.applied_time_cutoff is not None: filter_text += "+ " filter_text += "`Prioritize Recently Updated Docs`" filter_block = SectionBlock(text=f"_{filter_text}_") - if not answer: + if not formatted_answer: answer_blocks = [ SectionBlock( text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓" ) ] else: - answer_processed = decode_escapes(remove_slack_text_interactions(answer)) + answer_processed = decode_escapes( + remove_slack_text_interactions(formatted_answer) + ) if process_message_for_citations: answer_processed = _process_citations_for_slack(answer_processed) answer_blocks = [ SectionBlock(text=text) for text in _split_text(answer_processed) ] - if quotes: - quotes_blocks = build_quotes_block(quotes) - - # if no quotes OR `build_quotes_block()` did not give back any blocks - if not quotes_blocks: - quotes_blocks = [ - SectionBlock( - text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔" - ) - ] response_blocks: list[Block] = [] @@ -425,20 +431,34 @@ def build_qa_response_blocks( response_blocks.extend(answer_blocks) - if message_id is not None and not skip_ai_feedback: - response_blocks.append( - build_qa_feedback_block( - message_id=message_id, feedback_reminder_id=feedback_reminder_id - ) - ) + return response_blocks - if not skip_quotes: - response_blocks.extend(quotes_blocks) - return response_blocks +def _build_continue_in_web_ui_block( + tenant_id: str | None, + message_id: int | None, +) -> Block: + if message_id is None: + raise ValueError("No message id provided to build continue in web ui block") + with get_session_with_tenant(tenant_id) as db_session: + chat_session = get_chat_session_by_message_id( + db_session=db_session, + message_id=message_id, + ) + return ActionsBlock( + block_id=build_continue_in_web_ui_id(message_id), + elements=[ + ButtonElement( + action_id=CONTINUE_IN_WEB_UI_ACTION_ID, + text="Continue Chat in Danswer!", + style="primary", + url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}", + ), + ], + ) -def build_follow_up_block(message_id: int | None) -> ActionsBlock: +def _build_follow_up_block(message_id: int | None) -> ActionsBlock: return ActionsBlock( block_id=build_feedback_id(message_id) if message_id is not None else None, elements=[ @@ -483,3 +503,75 @@ def build_follow_up_resolved_blocks( ] ) return [text_block, button_block] + + +def build_slack_response_blocks( + answer: ChatDanswerBotResponse, + tenant_id: str | None, + message_info: SlackMessageInfo, + channel_conf: ChannelConfig | None, + use_citations: bool, + feedback_reminder_id: str | None, + skip_ai_feedback: bool = False, +) -> list[Block]: + """ + This function is a top level function that builds all the blocks for the Slack response. + It also handles combining all the blocks together. + """ + # If called with the DanswerBot slash command, the question is lost so we have to reshow it + restate_question_block = get_restate_blocks( + message_info.thread_messages[-1].message, message_info.is_bot_msg + ) + + answer_blocks = _build_qa_response_blocks( + answer=answer, + process_message_for_citations=use_citations, + ) + + web_follow_up_block = [] + if channel_conf and channel_conf.get("show_continue_in_web_ui"): + web_follow_up_block.append( + _build_continue_in_web_ui_block( + tenant_id=tenant_id, + message_id=answer.chat_message_id, + ) + ) + + follow_up_block = [] + if channel_conf and channel_conf.get("follow_up_tags") is not None: + follow_up_block.append( + _build_follow_up_block(message_id=answer.chat_message_id) + ) + + ai_feedback_block = [] + if answer.chat_message_id is not None and not skip_ai_feedback: + ai_feedback_block.append( + _build_qa_feedback_block( + message_id=answer.chat_message_id, + feedback_reminder_id=feedback_reminder_id, + ) + ) + + citations_blocks = [] + document_blocks = [] + if use_citations and answer.citations: + citations_blocks = _build_citations_blocks(answer) + else: + document_blocks = _priority_ordered_documents_blocks(answer) + + citations_divider = [DividerBlock()] if citations_blocks else [] + buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else [] + + all_blocks = ( + restate_question_block + + answer_blocks + + ai_feedback_block + + citations_divider + + citations_blocks + + document_blocks + + buttons_divider + + web_follow_up_block + + follow_up_block + ) + + return all_blocks diff --git a/backend/danswer/danswerbot/slack/constants.py b/backend/danswer/danswerbot/slack/constants.py index cf2b38032c3..6a5b3ed43ed 100644 --- a/backend/danswer/danswerbot/slack/constants.py +++ b/backend/danswer/danswerbot/slack/constants.py @@ -2,6 +2,7 @@ LIKE_BLOCK_ACTION_ID = "feedback-like" DISLIKE_BLOCK_ACTION_ID = "feedback-dislike" +CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui" FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button" IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button" FOLLOWUP_BUTTON_ACTION_ID = "followup-button" diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index ec423979941..9335b96874f 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -28,7 +28,7 @@ from danswer.danswerbot.slack.utils import build_feedback_id from danswer.danswerbot.slack.utils import decompose_action_id from danswer.danswerbot.slack.utils import fetch_group_ids_from_names -from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails +from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails from danswer.danswerbot.slack.utils import get_channel_name_from_id from danswer.danswerbot.slack.utils import get_feedback_visibility from danswer.danswerbot.slack.utils import read_slack_thread @@ -267,7 +267,7 @@ def handle_followup_button( tag_names = slack_channel_config.channel_config.get("follow_up_tags") remaining = None if tag_names: - tag_ids, remaining = fetch_user_ids_from_emails( + tag_ids, remaining = fetch_slack_user_ids_from_emails( tag_names, client.web_client ) if remaining: diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 6bec83def4b..1f19d0a70a6 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -13,7 +13,7 @@ handle_standard_answers, ) from danswer.danswerbot.slack.models import SlackMessageInfo -from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails +from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import slack_usage_report @@ -184,7 +184,7 @@ def handle_message( send_to: list[str] | None = None missing_users: list[str] | None = None if respond_member_group_list: - send_to, missing_ids = fetch_user_ids_from_emails( + send_to, missing_ids = fetch_slack_user_ids_from_emails( respond_member_group_list, client ) diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index 3d5f013dca8..e1e3673e8ec 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -1,60 +1,43 @@ import functools from collections.abc import Callable from typing import Any -from typing import cast from typing import Optional from typing import TypeVar from retry import retry from slack_sdk import WebClient -from slack_sdk.models.blocks import DividerBlock from slack_sdk.models.blocks import SectionBlock +from danswer.chat.chat_utils import prepare_chat_message_request +from danswer.chat.models import ChatDanswerBotResponse +from danswer.chat.process_message import gather_stream_for_slack +from danswer.chat.process_message import stream_chat_message_objects from danswer.configs.app_configs import DISABLE_GENERATIVE_AI -from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT -from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT +from danswer.configs.constants import DEFAULT_PERSONA_ID from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES -from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE -from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI -from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION +from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE from danswer.context.search.enums import OptionalSearchSetting from danswer.context.search.models import BaseFilters -from danswer.context.search.models import RerankingDetails from danswer.context.search.models import RetrievalDetails -from danswer.danswerbot.slack.blocks import build_documents_blocks -from danswer.danswerbot.slack.blocks import build_follow_up_block -from danswer.danswerbot.slack.blocks import build_qa_response_blocks -from danswer.danswerbot.slack.blocks import build_sources_blocks -from danswer.danswerbot.slack.blocks import get_restate_blocks -from danswer.danswerbot.slack.formatting import format_slack_message +from danswer.danswerbot.slack.blocks import build_slack_response_blocks from danswer.danswerbot.slack.handlers.utils import send_team_member_message +from danswer.danswerbot.slack.handlers.utils import slackify_message_thread from danswer.danswerbot.slack.models import SlackMessageInfo from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import SlackRateLimiter from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_session_with_tenant -from danswer.db.models import Persona -from danswer.db.models import SlackBotResponseType from danswer.db.models import SlackChannelConfig -from danswer.db.persona import fetch_persona_by_id -from danswer.db.search_settings import get_current_search_settings +from danswer.db.models import User +from danswer.db.persona import get_persona_by_id from danswer.db.users import get_user_by_email -from danswer.llm.answering.prompts.citations_prompt import ( - compute_max_document_tokens_for_persona, -) -from danswer.llm.factory import get_llms_for_persona -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_max_input_tokens -from danswer.one_shot_answer.answer_question import get_search_answer -from danswer.one_shot_answer.models import DirectQARequest -from danswer.one_shot_answer.models import OneShotQAResponse +from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.utils.logger import DanswerLoggingAdapter - srl = SlackRateLimiter() RT = TypeVar("RT") # return type @@ -89,16 +72,14 @@ def handle_regular_answer( feedback_reminder_id: str | None, tenant_id: str | None, num_retries: int = DANSWER_BOT_NUM_RETRIES, - answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT, - thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE, + thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE, should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS, disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER, - disable_cot: bool = DANSWER_BOT_DISABLE_COT, - reflexion: bool = ENABLE_DANSWERBOT_REFLEXION, ) -> bool: channel_conf = slack_channel_config.channel_config if slack_channel_config else None messages = message_info.thread_messages + message_ts_to_respond_to = message_info.msg_to_respond is_bot_msg = message_info.is_bot_msg user = None @@ -108,9 +89,18 @@ def handle_regular_answer( user = get_user_by_email(message_info.email, db_session) document_set_names: list[str] | None = None - persona = slack_channel_config.persona if slack_channel_config else None prompt = None - if persona: + # If no persona is specified, use the default search based persona + # This way slack flow always has a persona + persona = slack_channel_config.persona if slack_channel_config else None + if not persona: + with get_session_with_tenant(tenant_id) as db_session: + persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session) + document_set_names = [ + document_set.name for document_set in persona.document_sets + ] + prompt = persona.prompts[0] if persona.prompts else None + else: document_set_names = [ document_set.name for document_set in persona.document_sets ] @@ -118,6 +108,26 @@ def handle_regular_answer( should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False + # TODO: Add in support for Slack to truncate messages based on max LLM context + # llm, _ = get_llms_for_persona(persona) + + # llm_tokenizer = get_tokenizer( + # model_name=llm.config.model_name, + # provider_type=llm.config.model_provider, + # ) + + # # In cases of threads, split the available tokens between docs and thread context + # input_tokens = get_max_input_tokens( + # model_name=llm.config.model_name, + # model_provider=llm.config.model_provider, + # ) + # max_history_tokens = int(input_tokens * thread_context_percent) + # combined_message = combine_message_thread( + # messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer + # ) + + combined_message = slackify_message_thread(messages) + bypass_acl = False if ( slack_channel_config @@ -128,13 +138,6 @@ def handle_regular_answer( # with non-public document sets bypass_acl = True - # figure out if we want to use citations or quotes - use_citations = ( - not DANSWER_BOT_USE_QUOTES - if slack_channel_config is None - else slack_channel_config.response_type == SlackBotResponseType.CITATIONS - ) - if not message_ts_to_respond_to and not is_bot_msg: # if the message is not "/danswer" command, then it should have a message ts to respond to raise RuntimeError( @@ -147,75 +150,23 @@ def handle_regular_answer( backoff=2, ) @rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to) - def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None: - max_document_tokens: int | None = None - max_history_tokens: int | None = None - + def _get_slack_answer( + new_message_request: CreateChatMessageRequest, danswer_user: User | None + ) -> ChatDanswerBotResponse: with get_session_with_tenant(tenant_id) as db_session: - if len(new_message_request.messages) > 1: - if new_message_request.persona_config: - raise RuntimeError("Slack bot does not support persona config") - elif new_message_request.persona_id is not None: - persona = cast( - Persona, - fetch_persona_by_id( - db_session, - new_message_request.persona_id, - user=None, - get_editable=False, - ), - ) - else: - raise RuntimeError( - "No persona id provided, this should never happen." - ) - - llm, _ = get_llms_for_persona(persona) - - # In cases of threads, split the available tokens between docs and thread context - input_tokens = get_max_input_tokens( - model_name=llm.config.model_name, - model_provider=llm.config.model_provider, - ) - max_history_tokens = int(input_tokens * thread_context_percent) - - remaining_tokens = input_tokens - max_history_tokens - - query_text = new_message_request.messages[0].message - if persona: - max_document_tokens = compute_max_document_tokens_for_persona( - persona=persona, - actual_user_input=query_text, - max_llm_token_override=remaining_tokens, - ) - else: - max_document_tokens = ( - remaining_tokens - - 512 # Needs to be more than any of the QA prompts - - check_number_of_tokens(query_text) - ) - - if DISABLE_GENERATIVE_AI: - return None - - # This also handles creating the query event in postgres - answer = get_search_answer( - query_req=new_message_request, - user=user, - max_document_tokens=max_document_tokens, - max_history_tokens=max_history_tokens, + packets = stream_chat_message_objects( + new_msg_req=new_message_request, + user=danswer_user, db_session=db_session, - answer_generation_timeout=answer_generation_timeout, - enable_reflexion=reflexion, bypass_acl=bypass_acl, - use_citations=use_citations, - danswerbot_flow=True, ) - if not answer.error_msg: - return answer - else: - raise RuntimeError(answer.error_msg) + answer = gather_stream_for_slack(packets) + + if answer.error_msg: + raise RuntimeError(answer.error_msg) + + return answer try: # By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters @@ -245,26 +196,24 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non enable_auto_detect_filters=auto_detect_filters, ) - # Always apply reranking settings if it exists, this is the non-streaming flow with get_session_with_tenant(tenant_id) as db_session: - saved_search_settings = get_current_search_settings(db_session) - - # This includes throwing out answer via reflexion - answer = _get_answer( - DirectQARequest( - messages=messages, - multilingual_query_expansion=saved_search_settings.multilingual_expansion - if saved_search_settings - else None, - prompt_id=prompt.id if prompt else None, - persona_id=persona.id if persona is not None else 0, - retrieval_options=retrieval_details, - chain_of_thought=not disable_cot, - rerank_settings=RerankingDetails.from_db_model(saved_search_settings) - if saved_search_settings - else None, + answer_request = prepare_chat_message_request( + message_text=combined_message, + user=user, + persona_id=persona.id, + # This is not used in the Slack flow, only in the answer API + persona_override_config=None, + prompt=prompt, + message_ts_to_respond_to=message_ts_to_respond_to, + retrieval_details=retrieval_details, + rerank_settings=None, # Rerank customization supported in Slack flow + db_session=db_session, ) + + answer = _get_slack_answer( + new_message_request=answer_request, danswer_user=user ) + except Exception as e: logger.exception( f"Unable to process message - did not successfully answer " @@ -365,7 +314,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non top_docs = retrieval_info.top_documents if not top_docs and not should_respond_even_with_no_docs: logger.error( - f"Unable to answer question: '{answer.rephrase}' - no documents found" + f"Unable to answer question: '{combined_message}' - no documents found" ) # Optionally, respond in thread with the error message # Used primarily for debugging purposes @@ -386,18 +335,18 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non ) return True - only_respond_with_citations_or_quotes = ( + only_respond_if_citations = ( channel_conf and "well_answered_postfilter" in channel_conf.get("answer_filters", []) ) - has_citations_or_quotes = bool(answer.citations or answer.quotes) + if ( - only_respond_with_citations_or_quotes - and not has_citations_or_quotes + only_respond_if_citations + and not answer.citations and not message_info.bypass_filters ): logger.error( - f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!" + f"Unable to find citations to answer: '{answer.answer}' - not answering!" ) # Optionally, respond in thread with the error message # Used primarily for debugging purposes @@ -411,67 +360,22 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non ) return True - # If called with the DanswerBot slash command, the question is lost so we have to reshow it - restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg) - formatted_answer = format_slack_message(answer.answer) if answer.answer else None - - answer_blocks = build_qa_response_blocks( - message_id=answer.chat_message_id, - answer=formatted_answer, - quotes=answer.quotes.quotes if answer.quotes else None, - source_filters=retrieval_info.applied_source_filters, - time_cutoff=retrieval_info.applied_time_cutoff, - favor_recent=retrieval_info.recency_bias_multiplier > 1, - # currently Personas don't support quotes - # if citations are enabled, also don't use quotes - skip_quotes=persona is not None or use_citations, - process_message_for_citations=use_citations, + all_blocks = build_slack_response_blocks( + tenant_id=tenant_id, + message_info=message_info, + answer=answer, + channel_conf=channel_conf, + use_citations=True, # No longer supporting quotes feedback_reminder_id=feedback_reminder_id, ) - # Get the chunks fed to the LLM only, then fill with other docs - llm_doc_inds = answer.llm_selected_doc_indices or [] - llm_docs = [top_docs[i] for i in llm_doc_inds] - remaining_docs = [ - doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds - ] - priority_ordered_docs = llm_docs + remaining_docs - - document_blocks = [] - citations_block = [] - # if citations are enabled, only show cited documents - if use_citations: - citations = answer.citations or [] - cited_docs = [] - for citation in citations: - matching_doc = next( - (d for d in top_docs if d.document_id == citation.document_id), - None, - ) - if matching_doc: - cited_docs.append((citation.citation_num, matching_doc)) - - cited_docs.sort() - citations_block = build_sources_blocks(cited_documents=cited_docs) - elif priority_ordered_docs: - document_blocks = build_documents_blocks( - documents=priority_ordered_docs, - message_id=answer.chat_message_id, - ) - document_blocks = [DividerBlock()] + document_blocks - - all_blocks = ( - restate_question_block + answer_blocks + citations_block + document_blocks - ) - - if channel_conf and channel_conf.get("follow_up_tags") is not None: - all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id)) - try: respond_in_thread( client=client, channel=channel, - receiver_ids=receiver_ids, + receiver_ids=[message_info.sender] + if message_info.is_bot_msg and message_info.sender + else receiver_ids, text="Hello! Danswer has some results for you!", blocks=all_blocks, thread_ts=message_ts_to_respond_to, diff --git a/backend/danswer/danswerbot/slack/handlers/utils.py b/backend/danswer/danswerbot/slack/handlers/utils.py index 296b7b90d41..d34e8455df7 100644 --- a/backend/danswer/danswerbot/slack/handlers/utils.py +++ b/backend/danswer/danswerbot/slack/handlers/utils.py @@ -1,8 +1,33 @@ from slack_sdk import WebClient +from danswer.chat.models import ThreadMessage +from danswer.configs.constants import MessageType from danswer.danswerbot.slack.utils import respond_in_thread +def slackify_message_thread(messages: list[ThreadMessage]) -> str: + # Note: this does not handle extremely long threads, every message will be included + # with weaker LLMs, this could cause issues with exceeeding the token limit + if not messages: + return "" + + message_strs: list[str] = [] + for message in messages: + if message.role == MessageType.USER: + message_text = ( + f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}" + ) + elif message.role == MessageType.ASSISTANT: + message_text = f"AI said in Slack:\n{message.message}" + else: + message_text = ( + f"{message.role.value.upper()} said in Slack:\n{message.message}" + ) + message_strs.append(message_text) + + return "\n\n".join(message_strs) + + def send_team_member_message( client: WebClient, channel: str, diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 5f6cabb3406..b19c89d8576 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -19,6 +19,8 @@ from slack_sdk.socket_mode.response import SocketModeResponse from sqlalchemy.orm import Session +from danswer.chat.models import ThreadMessage +from danswer.configs.app_configs import DEV_MODE from danswer.configs.app_configs import POD_NAME from danswer.configs.app_configs import POD_NAMESPACE from danswer.configs.constants import DanswerRedisLocks @@ -74,7 +76,6 @@ from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder -from danswer.one_shot_answer.models import ThreadMessage from danswer.redis.redis_pool import get_redis_client from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger @@ -250,7 +251,7 @@ def acquire_tenants(self) -> None: nx=True, ex=TENANT_LOCK_EXPIRATION, ) - if not acquired: + if not acquired and not DEV_MODE: logger.debug(f"Another pod holds the lock for tenant {tenant_id}") continue diff --git a/backend/danswer/danswerbot/slack/models.py b/backend/danswer/danswerbot/slack/models.py index 6394eab562d..ef03cc0544d 100644 --- a/backend/danswer/danswerbot/slack/models.py +++ b/backend/danswer/danswerbot/slack/models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from danswer.one_shot_answer.models import ThreadMessage +from danswer.chat.models import ThreadMessage class SlackMessageInfo(BaseModel): diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index e19ce8b688c..147356b76d1 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -3,14 +3,15 @@ import re import string import time +import uuid from typing import Any from typing import cast -from typing import Optional from retry import retry from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from slack_sdk.models.blocks import Block +from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.metadata import Metadata from slack_sdk.socket_mode import SocketModeClient @@ -30,13 +31,13 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import SlackTextCleaner from danswer.danswerbot.slack.constants import FeedbackVisibility +from danswer.danswerbot.slack.models import ThreadMessage from danswer.db.engine import get_session_with_tenant from danswer.db.users import get_user_by_email from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llms from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string -from danswer.one_shot_answer.models import ThreadMessage from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry @@ -140,6 +141,40 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str: return re.sub(rf"<@{bot_tag_id}>\s", "", message_str) +def _check_for_url_in_block(block: Block) -> bool: + """ + Check if the block has a key that contains "url" in it + """ + block_dict = block.to_dict() + + def check_dict_for_url(d: dict) -> bool: + for key, value in d.items(): + if "url" in key.lower(): + return True + if isinstance(value, dict): + if check_dict_for_url(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_dict_for_url(item): + return True + return False + + return check_dict_for_url(block_dict) + + +def _build_error_block(error_message: str) -> Block: + """ + Build an error block to display in slack so that the user can see + the error without completely breaking + """ + display_text = ( + "There was an error displaying all of the Onyx answers." + f" Please let an admin or an onyx developer know. Error: {error_message}" + ) + return SectionBlock(text=display_text) + + @retry( tries=DANSWER_BOT_NUM_RETRIES, delay=0.25, @@ -162,24 +197,9 @@ def respond_in_thread( message_ids: list[str] = [] if not receiver_ids: slack_call = make_slack_api_rate_limited(client.chat_postMessage) - response = slack_call( - channel=channel, - text=text, - blocks=blocks, - thread_ts=thread_ts, - metadata=metadata, - unfurl_links=unfurl, - unfurl_media=unfurl, - ) - if not response.get("ok"): - raise RuntimeError(f"Failed to post message: {response}") - message_ids.append(response["message_ts"]) - else: - slack_call = make_slack_api_rate_limited(client.chat_postEphemeral) - for receiver in receiver_ids: + try: response = slack_call( channel=channel, - user=receiver, text=text, blocks=blocks, thread_ts=thread_ts, @@ -187,8 +207,68 @@ def respond_in_thread( unfurl_links=unfurl, unfurl_media=unfurl, ) - if not response.get("ok"): - raise RuntimeError(f"Failed to post message: {response}") + except Exception as e: + logger.warning(f"Failed to post message: {e} \n blocks: {blocks}") + logger.warning("Trying again without blocks that have urls") + + if not blocks: + raise e + + blocks_without_urls = [ + block for block in blocks if not _check_for_url_in_block(block) + ] + blocks_without_urls.append(_build_error_block(str(e))) + + # Try again wtihout blocks containing url + response = slack_call( + channel=channel, + text=text, + blocks=blocks_without_urls, + thread_ts=thread_ts, + metadata=metadata, + unfurl_links=unfurl, + unfurl_media=unfurl, + ) + + message_ids.append(response["message_ts"]) + else: + slack_call = make_slack_api_rate_limited(client.chat_postEphemeral) + for receiver in receiver_ids: + try: + response = slack_call( + channel=channel, + user=receiver, + text=text, + blocks=blocks, + thread_ts=thread_ts, + metadata=metadata, + unfurl_links=unfurl, + unfurl_media=unfurl, + ) + except Exception as e: + logger.warning(f"Failed to post message: {e} \n blocks: {blocks}") + logger.warning("Trying again without blocks that have urls") + + if not blocks: + raise e + + blocks_without_urls = [ + block for block in blocks if not _check_for_url_in_block(block) + ] + blocks_without_urls.append(_build_error_block(str(e))) + + # Try again wtihout blocks containing url + response = slack_call( + channel=channel, + user=receiver, + text=text, + blocks=blocks_without_urls, + thread_ts=thread_ts, + metadata=metadata, + unfurl_links=unfurl, + unfurl_media=unfurl, + ) + message_ids.append(response["message_ts"]) return message_ids @@ -216,6 +296,13 @@ def build_feedback_id( return unique_prefix + ID_SEPARATOR + feedback_id +def build_continue_in_web_ui_id( + message_id: int, +) -> str: + unique_prefix = str(uuid.uuid4())[:10] + return unique_prefix + ID_SEPARATOR + str(message_id) + + def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]: """Decompose into query_id, document_id, document_rank, see above function""" try: @@ -313,7 +400,7 @@ def get_channel_name_from_id( raise e -def fetch_user_ids_from_emails( +def fetch_slack_user_ids_from_emails( user_emails: list[str], client: WebClient ) -> tuple[list[str], list[str]]: user_ids: list[str] = [] @@ -522,7 +609,7 @@ def refill(self) -> None: self.last_reset_time = time.time() def notify( - self, client: WebClient, channel: str, position: int, thread_ts: Optional[str] + self, client: WebClient, channel: str, position: int, thread_ts: str | None ) -> None: respond_in_thread( client=client, diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index a76fcccdd8d..f0849645fe4 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -3,6 +3,7 @@ from datetime import timedelta from uuid import UUID +from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import func @@ -30,6 +31,7 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import ToolCall from danswer.db.models import User +from danswer.db.persona import get_best_persona_id_for_user from danswer.db.pg_file_store import delete_lobj_by_name from danswer.file_store.models import FileDescriptor from danswer.llm.override_models import LLMOverride @@ -143,16 +145,10 @@ def get_chat_sessions_by_user( user_id: UUID | None, deleted: bool | None, db_session: Session, - only_one_shot: bool = False, limit: int = 50, ) -> list[ChatSession]: stmt = select(ChatSession).where(ChatSession.user_id == user_id) - if only_one_shot: - stmt = stmt.where(ChatSession.one_shot.is_(True)) - else: - stmt = stmt.where(ChatSession.one_shot.is_(False)) - stmt = stmt.order_by(desc(ChatSession.time_created)) if deleted is not None: @@ -224,12 +220,11 @@ def delete_messages_and_files_from_chat_session( def create_chat_session( db_session: Session, - description: str, + description: str | None, user_id: UUID | None, persona_id: int | None, # Can be none if temporary persona is used llm_override: LLMOverride | None = None, prompt_override: PromptOverride | None = None, - one_shot: bool = False, danswerbot_flow: bool = False, slack_thread_id: str | None = None, ) -> ChatSession: @@ -239,7 +234,6 @@ def create_chat_session( description=description, llm_override=llm_override, prompt_override=prompt_override, - one_shot=one_shot, danswerbot_flow=danswerbot_flow, slack_thread_id=slack_thread_id, ) @@ -250,6 +244,48 @@ def create_chat_session( return chat_session +def duplicate_chat_session_for_user_from_slack( + db_session: Session, + user: User | None, + chat_session_id: UUID, +) -> ChatSession: + """ + This takes a chat session id for a session in Slack and: + - Creates a new chat session in the DB + - Tries to copy the persona from the original chat session + (if it is available to the user clicking the button) + - Sets the user to the given user (if provided) + """ + chat_session = get_chat_session_by_id( + chat_session_id=chat_session_id, + user_id=None, # Ignore user permissions for this + db_session=db_session, + ) + if not chat_session: + raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided") + + # This enforces permissions and sets a default + new_persona_id = get_best_persona_id_for_user( + db_session=db_session, + user=user, + persona_id=chat_session.persona_id, + ) + + return create_chat_session( + db_session=db_session, + user_id=user.id if user else None, + persona_id=new_persona_id, + # Set this to empty string so the frontend will force a rename + description="", + llm_override=chat_session.llm_override, + prompt_override=chat_session.prompt_override, + # Chat is in UI now so this is false + danswerbot_flow=False, + # Maybe we want this in the future to track if it was created from Slack + slack_thread_id=None, + ) + + def update_chat_session( db_session: Session, user_id: UUID | None, @@ -336,6 +372,28 @@ def get_chat_message( return chat_message +def get_chat_session_by_message_id( + db_session: Session, + message_id: int, +) -> ChatSession: + """ + Should only be used for Slack + Get the chat session associated with a specific message ID + Note: this ignores permission checks. + """ + stmt = select(ChatMessage).where(ChatMessage.id == message_id) + + result = db_session.execute(stmt) + chat_message = result.scalar_one_or_none() + + if chat_message is None: + raise ValueError( + f"Unable to find chat session associated with message ID: {message_id}" + ) + + return chat_message.chat_session + + def get_chat_messages_by_sessions( chat_session_ids: list[UUID], user_id: UUID | None, @@ -355,6 +413,44 @@ def get_chat_messages_by_sessions( return db_session.execute(stmt).scalars().all() +def add_chats_to_session_from_slack_thread( + db_session: Session, + slack_chat_session_id: UUID, + new_chat_session_id: UUID, +) -> None: + new_root_message = get_or_create_root_message( + chat_session_id=new_chat_session_id, + db_session=db_session, + ) + + for chat_message in get_chat_messages_by_sessions( + chat_session_ids=[slack_chat_session_id], + user_id=None, # Ignore user permissions for this + db_session=db_session, + skip_permission_check=True, + ): + if chat_message.message_type == MessageType.SYSTEM: + continue + # Duplicate the message + new_root_message = create_new_chat_message( + db_session=db_session, + chat_session_id=new_chat_session_id, + parent_message=new_root_message, + message=chat_message.message, + files=chat_message.files, + rephrased_query=chat_message.rephrased_query, + error=chat_message.error, + citations=chat_message.citations, + reference_docs=chat_message.search_docs, + tool_call=chat_message.tool_call, + prompt_id=chat_message.prompt_id, + token_count=chat_message.token_count, + message_type=chat_message.message_type, + alternate_assistant_id=chat_message.alternate_assistant_id, + overridden_model=chat_message.overridden_model, + ) + + def get_search_docs_for_chat_message( chat_message_id: int, db_session: Session ) -> list[SearchDoc]: diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index 2cc96f6fa63..26730d1178f 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None: def _relate_groups_to_cc_pair__no_commit( db_session: Session, cc_pair_id: int, - user_group_ids: list[int], + user_group_ids: list[int] | None = None, ) -> None: + if not user_group_ids: + return + for group_id in user_group_ids: db_session.add( UserGroup__ConnectorCredentialPair( @@ -402,12 +405,11 @@ def add_credential_to_connector( db_session.flush() # make sure the association has an id db_session.refresh(association) - if groups and access_type != AccessType.SYNC: - _relate_groups_to_cc_pair__no_commit( - db_session=db_session, - cc_pair_id=association.id, - user_group_ids=groups, - ) + _relate_groups_to_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=association.id, + user_group_ids=groups, + ) db_session.commit() diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index 4a146c5c5f4..3ee165b34d0 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -20,7 +20,6 @@ from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.server.documents.models import CredentialBase -from danswer.server.documents.models import CredentialDataUpdateRequest from danswer.utils.logger import setup_logger @@ -248,7 +247,6 @@ def create_credential( ) db_session.commit() - return credential @@ -263,7 +261,8 @@ def _cleanup_credential__user_group_relationships__no_commit( def alter_credential( credential_id: int, - credential_data: CredentialDataUpdateRequest, + name: str, + credential_json: dict[str, Any], user: User, db_session: Session, ) -> Credential | None: @@ -273,11 +272,13 @@ def alter_credential( if credential is None: return None - credential.name = credential_data.name + credential.name = name - # Update only the keys present in credential_data.credential_json - for key, value in credential_data.credential_json.items(): - credential.credential_json[key] = value + # Assign a new dictionary to credential.credential_json + credential.credential_json = { + **credential.credential_json, + **credential_json, + } credential.user_id = user.id if user is not None else None db_session.commit() @@ -310,8 +311,8 @@ def update_credential_json( credential = fetch_credential_by_id(credential_id, user, db_session) if credential is None: return None - credential.credential_json = credential_json + credential.credential_json = credential_json db_session.commit() return credential diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 5d4753e136a..8ad8eca7a0f 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -37,6 +37,7 @@ from danswer.configs.app_configs import POSTGRES_USER from danswer.configs.app_configs import USER_AUTH_SECRET from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME +from danswer.server.utils import BasicAuthenticationError from danswer.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA @@ -426,7 +427,9 @@ def get_session() -> Generator[Session, None, None]: """Generate a database session with the appropriate tenant schema set.""" tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT: - raise HTTPException(status_code=401, detail="User must authenticate") + raise BasicAuthenticationError( + detail="User must authenticate", + ) engine = get_sqlalchemy_engine() diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 06bbee10559..2c8ccd99aa0 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -522,12 +522,16 @@ def expire_index_attempts( search_settings_id: int, db_session: Session, ) -> None: - delete_query = ( - delete(IndexAttempt) + not_started_query = ( + update(IndexAttempt) .where(IndexAttempt.search_settings_id == search_settings_id) .where(IndexAttempt.status == IndexingStatus.NOT_STARTED) + .values( + status=IndexingStatus.CANCELED, + error_msg="Canceled, likely due to model swap", + ) ) - db_session.execute(delete_query) + db_session.execute(not_started_query) update_query = ( update(IndexAttempt) @@ -549,9 +553,14 @@ def cancel_indexing_attempts_for_ccpair( include_secondary_index: bool = False, ) -> None: stmt = ( - delete(IndexAttempt) + update(IndexAttempt) .where(IndexAttempt.connector_credential_pair_id == cc_pair_id) .where(IndexAttempt.status == IndexingStatus.NOT_STARTED) + .values( + status=IndexingStatus.CANCELED, + error_msg="Canceled by user", + time_started=datetime.now(timezone.utc), + ) ) if not include_secondary_index: diff --git a/backend/danswer/db/input_prompt.py b/backend/danswer/db/input_prompt.py deleted file mode 100644 index efa54d986a1..00000000000 --- a/backend/danswer/db/input_prompt.py +++ /dev/null @@ -1,202 +0,0 @@ -from uuid import UUID - -from fastapi import HTTPException -from sqlalchemy import select -from sqlalchemy.orm import Session - -from danswer.db.models import InputPrompt -from danswer.db.models import User -from danswer.server.features.input_prompt.models import InputPromptSnapshot -from danswer.server.manage.models import UserInfo -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -def insert_input_prompt_if_not_exists( - user: User | None, - input_prompt_id: int | None, - prompt: str, - content: str, - active: bool, - is_public: bool, - db_session: Session, - commit: bool = True, -) -> InputPrompt: - if input_prompt_id is not None: - input_prompt = ( - db_session.query(InputPrompt).filter_by(id=input_prompt_id).first() - ) - else: - query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt) - if user: - query = query.filter(InputPrompt.user_id == user.id) - else: - query = query.filter(InputPrompt.user_id.is_(None)) - input_prompt = query.first() - - if input_prompt is None: - input_prompt = InputPrompt( - id=input_prompt_id, - prompt=prompt, - content=content, - active=active, - is_public=is_public or user is None, - user_id=user.id if user else None, - ) - db_session.add(input_prompt) - - if commit: - db_session.commit() - - return input_prompt - - -def insert_input_prompt( - prompt: str, - content: str, - is_public: bool, - user: User | None, - db_session: Session, -) -> InputPrompt: - input_prompt = InputPrompt( - prompt=prompt, - content=content, - active=True, - is_public=is_public or user is None, - user_id=user.id if user is not None else None, - ) - db_session.add(input_prompt) - db_session.commit() - - return input_prompt - - -def update_input_prompt( - user: User | None, - input_prompt_id: int, - prompt: str, - content: str, - active: bool, - db_session: Session, -) -> InputPrompt: - input_prompt = db_session.scalar( - select(InputPrompt).where(InputPrompt.id == input_prompt_id) - ) - if input_prompt is None: - raise ValueError(f"No input prompt with id {input_prompt_id}") - - if not validate_user_prompt_authorization(user, input_prompt): - raise HTTPException(status_code=401, detail="You don't own this prompt") - - input_prompt.prompt = prompt - input_prompt.content = content - input_prompt.active = active - - db_session.commit() - return input_prompt - - -def validate_user_prompt_authorization( - user: User | None, input_prompt: InputPrompt -) -> bool: - prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt) - - if prompt.user_id is not None: - if user is None: - return False - - user_details = UserInfo.from_model(user) - if str(user_details.id) != str(prompt.user_id): - return False - return True - - -def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None: - input_prompt = db_session.scalar( - select(InputPrompt).where(InputPrompt.id == input_prompt_id) - ) - - if input_prompt is None: - raise ValueError(f"No input prompt with id {input_prompt_id}") - - if not input_prompt.is_public: - raise HTTPException(status_code=400, detail="This prompt is not public") - - db_session.delete(input_prompt) - db_session.commit() - - -def remove_input_prompt( - user: User | None, input_prompt_id: int, db_session: Session -) -> None: - input_prompt = db_session.scalar( - select(InputPrompt).where(InputPrompt.id == input_prompt_id) - ) - if input_prompt is None: - raise ValueError(f"No input prompt with id {input_prompt_id}") - - if input_prompt.is_public: - raise HTTPException( - status_code=400, detail="Cannot delete public prompts with this method" - ) - - if not validate_user_prompt_authorization(user, input_prompt): - raise HTTPException(status_code=401, detail="You do not own this prompt") - - db_session.delete(input_prompt) - db_session.commit() - - -def fetch_input_prompt_by_id( - id: int, user_id: UUID | None, db_session: Session -) -> InputPrompt: - query = select(InputPrompt).where(InputPrompt.id == id) - - if user_id: - query = query.where( - (InputPrompt.user_id == user_id) | (InputPrompt.user_id is None) - ) - else: - # If no user_id is provided, only fetch prompts without a user_id (aka public) - query = query.where(InputPrompt.user_id == None) # noqa - - result = db_session.scalar(query) - - if result is None: - raise HTTPException(422, "No input prompt found") - - return result - - -def fetch_public_input_prompts( - db_session: Session, -) -> list[InputPrompt]: - query = select(InputPrompt).where(InputPrompt.is_public) - return list(db_session.scalars(query).all()) - - -def fetch_input_prompts_by_user( - db_session: Session, - user_id: UUID | None, - active: bool | None = None, - include_public: bool = False, -) -> list[InputPrompt]: - query = select(InputPrompt) - - if user_id is not None: - if include_public: - query = query.where( - (InputPrompt.user_id == user_id) | InputPrompt.is_public - ) - else: - query = query.where(InputPrompt.user_id == user_id) - - elif include_public: - query = query.where(InputPrompt.is_public) - - if active is not None: - query = query.where(InputPrompt.active == active) - - return list(db_session.scalars(query).all()) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3cae55a9c66..2fd40dfce4e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1,6 +1,5 @@ import datetime import json -from enum import Enum as PyEnum from typing import Any from typing import Literal from typing import NotRequired @@ -126,6 +125,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): # if specified, controls the assistants that are shown to the user + their order # if not specified, all assistants are shown + auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True) chosen_assistants: Mapped[list[int] | None] = mapped_column( postgresql.JSONB(), nullable=True, default=None ) @@ -159,9 +159,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base): ) prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user") - input_prompts: Mapped[list["InputPrompt"]] = relationship( - "InputPrompt", back_populates="user" - ) # Personas owned by this user personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") @@ -178,31 +175,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base): ) -class InputPrompt(Base): - __tablename__ = "inputprompt" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - prompt: Mapped[str] = mapped_column(String) - content: Mapped[str] = mapped_column(String) - active: Mapped[bool] = mapped_column(Boolean) - user: Mapped[User | None] = relationship("User", back_populates="input_prompts") - is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), nullable=True - ) - - -class InputPrompt__User(Base): - __tablename__ = "inputprompt__user" - - input_prompt_id: Mapped[int] = mapped_column( - ForeignKey("inputprompt.id"), primary_key=True - ) - user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("inputprompt.id"), primary_key=True - ) - - class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): pass @@ -596,6 +568,25 @@ class Connector(Base): list["DocumentByConnectorCredentialPair"] ] = relationship("DocumentByConnectorCredentialPair", back_populates="connector") + # synchronize this validation logic with RefreshFrequencySchema etc on front end + # until we have a centralized validation schema + + # TODO(rkuo): experiment with SQLAlchemy validators rather than manual checks + # https://docs.sqlalchemy.org/en/20/orm/mapped_attributes.html + def validate_refresh_freq(self) -> None: + if self.refresh_freq is not None: + if self.refresh_freq < 60: + raise ValueError( + "refresh_freq must be greater than or equal to 60 seconds." + ) + + def validate_prune_freq(self) -> None: + if self.prune_freq is not None: + if self.prune_freq < 86400: + raise ValueError( + "prune_freq must be greater than or equal to 86400 seconds." + ) + class Credential(Base): __tablename__ = "credential" @@ -963,9 +954,8 @@ class ChatSession(Base): persona_id: Mapped[int | None] = mapped_column( ForeignKey("persona.id"), nullable=True ) - description: Mapped[str] = mapped_column(Text) - # One-shot direct answering, currently the two types of chats are not mixed - one_shot: Mapped[bool] = mapped_column(Boolean, default=False) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + # This chat created by DanswerBot danswerbot_flow: Mapped[bool] = mapped_column(Boolean, default=False) # Only ever set to True if system is set to not hard-delete chats deleted: Mapped[bool] = mapped_column(Boolean, default=False) @@ -1484,18 +1474,16 @@ class ChannelConfig(TypedDict): # If None then no follow up # If empty list, follow up with no tags follow_up_tags: NotRequired[list[str]] - - -class SlackBotResponseType(str, PyEnum): - QUOTES = "quotes" - CITATIONS = "citations" + show_continue_in_web_ui: NotRequired[bool] # defaults to False class SlackChannelConfig(Base): __tablename__ = "slack_channel_config" id: Mapped[int] = mapped_column(primary_key=True) - slack_bot_id: Mapped[int] = mapped_column(ForeignKey("slack_bot.id"), nullable=True) + slack_bot_id: Mapped[int] = mapped_column( + ForeignKey("slack_bot.id"), nullable=False + ) persona_id: Mapped[int | None] = mapped_column( ForeignKey("persona.id"), nullable=True ) @@ -1503,9 +1491,6 @@ class SlackChannelConfig(Base): channel_config: Mapped[ChannelConfig] = mapped_column( postgresql.JSONB(), nullable=False ) - response_type: Mapped[SlackBotResponseType] = mapped_column( - Enum(SlackBotResponseType, native_enum=False), nullable=False - ) enable_auto_filters: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False @@ -1536,6 +1521,7 @@ class SlackBot(Base): slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship( "SlackChannelConfig", back_populates="slack_bot", + cascade="all, delete-orphan", ) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 0710c399811..ee97885b376 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -113,6 +113,31 @@ def fetch_persona_by_id( return persona +def get_best_persona_id_for_user( + db_session: Session, user: User | None, persona_id: int | None = None +) -> int | None: + if persona_id is not None: + stmt = select(Persona).where(Persona.id == persona_id).distinct() + stmt = _add_user_filters( + stmt=stmt, + user=user, + # We don't want to filter by editable here, we just want to see if the + # persona is usable by the user + get_editable=False, + ) + persona = db_session.scalars(stmt).one_or_none() + if persona: + return persona.id + + # If the persona is not found, or the slack bot is using doc sets instead of personas, + # we need to find the best persona for the user + # This is the persona with the highest display priority that the user has access to + stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct() + stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True) + persona = db_session.scalars(stmt).one_or_none() + return persona.id if persona else None + + def _get_persona_by_name( persona_name: str, user: User | None, db_session: Session ) -> Persona | None: @@ -390,9 +415,6 @@ def upsert_prompt( return prompt -# NOTE: This operation cannot update persona configuration options that -# are core to the persona, such as its display priority and -# whether or not the assistant is a built-in / default assistant def upsert_persona( user: User | None, name: str, @@ -424,10 +446,16 @@ def upsert_persona( chunks_above: int = CONTEXT_CHUNKS_ABOVE, chunks_below: int = CONTEXT_CHUNKS_BELOW, ) -> Persona: + """ + NOTE: This operation cannot update persona configuration options that + are core to the persona, such as its display priority and + whether or not the assistant is a built-in / default assistant + """ + if persona_id is not None: - persona = db_session.query(Persona).filter_by(id=persona_id).first() + existing_persona = db_session.query(Persona).filter_by(id=persona_id).first() else: - persona = _get_persona_by_name( + existing_persona = _get_persona_by_name( persona_name=name, user=user, db_session=db_session ) @@ -453,57 +481,78 @@ def upsert_persona( prompts = None if prompt_ids is not None: prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all() - if not prompts and prompt_ids: - raise ValueError("prompts not found") + + if prompts is not None and len(prompts) == 0: + raise ValueError( + f"Invalid Persona config, no valid prompts " + f"specified. Specified IDs were: '{prompt_ids}'" + ) # ensure all specified tools are valid if tools: validate_persona_tools(tools) - if persona: - if persona.builtin_persona and not builtin_persona: + if existing_persona: + # Built-in personas can only be updated through YAML configuration. + # This ensures that core system personas are not modified unintentionally. + if existing_persona.builtin_persona and not builtin_persona: raise ValueError("Cannot update builtin persona with non-builtin.") # this checks if the user has permission to edit the persona - persona = fetch_persona_by_id( - db_session=db_session, persona_id=persona.id, user=user, get_editable=True + # will raise an Exception if the user does not have permission + existing_persona = fetch_persona_by_id( + db_session=db_session, + persona_id=existing_persona.id, + user=user, + get_editable=True, ) - persona.name = name - persona.description = description - persona.num_chunks = num_chunks - persona.chunks_above = chunks_above - persona.chunks_below = chunks_below - persona.llm_relevance_filter = llm_relevance_filter - persona.llm_filter_extraction = llm_filter_extraction - persona.recency_bias = recency_bias - persona.llm_model_provider_override = llm_model_provider_override - persona.llm_model_version_override = llm_model_version_override - persona.starter_messages = starter_messages - persona.deleted = False # Un-delete if previously deleted - persona.is_public = is_public - persona.icon_color = icon_color - persona.icon_shape = icon_shape + # The following update excludes `default`, `built-in`, and display priority. + # Display priority is handled separately in the `display-priority` endpoint. + # `default` and `built-in` properties can only be set when creating a persona. + existing_persona.name = name + existing_persona.description = description + existing_persona.num_chunks = num_chunks + existing_persona.chunks_above = chunks_above + existing_persona.chunks_below = chunks_below + existing_persona.llm_relevance_filter = llm_relevance_filter + existing_persona.llm_filter_extraction = llm_filter_extraction + existing_persona.recency_bias = recency_bias + existing_persona.llm_model_provider_override = llm_model_provider_override + existing_persona.llm_model_version_override = llm_model_version_override + existing_persona.starter_messages = starter_messages + existing_persona.deleted = False # Un-delete if previously deleted + existing_persona.is_public = is_public + existing_persona.icon_color = icon_color + existing_persona.icon_shape = icon_shape if remove_image or uploaded_image_id: - persona.uploaded_image_id = uploaded_image_id - persona.is_visible = is_visible - persona.search_start_date = search_start_date - persona.category_id = category_id + existing_persona.uploaded_image_id = uploaded_image_id + existing_persona.is_visible = is_visible + existing_persona.search_start_date = search_start_date + existing_persona.category_id = category_id # Do not delete any associations manually added unless # a new updated list is provided if document_sets is not None: - persona.document_sets.clear() - persona.document_sets = document_sets or [] + existing_persona.document_sets.clear() + existing_persona.document_sets = document_sets or [] if prompts is not None: - persona.prompts.clear() - persona.prompts = prompts or [] + existing_persona.prompts.clear() + existing_persona.prompts = prompts if tools is not None: - persona.tools = tools or [] + existing_persona.tools = tools or [] + + persona = existing_persona else: - persona = Persona( + if not prompts: + raise ValueError( + "Invalid Persona config. " + "Must specify at least one prompt for a new persona." + ) + + new_persona = Persona( id=persona_id, user_id=user.id if user else None, is_public=is_public, @@ -516,7 +565,7 @@ def upsert_persona( llm_filter_extraction=llm_filter_extraction, recency_bias=recency_bias, builtin_persona=builtin_persona, - prompts=prompts or [], + prompts=prompts, document_sets=document_sets or [], llm_model_provider_override=llm_model_provider_override, llm_model_version_override=llm_model_version_override, @@ -531,8 +580,8 @@ def upsert_persona( is_default_persona=is_default_persona, category_id=category_id, ) - db_session.add(persona) - + db_session.add(new_persona) + persona = new_persona if commit: db_session.commit() else: diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 4f437eaae53..1134b326a76 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None: return latest_settings +def get_active_search_settings(db_session: Session) -> list[SearchSettings]: + """Returns active search settings. The first entry will always be the current search + settings. If there are new search settings that are being migrated to, those will be + the second entry.""" + search_settings_list: list[SearchSettings] = [] + + # Get the primary search settings + primary_search_settings = get_current_search_settings(db_session) + search_settings_list.append(primary_search_settings) + + # Check for secondary search settings + secondary_search_settings = get_secondary_search_settings(db_session) + if secondary_search_settings is not None: + # If secondary settings exist, add them to the list + search_settings_list.append(secondary_search_settings) + + return search_settings_list + + def get_all_search_settings(db_session: Session) -> list[SearchSettings]: query = select(SearchSettings).order_by(SearchSettings.id.desc()) result = db_session.execute(query) diff --git a/backend/danswer/db/slack_channel_config.py b/backend/danswer/db/slack_channel_config.py index 00e5965120a..d41d74c31c6 100644 --- a/backend/danswer/db/slack_channel_config.py +++ b/backend/danswer/db/slack_channel_config.py @@ -10,7 +10,6 @@ from danswer.db.models import ChannelConfig from danswer.db.models import Persona from danswer.db.models import Persona__DocumentSet -from danswer.db.models import SlackBotResponseType from danswer.db.models import SlackChannelConfig from danswer.db.models import User from danswer.db.persona import get_default_prompt @@ -83,7 +82,6 @@ def insert_slack_channel_config( slack_bot_id: int, persona_id: int | None, channel_config: ChannelConfig, - response_type: SlackBotResponseType, standard_answer_category_ids: list[int], enable_auto_filters: bool, ) -> SlackChannelConfig: @@ -115,7 +113,6 @@ def insert_slack_channel_config( slack_bot_id=slack_bot_id, persona_id=persona_id, channel_config=channel_config, - response_type=response_type, standard_answer_categories=existing_standard_answer_categories, enable_auto_filters=enable_auto_filters, ) @@ -130,7 +127,6 @@ def update_slack_channel_config( slack_channel_config_id: int, persona_id: int | None, channel_config: ChannelConfig, - response_type: SlackBotResponseType, standard_answer_category_ids: list[int], enable_auto_filters: bool, ) -> SlackChannelConfig: @@ -170,7 +166,6 @@ def update_slack_channel_config( # will encounter `violates foreign key constraint` errors slack_channel_config.persona_id = persona_id slack_channel_config.channel_config = channel_config - slack_channel_config.response_type = response_type slack_channel_config.standard_answer_categories = list( existing_standard_answer_categories ) diff --git a/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd index e712266fa08..8789a0534e7 100644 --- a/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -4,6 +4,8 @@ schema DANSWER_CHUNK_NAME { # Not to be confused with the UUID generated for this chunk which is called documentid by default field document_id type string { indexing: summary | attribute + attribute: fast-search + rank: filter } field chunk_id type int { indexing: summary | attribute diff --git a/backend/danswer/file_processing/extract_file_text.py b/backend/danswer/file_processing/extract_file_text.py index 9effad5b4e0..58016e80d63 100644 --- a/backend/danswer/file_processing/extract_file_text.py +++ b/backend/danswer/file_processing/extract_file_text.py @@ -6,6 +6,7 @@ from collections.abc import Callable from collections.abc import Iterator from email.parser import Parser as EmailParser +from io import BytesIO from pathlib import Path from typing import Any from typing import Dict @@ -15,13 +16,17 @@ import docx # type: ignore import openpyxl # type: ignore import pptx # type: ignore +from docx import Document +from fastapi import UploadFile from pypdf import PdfReader from pypdf.errors import PdfStreamError from danswer.configs.constants import DANSWER_METADATA_FILENAME +from danswer.configs.constants import FileOrigin from danswer.file_processing.html_utils import parse_html_page_basic from danswer.file_processing.unstructured import get_unstructured_api_key from danswer.file_processing.unstructured import unstructured_to_text +from danswer.file_store.file_store import FileStore from danswer.utils.logger import setup_logger logger = setup_logger() @@ -65,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str: return extension -def check_file_ext_is_valid(ext: str) -> bool: +def is_valid_file_ext(ext: str) -> bool: return ext in VALID_FILE_EXTENSIONS @@ -295,7 +300,7 @@ def pptx_to_text(file: IO[Any]) -> str: def xlsx_to_text(file: IO[Any]) -> str: - workbook = openpyxl.load_workbook(file) + workbook = openpyxl.load_workbook(file, read_only=True) text_content = [] for sheet in workbook.worksheets: sheet_string = "\n".join( @@ -359,7 +364,7 @@ def extract_file_text( elif file_name is not None: final_extension = get_file_ext(file_name) - if check_file_ext_is_valid(final_extension): + if is_valid_file_ext(final_extension): return extension_to_function.get(final_extension, file_io_to_text)(file) # Either the file somehow has no name or the extension is not one that we recognize @@ -375,3 +380,35 @@ def extract_file_text( ) from e logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}") return "" + + +def convert_docx_to_txt( + file: UploadFile, file_store: FileStore, file_path: str +) -> None: + file.file.seek(0) + docx_content = file.file.read() + doc = Document(BytesIO(docx_content)) + + # Extract text from the document + full_text = [] + for para in doc.paragraphs: + full_text.append(para.text) + + # Join the extracted text + text_content = "\n".join(full_text) + + txt_file_path = docx_to_txt_filename(file_path) + file_store.save_file( + file_name=txt_file_path, + content=BytesIO(text_content.encode("utf-8")), + display_name=file.filename, + file_origin=FileOrigin.CONNECTOR, + file_type="text/plain", + ) + + +def docx_to_txt_filename(file_path: str) -> str: + """ + Convert a .docx file path to its corresponding .txt file path. + """ + return file_path.rsplit(".", 1)[0] + ".txt" diff --git a/backend/danswer/file_store/file_store.py b/backend/danswer/file_store/file_store.py index 9bc4c41d361..e57b9222a1b 100644 --- a/backend/danswer/file_store/file_store.py +++ b/backend/danswer/file_store/file_store.py @@ -59,6 +59,12 @@ def read_file( Contents of the file and metadata dict """ + @abstractmethod + def read_file_record(self, file_name: str) -> PGFileStore: + """ + Read the file record by the name + """ + @abstractmethod def delete_file(self, file_name: str) -> None: """ diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index e9eea2c262d..978bb92e6be 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -1,6 +1,6 @@ +import base64 from collections.abc import Callable from io import BytesIO -from typing import Any from typing import cast from uuid import uuid4 @@ -13,8 +13,8 @@ from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import FileDescriptor from danswer.file_store.models import InMemoryChatFile +from danswer.utils.b64 import get_image_type from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR def load_chat_file( @@ -75,11 +75,58 @@ def save_file_from_url(url: str, tenant_id: str) -> str: return unique_id -def save_files_from_urls(urls: list[str]) -> list[str]: - tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() +def save_file_from_base64(base64_string: str, tenant_id: str) -> str: + with get_session_with_tenant(tenant_id) as db_session: + unique_id = str(uuid4()) + file_store = get_default_file_store(db_session) + file_store.save_file( + file_name=unique_id, + content=BytesIO(base64.b64decode(base64_string)), + display_name="GeneratedImage", + file_origin=FileOrigin.CHAT_IMAGE_GEN, + file_type=get_image_type(base64_string), + ) + return unique_id - funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [ - (save_file_from_url, (url, tenant_id)) for url in urls + +def save_file( + tenant_id: str, + url: str | None = None, + base64_data: str | None = None, +) -> str: + """Save a file from either a URL or base64 encoded string. + + Args: + tenant_id: The tenant ID to save the file under + url: URL to download file from + base64_data: Base64 encoded file data + + Returns: + The unique ID of the saved file + + Raises: + ValueError: If neither url nor base64_data is provided, or if both are provided + """ + if url is not None and base64_data is not None: + raise ValueError("Cannot specify both url and base64_data") + + if url is not None: + return save_file_from_url(url, tenant_id) + elif base64_data is not None: + return save_file_from_base64(base64_data, tenant_id) + else: + raise ValueError("Must specify either url or base64_data") + + +def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]: + # NOTE: be explicit about typing so that if we change things, we get notified + funcs: list[ + tuple[ + Callable[[str, str | None, str | None], str], + tuple[str, str | None, str | None], + ] + ] = [(save_file, (tenant_id, url, None)) for url in urls] + [ + (save_file, (tenant_id, None, base64_file)) for base64_file in base64_files ] - # Must pass in tenant_id here, since this is called by multithreading + return run_functions_tuples_in_parallel(funcs) diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index b1ee8f4d944..bace61cec80 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -1,4 +1,5 @@ import traceback +from collections.abc import Callable from functools import partial from http import HTTPStatus from typing import Protocol @@ -12,6 +13,7 @@ from danswer.access.models import DocumentAccess from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT +from danswer.configs.app_configs import MAX_DOCUMENT_CHARS from danswer.configs.constants import DEFAULT_BOOST from danswer.connectors.cross_connector_utils.miscellaneous_utils import ( get_experts_stores_representations, @@ -202,40 +204,13 @@ def index_doc_batch_with_handler( def index_doc_batch_prepare( - document_batch: list[Document], + documents: list[Document], index_attempt_metadata: IndexAttemptMetadata, db_session: Session, ignore_time_skip: bool = False, ) -> DocumentBatchPrepareContext | None: """Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc. This preceeds indexing it into the actual document index.""" - documents: list[Document] = [] - for document in document_batch: - empty_contents = not any(section.text.strip() for section in document.sections) - if ( - (not document.title or not document.title.strip()) - and not document.semantic_identifier.strip() - and empty_contents - ): - # Skip documents that have neither title nor content - # If the document doesn't have either, then there is no useful information in it - # This is again verified later in the pipeline after chunking but at that point there should - # already be no documents that are empty. - logger.warning( - f"Skipping document with ID {document.id} as it has neither title nor content." - ) - continue - - if document.title is not None and not document.title.strip() and empty_contents: - # The title is explicitly empty ("" and not None) and the document is empty - # so when building the chunk text representation, it will be empty and unuseable - logger.warning( - f"Skipping document with ID {document.id} as the chunks will be empty." - ) - continue - - documents.append(document) - # Create a trimmed list of docs that don't have a newer updated at # Shortcuts the time-consuming flow on connector index retries document_ids: list[str] = [document.id for document in documents] @@ -282,17 +257,64 @@ def index_doc_batch_prepare( ) +def filter_documents(document_batch: list[Document]) -> list[Document]: + documents: list[Document] = [] + for document in document_batch: + empty_contents = not any(section.text.strip() for section in document.sections) + if ( + (not document.title or not document.title.strip()) + and not document.semantic_identifier.strip() + and empty_contents + ): + # Skip documents that have neither title nor content + # If the document doesn't have either, then there is no useful information in it + # This is again verified later in the pipeline after chunking but at that point there should + # already be no documents that are empty. + logger.warning( + f"Skipping document with ID {document.id} as it has neither title nor content." + ) + continue + + if document.title is not None and not document.title.strip() and empty_contents: + # The title is explicitly empty ("" and not None) and the document is empty + # so when building the chunk text representation, it will be empty and unuseable + logger.warning( + f"Skipping document with ID {document.id} as the chunks will be empty." + ) + continue + + section_chars = sum(len(section.text) for section in document.sections) + if ( + MAX_DOCUMENT_CHARS + and len(document.title or document.semantic_identifier) + section_chars + > MAX_DOCUMENT_CHARS + ): + # Skip documents that are too long, later on there are more memory intensive steps done on the text + # and the container will run out of memory and crash. Several other checks are included upstream but + # those are at the connector level so a catchall is still needed. + # Assumption here is that files that are that long, are generated files and not the type users + # generally care for. + logger.warning( + f"Skipping document with ID {document.id} as it is too long." + ) + continue + + documents.append(document) + return documents + + @log_function_time(debug_only=True) def index_doc_batch( *, + document_batch: list[Document], chunker: Chunker, embedder: IndexingEmbedder, document_index: DocumentIndex, - document_batch: list[Document], index_attempt_metadata: IndexAttemptMetadata, db_session: Session, ignore_time_skip: bool = False, tenant_id: str | None = None, + filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents, ) -> tuple[int, int]: """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the @@ -309,8 +331,11 @@ def index_doc_batch( is_public=False, ) + logger.debug("Filtering Documents") + filtered_documents = filter_fnc(document_batch) + ctx = index_doc_batch_prepare( - document_batch=document_batch, + documents=filtered_documents, index_attempt_metadata=index_attempt_metadata, ignore_time_skip=ignore_time_skip, db_session=db_session, diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py deleted file mode 100644 index 03f72a0968c..00000000000 --- a/backend/danswer/llm/answering/models.py +++ /dev/null @@ -1,163 +0,0 @@ -from collections.abc import Callable -from collections.abc import Iterator -from typing import TYPE_CHECKING - -from langchain.schema.messages import AIMessage -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage -from pydantic import BaseModel -from pydantic import ConfigDict -from pydantic import Field -from pydantic import model_validator - -from danswer.chat.models import AnswerQuestionStreamReturn -from danswer.configs.constants import MessageType -from danswer.file_store.models import InMemoryChatFile -from danswer.llm.override_models import PromptOverride -from danswer.llm.utils import build_content_with_imgs -from danswer.tools.models import ToolCallFinalResult - -if TYPE_CHECKING: - from danswer.db.models import ChatMessage - from danswer.db.models import Prompt - - -StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] - - -class PreviousMessage(BaseModel): - """Simplified version of `ChatMessage`""" - - message: str - token_count: int - message_type: MessageType - files: list[InMemoryChatFile] - tool_call: ToolCallFinalResult | None - - @classmethod - def from_chat_message( - cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile] - ) -> "PreviousMessage": - message_file_ids = ( - [file["id"] for file in chat_message.files] if chat_message.files else [] - ) - return cls( - message=chat_message.message, - token_count=chat_message.token_count, - message_type=chat_message.message_type, - files=[ - file - for file in available_files - if str(file.file_id) in message_file_ids - ], - tool_call=ToolCallFinalResult( - tool_name=chat_message.tool_call.tool_name, - tool_args=chat_message.tool_call.tool_arguments, - tool_result=chat_message.tool_call.tool_result, - ) - if chat_message.tool_call - else None, - ) - - def to_langchain_msg(self) -> BaseMessage: - content = build_content_with_imgs(self.message, self.files) - if self.message_type == MessageType.USER: - return HumanMessage(content=content) - elif self.message_type == MessageType.ASSISTANT: - return AIMessage(content=content) - else: - return SystemMessage(content=content) - - -class DocumentPruningConfig(BaseModel): - max_chunks: int | None = None - max_window_percentage: float | None = None - max_tokens: int | None = None - # different pruning behavior is expected when the - # user manually selects documents they want to chat with - # e.g. we don't want to truncate each document to be no more - # than one chunk long - is_manually_selected_docs: bool = False - # If user specifies to include additional context Chunks for each match, then different pruning - # is used. As many Sections as possible are included, and the last Section is truncated - # If this is false, all of the Sections are truncated if they are longer than the expected Chunk size. - # Sections are often expected to be longer than the maximum Chunk size but Chunks should not be. - use_sections: bool = True - # If using tools, then we need to consider the tool length - tool_num_tokens: int = 0 - # If using a tool message to represent the docs, then we have to JSON serialize - # the document content, which adds to the token count. - using_tool_message: bool = False - - -class ContextualPruningConfig(DocumentPruningConfig): - num_chunk_multiple: int - - @classmethod - def from_doc_pruning_config( - cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig - ) -> "ContextualPruningConfig": - return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict()) - - -class CitationConfig(BaseModel): - all_docs_useful: bool = False - - -class QuotesConfig(BaseModel): - pass - - -class AnswerStyleConfig(BaseModel): - citation_config: CitationConfig | None = None - quotes_config: QuotesConfig | None = None - document_pruning_config: DocumentPruningConfig = Field( - default_factory=DocumentPruningConfig - ) - # forces the LLM to return a structured response, see - # https://platform.openai.com/docs/guides/structured-outputs/introduction - # right now, only used by the simple chat API - structured_response_format: dict | None = None - - @model_validator(mode="after") - def check_quotes_and_citation(self) -> "AnswerStyleConfig": - if self.citation_config is None and self.quotes_config is None: - raise ValueError( - "One of `citation_config` or `quotes_config` must be provided" - ) - - if self.citation_config is not None and self.quotes_config is not None: - raise ValueError( - "Only one of `citation_config` or `quotes_config` must be provided" - ) - - return self - - -class PromptConfig(BaseModel): - """Final representation of the Prompt configuration passed - into the `Answer` object.""" - - system_prompt: str - task_prompt: str - datetime_aware: bool - include_citations: bool - - @classmethod - def from_model( - cls, model: "Prompt", prompt_override: PromptOverride | None = None - ) -> "PromptConfig": - override_system_prompt = ( - prompt_override.system_prompt if prompt_override else None - ) - override_task_prompt = prompt_override.task_prompt if prompt_override else None - - return cls( - system_prompt=override_system_prompt or model.system_prompt, - task_prompt=override_task_prompt or model.task_prompt, - datetime_aware=model.datetime_aware, - include_citations=model.include_citations, - ) - - model_config = ConfigDict(frozen=True) diff --git a/backend/danswer/llm/answering/prompts/utils.py b/backend/danswer/llm/answering/prompts/utils.py deleted file mode 100644 index bcc8b891815..00000000000 --- a/backend/danswer/llm/answering/prompts/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT - - -def build_dummy_prompt( - system_prompt: str, task_prompt: str, retrieval_disabled: bool -) -> str: - if retrieval_disabled: - return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - return PARAMATERIZED_PROMPT.format( - context_docs_str="", - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 031fcd7163a..88b8f0396d5 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -26,7 +26,9 @@ from langchain_core.prompt_values import PromptValue from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS -from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING +from danswer.configs.model_configs import ( + DISABLE_LITELLM_STREAMING, +) from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import LITELLM_EXTRA_BODY from danswer.llm.interfaces import LLM @@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk( if role == "user": return HumanMessageChunk(content=content) - elif role == "assistant": + # NOTE: if tool calls are present, then it's an assistant. + # In Ollama, the role will be None for tool-calls + elif role == "assistant" or tool_calls: if tool_calls: tool_call = tool_calls[0] tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or "" @@ -236,6 +240,7 @@ def __init__( custom_config: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None, extra_body: dict | None = LITELLM_EXTRA_BODY, + model_kwargs: dict[str, Any] | None = None, long_term_logger: LongTermLogger | None = None, ): self._timeout = timeout @@ -263,12 +268,16 @@ def __init__( # NOTE: have to set these as environment variables for Litellm since # not all are able to passed in but they always support them set as env - # variables + # variables. We'll also try passing them in, since litellm just ignores + # addtional kwargs (and some kwargs MUST be passed in rather than set as + # env variables) if custom_config: for k, v in custom_config.items(): os.environ[k] = v - model_kwargs: dict[str, Any] = {} + model_kwargs = model_kwargs or {} + if custom_config: + model_kwargs.update(custom_config) if extra_headers: model_kwargs.update({"extra_headers": extra_headers}) if extra_body: diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index 9a2ae66d396..0b688a0cfcb 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,5 +1,9 @@ +from typing import Any + +from danswer.chat.models import PersonaOverrideConfig from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.chat_configs import QA_TIMEOUT +from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_default_provider @@ -10,8 +14,20 @@ from danswer.llm.interfaces import LLM from danswer.llm.override_models import LLMOverride from danswer.utils.headers import build_llm_extra_headers +from danswer.utils.logger import setup_logger from danswer.utils.long_term_log import LongTermLogger +logger = setup_logger() + + +def _build_extra_model_kwargs(provider: str) -> dict[str, Any]: + """Ollama requires us to specify the max context window. + + For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value. + TODO: allow model-specific values to be configured via the UI. + """ + return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {} + def get_main_llm_from_tuple( llms: tuple[LLM, LLM], @@ -20,11 +36,15 @@ def get_main_llm_from_tuple( def get_llms_for_persona( - persona: Persona, + persona: Persona | PersonaOverrideConfig | None, llm_override: LLMOverride | None = None, additional_headers: dict[str, str] | None = None, long_term_logger: LongTermLogger | None = None, ) -> tuple[LLM, LLM]: + if persona is None: + logger.warning("No persona provided, using default LLMs") + return get_default_llms() + model_provider_override = llm_override.model_provider if llm_override else None model_version_override = llm_override.model_version if llm_override else None temperature_override = llm_override.temperature if llm_override else None @@ -59,6 +79,7 @@ def _create_llm(model: str) -> LLM: api_base=llm_provider.api_base, api_version=llm_provider.api_version, custom_config=llm_provider.custom_config, + temperature=temperature_override, additional_headers=additional_headers, long_term_logger=long_term_logger, ) @@ -116,11 +137,13 @@ def get_llm( api_base: str | None = None, api_version: str | None = None, custom_config: dict[str, str] | None = None, - temperature: float = GEN_AI_TEMPERATURE, + temperature: float | None = None, timeout: int = QA_TIMEOUT, additional_headers: dict[str, str] | None = None, long_term_logger: LongTermLogger | None = None, ) -> LLM: + if temperature is None: + temperature = GEN_AI_TEMPERATURE return DefaultMultiLLM( model_provider=provider, model_name=model, @@ -132,5 +155,6 @@ def get_llm( temperature=temperature, custom_config=custom_config, extra_headers=build_llm_extra_headers(additional_headers), + model_kwargs=_build_extra_model_kwargs(provider), long_term_logger=long_term_logger, ) diff --git a/backend/danswer/llm/models.py b/backend/danswer/llm/models.py new file mode 100644 index 00000000000..182fc97fb26 --- /dev/null +++ b/backend/danswer/llm/models.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +from langchain.schema.messages import AIMessage +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage +from pydantic import BaseModel + +from danswer.configs.constants import MessageType +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.utils import build_content_with_imgs +from danswer.tools.models import ToolCallFinalResult + +if TYPE_CHECKING: + from danswer.db.models import ChatMessage + + +class PreviousMessage(BaseModel): + """Simplified version of `ChatMessage`""" + + message: str + token_count: int + message_type: MessageType + files: list[InMemoryChatFile] + tool_call: ToolCallFinalResult | None + + @classmethod + def from_chat_message( + cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile] + ) -> "PreviousMessage": + message_file_ids = ( + [file["id"] for file in chat_message.files] if chat_message.files else [] + ) + return cls( + message=chat_message.message, + token_count=chat_message.token_count, + message_type=chat_message.message_type, + files=[ + file + for file in available_files + if str(file.file_id) in message_file_ids + ], + tool_call=ToolCallFinalResult( + tool_name=chat_message.tool_call.tool_name, + tool_args=chat_message.tool_call.tool_arguments, + tool_result=chat_message.tool_call.tool_result, + ) + if chat_message.tool_call + else None, + ) + + def to_langchain_msg(self) -> BaseMessage: + content = build_content_with_imgs(self.message, self.files) + if self.message_type == MessageType.USER: + return HumanMessage(content=content) + elif self.message_type == MessageType.ASSISTANT: + return AIMessage(content=content) + else: + return SystemMessage(content=content) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 58d7f5d5dee..55758d40e57 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -1,14 +1,11 @@ -import io +import copy import json from collections.abc import Callable from collections.abc import Iterator from typing import Any from typing import cast -from typing import TYPE_CHECKING -from typing import Union import litellm # type: ignore -import pandas as pd import tiktoken from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ChatPromptValue @@ -35,17 +32,15 @@ from danswer.configs.model_configs import GEN_AI_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS -from danswer.db.models import ChatMessage from danswer.file_store.models import ChatFileType from danswer.file_store.models import InMemoryChatFile from danswer.llm.interfaces import LLM from danswer.prompts.constants import CODE_BLOCK_PAT +from danswer.utils.b64 import get_image_type +from danswer.utils.b64 import get_image_type_from_bytes from danswer.utils.logger import setup_logger from shared_configs.configs import LOG_LEVEL -if TYPE_CHECKING: - from danswer.llm.answering.models import PreviousMessage - logger = setup_logger() @@ -103,92 +98,39 @@ def litellm_exception_to_error_msg( return error_msg -def translate_danswer_msg_to_langchain( - msg: Union[ChatMessage, "PreviousMessage"], -) -> BaseMessage: - files: list[InMemoryChatFile] = [] - - # If the message is a `ChatMessage`, it doesn't have the downloaded files - # attached. Just ignore them for now. - if not isinstance(msg, ChatMessage): - files = msg.files - content = build_content_with_imgs(msg.message, files, message_type=msg.message_type) - - if msg.message_type == MessageType.SYSTEM: - raise ValueError("System messages are not currently part of history") - if msg.message_type == MessageType.ASSISTANT: - return AIMessage(content=content) - if msg.message_type == MessageType.USER: - return HumanMessage(content=content) - - raise ValueError(f"New message type {msg.message_type} not handled") - - -def translate_history_to_basemessages( - history: list[ChatMessage] | list["PreviousMessage"], -) -> tuple[list[BaseMessage], list[int]]: - history_basemessages = [ - translate_danswer_msg_to_langchain(msg) - for msg in history - if msg.token_count != 0 - ] - history_token_counts = [msg.token_count for msg in history if msg.token_count != 0] - return history_basemessages, history_token_counts - - -# Processes CSV files to show the first 5 rows and max_columns (default 40) columns -def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str: - df = pd.read_csv(io.StringIO(file.content.decode("utf-8"))) - - csv_preview = df.head().to_string(max_cols=max_columns) - - file_name_section = ( - f"CSV FILE NAME: {file.filename}\n" - if file.filename - else "CSV FILE (NO NAME PROVIDED):\n" - ) - return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n" - - def _build_content( message: str, files: list[InMemoryChatFile] | None = None, ) -> str: """Applies all non-image files.""" - text_files = ( - [file for file in files if file.file_type == ChatFileType.PLAIN_TEXT] - if files - else None - ) + if not files: + return message - csv_files = ( - [file for file in files if file.file_type == ChatFileType.CSV] - if files - else None - ) + text_files = [ + file + for file in files + if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV) + ] - if not text_files and not csv_files: + if not text_files: return message final_message_with_files = "FILES:\n\n" - for file in text_files or []: + for file in text_files: file_content = file.content.decode("utf-8") file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else "" final_message_with_files += ( f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n" ) - for file in csv_files or []: - final_message_with_files += _process_csv_file(file) - - final_message_with_files += message - return final_message_with_files + return final_message_with_files + message def build_content_with_imgs( message: str, files: list[InMemoryChatFile] | None = None, img_urls: list[str] | None = None, + b64_imgs: list[str] | None = None, message_type: MessageType = MessageType.USER, ) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type files = files or [] @@ -201,6 +143,7 @@ def build_content_with_imgs( ) img_urls = img_urls or [] + b64_imgs = b64_imgs or [] message_main_content = _build_content(message, files) @@ -219,11 +162,22 @@ def build_content_with_imgs( { "type": "image_url", "image_url": { - "url": f"data:image/jpeg;base64,{file.to_base64()}", + "url": ( + f"data:{get_image_type_from_bytes(file.content)};" + f"base64,{file.to_base64()}" + ), + }, + } + for file in img_files + ] + + [ + { + "type": "image_url", + "image_url": { + "url": f"data:{get_image_type(b64_img)};base64,{b64_img}", }, } - for file in files - if file.file_type == "image" + for b64_img in b64_imgs ] + [ { @@ -385,6 +339,62 @@ def test_llm(llm: LLM) -> str | None: return error_msg +def get_model_map() -> dict: + starting_map = copy.deepcopy(cast(dict, litellm.model_cost)) + + # NOTE: we could add additional models here in the future, + # but for now there is no point. Ollama allows the user to + # to specify their desired max context window, and it's + # unlikely to be standard across users even for the same model + # (it heavily depends on their hardware). For now, we'll just + # rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this. + # for model_name in [ + # "llama3.2", + # "llama3.2:1b", + # "llama3.2:3b", + # "llama3.2:11b", + # "llama3.2:90b", + # ]: + # starting_map[f"ollama/{model_name}"] = { + # "max_tokens": 128000, + # "max_input_tokens": 128000, + # "max_output_tokens": 128000, + # } + + return starting_map + + +def _strip_extra_provider_from_model_name(model_name: str) -> str: + return model_name.split("/")[1] if "/" in model_name else model_name + + +def _strip_colon_from_model_name(model_name: str) -> str: + return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name + + +def _find_model_obj( + model_map: dict, provider: str, model_names: list[str | None] +) -> dict | None: + # Filter out None values and deduplicate model names + filtered_model_names = [name for name in model_names if name] + + # First try all model names with provider prefix + for model_name in filtered_model_names: + model_obj = model_map.get(f"{provider}/{model_name}") + if model_obj: + logger.debug(f"Using model object for {provider}/{model_name}") + return model_obj + + # Then try all model names without provider prefix + for model_name in filtered_model_names: + model_obj = model_map.get(model_name) + if model_obj: + logger.debug(f"Using model object for {model_name}") + return model_obj + + return None + + def get_llm_max_tokens( model_map: dict, model_name: str, @@ -404,22 +414,22 @@ def get_llm_max_tokens( return GEN_AI_MAX_TOKENS try: - model_obj = model_map.get(f"{model_provider}/{model_name}") - if model_obj: - logger.debug(f"Using model object for {model_provider}/{model_name}") - - if not model_obj: - model_obj = model_map.get(model_name) - if model_obj: - logger.debug(f"Using model object for {model_name}") - - if not model_obj: - model_name_split = model_name.split("/") - if len(model_name_split) > 1: - model_obj = model_map.get(model_name_split[1]) - if model_obj: - logger.debug(f"Using model object for {model_name_split[1]}") - + extra_provider_stripped_model_name = _strip_extra_provider_from_model_name( + model_name + ) + model_obj = _find_model_obj( + model_map, + model_provider, + [ + model_name, + # Remove leading extra provider. Usually for cases where user has a + # customer model proxy which appends another prefix + extra_provider_stripped_model_name, + # remove :XXXX from the end, if present. Needed for ollama. + _strip_colon_from_model_name(model_name), + _strip_colon_from_model_name(extra_provider_stripped_model_name), + ], + ) if not model_obj: raise RuntimeError( f"No litellm entry found for {model_provider}/{model_name}" @@ -495,7 +505,7 @@ def get_max_input_tokens( # `model_cost` dict is a named public interface: # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost # model_map is litellm.model_cost - litellm_model_map = litellm.model_cost + litellm_model_map = get_model_map() input_toks = ( get_llm_max_tokens( diff --git a/backend/danswer/main.py b/backend/danswer/main.py index aa16fa20c5f..cfc490ddb47 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -25,7 +25,7 @@ from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserUpdate from danswer.auth.users import auth_backend -from danswer.auth.users import BasicAuthenticationError +from danswer.auth.users import create_danswer_oauth_router from danswer.auth.users import fastapi_users from danswer.configs.app_configs import APP_API_PREFIX from danswer.configs.app_configs import APP_HOST @@ -44,6 +44,7 @@ from danswer.configs.constants import POSTGRES_WEB_APP_NAME from danswer.db.engine import SqlEngine from danswer.db.engine import warm_up_connections +from danswer.server.api_key.api import router as api_key_router from danswer.server.auth_check import check_router_auth from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.documents.cc_pair import router as cc_pair_router @@ -51,12 +52,9 @@ from danswer.server.documents.credential import router as credential_router from danswer.server.documents.document import router as document_router from danswer.server.documents.indexing import router as indexing_router +from danswer.server.documents.standard_oauth import router as oauth_router from danswer.server.features.document_set.api import router as document_set_router from danswer.server.features.folder.api import router as folder_router -from danswer.server.features.input_prompt.api import ( - admin_router as admin_input_prompt_router, -) -from danswer.server.features.input_prompt.api import basic_router as input_prompt_router from danswer.server.features.notifications.api import router as notification_router from danswer.server.features.persona.api import admin_router as admin_persona_router from danswer.server.features.persona.api import basic_router as persona_router @@ -90,6 +88,7 @@ from danswer.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) +from danswer.server.utils import BasicAuthenticationError from danswer.setup import setup_danswer from danswer.setup import setup_multitenant_danswer from danswer.utils.logger import setup_logger @@ -205,7 +204,7 @@ def log_http_error(_: Request, exc: Exception) -> JSONResponse: if isinstance(exc, BasicAuthenticationError): # For BasicAuthenticationError, just log a brief message without stack trace (almost always spam) - logger.error(f"Authentication failed: {str(exc)}") + logger.warning(f"Authentication failed: {str(exc)}") elif status_code >= 400: error_msg = f"{str(exc)}\n" @@ -258,8 +257,6 @@ def get_application() -> FastAPI: ) include_router_with_global_prefix_prepended(application, persona_router) include_router_with_global_prefix_prepended(application, admin_persona_router) - include_router_with_global_prefix_prepended(application, input_prompt_router) - include_router_with_global_prefix_prepended(application, admin_input_prompt_router) include_router_with_global_prefix_prepended(application, notification_router) include_router_with_global_prefix_prepended(application, prompt_router) include_router_with_global_prefix_prepended(application, tool_router) @@ -281,6 +278,8 @@ def get_application() -> FastAPI: application, get_full_openai_assistants_api_router() ) include_router_with_global_prefix_prepended(application, long_term_logs_router) + include_router_with_global_prefix_prepended(application, api_key_router) + include_router_with_global_prefix_prepended(application, oauth_router) include_router_with_global_prefix_prepended(application, eea_config_router) @@ -326,7 +325,7 @@ def get_application() -> FastAPI: oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) include_router_with_global_prefix_prepended( application, - fastapi_users.get_oauth_router( + create_danswer_oauth_router( oauth_client, auth_backend, USER_AUTH_SECRET, diff --git a/backend/danswer/natural_language_processing/exceptions.py b/backend/danswer/natural_language_processing/exceptions.py new file mode 100644 index 00000000000..5ca112f64ea --- /dev/null +++ b/backend/danswer/natural_language_processing/exceptions.py @@ -0,0 +1,4 @@ +class ModelServerRateLimitError(Exception): + """ + Exception raised for rate limiting errors from the model server. + """ diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index f90f3bdec42..e0097f3e181 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -6,6 +6,9 @@ import requests from httpx import HTTPError +from requests import JSONDecodeError +from requests import RequestException +from requests import Response from retry import retry from danswer.configs.app_configs import LARGE_CHUNK_RATIO @@ -16,6 +19,9 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.models import SearchSettings from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface +from danswer.natural_language_processing.exceptions import ( + ModelServerRateLimitError, +) from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content from danswer.utils.logger import setup_logger @@ -101,28 +107,43 @@ def __init__( self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse: - def _make_request() -> EmbedResponse: + def _make_request() -> Response: response = requests.post( self.embed_server_endpoint, json=embed_request.model_dump() ) - try: - response.raise_for_status() - except requests.HTTPError as e: - try: - error_detail = response.json().get("detail", str(e)) - except Exception: - error_detail = response.text - raise HTTPError(f"HTTP error occurred: {error_detail}") from e - except requests.RequestException as e: - raise HTTPError(f"Request failed: {str(e)}") from e + # signify that this is a rate limit error + if response.status_code == 429: + raise ModelServerRateLimitError(response.text) - return EmbedResponse(**response.json()) + response.raise_for_status() + return response + + final_make_request_func = _make_request - # only perform retries for the non-realtime embedding of passages (e.g. for indexing) + # if the text type is a passage, add some default + # retries + handling for rate limiting if embed_request.text_type == EmbedTextType.PASSAGE: - return retry(tries=3, delay=5)(_make_request)() - else: - return _make_request() + final_make_request_func = retry( + tries=3, + delay=5, + exceptions=(RequestException, ValueError, JSONDecodeError), + )(final_make_request_func) + # use 10 second delay as per Azure suggestion + final_make_request_func = retry( + tries=10, delay=10, exceptions=ModelServerRateLimitError + )(final_make_request_func) + + try: + response = final_make_request_func() + return EmbedResponse(**response.json()) + except requests.HTTPError as e: + try: + error_detail = response.json().get("detail", str(e)) + except Exception: + error_detail = response.text + raise HTTPError(f"HTTP error occurred: {error_detail}") from e + except requests.RequestException as e: + raise HTTPError(f"Request failed: {str(e)}") from e def _batch_encode_texts( self, diff --git a/backend/danswer/natural_language_processing/utils.py b/backend/danswer/natural_language_processing/utils.py index daa63c580b9..0f0da04e085 100644 --- a/backend/danswer/natural_language_processing/utils.py +++ b/backend/danswer/natural_language_processing/utils.py @@ -134,7 +134,7 @@ def _try_initialize_tokenizer( return tokenizer except Exception as hf_error: logger.warning( - f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}" + f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}" ) # If both initializations fail, return None diff --git a/backend/danswer/one_shot_answer/__init__.py b/backend/danswer/one_shot_answer/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py deleted file mode 100644 index 9f8ce99231b..00000000000 --- a/backend/danswer/one_shot_answer/answer_question.py +++ /dev/null @@ -1,446 +0,0 @@ -from collections.abc import Callable -from collections.abc import Iterator -from typing import cast - -from sqlalchemy.orm import Session - -from danswer.chat.chat_utils import reorganize_citations -from danswer.chat.models import CitationInfo -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerContexts -from danswer.chat.models import DanswerQuotes -from danswer.chat.models import DocumentRelevance -from danswer.chat.models import LLMRelevanceFilterResponse -from danswer.chat.models import QADocsResponse -from danswer.chat.models import RelevanceAnalysis -from danswer.chat.models import StreamingError -from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE -from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT -from danswer.configs.chat_configs import QA_TIMEOUT -from danswer.configs.constants import MessageType -from danswer.context.search.enums import LLMEvaluationType -from danswer.context.search.models import RerankMetricsContainer -from danswer.context.search.models import RetrievalMetricsContainer -from danswer.context.search.utils import chunks_or_sections_to_search_docs -from danswer.context.search.utils import dedupe_documents -from danswer.db.chat import create_chat_session -from danswer.db.chat import create_db_search_doc -from danswer.db.chat import create_new_chat_message -from danswer.db.chat import get_or_create_root_message -from danswer.db.chat import translate_db_message_to_chat_message_detail -from danswer.db.chat import translate_db_search_doc_to_server_search_doc -from danswer.db.chat import update_search_docs_table_with_relevance -from danswer.db.engine import get_session_context_manager -from danswer.db.models import Persona -from danswer.db.models import User -from danswer.db.persona import get_prompt_by_id -from danswer.llm.answering.answer import Answer -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import CitationConfig -from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.models import QuotesConfig -from danswer.llm.factory import get_llms_for_persona -from danswer.llm.factory import get_main_llm_from_tuple -from danswer.natural_language_processing.utils import get_tokenizer -from danswer.one_shot_answer.models import DirectQARequest -from danswer.one_shot_answer.models import OneShotQAResponse -from danswer.one_shot_answer.models import QueryRephrase -from danswer.one_shot_answer.qa_utils import combine_message_thread -from danswer.secondary_llm_flows.answer_validation import get_answer_validity -from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase -from danswer.server.query_and_chat.models import ChatMessageDetail -from danswer.server.utils import get_json_line -from danswer.tools.force import ForceUseTool -from danswer.tools.models import ToolResponse -from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID -from danswer.tools.tool_implementations.search.search_tool import ( - SEARCH_RESPONSE_SUMMARY_ID, -) -from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary -from danswer.tools.tool_implementations.search.search_tool import SearchTool -from danswer.tools.tool_implementations.search.search_tool import ( - SECTION_RELEVANCE_LIST_ID, -) -from danswer.tools.tool_runner import ToolCallKickoff -from danswer.utils.logger import setup_logger -from danswer.utils.long_term_log import LongTermLogger -from danswer.utils.timing import log_generator_function_time -from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop - -logger = setup_logger() - -AnswerObjectIterator = Iterator[ - QueryRephrase - | QADocsResponse - | LLMRelevanceFilterResponse - | DanswerAnswerPiece - | DanswerQuotes - | DanswerContexts - | StreamingError - | ChatMessageDetail - | CitationInfo - | ToolCallKickoff - | DocumentRelevance -] - - -def stream_answer_objects( - query_req: DirectQARequest, - user: User | None, - # These need to be passed in because in Web UI one shot flow, - # we can have much more document as there is no history. - # For Slack flow, we need to save more tokens for the thread context - max_document_tokens: int | None, - max_history_tokens: int | None, - db_session: Session, - # Needed to translate persona num_chunks to tokens to the LLM - default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - timeout: int = QA_TIMEOUT, - bypass_acl: bool = False, - use_citations: bool = False, - danswerbot_flow: bool = False, - retrieval_metrics_callback: ( - Callable[[RetrievalMetricsContainer], None] | None - ) = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> AnswerObjectIterator: - """Streams in order: - 1. [always] Retrieved documents, stops flow if nothing is found - 2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on - 3. [always] A set of streamed DanswerAnswerPiece and DanswerQuotes at the end - or an error anywhere along the line if something fails - 4. [always] Details on the final AI response message that is created - """ - user_id = user.id if user is not None else None - query_msg = query_req.messages[-1] - history = query_req.messages[:-1] - - chat_session = create_chat_session( - db_session=db_session, - description="", # One shot queries don't need naming as it's never displayed - user_id=user_id, - persona_id=query_req.persona_id, - one_shot=True, - danswerbot_flow=danswerbot_flow, - ) - - # permanent "log" store, used primarily for debugging - long_term_logger = LongTermLogger( - metadata={"user_id": str(user_id), "chat_session_id": str(chat_session.id)} - ) - - temporary_persona: Persona | None = None - - if query_req.persona_config is not None: - temporary_persona = fetch_ee_implementation_or_noop( - "danswer.server.query_and_chat.utils", "create_temporary_persona", None - )(db_session=db_session, persona_config=query_req.persona_config, user=user) - - persona = temporary_persona if temporary_persona else chat_session.persona - - try: - llm, fast_llm = get_llms_for_persona( - persona=persona, long_term_logger=long_term_logger - ) - except ValueError as e: - logger.error( - f"Failed to initialize LLMs for persona '{persona.name}': {str(e)}" - ) - if "No LLM provider" in str(e): - raise ValueError( - "Please configure a Generative AI model to use this feature." - ) from e - raise ValueError( - "Failed to initialize the AI model. Please check your configuration and try again." - ) from e - - llm_tokenizer = get_tokenizer( - model_name=llm.config.model_name, - provider_type=llm.config.model_provider, - ) - - # Create a chat session which will just store the root message, the query, and the AI response - root_message = get_or_create_root_message( - chat_session_id=chat_session.id, db_session=db_session - ) - - history_str = combine_message_thread( - messages=history, - max_tokens=max_history_tokens, - llm_tokenizer=llm_tokenizer, - ) - - rephrased_query = query_req.query_override or thread_based_query_rephrase( - user_query=query_msg.message, - history_str=history_str, - ) - - # Given back ahead of the documents for latency reasons - # In chat flow it's given back along with the documents - yield QueryRephrase(rephrased_query=rephrased_query) - - prompt = None - if query_req.prompt_id is not None: - # NOTE: let the user access any prompt as long as the Persona is shared - # with them - prompt = get_prompt_by_id( - prompt_id=query_req.prompt_id, user=None, db_session=db_session - ) - if prompt is None: - if not persona.prompts: - raise RuntimeError( - "Persona does not have any prompts - this should never happen" - ) - prompt = persona.prompts[0] - - # Create the first User query message - new_user_message = create_new_chat_message( - chat_session_id=chat_session.id, - parent_message=root_message, - prompt_id=query_req.prompt_id, - message=query_msg.message, - token_count=len(llm_tokenizer.encode(query_msg.message)), - message_type=MessageType.USER, - db_session=db_session, - commit=True, - ) - - prompt_config = PromptConfig.from_model(prompt) - document_pruning_config = DocumentPruningConfig( - max_chunks=int( - persona.num_chunks if persona.num_chunks is not None else default_num_chunks - ), - max_tokens=max_document_tokens, - ) - - answer_config = AnswerStyleConfig( - citation_config=CitationConfig() if use_citations else None, - quotes_config=QuotesConfig() if not use_citations else None, - document_pruning_config=document_pruning_config, - ) - - search_tool = SearchTool( - db_session=db_session, - user=user, - evaluation_type=( - LLMEvaluationType.SKIP - if DISABLE_LLM_DOC_RELEVANCE - else query_req.evaluation_type - ), - persona=persona, - retrieval_options=query_req.retrieval_options, - prompt_config=prompt_config, - llm=llm, - fast_llm=fast_llm, - pruning_config=document_pruning_config, - answer_style_config=answer_config, - bypass_acl=bypass_acl, - chunks_above=query_req.chunks_above, - chunks_below=query_req.chunks_below, - full_doc=query_req.full_doc, - ) - - answer = Answer( - question=query_msg.message, - answer_style_config=answer_config, - prompt_config=PromptConfig.from_model(prompt), - llm=get_main_llm_from_tuple( - get_llms_for_persona(persona=persona, long_term_logger=long_term_logger) - ), - single_message_history=history_str, - tools=[search_tool] if search_tool else [], - force_use_tool=( - ForceUseTool( - tool_name=search_tool.name, - args={"query": rephrased_query}, - force_use=True, - ) - ), - # for now, don't use tool calling for this flow, as we haven't - # tested quotes with tool calling too much yet - skip_explicit_tool_calling=True, - return_contexts=query_req.return_contexts, - skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation, - ) - # won't be any FileChatDisplay responses since that tool is never passed in - for packet in cast(AnswerObjectIterator, answer.processed_streamed_output): - # for one-shot flow, don't currently do anything with these - if isinstance(packet, ToolResponse): - # (likely fine that it comes after the initial creation of the search docs) - if packet.id == SEARCH_RESPONSE_SUMMARY_ID: - search_response_summary = cast(SearchResponseSummary, packet.response) - - top_docs = chunks_or_sections_to_search_docs( - search_response_summary.top_sections - ) - - # Deduping happens at the last step to avoid harming quality by dropping content early on - deduped_docs = top_docs - if query_req.retrieval_options.dedupe_docs: - deduped_docs, dropped_inds = dedupe_documents(top_docs) - - reference_db_search_docs = [ - create_db_search_doc(server_search_doc=doc, db_session=db_session) - for doc in deduped_docs - ] - - response_docs = [ - translate_db_search_doc_to_server_search_doc(db_search_doc) - for db_search_doc in reference_db_search_docs - ] - - initial_response = QADocsResponse( - rephrased_query=rephrased_query, - top_documents=response_docs, - predicted_flow=search_response_summary.predicted_flow, - predicted_search=search_response_summary.predicted_search, - applied_source_filters=search_response_summary.final_filters.source_type, - applied_time_cutoff=search_response_summary.final_filters.time_cutoff, - recency_bias_multiplier=search_response_summary.recency_bias_multiplier, - ) - - yield initial_response - - elif packet.id == SEARCH_DOC_CONTENT_ID: - yield packet.response - - elif packet.id == SECTION_RELEVANCE_LIST_ID: - document_based_response = {} - - if packet.response is not None: - for evaluation in packet.response: - document_based_response[ - evaluation.document_id - ] = RelevanceAnalysis( - relevant=evaluation.relevant, content=evaluation.content - ) - - evaluation_response = DocumentRelevance( - relevance_summaries=document_based_response - ) - if reference_db_search_docs is not None: - update_search_docs_table_with_relevance( - db_session=db_session, - reference_db_search_docs=reference_db_search_docs, - relevance_summary=evaluation_response, - ) - yield evaluation_response - - else: - yield packet - - # Saving Gen AI answer and responding with message info - gen_ai_response_message = create_new_chat_message( - chat_session_id=chat_session.id, - parent_message=new_user_message, - prompt_id=query_req.prompt_id, - message=answer.llm_answer, - token_count=len(llm_tokenizer.encode(answer.llm_answer)), - message_type=MessageType.ASSISTANT, - error=None, - reference_docs=reference_db_search_docs, - db_session=db_session, - commit=True, - ) - - msg_detail_response = translate_db_message_to_chat_message_detail( - gen_ai_response_message - ) - yield msg_detail_response - - -@log_generator_function_time() -def stream_search_answer( - query_req: DirectQARequest, - user: User | None, - max_document_tokens: int | None, - max_history_tokens: int | None, -) -> Iterator[str]: - with get_session_context_manager() as session: - objects = stream_answer_objects( - query_req=query_req, - user=user, - max_document_tokens=max_document_tokens, - max_history_tokens=max_history_tokens, - db_session=session, - ) - for obj in objects: - yield get_json_line(obj.model_dump()) - - -def get_search_answer( - query_req: DirectQARequest, - user: User | None, - max_document_tokens: int | None, - max_history_tokens: int | None, - db_session: Session, - answer_generation_timeout: int = QA_TIMEOUT, - enable_reflexion: bool = False, - bypass_acl: bool = False, - use_citations: bool = False, - danswerbot_flow: bool = False, - retrieval_metrics_callback: ( - Callable[[RetrievalMetricsContainer], None] | None - ) = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> OneShotQAResponse: - """Collects the streamed one shot answer responses into a single object""" - qa_response = OneShotQAResponse() - - results = stream_answer_objects( - query_req=query_req, - user=user, - max_document_tokens=max_document_tokens, - max_history_tokens=max_history_tokens, - db_session=db_session, - bypass_acl=bypass_acl, - use_citations=use_citations, - danswerbot_flow=danswerbot_flow, - timeout=answer_generation_timeout, - retrieval_metrics_callback=retrieval_metrics_callback, - rerank_metrics_callback=rerank_metrics_callback, - ) - - answer = "" - for packet in results: - if isinstance(packet, QueryRephrase): - qa_response.rephrase = packet.rephrased_query - if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: - answer += packet.answer_piece - elif isinstance(packet, QADocsResponse): - qa_response.docs = packet - elif isinstance(packet, LLMRelevanceFilterResponse): - qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices - elif isinstance(packet, DanswerQuotes): - qa_response.quotes = packet - elif isinstance(packet, CitationInfo): - if qa_response.citations: - qa_response.citations.append(packet) - else: - qa_response.citations = [packet] - elif isinstance(packet, DanswerContexts): - qa_response.contexts = packet - elif isinstance(packet, StreamingError): - qa_response.error_msg = packet.error - elif isinstance(packet, ChatMessageDetail): - qa_response.chat_message_id = packet.message_id - - if answer: - qa_response.answer = answer - - if enable_reflexion: - # Because follow up messages are explicitly tagged, we don't need to verify the answer - if len(query_req.messages) == 1: - first_query = query_req.messages[0].message - qa_response.answer_valid = get_answer_validity(first_query, answer) - else: - qa_response.answer_valid = True - - if use_citations and qa_response.answer and qa_response.citations: - # Reorganize citation nums to be in the same order as the answer - qa_response.answer, qa_response.citations = reorganize_citations( - qa_response.answer, qa_response.citations - ) - - return qa_response diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py deleted file mode 100644 index 630c7b5cab4..00000000000 --- a/backend/danswer/one_shot_answer/models.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import Any - -from pydantic import BaseModel -from pydantic import Field -from pydantic import model_validator - -from danswer.chat.models import CitationInfo -from danswer.chat.models import DanswerContexts -from danswer.chat.models import DanswerQuotes -from danswer.chat.models import QADocsResponse -from danswer.configs.constants import MessageType -from danswer.context.search.enums import LLMEvaluationType -from danswer.context.search.enums import RecencyBiasSetting -from danswer.context.search.enums import SearchType -from danswer.context.search.models import ChunkContext -from danswer.context.search.models import RerankingDetails -from danswer.context.search.models import RetrievalDetails - - -class QueryRephrase(BaseModel): - rephrased_query: str - - -class ThreadMessage(BaseModel): - message: str - sender: str | None = None - role: MessageType = MessageType.USER - - -class PromptConfig(BaseModel): - name: str - description: str = "" - system_prompt: str - task_prompt: str = "" - include_citations: bool = True - datetime_aware: bool = True - - -class ToolConfig(BaseModel): - id: int - - -class PersonaConfig(BaseModel): - name: str - description: str - search_type: SearchType = SearchType.SEMANTIC - num_chunks: float | None = None - llm_relevance_filter: bool = False - llm_filter_extraction: bool = False - recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO - llm_model_provider_override: str | None = None - llm_model_version_override: str | None = None - - prompts: list[PromptConfig] = Field(default_factory=list) - prompt_ids: list[int] = Field(default_factory=list) - - document_set_ids: list[int] = Field(default_factory=list) - tools: list[ToolConfig] = Field(default_factory=list) - tool_ids: list[int] = Field(default_factory=list) - custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list) - - -class DirectQARequest(ChunkContext): - persona_config: PersonaConfig | None = None - persona_id: int | None = None - - messages: list[ThreadMessage] - prompt_id: int | None = None - multilingual_query_expansion: list[str] | None = None - retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) - rerank_settings: RerankingDetails | None = None - evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED - - chain_of_thought: bool = False - return_contexts: bool = False - - # allows the caller to specify the exact search query they want to use - # can be used if the message sent to the LLM / query should not be the same - # will also disable Thread-based Rewording if specified - query_override: str | None = None - - # If True, skips generative an AI response to the search query - skip_gen_ai_answer_generation: bool = False - - @model_validator(mode="after") - def check_persona_fields(self) -> "DirectQARequest": - if (self.persona_config is None) == (self.persona_id is None): - raise ValueError("Exactly one of persona_config or persona_id must be set") - return self - - @model_validator(mode="after") - def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest": - if self.chain_of_thought and self.prompt_id is not None: - raise ValueError( - "If chain_of_thought is True, prompt_id must be None" - "The chain of thought prompt is only for question " - "answering and does not accept customizing." - ) - - return self - - -class OneShotQAResponse(BaseModel): - # This is built piece by piece, any of these can be None as the flow could break - answer: str | None = None - rephrase: str | None = None - quotes: DanswerQuotes | None = None - citations: list[CitationInfo] | None = None - docs: QADocsResponse | None = None - llm_selected_doc_indices: list[int] | None = None - error_msg: str | None = None - answer_valid: bool = True # Reflexion result, default True if Reflexion not run - chat_message_id: int | None = None - contexts: DanswerContexts | None = None diff --git a/backend/danswer/one_shot_answer/qa_utils.py b/backend/danswer/one_shot_answer/qa_utils.py deleted file mode 100644 index 6fbad99eff1..00000000000 --- a/backend/danswer/one_shot_answer/qa_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -from collections.abc import Generator - -from danswer.configs.constants import MessageType -from danswer.natural_language_processing.utils import BaseTokenizer -from danswer.one_shot_answer.models import ThreadMessage -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: - """Mock streaming by generating the passed in model output, character by character""" - for token in model_out: - yield token - - -def combine_message_thread( - messages: list[ThreadMessage], - max_tokens: int | None, - llm_tokenizer: BaseTokenizer, -) -> str: - """Used to create a single combined message context from threads""" - if not messages: - return "" - - message_strs: list[str] = [] - total_token_count = 0 - - for message in reversed(messages): - if message.role == MessageType.USER: - role_str = message.role.value.upper() - if message.sender: - role_str += " " + message.sender - else: - # Since other messages might have the user identifying information - # better to use Unknown for symmetry - role_str += " Unknown" - else: - role_str = message.role.value.upper() - - msg_str = f"{role_str}:\n{message.message}" - message_token_count = len(llm_tokenizer.encode(msg_str)) - - if ( - max_tokens is not None - and total_token_count + message_token_count > max_tokens - ): - break - - message_strs.insert(0, msg_str) - total_token_count += message_token_count - - return "\n\n".join(message_strs) diff --git a/backend/danswer/prompts/prompt_utils.py b/backend/danswer/prompts/prompt_utils.py index e8d77206660..0bfabd6e105 100644 --- a/backend/danswer/prompts/prompt_utils.py +++ b/backend/danswer/prompts/prompt_utils.py @@ -5,11 +5,11 @@ from langchain_core.messages import BaseMessage from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.configs.constants import DocumentSource from danswer.context.search.models import InferenceChunk from danswer.db.models import Prompt -from danswer.llm.answering.models import PromptConfig from danswer.prompts.chat_prompts import ADDITIONAL_INFO from danswer.prompts.chat_prompts import CITATION_REMINDER from danswer.prompts.constants import CODE_BLOCK_PAT diff --git a/backend/danswer/redis/redis_connector.py b/backend/danswer/redis/redis_connector.py index 8b52a2fd811..8d82fc11943 100644 --- a/backend/danswer/redis/redis_connector.py +++ b/backend/danswer/redis/redis_connector.py @@ -1,5 +1,8 @@ +import time + import redis +from danswer.db.models import SearchSettings from danswer.redis.redis_connector_delete import RedisConnectorDelete from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync @@ -31,6 +34,44 @@ def new_index(self, search_settings_id: int) -> RedisConnectorIndex: self.tenant_id, self.id, search_settings_id, self.redis ) + def wait_for_indexing_termination( + self, + search_settings_list: list[SearchSettings], + timeout: float = 15.0, + ) -> bool: + """ + Returns True if all indexing for the given redis connector is finished within the given timeout. + Returns False if the timeout is exceeded + + This check does not guarantee that current indexings being terminated + won't get restarted midflight + """ + + finished = False + + start = time.monotonic() + + while True: + still_indexing = False + for search_settings in search_settings_list: + redis_connector_index = self.new_index(search_settings.id) + if redis_connector_index.fenced: + still_indexing = True + break + + if not still_indexing: + finished = True + break + + now = time.monotonic() + if now - start > timeout: + break + + time.sleep(1) + continue + + return finished + @staticmethod def get_id_from_fence_key(key: str) -> str | None: """ diff --git a/backend/danswer/redis/redis_connector_credential_pair.py b/backend/danswer/redis/redis_connector_credential_pair.py index 7ed09d76a2d..46d9d2cf1d3 100644 --- a/backend/danswer/redis/redis_connector_credential_pair.py +++ b/backend/danswer/redis/redis_connector_credential_pair.py @@ -10,6 +10,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.document import ( construct_document_select_for_connector_credential_pair_by_needs_sync, @@ -104,7 +105,7 @@ def generate_tasks( # Priority on sync's triggered by new indexing should be medium result = celery_app.send_task( - "vespa_metadata_sync_task", + DanswerCeleryTask.VESPA_METADATA_SYNC_TASK, kwargs=dict(document_id=doc.id, tenant_id=tenant_id), queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, diff --git a/backend/danswer/redis/redis_connector_delete.py b/backend/danswer/redis/redis_connector_delete.py index 1b7a440b2e5..4ab42ee65a7 100644 --- a/backend/danswer/redis/redis_connector_delete.py +++ b/backend/danswer/redis/redis_connector_delete.py @@ -12,6 +12,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.document import construct_document_select_for_connector_credential_pair from danswer.db.models import Document as DbDocument @@ -114,7 +115,7 @@ def generate_tasks( # Priority on sync's triggered by new indexing should be medium result = celery_app.send_task( - "document_by_cc_pair_cleanup_task", + DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK, kwargs=dict( document_id=doc.id, connector_id=cc_pair.connector_id, diff --git a/backend/danswer/redis/redis_connector_doc_perm_sync.py b/backend/danswer/redis/redis_connector_doc_perm_sync.py index d9c3cd814ff..f14d761f709 100644 --- a/backend/danswer/redis/redis_connector_doc_perm_sync.py +++ b/backend/danswer/redis/redis_connector_doc_perm_sync.py @@ -12,10 +12,12 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask -class RedisConnectorPermissionSyncData(BaseModel): +class RedisConnectorPermissionSyncPayload(BaseModel): started: datetime | None + celery_task_id: str | None class RedisConnectorPermissionSync: @@ -78,14 +80,14 @@ def fenced(self) -> bool: return False @property - def payload(self) -> RedisConnectorPermissionSyncData | None: + def payload(self) -> RedisConnectorPermissionSyncPayload | None: # read related data and evaluate/print task progress fence_bytes = cast(bytes, self.redis.get(self.fence_key)) if fence_bytes is None: return None fence_str = fence_bytes.decode("utf-8") - payload = RedisConnectorPermissionSyncData.model_validate_json( + payload = RedisConnectorPermissionSyncPayload.model_validate_json( cast(str, fence_str) ) @@ -93,7 +95,7 @@ def payload(self) -> RedisConnectorPermissionSyncData | None: def set_fence( self, - payload: RedisConnectorPermissionSyncData | None, + payload: RedisConnectorPermissionSyncPayload | None, ) -> None: if not payload: self.redis.delete(self.fence_key) @@ -131,6 +133,8 @@ def generate_tasks( lock: RedisLock | None, new_permissions: list[DocExternalAccess], source_string: str, + connector_id: int, + credential_id: int, ) -> int | None: last_lock_time = time.monotonic() async_results = [] @@ -148,11 +152,13 @@ def generate_tasks( self.redis.sadd(self.taskset_key, custom_task_id) result = celery_app.send_task( - "update_external_document_permissions_task", + DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK, kwargs=dict( tenant_id=self.tenant_id, serialized_doc_external_access=doc_perm.to_dict(), source_string=source_string, + connector_id=connector_id, + credential_id=credential_id, ), queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT, task_id=custom_task_id, @@ -162,6 +168,12 @@ def generate_tasks( return len(async_results) + def reset(self) -> None: + self.redis.delete(self.generator_progress_key) + self.redis.delete(self.generator_complete_key) + self.redis.delete(self.taskset_key) + self.redis.delete(self.fence_key) + @staticmethod def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}" diff --git a/backend/danswer/redis/redis_connector_ext_group_sync.py b/backend/danswer/redis/redis_connector_ext_group_sync.py index 631845648c3..bbe539c3954 100644 --- a/backend/danswer/redis/redis_connector_ext_group_sync.py +++ b/backend/danswer/redis/redis_connector_ext_group_sync.py @@ -1,11 +1,18 @@ +from datetime import datetime from typing import cast import redis from celery import Celery +from pydantic import BaseModel from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session +class RedisConnectorExternalGroupSyncPayload(BaseModel): + started: datetime | None + celery_task_id: str | None + + class RedisConnectorExternalGroupSync: """Manages interactions with redis for external group syncing tasks. Should only be accessed through RedisConnector.""" @@ -68,12 +75,29 @@ def fenced(self) -> bool: return False - def set_fence(self, value: bool) -> None: - if not value: + @property + def payload(self) -> RedisConnectorExternalGroupSyncPayload | None: + # read related data and evaluate/print task progress + fence_bytes = cast(bytes, self.redis.get(self.fence_key)) + if fence_bytes is None: + return None + + fence_str = fence_bytes.decode("utf-8") + payload = RedisConnectorExternalGroupSyncPayload.model_validate_json( + cast(str, fence_str) + ) + + return payload + + def set_fence( + self, + payload: RedisConnectorExternalGroupSyncPayload | None, + ) -> None: + if not payload: self.redis.delete(self.fence_key) return - self.redis.set(self.fence_key, 0) + self.redis.set(self.fence_key, payload.model_dump_json()) @property def generator_complete(self) -> int | None: diff --git a/backend/danswer/redis/redis_connector_index.py b/backend/danswer/redis/redis_connector_index.py index 10fd3667fda..40b194af03e 100644 --- a/backend/danswer/redis/redis_connector_index.py +++ b/backend/danswer/redis/redis_connector_index.py @@ -29,6 +29,8 @@ class RedisConnectorIndex: GENERATOR_LOCK_PREFIX = "da_lock:indexing" + TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate + def __init__( self, tenant_id: str | None, @@ -51,6 +53,7 @@ def __init__( self.generator_lock_key = ( f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}" ) + self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}" @classmethod def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str: @@ -92,6 +95,18 @@ def set_fence( self.redis.set(self.fence_key, payload.model_dump_json()) + def terminating(self, celery_task_id: str) -> bool: + if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"): + return True + + return False + + def set_terminate(self, celery_task_id: str) -> None: + """This sets a signal. It does not block!""" + # We shouldn't need very long to terminate the spawned task. + # 10 minute TTL is good. + self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600) + def set_generator_complete(self, payload: int | None) -> None: if not payload: self.redis.delete(self.generator_complete_key) diff --git a/backend/danswer/redis/redis_connector_prune.py b/backend/danswer/redis/redis_connector_prune.py index f8e6f372619..9739d2f9832 100644 --- a/backend/danswer/redis/redis_connector_prune.py +++ b/backend/danswer/redis/redis_connector_prune.py @@ -10,6 +10,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id @@ -134,7 +135,7 @@ def generate_tasks( # Priority on sync's triggered by new indexing should be medium result = celery_app.send_task( - "document_by_cc_pair_cleanup_task", + DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK, kwargs=dict( document_id=doc_id, connector_id=cc_pair.connector_id, diff --git a/backend/danswer/redis/redis_document_set.py b/backend/danswer/redis/redis_document_set.py index 879d955eb88..ff92c30a4e5 100644 --- a/backend/danswer/redis/redis_document_set.py +++ b/backend/danswer/redis/redis_document_set.py @@ -11,6 +11,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.db.document_set import construct_document_select_by_docset from danswer.redis.redis_object_helper import RedisObjectHelper @@ -76,7 +77,7 @@ def generate_tasks( redis_client.sadd(self.taskset_key, custom_task_id) result = celery_app.send_task( - "vespa_metadata_sync_task", + DanswerCeleryTask.VESPA_METADATA_SYNC_TASK, kwargs=dict(document_id=doc.id, tenant_id=tenant_id), queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, diff --git a/backend/danswer/redis/redis_usergroup.py b/backend/danswer/redis/redis_usergroup.py index 7c49b9c7fb8..83bd8859632 100644 --- a/backend/danswer/redis/redis_usergroup.py +++ b/backend/danswer/redis/redis_usergroup.py @@ -11,6 +11,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerCeleryTask from danswer.redis.redis_object_helper import RedisObjectHelper from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import global_version @@ -89,7 +90,7 @@ def generate_tasks( redis_client.sadd(self.taskset_key, custom_task_id) result = celery_app.send_task( - "vespa_metadata_sync_task", + DanswerCeleryTask.VESPA_METADATA_SYNC_TASK, kwargs=dict(document_id=doc.id, tenant_id=tenant_id), queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, diff --git a/backend/danswer/secondary_llm_flows/choose_search.py b/backend/danswer/secondary_llm_flows/choose_search.py index 5016cf055bc..36539dd4a7d 100644 --- a/backend/danswer/secondary_llm_flows/choose_search.py +++ b/backend/danswer/secondary_llm_flows/choose_search.py @@ -3,14 +3,14 @@ from langchain.schema import SystemMessage from danswer.chat.chat_utils import combine_message_chain +from danswer.chat.prompt_builder.utils import translate_danswer_msg_to_langchain from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string -from danswer.llm.utils import translate_danswer_msg_to_langchain from danswer.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE from danswer.prompts.chat_prompts import NO_SEARCH from danswer.prompts.chat_prompts import REQUIRE_SEARCH_HINT diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 585af00bdc1..07f187e5b4f 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -4,10 +4,10 @@ from danswer.configs.chat_configs import DISABLE_LLM_QUERY_REPHRASE from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage -from danswer.llm.answering.models import PreviousMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llms from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.prompts.chat_prompts import HISTORY_QUERY_REPHRASE diff --git a/backend/danswer/seeding/input_prompts.yaml b/backend/danswer/seeding/input_prompts.yaml deleted file mode 100644 index cc7dbe78ea1..00000000000 --- a/backend/danswer/seeding/input_prompts.yaml +++ /dev/null @@ -1,24 +0,0 @@ -input_prompts: - - id: -5 - prompt: "Elaborate" - content: "Elaborate on the above, give me a more in depth explanation." - active: true - is_public: true - - - id: -4 - prompt: "Reword" - content: "Help me rewrite the following politely and concisely for professional communication:\n" - active: true - is_public: true - - - id: -3 - prompt: "Email" - content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n" - active: true - is_public: true - - - id: -2 - prompt: "Debug" - content: "Provide step-by-step troubleshooting instructions for the following issue:\n" - active: true - is_public: true diff --git a/backend/danswer/seeding/load_docs.py b/backend/danswer/seeding/load_docs.py index 1567f7f6bbb..5fe591423f0 100644 --- a/backend/danswer/seeding/load_docs.py +++ b/backend/danswer/seeding/load_docs.py @@ -196,7 +196,7 @@ def seed_initial_documents( docs, chunks = _create_indexable_chunks(processed_docs, tenant_id) index_doc_batch_prepare( - document_batch=docs, + documents=docs, index_attempt_metadata=IndexAttemptMetadata( connector_id=connector_id, credential_id=PUBLIC_CREDENTIAL_ID, diff --git a/backend/danswer/seeding/load_yamls.py b/backend/danswer/seeding/load_yamls.py index 0046352679c..6efa1efd368 100644 --- a/backend/danswer/seeding/load_yamls.py +++ b/backend/danswer/seeding/load_yamls.py @@ -1,13 +1,11 @@ import yaml from sqlalchemy.orm import Session -from danswer.configs.chat_configs import INPUT_PROMPT_YAML from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import PERSONAS_YAML from danswer.configs.chat_configs import PROMPTS_YAML from danswer.context.search.enums import RecencyBiasSetting from danswer.db.document_set import get_or_create_document_set_by_name -from danswer.db.input_prompt import insert_input_prompt_if_not_exists from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Persona from danswer.db.models import Prompt as PromptDBModel @@ -79,8 +77,12 @@ def load_personas_from_yaml( if prompts: prompt_ids = [prompt.id for prompt in prompts if prompt is not None] + if not prompt_ids: + raise ValueError("Invalid Persona config, no prompts exist") + p_id = persona.get("id") tool_ids = [] + if persona.get("image_generation"): image_gen_tool = ( db_session.query(ToolDBModel) @@ -122,36 +124,17 @@ def load_personas_from_yaml( tool_ids=tool_ids, builtin_persona=True, is_public=True, - display_priority=existing_persona.display_priority - if existing_persona is not None - else persona.get("display_priority"), - is_visible=existing_persona.is_visible - if existing_persona is not None - else persona.get("is_visible"), - db_session=db_session, - ) - - -def load_input_prompts_from_yaml( - db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML -) -> None: - with open(input_prompts_yaml, "r") as file: - data = yaml.safe_load(file) - - all_input_prompts = data.get("input_prompts", []) - for input_prompt in all_input_prompts: - # If these prompts are deleted (which is a hard delete in the DB), on server startup - # they will be recreated, but the user can always just deactivate them, just a light inconvenience - - insert_input_prompt_if_not_exists( - user=None, - input_prompt_id=input_prompt.get("id"), - prompt=input_prompt["prompt"], - content=input_prompt["content"], - is_public=input_prompt["is_public"], - active=input_prompt.get("active", True), + display_priority=( + existing_persona.display_priority + if existing_persona is not None + else persona.get("display_priority") + ), + is_visible=( + existing_persona.is_visible + if existing_persona is not None + else persona.get("is_visible") + ), db_session=db_session, - commit=True, ) @@ -159,8 +142,6 @@ def load_chat_yamls( db_session: Session, prompt_yaml: str = PROMPTS_YAML, personas_yaml: str = PERSONAS_YAML, - input_prompts_yaml: str = INPUT_PROMPT_YAML, ) -> None: load_prompts_from_yaml(db_session, prompt_yaml) load_personas_from_yaml(db_session, personas_yaml) - load_input_prompts_from_yaml(db_session, input_prompts_yaml) diff --git a/backend/danswer/seeding/personas.yaml b/backend/danswer/seeding/personas.yaml index bafaf2d788c..e628b32e6f7 100644 --- a/backend/danswer/seeding/personas.yaml +++ b/backend/danswer/seeding/personas.yaml @@ -5,7 +5,7 @@ personas: # this is for DanswerBot to use when tagged in a non-configured channel # Careful setting specific IDs, this won't autoincrement the next ID value for postgres - id: 0 - name: "Knowledge" + name: "Search" description: > Assistant with access to documents from your Connected Sources. # Default Prompt objects attached to the persona, see prompts.yaml diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 55808ebcee7..88c812b19e2 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -6,6 +6,7 @@ from fastapi import Depends from fastapi import HTTPException from fastapi import Query +from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -32,16 +33,17 @@ from danswer.db.engine import get_session from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair -from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import count_index_attempts_for_connector from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id +from danswer.db.models import SearchSettings from danswer.db.models import User +from danswer.db.search_settings import get_active_search_settings from danswer.db.search_settings import get_current_search_settings from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_pool import get_redis_client from danswer.server.documents.models import CCPairFullInfo +from danswer.server.documents.models import CCPropertyUpdateRequest from danswer.server.documents.models import CCStatusUpdateRequest from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata @@ -158,7 +160,19 @@ def update_cc_pair_status( status_update_request: CCStatusUpdateRequest, user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), -) -> None: + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """This method may wait up to 30 seconds if pausing the connector due to the need to + terminate tasks in progress. Tasks are not guaranteed to terminate within the + timeout. + + Returns HTTPStatus.OK if everything finished. + Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks + did not finish within the timeout. + """ + WAIT_TIMEOUT = 15.0 + still_terminating = False + cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, db_session=db_session, @@ -173,9 +187,72 @@ def update_cc_pair_status( ) if status_update_request.status == ConnectorCredentialPairStatus.PAUSED: - cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session) + search_settings_list: list[SearchSettings] = get_active_search_settings( + db_session + ) - cancel_indexing_attempts_past_model(db_session) + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + try: + redis_connector.stop.set_fence(True) + while True: + logger.debug( + f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}" + ) + wait_succeeded = redis_connector.wait_for_indexing_termination( + search_settings_list, WAIT_TIMEOUT + ) + if wait_succeeded: + logger.debug( + f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}" + ) + break + + logger.debug( + "Wait for indexing soft termination timed out. " + f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}" + ) + + for search_settings in search_settings_list: + redis_connector_index = redis_connector.new_index( + search_settings.id + ) + if not redis_connector_index.fenced: + continue + + index_payload = redis_connector_index.payload + if not index_payload: + continue + + if not index_payload.celery_task_id: + continue + + # Revoke the task to prevent it from running + primary_app.control.revoke(index_payload.celery_task_id) + + # If it is running, then signaling for termination will get the + # watchdog thread to kill the spawned task + redis_connector_index.set_terminate(index_payload.celery_task_id) + + logger.debug( + f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}" + ) + wait_succeeded = redis_connector.wait_for_indexing_termination( + search_settings_list, WAIT_TIMEOUT + ) + if wait_succeeded: + logger.debug( + f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}" + ) + break + + logger.debug( + f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}" + ) + still_terminating = True + break + finally: + redis_connector.stop.set_fence(False) update_connector_credential_pair_from_id( db_session=db_session, @@ -185,6 +262,18 @@ def update_cc_pair_status( db_session.commit() + if still_terminating: + return JSONResponse( + status_code=HTTPStatus.ACCEPTED, + content={ + "message": "Request accepted, background task termination still in progress" + }, + ) + + return JSONResponse( + status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)} + ) + @router.put("/admin/cc-pair/{cc_pair_id}/name") def update_cc_pair_name( @@ -215,6 +304,46 @@ def update_cc_pair_name( raise HTTPException(status_code=400, detail="Name must be unique") +@router.put("/admin/cc-pair/{cc_pair_id}/property") +def update_cc_pair_property( + cc_pair_id: int, + update_request: CCPropertyUpdateRequest, # in seconds + user: User | None = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), +) -> StatusResponse[int]: + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id=cc_pair_id, + db_session=db_session, + user=user, + get_editable=True, + ) + if not cc_pair: + raise HTTPException( + status_code=400, detail="CC Pair not found for current user's permissions" + ) + + # Can we centralize logic for updating connector properties + # so that we don't need to manually validate everywhere? + if update_request.name == "refresh_frequency": + cc_pair.connector.refresh_freq = int(update_request.value) + cc_pair.connector.validate_refresh_freq() + db_session.commit() + + msg = "Refresh frequency updated successfully" + elif update_request.name == "pruning_frequency": + cc_pair.connector.prune_freq = int(update_request.value) + cc_pair.connector.validate_prune_freq() + db_session.commit() + + msg = "Pruning frequency updated successfully" + else: + raise HTTPException( + status_code=400, detail=f"Property name {update_request.name} is not valid." + ) + + return StatusResponse(success=True, message=msg, data=cc_pair_id) + + @router.get("/admin/cc-pair/{cc_pair_id}/last_pruned") def get_cc_pair_last_pruned( cc_pair_id: int, @@ -267,9 +396,9 @@ def prune_cc_pair( ) logger.info( - f"Pruning cc_pair: cc_pair_id={cc_pair_id} " - f"connector_id={cc_pair.connector_id} " - f"credential_id={cc_pair.credential_id} " + f"Pruning cc_pair: cc_pair={cc_pair_id} " + f"connector={cc_pair.connector_id} " + f"credential={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) tasks_created = try_creating_prune_generator_task( diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index f29cb42f151..e7cf00ba6d2 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -20,6 +20,7 @@ from danswer.background.celery.versioned_apps.primary import app as primary_app from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import DocumentSource from danswer.configs.constants import FileOrigin from danswer.connectors.google_utils.google_auth import ( @@ -85,6 +86,7 @@ from danswer.db.models import User from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_secondary_search_settings +from danswer.file_processing.extract_file_text import convert_docx_to_txt from danswer.file_store.file_store import get_default_file_store from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.redis.redis_connector import RedisConnector @@ -392,6 +394,12 @@ def upload_files( file_origin=FileOrigin.CONNECTOR, file_type=file.content_type or "text/plain", ) + + if file.content_type and file.content_type.startswith( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ): + convert_docx_to_txt(file, file_store, file_path) + except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return FileUploadResponse(file_paths=deduped_file_paths) @@ -867,7 +875,7 @@ def connector_run_once( # run the beat task to pick up the triggers immediately primary_app.send_task( - "check_for_indexing", + DanswerCeleryTask.CHECK_FOR_INDEXING, priority=DanswerCeleryPriority.HIGH, kwargs={"tenant_id": tenant_id}, ) diff --git a/backend/danswer/server/documents/credential.py b/backend/danswer/server/documents/credential.py index 602ca27ee5c..160664e55c0 100644 --- a/backend/danswer/server/documents/credential.py +++ b/backend/danswer/server/documents/credential.py @@ -181,7 +181,13 @@ def update_credential_data( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> CredentialBase: - credential = alter_credential(credential_id, credential_update, user, db_session) + credential = alter_credential( + credential_id, + credential_update.name, + credential_update.credential_json, + user, + db_session, + ) if credential is None: raise HTTPException( diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index 7b523d929ec..2c4f509444f 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -364,6 +364,11 @@ class RunConnectorRequest(BaseModel): from_beginning: bool = False +class CCPropertyUpdateRequest(BaseModel): + name: str + value: str + + """Connectors Models""" diff --git a/backend/danswer/server/documents/standard_oauth.py b/backend/danswer/server/documents/standard_oauth.py new file mode 100644 index 00000000000..ddc85761914 --- /dev/null +++ b/backend/danswer/server/documents/standard_oauth.py @@ -0,0 +1,142 @@ +import uuid +from typing import Annotated +from typing import cast + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Query +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import DocumentSource +from danswer.connectors.interfaces import OAuthConnector +from danswer.db.credentials import create_credential +from danswer.db.engine import get_current_tenant_id +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.redis.redis_pool import get_redis_client +from danswer.server.documents.models import CredentialBase +from danswer.utils.logger import setup_logger +from danswer.utils.subclasses import find_all_subclasses_in_dir + +logger = setup_logger() + +router = APIRouter(prefix="/connector/oauth") + +_OAUTH_STATE_KEY_FMT = "oauth_state:{state}" +_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes + +# Cache for OAuth connectors, populated at module load time +_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {} + + +def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]: + """Walk through the connectors package to find all OAuthConnector implementations""" + global _OAUTH_CONNECTORS + if _OAUTH_CONNECTORS: # Return cached connectors if already discovered + return _OAUTH_CONNECTORS + + oauth_connectors = find_all_subclasses_in_dir( + cast(type[OAuthConnector], OAuthConnector), "danswer.connectors" + ) + + _OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors} + return _OAUTH_CONNECTORS + + +# Discover OAuth connectors at module load time +_discover_oauth_connectors() + + +class AuthorizeResponse(BaseModel): + redirect_url: str + + +@router.get("/authorize/{source}") +def oauth_authorize( + source: DocumentSource, + desired_return_url: Annotated[str | None, Query()] = None, + _: User = Depends(current_user), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> AuthorizeResponse: + """Initiates the OAuth flow by redirecting to the provider's auth page""" + oauth_connectors = _discover_oauth_connectors() + + if source not in oauth_connectors: + raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}") + + connector_cls = oauth_connectors[source] + base_url = WEB_DOMAIN + + # store state in redis + if not desired_return_url: + desired_return_url = f"{base_url}/admin/connectors/{source}?step=0" + redis_client = get_redis_client(tenant_id=tenant_id) + state = str(uuid.uuid4()) + redis_client.set( + _OAUTH_STATE_KEY_FMT.format(state=state), + desired_return_url, + ex=_OAUTH_STATE_EXPIRATION_SECONDS, + ) + + return AuthorizeResponse( + redirect_url=connector_cls.oauth_authorization_url(base_url, state) + ) + + +class CallbackResponse(BaseModel): + redirect_url: str + + +@router.get("/callback/{source}") +def oauth_callback( + source: DocumentSource, + code: Annotated[str, Query()], + state: Annotated[str, Query()], + db_session: Session = Depends(get_session), + user: User = Depends(current_user), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> CallbackResponse: + """Handles the OAuth callback and exchanges the code for tokens""" + oauth_connectors = _discover_oauth_connectors() + + if source not in oauth_connectors: + raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}") + + connector_cls = oauth_connectors[source] + + # get state from redis + redis_client = get_redis_client(tenant_id=tenant_id) + original_url_bytes = cast( + bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state)) + ) + if not original_url_bytes: + raise HTTPException(status_code=400, detail="Invalid OAuth state") + original_url = original_url_bytes.decode("utf-8") + + token_info = connector_cls.oauth_code_to_token(code) + + # Create a new credential with the token info + credential_data = CredentialBase( + credential_json=token_info, + admin_public=True, # Or based on some logic/parameter + source=source, + name=f"{source.title()} OAuth Credential", + ) + + credential = create_credential( + credential_data=credential_data, + user=user, + db_session=db_session, + ) + + return CallbackResponse( + redirect_url=( + f"{original_url}?credentialId={credential.id}" + if "?" not in original_url + else f"{original_url}&credentialId={credential.id}" + ) + ) diff --git a/backend/danswer/server/features/input_prompt/__init__.py b/backend/danswer/server/features/input_prompt/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/backend/danswer/server/features/input_prompt/api.py b/backend/danswer/server/features/input_prompt/api.py deleted file mode 100644 index 58eecd0093d..00000000000 --- a/backend/danswer/server/features/input_prompt/api.py +++ /dev/null @@ -1,134 +0,0 @@ -from fastapi import APIRouter -from fastapi import Depends -from fastapi import HTTPException -from sqlalchemy.orm import Session - -from danswer.auth.users import current_admin_user -from danswer.auth.users import current_user -from danswer.db.engine import get_session -from danswer.db.input_prompt import fetch_input_prompt_by_id -from danswer.db.input_prompt import fetch_input_prompts_by_user -from danswer.db.input_prompt import fetch_public_input_prompts -from danswer.db.input_prompt import insert_input_prompt -from danswer.db.input_prompt import remove_input_prompt -from danswer.db.input_prompt import remove_public_input_prompt -from danswer.db.input_prompt import update_input_prompt -from danswer.db.models import User -from danswer.server.features.input_prompt.models import CreateInputPromptRequest -from danswer.server.features.input_prompt.models import InputPromptSnapshot -from danswer.server.features.input_prompt.models import UpdateInputPromptRequest -from danswer.utils.logger import setup_logger - -logger = setup_logger() - -basic_router = APIRouter(prefix="/input_prompt") -admin_router = APIRouter(prefix="/admin/input_prompt") - - -@basic_router.get("") -def list_input_prompts( - user: User | None = Depends(current_user), - include_public: bool = False, - db_session: Session = Depends(get_session), -) -> list[InputPromptSnapshot]: - user_prompts = fetch_input_prompts_by_user( - user_id=user.id if user is not None else None, - db_session=db_session, - include_public=include_public, - ) - return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts] - - -@basic_router.get("/{input_prompt_id}") -def get_input_prompt( - input_prompt_id: int, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> InputPromptSnapshot: - input_prompt = fetch_input_prompt_by_id( - id=input_prompt_id, - user_id=user.id if user is not None else None, - db_session=db_session, - ) - return InputPromptSnapshot.from_model(input_prompt=input_prompt) - - -@basic_router.post("") -def create_input_prompt( - create_input_prompt_request: CreateInputPromptRequest, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> InputPromptSnapshot: - input_prompt = insert_input_prompt( - prompt=create_input_prompt_request.prompt, - content=create_input_prompt_request.content, - is_public=create_input_prompt_request.is_public, - user=user, - db_session=db_session, - ) - return InputPromptSnapshot.from_model(input_prompt) - - -@basic_router.patch("/{input_prompt_id}") -def patch_input_prompt( - input_prompt_id: int, - update_input_prompt_request: UpdateInputPromptRequest, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> InputPromptSnapshot: - try: - updated_input_prompt = update_input_prompt( - user=user, - input_prompt_id=input_prompt_id, - prompt=update_input_prompt_request.prompt, - content=update_input_prompt_request.content, - active=update_input_prompt_request.active, - db_session=db_session, - ) - except ValueError as e: - error_msg = "Error occurred while updated input prompt" - logger.warn(f"{error_msg}. Stack trace: {e}") - raise HTTPException(status_code=404, detail=error_msg) - - return InputPromptSnapshot.from_model(updated_input_prompt) - - -@basic_router.delete("/{input_prompt_id}") -def delete_input_prompt( - input_prompt_id: int, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> None: - try: - remove_input_prompt(user, input_prompt_id, db_session) - - except ValueError as e: - error_msg = "Error occurred while deleting input prompt" - logger.warn(f"{error_msg}. Stack trace: {e}") - raise HTTPException(status_code=404, detail=error_msg) - - -@admin_router.delete("/{input_prompt_id}") -def delete_public_input_prompt( - input_prompt_id: int, - _: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> None: - try: - remove_public_input_prompt(input_prompt_id, db_session) - - except ValueError as e: - error_msg = "Error occurred while deleting input prompt" - logger.warn(f"{error_msg}. Stack trace: {e}") - raise HTTPException(status_code=404, detail=error_msg) - - -@admin_router.get("") -def list_public_input_prompts( - _: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> list[InputPromptSnapshot]: - user_prompts = fetch_public_input_prompts( - db_session=db_session, - ) - return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts] diff --git a/backend/danswer/server/features/input_prompt/models.py b/backend/danswer/server/features/input_prompt/models.py deleted file mode 100644 index 21ce2ba4e5b..00000000000 --- a/backend/danswer/server/features/input_prompt/models.py +++ /dev/null @@ -1,47 +0,0 @@ -from uuid import UUID - -from pydantic import BaseModel - -from danswer.db.models import InputPrompt -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -class CreateInputPromptRequest(BaseModel): - prompt: str - content: str - is_public: bool - - -class UpdateInputPromptRequest(BaseModel): - prompt: str - content: str - active: bool - - -class InputPromptResponse(BaseModel): - id: int - prompt: str - content: str - active: bool - - -class InputPromptSnapshot(BaseModel): - id: int - prompt: str - content: str - active: bool - user_id: UUID | None - is_public: bool - - @classmethod - def from_model(cls, input_prompt: InputPrompt) -> "InputPromptSnapshot": - return InputPromptSnapshot( - id=input_prompt.id, - prompt=input_prompt.prompt, - content=input_prompt.content, - active=input_prompt.active, - user_id=input_prompt.user_id, - is_public=input_prompt.is_public, - ) diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index fd092fb90ef..f6cb3a2d296 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -13,6 +13,7 @@ from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_limited_user from danswer.auth.users import current_user +from danswer.chat.prompt_builder.utils import build_dummy_prompt from danswer.configs.constants import FileOrigin from danswer.configs.constants import NotificationType from danswer.db.engine import get_session @@ -33,7 +34,6 @@ from danswer.db.persona import update_persona_visibility from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import ChatFileType -from danswer.llm.answering.prompts.utils import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import ImageGenerationToolStatus from danswer.server.features.persona.models import PersonaCategoryCreate diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 1ceeb776abc..cbf744500d4 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -13,6 +13,7 @@ from danswer.background.celery.versioned_apps.primary import app as primary_app from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryTask from danswer.configs.constants import DocumentSource from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME from danswer.db.connector_credential_pair import get_connector_credential_pair @@ -199,7 +200,7 @@ def create_deletion_attempt_for_connector_id( # run the beat task to pick up this deletion from the db immediately primary_app.send_task( - "check_for_connector_deletion_task", + DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION, priority=DanswerCeleryPriority.HIGH, kwargs={"tenant_id": tenant_id}, ) diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 74a3a774e21..0e37fc89191 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from typing import TYPE_CHECKING from pydantic import BaseModel @@ -15,7 +16,6 @@ from danswer.db.models import AllowedAnswerFilters from danswer.db.models import ChannelConfig from danswer.db.models import SlackBot as SlackAppModel -from danswer.db.models import SlackBotResponseType from danswer.db.models import SlackChannelConfig as SlackChannelConfigModel from danswer.db.models import User from danswer.server.features.persona.models import PersonaSnapshot @@ -45,6 +45,7 @@ class UserPreferences(BaseModel): visible_assistants: list[int] = [] recent_assistants: list[int] | None = None default_model: str | None = None + auto_scroll: bool | None = None class UserInfo(BaseModel): @@ -79,6 +80,7 @@ def from_model( role=user.role, preferences=( UserPreferences( + auto_scroll=user.auto_scroll, chosen_assistants=user.chosen_assistants, default_model=user.default_model, hidden_assistants=user.hidden_assistants, @@ -128,6 +130,10 @@ class HiddenUpdateRequest(BaseModel): hidden: bool +class AutoScrollRequest(BaseModel): + auto_scroll: bool | None + + class SlackBotCreationRequest(BaseModel): name: str enabled: bool @@ -142,6 +148,12 @@ class SlackBotTokens(BaseModel): model_config = ConfigDict(frozen=True) +# TODO No longer in use, remove later +class SlackBotResponseType(str, Enum): + QUOTES = "quotes" + CITATIONS = "citations" + + class SlackChannelConfigCreationRequest(BaseModel): slack_bot_id: int # currently, a persona is created for each Slack channel config @@ -156,6 +168,7 @@ class SlackChannelConfigCreationRequest(BaseModel): channel_name: str respond_tag_only: bool = False respond_to_bots: bool = False + show_continue_in_web_ui: bool = False enable_auto_filters: bool = False # If no team members, assume respond in the channel to everyone respond_member_group_list: list[str] = Field(default_factory=list) @@ -190,7 +203,6 @@ class SlackChannelConfig(BaseModel): id: int persona: PersonaSnapshot | None channel_config: ChannelConfig - response_type: SlackBotResponseType # XXX this is going away soon standard_answer_categories: list[StandardAnswerCategory] enable_auto_filters: bool @@ -210,7 +222,6 @@ def from_model( else None ), channel_config=slack_channel_config_model.channel_config, - response_type=slack_channel_config_model.response_type, # XXX this is going away soon standard_answer_categories=[ StandardAnswerCategory.from_model(standard_answer_category_model) @@ -255,5 +266,7 @@ class FullModelVersionResponse(BaseModel): class AllUsersResponse(BaseModel): accepted: list[FullUserSnapshot] invited: list[InvitedUserSnapshot] + slack_users: list[FullUserSnapshot] accepted_pages: int invited_pages: int + slack_users_pages: int diff --git a/backend/danswer/server/manage/slack_bot.py b/backend/danswer/server/manage/slack_bot.py index 036f2fca0dd..10c5a1ac236 100644 --- a/backend/danswer/server/manage/slack_bot.py +++ b/backend/danswer/server/manage/slack_bot.py @@ -80,6 +80,10 @@ def _form_channel_config( if follow_up_tags is not None: channel_config["follow_up_tags"] = follow_up_tags + channel_config[ + "show_continue_in_web_ui" + ] = slack_channel_config_creation_request.show_continue_in_web_ui + channel_config[ "respond_to_bots" ] = slack_channel_config_creation_request.respond_to_bots @@ -114,7 +118,6 @@ def create_slack_channel_config( slack_bot_id=slack_channel_config_creation_request.slack_bot_id, persona_id=persona_id, channel_config=channel_config, - response_type=slack_channel_config_creation_request.response_type, standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories, db_session=db_session, enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters, @@ -178,7 +181,6 @@ def patch_slack_channel_config( slack_channel_config_id=slack_channel_config_id, persona_id=persona_id, channel_config=channel_config, - response_type=slack_channel_config_creation_request.response_type, standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories, enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters, ) diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 5e4197aaf5c..75fd9dfe3a8 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -26,7 +26,6 @@ from danswer.auth.noauth_user import set_no_auth_user_preferences from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserStatus -from danswer.auth.users import BasicAuthenticationError from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user @@ -34,7 +33,6 @@ from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import ENABLE_EMAIL_INVITES from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS -from danswer.configs.app_configs import SUPER_USERS from danswer.configs.app_configs import VALID_EMAIL_DOMAINS from danswer.configs.constants import AuthType from danswer.db.api_key import is_api_key_email_address @@ -52,6 +50,7 @@ from danswer.db.users import validate_user_role_update from danswer.key_value_store.factory import get_kv_store from danswer.server.manage.models import AllUsersResponse +from danswer.server.manage.models import AutoScrollRequest from danswer.server.manage.models import UserByEmail from danswer.server.manage.models import UserInfo from danswer.server.manage.models import UserPreferences @@ -60,9 +59,11 @@ from danswer.server.models import FullUserSnapshot from danswer.server.models import InvitedUserSnapshot from danswer.server.models import MinimalUserSnapshot +from danswer.server.utils import BasicAuthenticationError from danswer.server.utils import send_user_email_invite from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop +from ee.danswer.configs.app_configs import SUPER_USERS from shared_configs.configs import MULTI_TENANT logger = setup_logger() @@ -118,6 +119,7 @@ def set_user_role( def list_all_users( q: str | None = None, accepted_page: int | None = None, + slack_users_page: int | None = None, invited_page: int | None = None, user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), @@ -130,7 +132,12 @@ def list_all_users( for user in list_users(db_session, email_filter_string=q) if not is_api_key_email_address(user.email) ] - accepted_emails = {user.email for user in users} + + slack_users = [user for user in users if user.role == UserRole.SLACK_USER] + accepted_users = [user for user in users if user.role != UserRole.SLACK_USER] + + accepted_emails = {user.email for user in accepted_users} + slack_users_emails = {user.email for user in slack_users} invited_emails = get_invited_users() if q: invited_emails = [ @@ -138,10 +145,11 @@ def list_all_users( ] accepted_count = len(accepted_emails) + slack_users_count = len(slack_users_emails) invited_count = len(invited_emails) # If any of q, accepted_page, or invited_page is None, return all users - if accepted_page is None or invited_page is None: + if accepted_page is None or invited_page is None or slack_users_page is None: return AllUsersResponse( accepted=[ FullUserSnapshot( @@ -152,11 +160,23 @@ def list_all_users( UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED ), ) - for user in users + for user in accepted_users + ], + slack_users=[ + FullUserSnapshot( + id=user.id, + email=user.email, + role=user.role, + status=( + UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED + ), + ) + for user in slack_users ], invited=[InvitedUserSnapshot(email=email) for email in invited_emails], accepted_pages=1, invited_pages=1, + slack_users_pages=1, ) # Otherwise, return paginated results @@ -168,13 +188,27 @@ def list_all_users( role=user.role, status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED, ) - for user in users + for user in accepted_users ][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE], + slack_users=[ + FullUserSnapshot( + id=user.id, + email=user.email, + role=user.role, + status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED, + ) + for user in slack_users + ][ + slack_users_page + * USERS_PAGE_SIZE : (slack_users_page + 1) + * USERS_PAGE_SIZE + ], invited=[InvitedUserSnapshot(email=email) for email in invited_emails][ invited_page * USERS_PAGE_SIZE : (invited_page + 1) * USERS_PAGE_SIZE ], accepted_pages=accepted_count // USERS_PAGE_SIZE + 1, invited_pages=invited_count // USERS_PAGE_SIZE + 1, + slack_users_pages=slack_users_count // USERS_PAGE_SIZE + 1, ) @@ -193,11 +227,11 @@ def bulk_invite_users( ) tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() - normalized_emails = [] + new_invited_emails = [] try: for email in emails: email_info = validate_email(email) - normalized_emails.append(email_info.normalized) # type: ignore + new_invited_emails.append(email_info.normalized) except (EmailUndeliverableError, EmailNotValidError) as e: raise HTTPException( @@ -209,7 +243,7 @@ def bulk_invite_users( try: fetch_ee_implementation_or_noop( "danswer.server.tenants.provisioning", "add_users_to_tenant", None - )(normalized_emails, tenant_id) + )(new_invited_emails, tenant_id) except IntegrityError as e: if isinstance(e.orig, UniqueViolation): @@ -223,7 +257,7 @@ def bulk_invite_users( initial_invited_users = get_invited_users() - all_emails = list(set(normalized_emails) | set(initial_invited_users)) + all_emails = list(set(new_invited_emails) | set(initial_invited_users)) number_of_invited_users = write_invited_users(all_emails) if not MULTI_TENANT: @@ -235,7 +269,7 @@ def bulk_invite_users( )(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)) if ENABLE_EMAIL_INVITES: try: - for email in all_emails: + for email in new_invited_emails: send_user_email_invite(email, current_user) except Exception as e: logger.error(f"Error sending email invite to invited users: {e}") @@ -249,7 +283,7 @@ def bulk_invite_users( write_invited_users(initial_invited_users) # Reset to original state fetch_ee_implementation_or_noop( "danswer.server.tenants.user_mapping", "remove_users_from_tenant", None - )(normalized_emails, tenant_id) + )(new_invited_emails, tenant_id) raise e @@ -497,7 +531,6 @@ def verify_user_logged_in( return fetch_no_auth_user(store) raise BasicAuthenticationError(detail="User Not Authenticated") - if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc): raise BasicAuthenticationError( detail="Access denied. User's OIDC token has expired.", @@ -581,6 +614,30 @@ def update_user_recent_assistants( db_session.commit() +@router.patch("/auto-scroll") +def update_user_auto_scroll( + request: AutoScrollRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + if user is None: + if AUTH_TYPE == AuthType.DISABLED: + store = get_kv_store() + no_auth_user = fetch_no_auth_user(store) + no_auth_user.preferences.auto_scroll = request.auto_scroll + set_no_auth_user_preferences(store, no_auth_user.preferences) + return + else: + raise RuntimeError("This should never happen") + + db_session.execute( + update(User) + .where(User.id == user.id) # type: ignore + .values(auto_scroll=request.auto_scroll) + ) + db_session.commit() + + @router.patch("/user/default-model") def update_user_default_model( request: ChosenDefaultModelRequest, diff --git a/backend/danswer/server/openai_assistants_api/runs_api.py b/backend/danswer/server/openai_assistants_api/runs_api.py index 44bfaa3aca4..74fa6fa0497 100644 --- a/backend/danswer/server/openai_assistants_api/runs_api.py +++ b/backend/danswer/server/openai_assistants_api/runs_api.py @@ -109,6 +109,7 @@ def process_run_in_background( prompt_id=chat_session.persona.prompts[0].id, search_doc_ids=None, retrieval_options=search_tool_retrieval_details, # Adjust as needed + rerank_settings=None, query_override=None, regenerate=None, llm_override=None, diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index c4728336c86..6d62a7bfc66 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -1,6 +1,7 @@ import asyncio import io import json +import os import uuid from collections.abc import Callable from collections.abc import Generator @@ -23,13 +24,18 @@ from danswer.chat.chat_utils import create_chat_chain from danswer.chat.chat_utils import extract_headers from danswer.chat.process_message import stream_chat_message +from danswer.chat.prompt_builder.citations_prompt import ( + compute_max_document_tokens_for_persona, +) from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS +from danswer.db.chat import add_chats_to_session_from_slack_thread from danswer.db.chat import create_chat_session from danswer.db.chat import create_new_chat_message from danswer.db.chat import delete_chat_session +from danswer.db.chat import duplicate_chat_session_for_user_from_slack from danswer.db.chat import get_chat_message from danswer.db.chat import get_chat_messages_by_session from danswer.db.chat import get_chat_session_by_id @@ -45,13 +51,11 @@ from danswer.db.persona import get_persona_by_id from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index +from danswer.file_processing.extract_file_text import docx_to_txt_filename from danswer.file_processing.extract_file_text import extract_file_text from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import ChatFileType from danswer.file_store.models import FileDescriptor -from danswer.llm.answering.prompts.citations_prompt import ( - compute_max_document_tokens_for_persona, -) from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llms from danswer.llm.factory import get_llms_for_persona @@ -532,6 +536,38 @@ def seed_chat( ) +class SeedChatFromSlackRequest(BaseModel): + chat_session_id: UUID + + +class SeedChatFromSlackResponse(BaseModel): + redirect_url: str + + +@router.post("/seed-chat-session-from-slack") +def seed_chat_from_slack( + chat_seed_request: SeedChatFromSlackRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> SeedChatFromSlackResponse: + slack_chat_session_id = chat_seed_request.chat_session_id + new_chat_session = duplicate_chat_session_for_user_from_slack( + db_session=db_session, + user=user, + chat_session_id=slack_chat_session_id, + ) + + add_chats_to_session_from_slack_thread( + db_session=db_session, + slack_chat_session_id=slack_chat_session_id, + new_chat_session_id=new_chat_session.id, + ) + + return SeedChatFromSlackResponse( + redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}" + ) + + """File upload""" @@ -673,14 +709,30 @@ def upload_files_for_chat( } -@router.get("/file/{file_id}") +@router.get("/file/{file_id:path}") def fetch_chat_file( file_id: str, db_session: Session = Depends(get_session), _: User | None = Depends(current_user), ) -> Response: file_store = get_default_file_store(db_session) + file_record = file_store.read_file_record(file_id) + if not file_record: + raise HTTPException(status_code=404, detail="File not found") + + original_file_name = file_record.display_name + if file_record.file_type.startswith( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ): + # Check if a converted text file exists for .docx files + txt_file_name = docx_to_txt_filename(original_file_name) + txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name) + txt_file_record = file_store.read_file_record(txt_file_id) + if txt_file_record: + file_record = txt_file_record + file_id = txt_file_id + + media_type = file_record.file_type file_io = file_store.read_file(file_id, mode="b") - # NOTE: specifying "image/jpeg" here, but it still works for pngs - # TODO: do this properly - return Response(content=file_io.read(), media_type="image/jpeg") + + return StreamingResponse(file_io, media_type=media_type) diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index ae6e651fff1..34ef556daff 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -1,16 +1,19 @@ from datetime import datetime from typing import Any +from typing import TYPE_CHECKING from uuid import UUID from pydantic import BaseModel from pydantic import model_validator +from danswer.chat.models import PersonaOverrideConfig from danswer.chat.models import RetrievalDocs from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.context.search.models import BaseFilters from danswer.context.search.models import ChunkContext +from danswer.context.search.models import RerankingDetails from danswer.context.search.models import RetrievalDetails from danswer.context.search.models import SearchDoc from danswer.context.search.models import Tag @@ -20,6 +23,9 @@ from danswer.llm.override_models import PromptOverride from danswer.tools.models import ToolCallFinalResult +if TYPE_CHECKING: + pass + class SourceTag(Tag): source: DocumentSource @@ -79,6 +85,7 @@ class CreateChatMessageRequest(ChunkContext): message: str # Files that we should attach to this message file_descriptors: list[FileDescriptor] + # If no prompt provided, uses the largest prompt of the chat session # but really this should be explicitly specified, only in the simplified APIs is this inferred # Use prompt_id 0 to use the system default prompt which is Answer-Question @@ -86,6 +93,8 @@ class CreateChatMessageRequest(ChunkContext): # If search_doc_ids provided, then retrieval options are unused search_doc_ids: list[int] | None retrieval_options: RetrievalDetails | None + # Useable via the APIs but not recommended for most flows + rerank_settings: RerankingDetails | None = None # allows the caller to specify the exact search query they want to use # will disable Query Rewording if specified query_override: str | None = None @@ -101,6 +110,10 @@ class CreateChatMessageRequest(ChunkContext): # allow user to specify an alternate assistnat alternate_assistant_id: int | None = None + # This takes the priority over the prompt_override + # This won't be a type that's passed in directly from the API + persona_override_config: PersonaOverrideConfig | None = None + # used for seeded chats to kick off the generation of an AI answer use_existing_user_message: bool = False @@ -144,7 +157,7 @@ class RenameChatSessionResponse(BaseModel): class ChatSessionDetails(BaseModel): id: UUID - name: str + name: str | None persona_id: int | None = None time_created: str shared_status: ChatSessionSharedStatus @@ -197,14 +210,14 @@ def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: class SearchSessionDetailResponse(BaseModel): search_session_id: UUID - description: str + description: str | None documents: list[SearchDoc] messages: list[ChatMessageDetail] class ChatSessionDetailResponse(BaseModel): chat_session_id: UUID - description: str + description: str | None persona_id: int | None = None persona_name: str | None messages: list[ChatMessageDetail] diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index f07d98f0aa9..65e7889dd39 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -1,15 +1,11 @@ -import json -from collections.abc import Generator from uuid import UUID from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException -from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from danswer.auth.users import current_curator_or_admin_user -from danswer.auth.users import current_limited_user from danswer.auth.users import current_user from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType @@ -32,8 +28,6 @@ from danswer.db.tag import find_tags from danswer.document_index.factory import get_default_document_index from danswer.document_index.vespa.index import VespaIndex -from danswer.one_shot_answer.answer_question import stream_search_answer -from danswer.one_shot_answer.models import DirectQARequest from danswer.server.query_and_chat.models import AdminSearchRequest from danswer.server.query_and_chat.models import AdminSearchResponse from danswer.server.query_and_chat.models import ChatSessionDetails @@ -41,7 +35,6 @@ from danswer.server.query_and_chat.models import SearchSessionDetailResponse from danswer.server.query_and_chat.models import SourceTag from danswer.server.query_and_chat.models import TagResponse -from danswer.server.query_and_chat.token_limit import check_token_rate_limits from danswer.utils.logger import setup_logger logger = setup_logger() @@ -140,7 +133,7 @@ def get_user_search_sessions( try: search_sessions = get_chat_sessions_by_user( - user_id=user_id, deleted=False, db_session=db_session, only_one_shot=True + user_id=user_id, deleted=False, db_session=db_session ) except ValueError: raise HTTPException( @@ -229,29 +222,3 @@ def get_search_session( ], ) return response - - -@basic_router.post("/stream-answer-with-quote") -def get_answer_with_quote( - query_request: DirectQARequest, - user: User = Depends(current_limited_user), - _: None = Depends(check_token_rate_limits), -) -> StreamingResponse: - query = query_request.messages[0].message - - logger.notice(f"Received query for one shot answer with quotes: {query}") - - def stream_generator() -> Generator[str, None, None]: - try: - for packet in stream_search_answer( - query_req=query_request, - user=user, - max_document_tokens=None, - max_history_tokens=0, - ): - yield json.dumps(packet) if isinstance(packet, dict) else packet - except Exception as e: - logger.exception("Error in search answer streaming") - yield json.dumps({"error": str(e)}) - - return StreamingResponse(stream_generator(), media_type="application/json") diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 4f598a18353..c453c2fb51e 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -2,7 +2,6 @@ from fastapi import APIRouter from fastapi import Depends -from fastapi import HTTPException from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session @@ -38,10 +37,6 @@ def put_settings( settings: Settings, _: User | None = Depends(current_admin_user) ) -> None: - try: - settings.check_validity() - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) store_settings(settings) diff --git a/backend/danswer/server/settings/models.py b/backend/danswer/server/settings/models.py index af93595501d..55571536d4e 100644 --- a/backend/danswer/server/settings/models.py +++ b/backend/danswer/server/settings/models.py @@ -41,33 +41,10 @@ def from_model(cls, notif: NotificationDBModel) -> "Notification": class Settings(BaseModel): """General settings""" - chat_page_enabled: bool = True - search_page_enabled: bool = True - default_page: PageType = PageType.SEARCH maximum_chat_retention_days: int | None = None gpu_enabled: bool | None = None product_gating: GatingType = GatingType.NONE - def check_validity(self) -> None: - chat_page_enabled = self.chat_page_enabled - search_page_enabled = self.search_page_enabled - default_page = self.default_page - - if chat_page_enabled is False and search_page_enabled is False: - raise ValueError( - "One of `search_page_enabled` and `chat_page_enabled` must be True." - ) - - if default_page == PageType.CHAT and chat_page_enabled is False: - raise ValueError( - "The default page cannot be 'chat' if the chat page is disabled." - ) - - if default_page == PageType.SEARCH and search_page_enabled is False: - raise ValueError( - "The default page cannot be 'search' if the search page is disabled." - ) - class UserSettings(Settings): notifications: list[Notification] diff --git a/backend/danswer/server/utils.py b/backend/danswer/server/utils.py index 68e6dc8d0b8..f59066f9c72 100644 --- a/backend/danswer/server/utils.py +++ b/backend/danswer/server/utils.py @@ -6,6 +6,9 @@ from textwrap import dedent from typing import Any +from fastapi import HTTPException +from fastapi import status + from danswer.configs.app_configs import SMTP_PASS from danswer.configs.app_configs import SMTP_PORT from danswer.configs.app_configs import SMTP_SERVER @@ -14,6 +17,11 @@ from danswer.db.models import User +class BasicAuthenticationError(HTTPException): + def __init__(self, detail: str): + super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail) + + class DateTimeEncoder(json.JSONEncoder): """Custom JSON encoder that converts datetime objects to ISO format strings.""" diff --git a/backend/danswer/setup.py b/backend/danswer/setup.py index 99173821a45..9571ac28926 100644 --- a/backend/danswer/setup.py +++ b/backend/danswer/setup.py @@ -4,6 +4,7 @@ from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.app_configs import MANAGED_VESPA +from danswer.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP from danswer.configs.constants import KV_REINDEX_KEY from danswer.configs.constants import KV_SEARCH_SETTINGS from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION @@ -38,7 +39,6 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder -from danswer.seeding.load_docs import seed_initial_documents from danswer.seeding.load_yamls import load_chat_yamls from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.settings.store import load_settings @@ -150,7 +150,7 @@ def setup_danswer( # update multipass indexing setting based on GPU availability update_default_multipass_indexing(db_session) - seed_initial_documents(db_session, tenant_id, cohere_enabled) + # seed_initial_documents(db_session, tenant_id, cohere_enabled) def translate_saved_search_settings(db_session: Session) -> None: @@ -221,13 +221,13 @@ def setup_vespa( document_index: DocumentIndex, index_setting: IndexingSetting, secondary_index_setting: IndexingSetting | None, + num_attempts: int = VESPA_NUM_ATTEMPTS_ON_STARTUP, ) -> bool: # Vespa startup is a bit slow, so give it a few seconds WAIT_SECONDS = 5 - VESPA_ATTEMPTS = 5 - for x in range(VESPA_ATTEMPTS): + for x in range(num_attempts): try: - logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...") + logger.notice(f"Setting up Vespa (attempt {x+1}/{num_attempts})...") document_index.ensure_indices_exist( index_embedding_dim=index_setting.model_dim, secondary_index_embedding_dim=secondary_index_setting.model_dim @@ -244,7 +244,7 @@ def setup_vespa( time.sleep(WAIT_SECONDS) logger.error( - f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})" + f"Vespa setup did not succeed. Attempt limit reached. ({num_attempts})" ) return False @@ -254,13 +254,14 @@ def setup_postgres(db_session: Session) -> None: create_initial_public_credential(db_session) create_initial_default_connector(db_session) associate_default_cc_pair(db_session) - - logger.notice("Loading default Prompts and Personas") delete_old_default_personas(db_session) - load_chat_yamls(db_session) logger.notice("Loading built-in tools") load_builtin_tools(db_session) + + logger.notice("Loading default Prompts and Personas") + load_chat_yamls(db_session) + refresh_built_in_tools_cache(db_session) auto_add_search_tool_to_personas(db_session) diff --git a/backend/danswer/tools/base_tool.py b/backend/danswer/tools/base_tool.py index 73902504462..ebacf687aab 100644 --- a/backend/danswer/tools/base_tool.py +++ b/backend/danswer/tools/base_tool.py @@ -7,7 +7,7 @@ from danswer.tools.tool import Tool if TYPE_CHECKING: - from danswer.llm.answering.prompts.build import AnswerPromptBuilder + from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.tools.tool_implementations.custom.custom_tool import ( CustomToolCallSummary, ) diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py index 6fc9251a18a..210a8028645 100644 --- a/backend/danswer/tools/tool.py +++ b/backend/danswer/tools/tool.py @@ -3,13 +3,13 @@ from typing import Any from typing import TYPE_CHECKING -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.utils.special_types import JSON_ro if TYPE_CHECKING: - from danswer.llm.answering.prompts.build import AnswerPromptBuilder + from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.tools.message import ToolCallSummary from danswer.tools.models import ToolResponse diff --git a/backend/danswer/tools/tool_constructor.py b/backend/danswer/tools/tool_constructor.py index a8fb5706dc2..6f371793551 100644 --- a/backend/danswer/tools/tool_constructor.py +++ b/backend/danswer/tools/tool_constructor.py @@ -5,6 +5,10 @@ from pydantic import Field from sqlalchemy.orm import Session +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import CitationConfig +from danswer.chat.models import DocumentPruningConfig +from danswer.chat.models import PromptConfig from danswer.configs.app_configs import AZURE_DALLE_API_BASE from danswer.configs.app_configs import AZURE_DALLE_API_KEY from danswer.configs.app_configs import AZURE_DALLE_API_VERSION @@ -13,15 +17,12 @@ from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.context.search.enums import LLMEvaluationType from danswer.context.search.models import InferenceSection +from danswer.context.search.models import RerankingDetails from danswer.context.search.models import RetrievalDetails from danswer.db.llm import fetch_existing_llm_providers from danswer.db.models import Persona from danswer.db.models import User from danswer.file_store.models import InMemoryChatFile -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import CitationConfig -from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import PromptConfig from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLMConfig from danswer.natural_language_processing.utils import get_tokenizer @@ -102,11 +103,14 @@ class SearchToolConfig(BaseModel): default_factory=DocumentPruningConfig ) retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) + rerank_settings: RerankingDetails | None = None selected_sections: list[InferenceSection] | None = None chunks_above: int = 0 chunks_below: int = 0 full_doc: bool = False latest_query_files: list[InMemoryChatFile] | None = None + # Use with care, should only be used for DanswerBot in channels with multiple users + bypass_acl: bool = False class InternetSearchToolConfig(BaseModel): @@ -170,6 +174,8 @@ def construct_tools( if persona.llm_relevance_filter else LLMEvaluationType.SKIP ), + rerank_settings=search_tool_config.rerank_settings, + bypass_acl=search_tool_config.bypass_acl, ) tool_dict[db_tool_model.id] = [search_tool] diff --git a/backend/danswer/tools/tool_implementations/custom/custom_tool.py b/backend/danswer/tools/tool_implementations/custom/custom_tool.py index c25d61b3cf3..b874a2164a5 100644 --- a/backend/danswer/tools/tool_implementations/custom/custom_tool.py +++ b/backend/danswer/tools/tool_implementations/custom/custom_tool.py @@ -15,14 +15,14 @@ from pydantic import BaseModel from requests import JSONDecodeError +from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.constants import FileOrigin from danswer.db.engine import get_session_with_default_tenant from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import ChatFileType from danswer.file_store.models import InMemoryChatFile -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.tools.base_tool import BaseTool from danswer.tools.message import ToolCallSummary from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER diff --git a/backend/danswer/tools/tool_implementations/images/image_generation_tool.py b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py index 70763fc7896..d8d3d754316 100644 --- a/backend/danswer/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py @@ -4,14 +4,16 @@ from typing import Any from typing import cast +import requests from litellm import image_generation # type: ignore from pydantic import BaseModel from danswer.chat.chat_utils import combine_message_chain +from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.prompts.build import AnswerPromptBuilder +from danswer.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import message_to_string from danswer.prompts.constants import GENERAL_SEP_PAT @@ -56,9 +58,18 @@ """.strip() +class ImageFormat(str, Enum): + URL = "url" + BASE64 = "b64_json" + + +_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT) + + class ImageGenerationResponse(BaseModel): revised_prompt: str - url: str + url: str | None + image_data: str | None class ImageShape(str, Enum): @@ -80,6 +91,7 @@ def __init__( model: str = "dall-e-3", num_imgs: int = 2, additional_headers: dict[str, str] | None = None, + output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT, ) -> None: self.api_key = api_key self.api_base = api_base @@ -89,6 +101,7 @@ def __init__( self.num_imgs = num_imgs self.additional_headers = additional_headers + self.output_format = output_format @property def name(self) -> str: @@ -168,7 +181,7 @@ def build_tool_message_content( ) return build_content_with_imgs( - json.dumps( + message=json.dumps( [ { "revised_prompt": image_generation.revised_prompt, @@ -177,13 +190,10 @@ def build_tool_message_content( for image_generation in image_generations ] ), - # NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow - # Tool messages to contain images - # img_urls=[image_generation.url for image_generation in image_generations], ) def _generate_image( - self, prompt: str, shape: ImageShape + self, prompt: str, shape: ImageShape, format: ImageFormat ) -> ImageGenerationResponse: if shape == ImageShape.LANDSCAPE: size = "1792x1024" @@ -197,20 +207,32 @@ def _generate_image( prompt=prompt, model=self.model, api_key=self.api_key, - # need to pass in None rather than empty str api_base=self.api_base or None, api_version=self.api_version or None, size=size, n=1, + response_format=format, extra_headers=build_llm_extra_headers(self.additional_headers), ) + + if format == ImageFormat.URL: + url = response.data[0]["url"] + image_data = None + else: + url = None + image_data = response.data[0]["b64_json"] + return ImageGenerationResponse( revised_prompt=response.data[0]["revised_prompt"], - url=response.data[0]["url"], + url=url, + image_data=image_data, ) + except requests.RequestException as e: + logger.error(f"Error fetching or converting image: {e}") + raise ValueError("Failed to fetch or convert the generated image") except Exception as e: - logger.debug(f"Error occured during image generation: {e}") + logger.debug(f"Error occurred during image generation: {e}") error_message = str(e) if "OpenAIException" in str(type(e)): @@ -235,9 +257,8 @@ def _generate_image( def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: prompt = cast(str, kwargs["prompt"]) shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE)) + format = self.output_format - # dalle3 only supports 1 image at a time, which is why we have to - # parallelize this via threading results = cast( list[ImageGenerationResponse], run_functions_tuples_in_parallel( @@ -247,6 +268,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: ( prompt, shape, + format, ), ) for _ in range(self.num_imgs) @@ -288,11 +310,17 @@ def build_next_prompt( if img_generation_response is None: raise ValueError("No image generation response found") - img_urls = [img.url for img in img_generation_response] + img_urls = [img.url for img in img_generation_response if img.url is not None] + b64_imgs = [ + img.image_data + for img in img_generation_response + if img.image_data is not None + ] prompt_builder.update_user_prompt( build_image_generation_user_prompt( query=prompt_builder.get_user_message_content(), img_urls=img_urls, + b64_imgs=b64_imgs, ) ) diff --git a/backend/danswer/tools/tool_implementations/images/prompt.py b/backend/danswer/tools/tool_implementations/images/prompt.py index bb729bfcd1c..e5f11ba62d1 100644 --- a/backend/danswer/tools/tool_implementations/images/prompt.py +++ b/backend/danswer/tools/tool_implementations/images/prompt.py @@ -11,11 +11,14 @@ def build_image_generation_user_prompt( - query: str, img_urls: list[str] | None = None + query: str, + img_urls: list[str] | None = None, + b64_imgs: list[str] | None = None, ) -> HumanMessage: return HumanMessage( content=build_content_with_imgs( message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(), + b64_imgs=b64_imgs, img_urls=img_urls, ) ) diff --git a/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py b/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py index fd59b08abe1..cdd52f7633d 100644 --- a/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py +++ b/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py @@ -7,15 +7,15 @@ import httpx from danswer.chat.chat_utils import combine_message_chain +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.context.search.models import SearchDoc -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import message_to_string from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE from danswer.prompts.constants import GENERAL_SEP_PAT @@ -77,6 +77,7 @@ def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc: updated_at=datetime.now(), link=result.link, source_links={0: result.link}, + match_highlights=[], ) diff --git a/backend/danswer/tools/tool_implementations/search/search_tool.py b/backend/danswer/tools/tool_implementations/search/search_tool.py index 0a7be7e3885..a0c686bd6cf 100644 --- a/backend/danswer/tools/tool_implementations/search/search_tool.py +++ b/backend/danswer/tools/tool_implementations/search/search_tool.py @@ -7,10 +7,19 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import llm_doc_from_inference_section +from danswer.chat.llm_response_handler import LLMCall +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import ContextualPruningConfig from danswer.chat.models import DanswerContext from danswer.chat.models import DanswerContexts +from danswer.chat.models import DocumentPruningConfig from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.chat.models import SectionRelevancePiece +from danswer.chat.prompt_builder.build import AnswerPromptBuilder +from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens +from danswer.chat.prune_and_merge import prune_and_merge_sections +from danswer.chat.prune_and_merge import prune_sections from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS @@ -19,22 +28,14 @@ from danswer.context.search.enums import SearchType from danswer.context.search.models import IndexFilters from danswer.context.search.models import InferenceSection +from danswer.context.search.models import RerankingDetails from danswer.context.search.models import RetrievalDetails from danswer.context.search.models import SearchRequest from danswer.context.search.pipeline import SearchPipeline from danswer.db.models import Persona from danswer.db.models import User -from danswer.llm.answering.llm_response_handler import LLMCall -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import ContextualPruningConfig -from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder -from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens -from danswer.llm.answering.prune_and_merge import prune_and_merge_sections -from danswer.llm.answering.prune_and_merge import prune_sections from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.tools.message import ToolCallSummary @@ -47,6 +48,9 @@ from danswer.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) +from danswer.tools.tool_implementations.search_like_tool_utils import ( + ORIGINAL_CONTEXT_DOCUMENTS_ID, +) from danswer.utils.logger import setup_logger from danswer.utils.special_types import JSON_ro @@ -103,6 +107,7 @@ def __init__( chunks_below: int | None = None, full_doc: bool = False, bypass_acl: bool = False, + rerank_settings: RerankingDetails | None = None, ) -> None: self.user = user self.persona = persona @@ -118,6 +123,9 @@ def __init__( self.bypass_acl = bypass_acl self.db_session = db_session + # Only used via API + self.rerank_settings = rerank_settings + self.chunks_above = ( chunks_above if chunks_above is not None @@ -292,6 +300,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: self.retrieval_options.offset if self.retrieval_options else None ), limit=self.retrieval_options.limit if self.retrieval_options else None, + rerank_settings=self.rerank_settings, chunks_above=self.chunks_above, chunks_below=self.chunks_below, full_doc=self.full_doc, @@ -385,15 +394,35 @@ def build_next_prompt( """Other utility functions""" @classmethod - def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None: + def get_search_result( + cls, llm_call: LLMCall + ) -> tuple[list[LlmDoc], dict[str, int]] | None: + """ + Returns the final search results and a map of docs to their original search rank (which is what is displayed to user) + """ if not llm_call.tool_call_info: return None + final_search_results = [] + doc_id_to_original_search_rank_map = {} + for yield_item in llm_call.tool_call_info: if ( isinstance(yield_item, ToolResponse) and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID ): - return cast(list[LlmDoc], yield_item.response) - - return None + final_search_results = cast(list[LlmDoc], yield_item.response) + elif ( + isinstance(yield_item, ToolResponse) + and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID + ): + search_contexts = yield_item.response.contexts + original_doc_search_rank = 1 + for idx, doc in enumerate(search_contexts): + if doc.document_id not in doc_id_to_original_search_rank_map: + doc_id_to_original_search_rank_map[ + doc.document_id + ] = original_doc_search_rank + original_doc_search_rank += 1 + + return final_search_results, doc_id_to_original_search_rank_map diff --git a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py index 55890188d7e..7edb22fc144 100644 --- a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py +++ b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py @@ -2,19 +2,20 @@ from langchain_core.messages import HumanMessage +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import LlmDoc -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder -from danswer.llm.answering.prompts.citations_prompt import ( +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.build import AnswerPromptBuilder +from danswer.chat.prompt_builder.citations_prompt import ( build_citations_system_message, ) -from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message -from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message +from danswer.chat.prompt_builder.citations_prompt import build_citations_user_message +from danswer.chat.prompt_builder.quotes_prompt import build_quotes_user_message from danswer.tools.message import ToolCallSummary from danswer.tools.models import ToolResponse +ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content" FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index fb3eb8b9932..55ae7022ef5 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -2,8 +2,8 @@ from collections.abc import Generator from typing import Any -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.tools.models import ToolCallFinalResult from danswer.tools.models import ToolCallKickoff from danswer.tools.models import ToolResponse diff --git a/backend/danswer/tools/tool_selection.py b/backend/danswer/tools/tool_selection.py index dc8d697c2ad..f9fbaf9c064 100644 --- a/backend/danswer/tools/tool_selection.py +++ b/backend/danswer/tools/tool_selection.py @@ -3,8 +3,8 @@ from danswer.chat.chat_utils import combine_message_chain from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import message_to_string from danswer.prompts.constants import GENERAL_SEP_PAT from danswer.tools.tool import Tool diff --git a/backend/danswer/utils/b64.py b/backend/danswer/utils/b64.py new file mode 100644 index 00000000000..05a915814ad --- /dev/null +++ b/backend/danswer/utils/b64.py @@ -0,0 +1,25 @@ +import base64 + + +def get_image_type_from_bytes(raw_b64_bytes: bytes) -> str: + magic_number = raw_b64_bytes[:4] + + if magic_number.startswith(b"\x89PNG"): + mime_type = "image/png" + elif magic_number.startswith(b"\xFF\xD8"): + mime_type = "image/jpeg" + elif magic_number.startswith(b"GIF8"): + mime_type = "image/gif" + elif magic_number.startswith(b"RIFF") and raw_b64_bytes[8:12] == b"WEBP": + mime_type = "image/webp" + else: + raise ValueError( + "Unsupported image format - only PNG, JPEG, " "GIF, and WEBP are supported." + ) + + return mime_type + + +def get_image_type(raw_b64_string: str) -> str: + binary_data = base64.b64decode(raw_b64_string) + return get_image_type_from_bytes(binary_data) diff --git a/backend/danswer/utils/subclasses.py b/backend/danswer/utils/subclasses.py new file mode 100644 index 00000000000..72408f98b08 --- /dev/null +++ b/backend/danswer/utils/subclasses.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import importlib +import os +import pkgutil +import sys +from types import ModuleType +from typing import List +from typing import Type +from typing import TypeVar + +T = TypeVar("T") + + +def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]: + """ + Imports all modules found in the given directory and its subdirectories, + returning a list of imported module objects. + """ + dir_path = os.path.abspath(dir_path) + + if dir_path not in sys.path: + sys.path.insert(0, dir_path) + + imported_modules: List[ModuleType] = [] + + for _, package_name, _ in pkgutil.walk_packages([dir_path]): + try: + module = importlib.import_module(package_name) + imported_modules.append(module) + except Exception as e: + # Handle or log exceptions as needed + print(f"Could not import {package_name}: {e}") + + return imported_modules + + +def all_subclasses(cls: Type[T]) -> List[Type[T]]: + """ + Recursively find all subclasses of the given class. + """ + direct_subs = cls.__subclasses__() + result: List[Type[T]] = [] + for subclass in direct_subs: + result.append(subclass) + # Extend the result by recursively calling all_subclasses + result.extend(all_subclasses(subclass)) + return result + + +def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Type[T]]: + """ + Imports all modules from the given directory (and subdirectories), + then returns all classes that are subclasses of parent_class. + + :param parent_class: The class to find subclasses of. + :param directory: The directory to search for subclasses. + :return: A list of all subclasses of parent_class found in the directory. + """ + # First import all modules to ensure classes are loaded into memory + import_all_modules_from_dir(directory) + + # Gather all subclasses of the given parent class + subclasses = all_subclasses(parent_class) + return subclasses + + +# Example usage: +if __name__ == "__main__": + + class Animal: + pass + + # Suppose "mymodules" contains files that define classes inheriting from Animal + found_subclasses = find_all_subclasses_in_dir(Animal, "mymodules") + for sc in found_subclasses: + print("Found subclass:", sc.__name__) diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py index aab88efa8e4..3d44acc5ec2 100644 --- a/backend/ee/danswer/auth/users.py +++ b/backend/ee/danswer/auth/users.py @@ -1,23 +1,72 @@ +from functools import lru_cache + +import requests from fastapi import Depends from fastapi import HTTPException from fastapi import Request from fastapi import status +from jwt import decode as jwt_decode +from jwt import InvalidTokenError +from jwt import PyJWTError +from sqlalchemy import func +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from danswer.auth.users import current_admin_user from danswer.configs.app_configs import AUTH_TYPE -from danswer.configs.app_configs import SUPER_CLOUD_API_KEY -from danswer.configs.app_configs import SUPER_USERS from danswer.configs.constants import AuthType from danswer.db.models import User from danswer.utils.logger import setup_logger +from ee.danswer.configs.app_configs import JWT_PUBLIC_KEY_URL +from ee.danswer.configs.app_configs import SUPER_CLOUD_API_KEY +from ee.danswer.configs.app_configs import SUPER_USERS from ee.danswer.db.saml import get_saml_account from ee.danswer.server.seeding import get_seed_config from ee.danswer.utils.secrets import extract_hashed_cookie + logger = setup_logger() +@lru_cache() +def get_public_key() -> str | None: + if JWT_PUBLIC_KEY_URL is None: + logger.error("JWT_PUBLIC_KEY_URL is not set") + return None + + response = requests.get(JWT_PUBLIC_KEY_URL) + response.raise_for_status() + return response.text + + +async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None: + try: + public_key_pem = get_public_key() + if public_key_pem is None: + logger.error("Failed to retrieve public key") + return None + + payload = jwt_decode( + token, + public_key_pem, + algorithms=["RS256"], + audience=None, + ) + email = payload.get("email") + if email: + result = await async_db_session.execute( + select(User).where(func.lower(User.email) == func.lower(email)) + ) + return result.scalars().first() + except InvalidTokenError: + logger.error("Invalid JWT token") + get_public_key.cache_clear() + except PyJWTError as e: + logger.error(f"JWT decoding error: {str(e)}") + get_public_key.cache_clear() + return None + + def verify_auth_setting() -> None: # All the Auth flows are valid for EE version logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") @@ -38,6 +87,13 @@ async def optional_user_( ) user = saml_account.user if saml_account else None + # If user is still None, check for JWT in Authorization header + if user is None and JWT_PUBLIC_KEY_URL is not None: + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header[len("Bearer ") :].strip() + user = await verify_jwt_token(token, async_db_session) + return user diff --git a/backend/ee/danswer/background/celery/tasks/beat_schedule.py b/backend/ee/danswer/background/celery/tasks/beat_schedule.py index 86680e60c7f..4444d73544f 100644 --- a/backend/ee/danswer/background/celery/tasks/beat_schedule.py +++ b/backend/ee/danswer/background/celery/tasks/beat_schedule.py @@ -4,16 +4,17 @@ from danswer.background.celery.tasks.beat_schedule import ( tasks_to_schedule as base_tasks_to_schedule, ) +from danswer.configs.constants import DanswerCeleryTask ee_tasks_to_schedule = [ { "name": "autogenerate_usage_report", - "task": "autogenerate_usage_report_task", + "task": DanswerCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK, "schedule": timedelta(days=30), # TODO: change this to config flag }, { "name": "check-ttl-management", - "task": "check_ttl_management_task", + "task": DanswerCeleryTask.CHECK_TTL_MANAGEMENT_TASK, "schedule": timedelta(hours=1), }, ] diff --git a/backend/ee/danswer/chat/process_message.py b/backend/ee/danswer/chat/process_message.py new file mode 100644 index 00000000000..e28ef97e2ce --- /dev/null +++ b/backend/ee/danswer/chat/process_message.py @@ -0,0 +1,41 @@ +from danswer.chat.models import AllCitations +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerContexts +from danswer.chat.models import LLMRelevanceFilterResponse +from danswer.chat.models import QADocsResponse +from danswer.chat.models import StreamingError +from danswer.chat.process_message import ChatPacketStream +from danswer.server.query_and_chat.models import ChatMessageDetail +from danswer.utils.timing import log_function_time +from ee.danswer.server.query_and_chat.models import OneShotQAResponse + + +@log_function_time() +def gather_stream_for_answer_api( + packets: ChatPacketStream, +) -> OneShotQAResponse: + response = OneShotQAResponse() + + answer = "" + for packet in packets: + if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: + answer += packet.answer_piece + elif isinstance(packet, QADocsResponse): + response.docs = packet + # Extraneous, provided for backwards compatibility + response.rephrase = packet.rephrased_query + elif isinstance(packet, StreamingError): + response.error_msg = packet.error + elif isinstance(packet, ChatMessageDetail): + response.chat_message_id = packet.message_id + elif isinstance(packet, LLMRelevanceFilterResponse): + response.llm_selected_doc_indices = packet.llm_selected_doc_indices + elif isinstance(packet, AllCitations): + response.citations = packet.citations + elif isinstance(packet, DanswerContexts): + response.contexts = packet + + if answer: + response.answer = answer + + return response diff --git a/backend/ee/danswer/configs/app_configs.py b/backend/ee/danswer/configs/app_configs.py index 7e1ade5f3a2..057922dc246 100644 --- a/backend/ee/danswer/configs/app_configs.py +++ b/backend/ee/danswer/configs/app_configs.py @@ -1,3 +1,4 @@ +import json import os # Applicable for OIDC Auth @@ -10,6 +11,14 @@ ##### # Auto Permission Sync ##### +# In seconds, default is 5 minutes +CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int( + os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60 +) +# In seconds, default is 5 minutes +CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int( + os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60 +) NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2) @@ -19,3 +28,14 @@ OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY") ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY") COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY") + +# JWT Public Key URL +JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None) + + +# Super Users +SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]')) +SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") + +OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "") +OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "") diff --git a/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py b/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py index e0995acc334..6c29f9f38a8 100644 --- a/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py +++ b/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py @@ -155,7 +155,6 @@ def _handle_standard_answers( else 0, danswerbot_flow=True, slack_thread_id=slack_thread_id, - one_shot=True, ) root_message = get_or_create_root_message( diff --git a/backend/ee/danswer/db/analytics.py b/backend/ee/danswer/db/analytics.py index e0eff7850e4..8d27af06899 100644 --- a/backend/ee/danswer/db/analytics.py +++ b/backend/ee/danswer/db/analytics.py @@ -170,3 +170,67 @@ def fetch_danswerbot_analytics( ) return results + + +def fetch_persona_message_analytics( + db_session: Session, + persona_id: int, + start: datetime.datetime, + end: datetime.datetime, +) -> list[tuple[int, datetime.date]]: + """Gets the daily message counts for a specific persona within the given time range.""" + query = ( + select( + func.count(ChatMessage.id), + cast(ChatMessage.time_sent, Date), + ) + .join( + ChatSession, + ChatMessage.chat_session_id == ChatSession.id, + ) + .where( + or_( + ChatMessage.alternate_assistant_id == persona_id, + ChatSession.persona_id == persona_id, + ), + ChatMessage.time_sent >= start, + ChatMessage.time_sent <= end, + ChatMessage.message_type == MessageType.ASSISTANT, + ) + .group_by(cast(ChatMessage.time_sent, Date)) + .order_by(cast(ChatMessage.time_sent, Date)) + ) + + return [tuple(row) for row in db_session.execute(query).all()] + + +def fetch_persona_unique_users( + db_session: Session, + persona_id: int, + start: datetime.datetime, + end: datetime.datetime, +) -> list[tuple[int, datetime.date]]: + """Gets the daily unique user counts for a specific persona within the given time range.""" + query = ( + select( + func.count(func.distinct(ChatSession.user_id)), + cast(ChatMessage.time_sent, Date), + ) + .join( + ChatSession, + ChatMessage.chat_session_id == ChatSession.id, + ) + .where( + or_( + ChatMessage.alternate_assistant_id == persona_id, + ChatSession.persona_id == persona_id, + ), + ChatMessage.time_sent >= start, + ChatMessage.time_sent <= end, + ChatMessage.message_type == MessageType.ASSISTANT, + ) + .group_by(cast(ChatMessage.time_sent, Date)) + .order_by(cast(ChatMessage.time_sent, Date)) + ) + + return [tuple(row) for row in db_session.execute(query).all()] diff --git a/backend/ee/danswer/db/connector_credential_pair.py b/backend/ee/danswer/db/connector_credential_pair.py index bb91c0de74f..fea6caba61b 100644 --- a/backend/ee/danswer/db/connector_credential_pair.py +++ b/backend/ee/danswer/db/connector_credential_pair.py @@ -37,10 +37,15 @@ def get_cc_pairs_by_source( source_type: DocumentSource, only_sync: bool, ) -> list[ConnectorCredentialPair]: + """ + Get all cc_pairs for a given source type (and optionally only sync) + result is sorted by cc_pair id + """ query = ( db_session.query(ConnectorCredentialPair) .join(ConnectorCredentialPair.connector) .filter(Connector.source == source_type) + .order_by(ConnectorCredentialPair.id) ) if only_sync: diff --git a/backend/ee/danswer/db/document.py b/backend/ee/danswer/db/document.py index e061db6c75b..ec1d5741314 100644 --- a/backend/ee/danswer/db/document.py +++ b/backend/ee/danswer/db/document.py @@ -55,9 +55,10 @@ def upsert_document_external_perms( doc_id: str, external_access: ExternalAccess, source_type: DocumentSource, -) -> None: +) -> bool: """ - This sets the permissions for a document in postgres. + This sets the permissions for a document in postgres. Returns True if the + a new document was created, False otherwise. NOTE: this will replace any existing external access, it will not do a union """ document = db_session.scalars( @@ -85,7 +86,7 @@ def upsert_document_external_perms( ) db_session.add(document) db_session.commit() - return + return True # If the document exists, we need to check if the external access has changed if ( @@ -98,3 +99,5 @@ def upsert_document_external_perms( document.is_public = external_access.is_public document.last_modified = datetime.now(timezone.utc) db_session.commit() + + return False diff --git a/backend/ee/danswer/db/external_perm.py b/backend/ee/danswer/db/external_perm.py index 5411d3c8d34..7121130e3eb 100644 --- a/backend/ee/danswer/db/external_perm.py +++ b/backend/ee/danswer/db/external_perm.py @@ -10,6 +10,9 @@ from danswer.configs.constants import DocumentSource from danswer.db.models import User__ExternalUserGroupId from danswer.db.users import batch_add_ext_perm_user_if_not_exists +from danswer.utils.logger import setup_logger + +logger = setup_logger() class ExternalUserGroup(BaseModel): @@ -73,7 +76,13 @@ def replace_user__ext_group_for_cc_pair( new_external_permissions = [] for external_group in group_defs: for user_email in external_group.user_emails: - user_id = email_id_map[user_email] + user_id = email_id_map.get(user_email.lower()) + if user_id is None: + logger.warning( + f"User in group {external_group.id}" + f" with email {user_email} not found" + ) + continue new_external_permissions.append( User__ExternalUserGroupId( user_id=user_id, diff --git a/backend/ee/danswer/db/usage_export.py b/backend/ee/danswer/db/usage_export.py index 074e1ae7d6d..0958c624e33 100644 --- a/backend/ee/danswer/db/usage_export.py +++ b/backend/ee/danswer/db/usage_export.py @@ -33,12 +33,7 @@ def get_empty_chat_messages_entries__paginated( message_skeletons: list[ChatMessageSkeleton] = [] for chat_session in chat_sessions: - if chat_session.one_shot: - flow_type = FlowType.SEARCH - elif chat_session.danswerbot_flow: - flow_type = FlowType.SLACK - else: - flow_type = FlowType.CHAT + flow_type = FlowType.SLACK if chat_session.danswerbot_flow else FlowType.CHAT for message in chat_session.messages: # Only count user messages diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index ba9e3440497..187f7c7b901 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential__UserGroup @@ -298,6 +299,11 @@ def fetch_user_groups_for_documents( db_session: Session, document_ids: list[str], ) -> Sequence[tuple[str, list[str]]]: + """ + Fetches all user groups that have access to the given documents. + + NOTE: this doesn't include groups if the cc_pair is access type SYNC + """ stmt = ( select(Document.id, func.array_agg(UserGroup.name)) .join( @@ -306,7 +312,11 @@ def fetch_user_groups_for_documents( ) .join( ConnectorCredentialPair, - ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id, + and_( + ConnectorCredentialPair.id + == UserGroup__ConnectorCredentialPair.cc_pair_id, + ConnectorCredentialPair.access_type != AccessType.SYNC, + ), ) .join( DocumentByConnectorCredentialPair, diff --git a/backend/ee/danswer/external_permissions/confluence/doc_sync.py b/backend/ee/danswer/external_permissions/confluence/doc_sync.py index 57e8e2b226d..94f02409375 100644 --- a/backend/ee/danswer/external_permissions/confluence/doc_sync.py +++ b/backend/ee/danswer/external_permissions/confluence/doc_sync.py @@ -97,6 +97,7 @@ def _get_space_permissions( confluence_client: OnyxConfluence, is_cloud: bool, ) -> dict[str, ExternalAccess]: + logger.debug("Getting space permissions") # Gets all the spaces in the Confluence instance all_space_keys = [] start = 0 @@ -113,6 +114,7 @@ def _get_space_permissions( start += len(spaces_batch.get("results", [])) # Gets the permissions for each space + logger.debug(f"Got {len(all_space_keys)} spaces from confluence") space_permissions_by_space_key: dict[str, ExternalAccess] = {} for space_key in all_space_keys: if is_cloud: @@ -193,6 +195,7 @@ def _fetch_all_page_restrictions_for_space( confluence_client: OnyxConfluence, slim_docs: list[SlimDocument], space_permissions_by_space_key: dict[str, ExternalAccess], + is_cloud: bool, ) -> list[DocExternalAccess]: """ For all pages, if a page has restrictions, then use those restrictions. @@ -220,28 +223,52 @@ def _fetch_all_page_restrictions_for_space( continue space_key = slim_doc.perm_sync_data.get("space_key") - if space_permissions := space_permissions_by_space_key.get(space_key): - # If there are no restrictions, then use the space's restrictions - document_restrictions.append( - DocExternalAccess( - doc_id=slim_doc.id, - external_access=space_permissions, - ) + if not (space_permissions := space_permissions_by_space_key.get(space_key)): + logger.debug( + f"Individually fetching space permissions for space {space_key}" ) - if ( - not space_permissions.is_public - and not space_permissions.external_user_emails - and not space_permissions.external_user_group_ids - ): + try: + # If the space permissions are not in the cache, then fetch them + if is_cloud: + retrieved_space_permissions = _get_cloud_space_permissions( + confluence_client=confluence_client, space_key=space_key + ) + else: + retrieved_space_permissions = _get_server_space_permissions( + confluence_client=confluence_client, space_key=space_key + ) + space_permissions_by_space_key[space_key] = retrieved_space_permissions + space_permissions = retrieved_space_permissions + except Exception as e: logger.warning( - f"Permissions are empty for document: {slim_doc.id}\n" - "This means space permissions are may be wrong for" - f" Space key: {space_key}" + f"Error fetching space permissions for space {space_key}: {e}" ) + + if not space_permissions: + logger.warning( + f"No permissions found for document {slim_doc.id} in space {space_key}" + ) continue - logger.warning(f"No permissions found for document {slim_doc.id}") + # If there are no restrictions, then use the space's restrictions + document_restrictions.append( + DocExternalAccess( + doc_id=slim_doc.id, + external_access=space_permissions, + ) + ) + if ( + not space_permissions.is_public + and not space_permissions.external_user_emails + and not space_permissions.external_user_group_ids + ): + logger.warning( + f"Permissions are empty for document: {slim_doc.id}\n" + "This means space permissions are may be wrong for" + f" Space key: {space_key}" + ) + logger.debug("Finished fetching all page restrictions for space") return document_restrictions @@ -254,27 +281,29 @@ def confluence_doc_sync( it in postgres so that when it gets created later, the permissions are already populated """ + logger.debug("Starting confluence doc sync") confluence_connector = ConfluenceConnector( **cc_pair.connector.connector_specific_config ) confluence_connector.load_credentials(cc_pair.credential.credential_json) - if confluence_connector.confluence_client is None: - raise ValueError("Failed to load credentials") - confluence_client = confluence_connector.confluence_client is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) space_permissions_by_space_key = _get_space_permissions( - confluence_client=confluence_client, + confluence_client=confluence_connector.confluence_client, is_cloud=is_cloud, ) slim_docs = [] + logger.debug("Fetching all slim documents from confluence") for doc_batch in confluence_connector.retrieve_all_slim_documents(): + logger.debug(f"Got {len(doc_batch)} slim documents from confluence") slim_docs.extend(doc_batch) + logger.debug("Fetching all page restrictions for space") return _fetch_all_page_restrictions_for_space( - confluence_client=confluence_client, + confluence_client=confluence_connector.confluence_client, slim_docs=slim_docs, space_permissions_by_space_key=space_permissions_by_space_key, + is_cloud=is_cloud, ) diff --git a/backend/ee/danswer/external_permissions/confluence/group_sync.py b/backend/ee/danswer/external_permissions/confluence/group_sync.py index 383bc3c5d94..fd613445a2a 100644 --- a/backend/ee/danswer/external_permissions/confluence/group_sync.py +++ b/backend/ee/danswer/external_permissions/confluence/group_sync.py @@ -15,7 +15,10 @@ def _build_group_member_email_map( ) -> dict[str, set[str]]: group_member_emails: dict[str, set[str]] = {} for user_result in confluence_client.paginated_cql_user_retrieval(): - user = user_result["user"] + user = user_result.get("user", {}) + if not user: + logger.warning(f"user result missing user field: {user_result}") + continue email = user.get("email") if not email: # This field is only present in Confluence Server diff --git a/backend/ee/danswer/external_permissions/sync_params.py b/backend/ee/danswer/external_permissions/sync_params.py index c00090d748d..3dc4e46b9a1 100644 --- a/backend/ee/danswer/external_permissions/sync_params.py +++ b/backend/ee/danswer/external_permissions/sync_params.py @@ -3,6 +3,8 @@ from danswer.access.models import DocExternalAccess from danswer.configs.constants import DocumentSource from danswer.db.models import ConnectorCredentialPair +from ee.danswer.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY +from ee.danswer.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY from ee.danswer.db.external_perm import ExternalUserGroup from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync @@ -48,18 +50,23 @@ } +GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = { + DocumentSource.CONFLUENCE, +} + + # If nothing is specified here, we run the doc_sync every time the celery beat runs DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = { # Polling is not supported so we fetch all doc permissions every 5 minutes - DocumentSource.CONFLUENCE: 5 * 60, + DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY, DocumentSource.SLACK: 5 * 60, } # If nothing is specified here, we run the doc_sync every time the celery beat runs EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = { - # Polling is not supported so we fetch all group permissions every 60 seconds - DocumentSource.GOOGLE_DRIVE: 60, - DocumentSource.CONFLUENCE: 60, + # Polling is not supported so we fetch all group permissions every 30 minutes + DocumentSource.GOOGLE_DRIVE: 5 * 60, + DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY, } diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index 96655af2acd..c1e5977706d 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -13,7 +13,6 @@ from danswer.configs.constants import AuthType from danswer.main import get_application as get_application_base from danswer.main import include_router_with_global_prefix_prepended -from danswer.server.api_key.api import router as api_key_router from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version from ee.danswer.configs.app_configs import OPENID_CONFIG_URL @@ -27,6 +26,7 @@ ) from ee.danswer.server.manage.standard_answer import router as standard_answer_router from ee.danswer.server.middleware.tenant_tracking import add_tenant_id_middleware +from ee.danswer.server.oauth import router as oauth_router from ee.danswer.server.query_and_chat.chat_backend import ( router as chat_router, ) @@ -116,12 +116,12 @@ def get_application() -> FastAPI: # Analytics endpoints include_router_with_global_prefix_prepended(application, analytics_router) include_router_with_global_prefix_prepended(application, query_history_router) - # Api key management - include_router_with_global_prefix_prepended(application, api_key_router) # EE only backend APIs include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, chat_router) include_router_with_global_prefix_prepended(application, standard_answer_router) + include_router_with_global_prefix_prepended(application, oauth_router) + # Enterprise-only global settings include_router_with_global_prefix_prepended( application, enterprise_settings_admin_router diff --git a/backend/ee/danswer/server/analytics/api.py b/backend/ee/danswer/server/analytics/api.py index f79199323f5..2963dc2134c 100644 --- a/backend/ee/danswer/server/analytics/api.py +++ b/backend/ee/danswer/server/analytics/api.py @@ -11,11 +11,16 @@ from danswer.db.models import User from ee.danswer.db.analytics import fetch_danswerbot_analytics from ee.danswer.db.analytics import fetch_per_user_query_analytics +from ee.danswer.db.analytics import fetch_persona_message_analytics +from ee.danswer.db.analytics import fetch_persona_unique_users from ee.danswer.db.analytics import fetch_query_analytics router = APIRouter(prefix="/analytics") +_DEFAULT_LOOKBACK_DAYS = 30 + + class QueryAnalyticsResponse(BaseModel): total_queries: int total_likes: int @@ -33,7 +38,7 @@ def get_query_analytics( daily_query_usage_info = fetch_query_analytics( start=start or ( - datetime.datetime.utcnow() - datetime.timedelta(days=30) + datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) ), # default is 30d lookback end=end or datetime.datetime.utcnow(), db_session=db_session, @@ -64,7 +69,7 @@ def get_user_analytics( daily_query_usage_info_per_user = fetch_per_user_query_analytics( start=start or ( - datetime.datetime.utcnow() - datetime.timedelta(days=30) + datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) ), # default is 30d lookback end=end or datetime.datetime.utcnow(), db_session=db_session, @@ -98,7 +103,7 @@ def get_danswerbot_analytics( daily_danswerbot_info = fetch_danswerbot_analytics( start=start or ( - datetime.datetime.utcnow() - datetime.timedelta(days=30) + datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) ), # default is 30d lookback end=end or datetime.datetime.utcnow(), db_session=db_session, @@ -115,3 +120,74 @@ def get_danswerbot_analytics( ] return resolution_results + + +class PersonaMessageAnalyticsResponse(BaseModel): + total_messages: int + date: datetime.date + persona_id: int + + +@router.get("/admin/persona/messages") +def get_persona_messages( + persona_id: int, + start: datetime.datetime | None = None, + end: datetime.datetime | None = None, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[PersonaMessageAnalyticsResponse]: + """Fetch daily message counts for a single persona within the given time range.""" + start = start or ( + datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) + ) + end = end or datetime.datetime.utcnow() + + persona_message_counts = [] + for count, date in fetch_persona_message_analytics( + db_session=db_session, + persona_id=persona_id, + start=start, + end=end, + ): + persona_message_counts.append( + PersonaMessageAnalyticsResponse( + total_messages=count, + date=date, + persona_id=persona_id, + ) + ) + + return persona_message_counts + + +class PersonaUniqueUsersResponse(BaseModel): + unique_users: int + date: datetime.date + persona_id: int + + +@router.get("/admin/persona/unique-users") +def get_persona_unique_users( + persona_id: int, + start: datetime.datetime, + end: datetime.datetime, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[PersonaUniqueUsersResponse]: + """Get unique users per day for a single persona.""" + unique_user_counts = [] + daily_counts = fetch_persona_unique_users( + db_session=db_session, + persona_id=persona_id, + start=start, + end=end, + ) + for count, date in daily_counts: + unique_user_counts.append( + PersonaUniqueUsersResponse( + unique_users=count, + date=date, + persona_id=persona_id, + ) + ) + return unique_user_counts diff --git a/backend/ee/danswer/server/enterprise_settings/api.py b/backend/ee/danswer/server/enterprise_settings/api.py index 385adcf689e..272d8bf9369 100644 --- a/backend/ee/danswer/server/enterprise_settings/api.py +++ b/backend/ee/danswer/server/enterprise_settings/api.py @@ -113,10 +113,6 @@ async def refresh_access_token( def put_settings( settings: EnterpriseSettings, _: User | None = Depends(current_admin_user) ) -> None: - try: - settings.check_validity() - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) store_settings(settings) diff --git a/backend/ee/danswer/server/oauth.py b/backend/ee/danswer/server/oauth.py new file mode 100644 index 00000000000..8a39f1ec58e --- /dev/null +++ b/backend/ee/danswer/server/oauth.py @@ -0,0 +1,423 @@ +import base64 +import uuid +from typing import cast + +import requests +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import DocumentSource +from danswer.db.credentials import create_credential +from danswer.db.engine import get_current_tenant_id +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.redis.redis_pool import get_redis_client +from danswer.server.documents.models import CredentialBase +from danswer.utils.logger import setup_logger +from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_ID +from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET + + +logger = setup_logger() + +router = APIRouter(prefix="/oauth") + + +class SlackOAuth: + # https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth + # Example: https://api.slack.com/authentication/oauth-v2#exchanging + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + CLIENT_ID = OAUTH_SLACK_CLIENT_ID + CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET + + TOKEN_URL = "https://slack.com/api/oauth.v2.access" + + # SCOPE is per https://docs.danswer.dev/connectors/slack + BOT_SCOPE = ( + "channels:history," + "channels:read," + "groups:history," + "groups:read," + "channels:join," + "im:history," + "users:read," + "users:read.email," + "usergroups:read" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + url = ( + f"https://slack.com/oauth/v2/authorize" + f"?client_id={cls.CLIENT_ID}" + f"&redirect_uri={cls.REDIRECT_URI}" + f"&scope={cls.BOT_SCOPE}" + f"&state={state}" + ) + return url + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + + url = ( + f"https://slack.com/oauth/v2/authorize" + f"?client_id={cls.CLIENT_ID}" + f"&redirect_uri={cls.DEV_REDIRECT_URI}" + f"&scope={cls.BOT_SCOPE}" + f"&state={state}" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = SlackOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = SlackOAuth.OAuthSession.model_validate_json(session_json) + return session + + +# Work in progress +# class ConfluenceCloudOAuth: +# """work in progress""" + +# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/ + +# class OAuthSession(BaseModel): +# """Stored in redis to be looked up on callback""" + +# email: str +# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + +# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID +# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET +# TOKEN_URL = "https://auth.atlassian.com/oauth/token" + +# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/ +# CONFLUENCE_OAUTH_SCOPE = ( +# "read:confluence-props%20" +# "read:confluence-content.all%20" +# "read:confluence-content.summary%20" +# "read:confluence-content.permission%20" +# "read:confluence-user%20" +# "read:confluence-groups%20" +# "readonly:content.attachment:confluence" +# ) + +# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback" +# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + +# # eventually for Confluence Data Center +# # oauth_url = ( +# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}" +# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}" +# # f"&redirect_uri={redirectme_uri}" +# # ) + +# @classmethod +# def generate_oauth_url(cls, state: str) -> str: +# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + +# @classmethod +# def generate_dev_oauth_url(cls, state: str) -> str: +# """dev mode workaround for localhost testing +# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https +# """ +# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + +# @classmethod +# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: +# url = ( +# "https://auth.atlassian.com/authorize" +# f"?audience=api.atlassian.com" +# f"&client_id={cls.CLIENT_ID}" +# f"&redirect_uri={redirect_uri}" +# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}" +# f"&state={state}" +# "&response_type=code" +# "&prompt=consent" +# ) +# return url + +# @classmethod +# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: +# """Temporary state to store in redis. to be looked up on auth response. +# Returns a json string. +# """ +# session = ConfluenceCloudOAuth.OAuthSession( +# email=email, redirect_on_success=redirect_on_success +# ) +# return session.model_dump_json() + +# @classmethod +# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession: +# session = SlackOAuth.OAuthSession.model_validate_json(session_json) +# return session + + +@router.post("/prepare-authorization-request") +def prepare_authorization_request( + connector: DocumentSource, + redirect_on_success: str | None, + user: User = Depends(current_user), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Used by the frontend to generate the url for the user's browser during auth request. + + Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/ + """ + + oauth_uuid = uuid.uuid4() + oauth_uuid_str = str(oauth_uuid) + oauth_state = ( + base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8") + ) + + if connector == DocumentSource.SLACK: + oauth_url = SlackOAuth.generate_oauth_url(oauth_state) + session = SlackOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + # elif connector == DocumentSource.CONFLUENCE: + # oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state) + # session = ConfluenceCloudOAuth.session_dump_json( + # email=user.email, redirect_on_success=redirect_on_success + # ) + # elif connector == DocumentSource.JIRA: + # oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state) + # elif connector == DocumentSource.GOOGLE_DRIVE: + # oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state) + else: + oauth_url = None + + if not oauth_url: + raise HTTPException( + status_code=404, + detail=f"The document source type {connector} does not have OAuth implemented", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # 10 min is the max we want an oauth flow to be valid + r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600) + + return JSONResponse(content={"url": oauth_url}) + + +@router.post("/connector/slack/callback") +def handle_slack_oauth_callback( + code: str, + state: str, + user: User = Depends(current_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Slack client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = SlackOAuth.parse_session(session_json) + + # Exchange the authorization code for an access token + response = requests.post( + SlackOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": SlackOAuth.CLIENT_ID, + "client_secret": SlackOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": SlackOAuth.REDIRECT_URI, + }, + ) + + response_data = response.json() + + if not response_data.get("ok"): + raise HTTPException( + status_code=400, + detail=f"Slack OAuth failed: {response_data.get('error')}", + ) + + # Extract token and team information + access_token: str = response_data.get("access_token") + team_id: str = response_data.get("team", {}).get("id") + authed_user_id: str = response_data.get("authed_user", {}).get("id") + + credential_info = CredentialBase( + credential_json={"slack_bot_token": access_token}, + admin_public=True, + source=DocumentSource.SLACK, + name="Slack OAuth", + ) + + create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Slack OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Slack OAuth completed successfully.", + "team_id": team_id, + "authed_user_id": authed_user_id, + "redirect_on_success": session.redirect_on_success, + } + ) + + +# Work in progress +# @router.post("/connector/confluence/callback") +# def handle_confluence_oauth_callback( +# code: str, +# state: str, +# user: User = Depends(current_user), +# db_session: Session = Depends(get_session), +# tenant_id: str | None = Depends(get_current_tenant_id), +# ) -> JSONResponse: +# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET: +# raise HTTPException( +# status_code=500, +# detail="Confluence client ID or client secret is not configured." +# ) + +# r = get_redis_client(tenant_id=tenant_id) + +# # recover the state +# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding) +# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes + +# # Convert bytes back to a UUID +# oauth_uuid = uuid.UUID(bytes=uuid_bytes) +# oauth_uuid_str = str(oauth_uuid) + +# r_key = f"da_oauth:{oauth_uuid_str}" + +# result = r.get(r_key) +# if not result: +# raise HTTPException( +# status_code=400, +# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}" +# ) + +# try: +# session = ConfluenceCloudOAuth.parse_session(result) + +# # Exchange the authorization code for an access token +# response = requests.post( +# ConfluenceCloudOAuth.TOKEN_URL, +# headers={"Content-Type": "application/x-www-form-urlencoded"}, +# data={ +# "client_id": ConfluenceCloudOAuth.CLIENT_ID, +# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET, +# "code": code, +# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI, +# }, +# ) + +# response_data = response.json() + +# if not response_data.get("ok"): +# raise HTTPException( +# status_code=400, +# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}" +# ) + +# # Extract token and team information +# access_token: str = response_data.get("access_token") +# team_id: str = response_data.get("team", {}).get("id") +# authed_user_id: str = response_data.get("authed_user", {}).get("id") + +# credential_info = CredentialBase( +# credential_json={"slack_bot_token": access_token}, +# admin_public=True, +# source=DocumentSource.CONFLUENCE, +# name="Confluence OAuth", +# ) + +# logger.info(f"Slack access token: {access_token}") + +# credential = create_credential(credential_info, user, db_session) + +# logger.info(f"new_credential_id={credential.id}") +# except Exception as e: +# return JSONResponse( +# status_code=500, +# content={ +# "success": False, +# "message": f"An error occurred during Slack OAuth: {str(e)}", +# }, +# ) +# finally: +# r.delete(r_key) + +# # return the result +# return JSONResponse( +# content={ +# "success": True, +# "message": "Slack OAuth completed successfully.", +# "team_id": team_id, +# "authed_user_id": authed_user_id, +# "redirect_on_success": session.redirect_on_success, +# } +# ) diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index ef707cbfb24..0122077fca5 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_user +from danswer.chat.chat_utils import combine_message_thread from danswer.chat.chat_utils import create_chat_chain from danswer.chat.models import AllCitations from danswer.chat.models import DanswerAnswerPiece @@ -16,8 +17,8 @@ from danswer.chat.models import StreamingError from danswer.chat.process_message import ChatPacketStream from danswer.chat.process_message import stream_chat_message_objects +from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.constants import MessageType -from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE from danswer.context.search.models import OptionalSearchSetting from danswer.context.search.models import RetrievalDetails from danswer.context.search.models import SavedSearchDoc @@ -29,7 +30,6 @@ from danswer.llm.factory import get_llms_for_persona from danswer.llm.utils import get_max_input_tokens from danswer.natural_language_processing.utils import get_tokenizer -from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest @@ -171,6 +171,8 @@ def handle_simplified_chat_message( prompt_id=None, search_doc_ids=chat_message_req.search_doc_ids, retrieval_options=retrieval_options, + # Simple API does not support reranking, hide complexity from user + rerank_settings=None, query_override=chat_message_req.query_override, # Currently only applies to search flow not chat chunks_above=0, @@ -232,7 +234,6 @@ def handle_send_message_simple_with_history( description="handle_send_message_simple_with_history", user_id=user_id, persona_id=req.persona_id, - one_shot=False, ) llm, _ = get_llms_for_persona(persona=chat_session.persona) @@ -245,7 +246,7 @@ def handle_send_message_simple_with_history( input_tokens = get_max_input_tokens( model_name=llm.config.model_name, model_provider=llm.config.model_provider ) - max_history_tokens = int(input_tokens * DANSWER_BOT_TARGET_CHUNK_PERCENTAGE) + max_history_tokens = int(input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE) # Every chat Session begins with an empty root message root_message = get_or_create_root_message( @@ -293,6 +294,8 @@ def handle_send_message_simple_with_history( prompt_id=req.prompt_id, search_doc_ids=req.search_doc_ids, retrieval_options=retrieval_options, + # Simple API does not support reranking, hide complexity from user + rerank_settings=None, query_override=rephrased_query, chunks_above=0, chunks_below=0, diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index 1fd37a21145..101b2848cdb 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -2,7 +2,13 @@ from pydantic import BaseModel from pydantic import Field +from pydantic import model_validator +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerContexts +from danswer.chat.models import PersonaOverrideConfig +from danswer.chat.models import QADocsResponse +from danswer.chat.models import ThreadMessage from danswer.configs.constants import DocumentSource from danswer.context.search.enums import LLMEvaluationType from danswer.context.search.enums import SearchType @@ -10,7 +16,6 @@ from danswer.context.search.models import RerankingDetails from danswer.context.search.models import RetrievalDetails from danswer.context.search.models import SavedSearchDoc -from danswer.one_shot_answer.models import ThreadMessage from ee.danswer.server.manage.models import StandardAnswer @@ -96,3 +101,48 @@ class ChatBasicResponse(BaseModel): # TODO: deprecate both of these simple_search_docs: list[SimpleDoc] | None = None llm_chunks_indices: list[int] | None = None + + +class OneShotQARequest(ChunkContext): + # Supports simplier APIs that don't deal with chat histories or message edits + # Easier APIs to work with for developers + persona_override_config: PersonaOverrideConfig | None = None + persona_id: int | None = None + + messages: list[ThreadMessage] + prompt_id: int | None = None + retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) + rerank_settings: RerankingDetails | None = None + return_contexts: bool = False + + # allows the caller to specify the exact search query they want to use + # can be used if the message sent to the LLM / query should not be the same + # will also disable Thread-based Rewording if specified + query_override: str | None = None + + # If True, skips generative an AI response to the search query + skip_gen_ai_answer_generation: bool = False + + @model_validator(mode="after") + def check_persona_fields(self) -> "OneShotQARequest": + if self.persona_override_config is None and self.persona_id is None: + raise ValueError("Exactly one of persona_config or persona_id must be set") + elif self.persona_override_config is not None and ( + self.persona_id is not None or self.prompt_id is not None + ): + raise ValueError( + "If persona_override_config is set, persona_id and prompt_id cannot be set" + ) + return self + + +class OneShotQAResponse(BaseModel): + # This is built piece by piece, any of these can be None as the flow could break + answer: str | None = None + rephrase: str | None = None + citations: list[CitationInfo] | None = None + docs: QADocsResponse | None = None + llm_selected_doc_indices: list[int] | None = None + error_msg: str | None = None + chat_message_id: int | None = None + contexts: DanswerContexts | None = None diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index 0b380d5d3f7..16e4b4ebc2e 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -1,38 +1,47 @@ +import json +from collections.abc import Generator + from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.auth.users import current_user -from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE +from danswer.chat.chat_utils import combine_message_thread +from danswer.chat.chat_utils import prepare_chat_message_request +from danswer.chat.models import PersonaOverrideConfig +from danswer.chat.process_message import ChatPacketStream +from danswer.chat.process_message import stream_chat_message_objects +from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE from danswer.context.search.models import SavedSearchDocWithContent from danswer.context.search.models import SearchRequest from danswer.context.search.pipeline import SearchPipeline from danswer.context.search.utils import dedupe_documents from danswer.context.search.utils import drop_llm_indices from danswer.context.search.utils import relevant_sections_to_indices +from danswer.db.chat import get_prompt_by_id from danswer.db.engine import get_session +from danswer.db.models import Persona from danswer.db.models import User from danswer.db.persona import get_persona_by_id -from danswer.llm.answering.prompts.citations_prompt import ( - compute_max_document_tokens_for_persona, -) from danswer.llm.factory import get_default_llms from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.utils import get_max_input_tokens -from danswer.one_shot_answer.answer_question import get_search_answer -from danswer.one_shot_answer.models import DirectQARequest -from danswer.one_shot_answer.models import OneShotQAResponse +from danswer.natural_language_processing.utils import get_tokenizer +from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger +from ee.danswer.chat.process_message import gather_stream_for_answer_api from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import ( oneoff_standard_answers, ) from ee.danswer.server.query_and_chat.models import DocumentSearchRequest +from ee.danswer.server.query_and_chat.models import OneShotQARequest +from ee.danswer.server.query_and_chat.models import OneShotQAResponse from ee.danswer.server.query_and_chat.models import StandardAnswerRequest from ee.danswer.server.query_and_chat.models import StandardAnswerResponse -from ee.danswer.server.query_and_chat.utils import create_temporary_persona logger = setup_logger() @@ -125,58 +134,115 @@ def handle_search_request( return DocumentSearchResponse(top_documents=deduped_docs, llm_indices=llm_indices) -@basic_router.post("/answer-with-quote") -def get_answer_with_quote( - query_request: DirectQARequest, +def get_answer_stream( + query_request: OneShotQARequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), -) -> OneShotQAResponse: +) -> ChatPacketStream: query = query_request.messages[0].message - logger.notice(f"Received query for one shot answer API with quotes: {query}") + logger.notice(f"Received query for Answer API: {query}") - if query_request.persona_config is not None: - new_persona = create_temporary_persona( - db_session=db_session, - persona_config=query_request.persona_config, + if ( + query_request.persona_override_config is None + and query_request.persona_id is None + ): + raise KeyError("Must provide persona ID or Persona Config") + + prompt = None + if query_request.prompt_id is not None: + prompt = get_prompt_by_id( + prompt_id=query_request.prompt_id, user=user, + db_session=db_session, ) - persona = new_persona + persona_info: Persona | PersonaOverrideConfig | None = None + if query_request.persona_override_config is not None: + persona_info = query_request.persona_override_config elif query_request.persona_id is not None: - persona = get_persona_by_id( + persona_info = get_persona_by_id( persona_id=query_request.persona_id, user=user, db_session=db_session, is_for_edit=False, ) - else: - raise KeyError("Must provide persona ID or Persona Config") - llm = get_main_llm_from_tuple( - get_default_llms() if not persona else get_llms_for_persona(persona) + llm = get_main_llm_from_tuple(get_llms_for_persona(persona_info)) + + llm_tokenizer = get_tokenizer( + model_name=llm.config.model_name, + provider_type=llm.config.model_provider, ) + input_tokens = get_max_input_tokens( model_name=llm.config.model_name, model_provider=llm.config.model_provider ) - max_history_tokens = int(input_tokens * DANSWER_BOT_TARGET_CHUNK_PERCENTAGE) + max_history_tokens = int(input_tokens * MAX_THREAD_CONTEXT_PERCENTAGE) - remaining_tokens = input_tokens - max_history_tokens + combined_message = combine_message_thread( + messages=query_request.messages, + max_tokens=max_history_tokens, + llm_tokenizer=llm_tokenizer, + ) - max_document_tokens = compute_max_document_tokens_for_persona( - persona=persona, - actual_user_input=query, - max_llm_token_override=remaining_tokens, + # Also creates a new chat session + request = prepare_chat_message_request( + message_text=combined_message, + user=user, + persona_id=query_request.persona_id, + persona_override_config=query_request.persona_override_config, + prompt=prompt, + message_ts_to_respond_to=None, + retrieval_details=query_request.retrieval_options, + rerank_settings=query_request.rerank_settings, + db_session=db_session, ) - answer_details = get_search_answer( - query_req=query_request, + packets = stream_chat_message_objects( + new_msg_req=request, user=user, - max_document_tokens=max_document_tokens, - max_history_tokens=max_history_tokens, db_session=db_session, + include_contexts=query_request.return_contexts, ) - return answer_details + return packets + + +@basic_router.post("/answer-with-citation") +def get_answer_with_citation( + request: OneShotQARequest, + db_session: Session = Depends(get_session), + user: User | None = Depends(current_user), +) -> OneShotQAResponse: + try: + packets = get_answer_stream(request, user, db_session) + answer = gather_stream_for_answer_api(packets) + + if answer.error_msg: + raise RuntimeError(answer.error_msg) + + return answer + except Exception as e: + logger.error(f"Error in get_answer_with_citation: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="An internal server error occurred") + + +@basic_router.post("/stream-answer-with-citation") +def stream_answer_with_citation( + request: OneShotQARequest, + db_session: Session = Depends(get_session), + user: User | None = Depends(current_user), +) -> StreamingResponse: + def stream_generator() -> Generator[str, None, None]: + try: + for packet in get_answer_stream(request, user, db_session): + serialized = get_json_line(packet.model_dump()) + yield serialized + except Exception as e: + logger.exception("Error in answer streaming") + yield json.dumps({"error": str(e)}) + + return StreamingResponse(stream_generator(), media_type="application/json") @basic_router.get("/standard-answer") diff --git a/backend/ee/danswer/server/query_and_chat/utils.py b/backend/ee/danswer/server/query_and_chat/utils.py deleted file mode 100644 index be5507b01c2..00000000000 --- a/backend/ee/danswer/server/query_and_chat/utils.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import cast - -from fastapi import HTTPException -from sqlalchemy.orm import Session - -from danswer.auth.users import is_user_admin -from danswer.db.llm import fetch_existing_doc_sets -from danswer.db.llm import fetch_existing_tools -from danswer.db.models import Persona -from danswer.db.models import Prompt -from danswer.db.models import Tool -from danswer.db.models import User -from danswer.db.persona import get_prompts_by_ids -from danswer.one_shot_answer.models import PersonaConfig -from danswer.tools.tool_implementations.custom.custom_tool import ( - build_custom_tools_from_openapi_schema_and_headers, -) - - -def create_temporary_persona( - persona_config: PersonaConfig, db_session: Session, user: User | None = None -) -> Persona: - if not is_user_admin(user): - raise HTTPException( - status_code=403, - detail="User is not authorized to create a persona in one shot queries", - ) - - """Create a temporary Persona object from the provided configuration.""" - persona = Persona( - name=persona_config.name, - description=persona_config.description, - num_chunks=persona_config.num_chunks, - llm_relevance_filter=persona_config.llm_relevance_filter, - llm_filter_extraction=persona_config.llm_filter_extraction, - recency_bias=persona_config.recency_bias, - llm_model_provider_override=persona_config.llm_model_provider_override, - llm_model_version_override=persona_config.llm_model_version_override, - ) - - if persona_config.prompts: - persona.prompts = [ - Prompt( - name=p.name, - description=p.description, - system_prompt=p.system_prompt, - task_prompt=p.task_prompt, - include_citations=p.include_citations, - datetime_aware=p.datetime_aware, - ) - for p in persona_config.prompts - ] - elif persona_config.prompt_ids: - persona.prompts = get_prompts_by_ids( - db_session=db_session, prompt_ids=persona_config.prompt_ids - ) - - persona.tools = [] - if persona_config.custom_tools_openapi: - for schema in persona_config.custom_tools_openapi: - tools = cast( - list[Tool], - build_custom_tools_from_openapi_schema_and_headers(schema), - ) - persona.tools.extend(tools) - - if persona_config.tools: - tool_ids = [tool.id for tool in persona_config.tools] - persona.tools.extend( - fetch_existing_tools(db_session=db_session, tool_ids=tool_ids) - ) - - if persona_config.tool_ids: - persona.tools.extend( - fetch_existing_tools( - db_session=db_session, tool_ids=persona_config.tool_ids - ) - ) - - fetched_docs = fetch_existing_doc_sets( - db_session=db_session, doc_ids=persona_config.document_set_ids - ) - persona.document_sets = fetched_docs - - return persona diff --git a/backend/ee/danswer/server/query_history/api.py b/backend/ee/danswer/server/query_history/api.py index df6175cf271..0a15013dd65 100644 --- a/backend/ee/danswer/server/query_history/api.py +++ b/backend/ee/danswer/server/query_history/api.py @@ -179,13 +179,7 @@ def to_json(self) -> dict[str, str | None]: def determine_flow_type(chat_session: ChatSession) -> SessionType: - return ( - SessionType.SLACK - if chat_session.danswerbot_flow - else SessionType.SEARCH - if chat_session.one_shot - else SessionType.CHAT - ) + return SessionType.SLACK if chat_session.danswerbot_flow else SessionType.CHAT def fetch_and_process_chat_session_history_minimal( diff --git a/backend/ee/danswer/server/reporting/usage_export_models.py b/backend/ee/danswer/server/reporting/usage_export_models.py index 21cd104e862..efaee7378fa 100644 --- a/backend/ee/danswer/server/reporting/usage_export_models.py +++ b/backend/ee/danswer/server/reporting/usage_export_models.py @@ -9,7 +9,6 @@ class FlowType(str, Enum): CHAT = "chat" - SEARCH = "search" SLACK = "slack" diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index 7aa87379221..f1081fe5f37 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -132,13 +132,18 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> if personas: logger.notice("Seeding Personas") for persona in personas: + if not persona.prompt_ids: + raise ValueError( + f"Invalid Persona with name {persona.name}; no prompts exist" + ) + upsert_persona( user=None, # Seeding is done as admin name=persona.name, description=persona.description, - num_chunks=persona.num_chunks - if persona.num_chunks is not None - else 0.0, + num_chunks=( + persona.num_chunks if persona.num_chunks is not None else 0.0 + ), llm_relevance_filter=persona.llm_relevance_filter, llm_filter_extraction=persona.llm_filter_extraction, recency_bias=RecencyBiasSetting.AUTO, @@ -157,7 +162,6 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> def _seed_settings(settings: Settings) -> None: logger.notice("Seeding Settings") try: - settings.check_validity() store_base_settings(settings) logger.notice("Successfully seeded Settings") except ValueError as e: diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 003953cb29a..ef04c0a7f05 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,4 +1,6 @@ +import asyncio import json +from types import TracebackType from typing import cast from typing import Optional @@ -6,11 +8,12 @@ import openai import vertexai # type: ignore import voyageai # type: ignore -from cohere import Client as CohereClient +from cohere import AsyncClient as CohereAsyncClient from fastapi import APIRouter from fastapi import HTTPException from google.oauth2 import service_account # type: ignore -from litellm import embedding +from litellm import aembedding +from litellm.exceptions import RateLimitError from retry import retry from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import SentenceTransformer # type: ignore @@ -62,22 +65,31 @@ def __init__( provider: EmbeddingProvider, api_url: str | None = None, api_version: str | None = None, + timeout: int = API_BASED_EMBEDDING_TIMEOUT, ) -> None: self.provider = provider self.api_key = api_key self.api_url = api_url self.api_version = api_version + self.timeout = timeout + self.http_client = httpx.AsyncClient(timeout=timeout) + self._closed = False - def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: + async def _embed_openai( + self, texts: list[str], model: str | None + ) -> list[Embedding]: if not model: model = DEFAULT_OPENAI_MODEL - client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT) + # Use the OpenAI specific timeout for this one + client = openai.AsyncOpenAI( + api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT + ) final_embeddings: list[Embedding] = [] try: for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): - response = client.embeddings.create(input=text_batch, model=model) + response = await client.embeddings.create(input=text_batch, model=model) final_embeddings.extend( [embedding.embedding for embedding in response.data] ) @@ -92,19 +104,19 @@ def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: logger.error(error_string) raise RuntimeError(error_string) - def _embed_cohere( + async def _embed_cohere( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_COHERE_MODEL - client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT) + client = CohereAsyncClient(api_key=self.api_key) final_embeddings: list[Embedding] = [] for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN): # Does not use the same tokenizer as the Danswer API server but it's approximately the same # empirically it's only off by a very few tokens so it's not a big deal - response = client.embed( + response = await client.embed( texts=text_batch, model=model, input_type=embedding_type, @@ -113,26 +125,29 @@ def _embed_cohere( final_embeddings.extend(cast(list[Embedding], response.embeddings)) return final_embeddings - def _embed_voyage( + async def _embed_voyage( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_VOYAGE_MODEL - client = voyageai.Client( + client = voyageai.AsyncClient( api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT ) - response = client.embed( - texts, + response = await client.embed( + texts=texts, model=model, input_type=embedding_type, truncation=True, ) + return response.embeddings - def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]: - response = embedding( + async def _embed_azure( + self, texts: list[str], model: str | None + ) -> list[Embedding]: + response = await aembedding( model=model, input=texts, timeout=API_BASED_EMBEDDING_TIMEOUT, @@ -141,10 +156,9 @@ def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]: api_version=self.api_version, ) embeddings = [embedding["embedding"] for embedding in response.data] - return embeddings - def _embed_vertex( + async def _embed_vertex( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: @@ -157,7 +171,7 @@ def _embed_vertex( vertexai.init(project=project_id, credentials=credentials) client = TextEmbeddingModel.from_pretrained(model) - embeddings = client.get_embeddings( + embeddings = await client.get_embeddings_async( [ TextEmbeddingInput( text, @@ -165,11 +179,11 @@ def _embed_vertex( ) for text in texts ], - auto_truncate=True, # Also this is default + auto_truncate=True, # This is the default ) return [embedding.values for embedding in embeddings] - def _embed_litellm_proxy( + async def _embed_litellm_proxy( self, texts: list[str], model_name: str | None ) -> list[Embedding]: if not model_name: @@ -182,22 +196,20 @@ def _embed_litellm_proxy( {} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"} ) - with httpx.Client() as client: - response = client.post( - self.api_url, - json={ - "model": model_name, - "input": texts, - }, - headers=headers, - timeout=API_BASED_EMBEDDING_TIMEOUT, - ) - response.raise_for_status() - result = response.json() - return [embedding["embedding"] for embedding in result["data"]] + response = await self.http_client.post( + self.api_url, + json={ + "model": model_name, + "input": texts, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] @retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY) - def embed( + async def embed( self, *, texts: list[str], @@ -205,28 +217,22 @@ def embed( model_name: str | None = None, deployment_name: str | None = None, ) -> list[Embedding]: - try: - if self.provider == EmbeddingProvider.OPENAI: - return self._embed_openai(texts, model_name) - elif self.provider == EmbeddingProvider.AZURE: - return self._embed_azure(texts, f"azure/{deployment_name}") - elif self.provider == EmbeddingProvider.LITELLM: - return self._embed_litellm_proxy(texts, model_name) - - embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) - if self.provider == EmbeddingProvider.COHERE: - return self._embed_cohere(texts, model_name, embedding_type) - elif self.provider == EmbeddingProvider.VOYAGE: - return self._embed_voyage(texts, model_name, embedding_type) - elif self.provider == EmbeddingProvider.GOOGLE: - return self._embed_vertex(texts, model_name, embedding_type) - else: - raise ValueError(f"Unsupported provider: {self.provider}") - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error embedding text with {self.provider}: {str(e)}", - ) + if self.provider == EmbeddingProvider.OPENAI: + return await self._embed_openai(texts, model_name) + elif self.provider == EmbeddingProvider.AZURE: + return await self._embed_azure(texts, f"azure/{deployment_name}") + elif self.provider == EmbeddingProvider.LITELLM: + return await self._embed_litellm_proxy(texts, model_name) + + embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) + if self.provider == EmbeddingProvider.COHERE: + return await self._embed_cohere(texts, model_name, embedding_type) + elif self.provider == EmbeddingProvider.VOYAGE: + return await self._embed_voyage(texts, model_name, embedding_type) + elif self.provider == EmbeddingProvider.GOOGLE: + return await self._embed_vertex(texts, model_name, embedding_type) + else: + raise ValueError(f"Unsupported provider: {self.provider}") @staticmethod def create( @@ -238,6 +244,30 @@ def create( logger.debug(f"Creating Embedding instance for provider: {provider}") return CloudEmbedding(api_key, provider, api_url, api_version) + async def aclose(self) -> None: + """Explicitly close the client.""" + if not self._closed: + await self.http_client.aclose() + self._closed = True + + async def __aenter__(self) -> "CloudEmbedding": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() + + def __del__(self) -> None: + """Finalizer to warn about unclosed clients.""" + if not self._closed: + logger.warning( + "CloudEmbedding was not properly closed. Use 'async with' or call aclose()" + ) + def get_embedding_model( model_name: str, @@ -247,9 +277,6 @@ def get_embedding_model( global _GLOBAL_MODELS_DICT # A dictionary to store models - if _GLOBAL_MODELS_DICT is None: - _GLOBAL_MODELS_DICT = {} - if model_name not in _GLOBAL_MODELS_DICT: logger.notice(f"Loading {model_name}") # Some model architectures that aren't built into the Transformers or Sentence @@ -280,7 +307,7 @@ def get_local_reranking_model( @simple_log_function_time() -def embed_text( +async def embed_text( texts: list[str], text_type: EmbedTextType, model_name: str | None, @@ -316,18 +343,18 @@ def embed_text( "Cloud models take an explicit text type instead." ) - cloud_model = CloudEmbedding( + async with CloudEmbedding( api_key=api_key, provider=provider_type, api_url=api_url, api_version=api_version, - ) - embeddings = cloud_model.embed( - texts=texts, - model_name=model_name, - deployment_name=deployment_name, - text_type=text_type, - ) + ) as cloud_model: + embeddings = await cloud_model.embed( + texts=texts, + model_name=model_name, + deployment_name=deployment_name, + text_type=text_type, + ) if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" @@ -343,8 +370,12 @@ def embed_text( local_model = get_embedding_model( model_name=model_name, max_context_length=max_context_length ) - embeddings_vectors = local_model.encode( - prefixed_texts, normalize_embeddings=normalize_embeddings + # Run CPU-bound embedding in a thread pool + embeddings_vectors = await asyncio.get_event_loop().run_in_executor( + None, + lambda: local_model.encode( + prefixed_texts, normalize_embeddings=normalize_embeddings + ), ) embeddings = [ embedding if isinstance(embedding, list) else embedding.tolist() @@ -362,27 +393,31 @@ def embed_text( @simple_log_function_time() -def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: +async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: cross_encoder = get_local_reranking_model(model_name) - return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore + # Run CPU-bound reranking in a thread pool + return await asyncio.get_event_loop().run_in_executor( + None, + lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore + ) -def cohere_rerank( +async def cohere_rerank( query: str, docs: list[str], model_name: str, api_key: str ) -> list[float]: - cohere_client = CohereClient(api_key=api_key) - response = cohere_client.rerank(query=query, documents=docs, model=model_name) + cohere_client = CohereAsyncClient(api_key=api_key) + response = await cohere_client.rerank(query=query, documents=docs, model=model_name) results = response.results sorted_results = sorted(results, key=lambda item: item.index) return [result.relevance_score for result in sorted_results] -def litellm_rerank( +async def litellm_rerank( query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None ) -> list[float]: headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} - with httpx.Client() as client: - response = client.post( + async with httpx.AsyncClient() as client: + response = await client.post( api_url, json={ "model": model_name, @@ -416,7 +451,7 @@ async def process_embed_request( else: prefix = None - embeddings = embed_text( + embeddings = await embed_text( texts=embed_request.texts, model_name=embed_request.model_name, deployment_name=embed_request.deployment_name, @@ -430,6 +465,11 @@ async def process_embed_request( prefix=prefix, ) return EmbedResponse(embeddings=embeddings) + except RateLimitError as e: + raise HTTPException( + status_code=429, + detail=str(e), + ) except Exception as e: exception_detail = f"Error during embedding process:\n{str(e)}" logger.exception(exception_detail) @@ -451,7 +491,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons try: if rerank_request.provider_type is None: - sim_scores = local_rerank( + sim_scores = await local_rerank( query=rerank_request.query, docs=rerank_request.documents, model_name=rerank_request.model_name, @@ -461,7 +501,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons if rerank_request.api_url is None: raise ValueError("API URL is required for LiteLLM reranking.") - sim_scores = litellm_rerank( + sim_scores = await litellm_rerank( query=rerank_request.query, docs=rerank_request.documents, api_url=rerank_request.api_url, @@ -474,7 +514,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons elif rerank_request.provider_type == RerankerProvider.COHERE: if rerank_request.api_key is None: raise RuntimeError("Cohere Rerank Requires an API Key") - sim_scores = cohere_rerank( + sim_scores = await cohere_rerank( query=rerank_request.query, docs=rerank_request.documents, model_name=rerank_request.model_name, diff --git a/backend/model_server/management_endpoints.py b/backend/model_server/management_endpoints.py index 56640a2fa73..4c6387e0708 100644 --- a/backend/model_server/management_endpoints.py +++ b/backend/model_server/management_endpoints.py @@ -6,12 +6,12 @@ @router.get("/health") -def healthcheck() -> Response: +async def healthcheck() -> Response: return Response(status_code=200) @router.get("/gpu-status") -def gpu_status() -> dict[str, bool | str]: +async def gpu_status() -> dict[str, bool | str]: if torch.cuda.is_available(): return {"gpu_available": True, "type": "cuda"} elif torch.backends.mps.is_available(): diff --git a/backend/model_server/utils.py b/backend/model_server/utils.py index 0c2d6bac5dc..86192b031f6 100644 --- a/backend/model_server/utils.py +++ b/backend/model_server/utils.py @@ -1,3 +1,4 @@ +import asyncio import time from collections.abc import Callable from collections.abc import Generator @@ -21,21 +22,39 @@ def simple_log_function_time( include_args: bool = False, ) -> Callable[[F], F]: def decorator(func: F) -> F: - @wraps(func) - def wrapped_func(*args: Any, **kwargs: Any) -> Any: - start_time = time.time() - result = func(*args, **kwargs) - elapsed_time_str = str(time.time() - start_time) - log_name = func_name or func.__name__ - args_str = f" args={args} kwargs={kwargs}" if include_args else "" - final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" - if debug_only: - logger.debug(final_log) - else: - logger.notice(final_log) - - return result - - return cast(F, wrapped_func) + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any: + start_time = time.time() + result = await func(*args, **kwargs) + elapsed_time_str = str(time.time() - start_time) + log_name = func_name or func.__name__ + args_str = f" args={args} kwargs={kwargs}" if include_args else "" + final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" + if debug_only: + logger.debug(final_log) + else: + logger.notice(final_log) + return result + + return cast(F, wrapped_async_func) + else: + + @wraps(func) + def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any: + start_time = time.time() + result = func(*args, **kwargs) + elapsed_time_str = str(time.time() - start_time) + log_name = func_name or func.__name__ + args_str = f" args={args} kwargs={kwargs}" if include_args else "" + final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" + if debug_only: + logger.debug(final_log) + else: + logger.notice(final_log) + return result + + return cast(F, wrapped_sync_func) return decorator diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 1cfa8818813..9d9fd782104 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -29,7 +29,7 @@ trafilatura==1.12.2 langchain==0.1.17 langchain-core==0.1.50 langchain-text-splitters==0.0.1 -litellm==1.50.2 +litellm==1.54.1 lxml==5.3.0 lxml_html_clean==0.2.2 llama-index==0.9.45 @@ -38,7 +38,7 @@ msal==1.28.0 nltk==3.8.1 Office365-REST-Python-Client==2.5.9 oauthlib==3.2.2 -openai==1.52.2 +openai==1.55.3 openpyxl==3.1.2 playwright==1.41.2 psutil==5.9.5 diff --git a/backend/requirements/dev.txt b/backend/requirements/dev.txt index 27304dbef37..a89b8db674d 100644 --- a/backend/requirements/dev.txt +++ b/backend/requirements/dev.txt @@ -1,30 +1,34 @@ black==23.3.0 +boto3-stubs[s3]==1.34.133 celery-types==0.19.0 +cohere==5.6.1 +google-cloud-aiplatform==1.58.0 +lxml==5.3.0 +lxml_html_clean==0.2.2 mypy-extensions==1.0.0 mypy==1.8.0 +pandas-stubs==2.2.3.241009 +pandas==2.2.3 pre-commit==3.2.2 +pytest-asyncio==0.22.0 pytest==7.4.4 reorder-python-imports==3.9.0 ruff==0.0.286 -types-PyYAML==6.0.12.11 +sentence-transformers==2.6.1 +trafilatura==1.12.2 types-beautifulsoup4==4.12.0.3 types-html5lib==1.1.11.13 types-oauthlib==3.2.0.9 -types-setuptools==68.0.0.3 -types-Pillow==10.2.0.20240822 types-passlib==1.7.7.20240106 +types-Pillow==10.2.0.20240822 types-psutil==5.9.5.17 types-psycopg2==2.9.21.10 types-python-dateutil==2.8.19.13 types-pytz==2023.3.1.1 +types-PyYAML==6.0.12.11 types-regex==2023.3.23.1 types-requests==2.28.11.17 types-retry==0.9.9.3 +types-setuptools==68.0.0.3 types-urllib3==1.26.25.11 -trafilatura==1.12.2 -lxml==5.3.0 -lxml_html_clean==0.2.2 -boto3-stubs[s3]==1.34.133 -pandas==2.2.3 -pandas-stubs==2.2.3.241009 -cohere==5.6.1 \ No newline at end of file +voyageai==0.2.3 diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index 4803dc64eb6..531382cb4b1 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -12,5 +12,5 @@ torch==2.2.0 transformers==4.39.2 uvicorn==0.21.1 voyageai==0.2.3 -litellm==1.50.2 +litellm==1.54.1 sentry-sdk[fastapi,celery,starlette]==2.14.0 \ No newline at end of file diff --git a/backend/scripts/chat_loadtest.py b/backend/scripts/chat_loadtest.py new file mode 100644 index 00000000000..e34cd23e6e0 --- /dev/null +++ b/backend/scripts/chat_loadtest.py @@ -0,0 +1,226 @@ +"""Basic Usage: + +python scripts/chat_loadtest.py --api-key --url /api + +to run from the container itself, copy this file in and run: + +python chat_loadtest.py --api-key --url localhost:8080 + +For more options, checkout the bottom of the file. +""" +import argparse +import asyncio +import logging +import statistics +import time +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from logging import getLogger +from uuid import UUID + +import aiohttp + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], +) + +logger = getLogger(__name__) + + +@dataclass +class ChatMetrics: + session_id: UUID + total_time: float + first_doc_time: float + first_answer_time: float + tokens_per_second: float + total_tokens: int + + +class ChatLoadTester: + def __init__( + self, + base_url: str, + api_key: str | None, + num_concurrent: int, + messages_per_session: int, + ): + self.base_url = base_url + self.headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + self.num_concurrent = num_concurrent + self.messages_per_session = messages_per_session + self.metrics: list[ChatMetrics] = [] + + async def create_chat_session(self, session: aiohttp.ClientSession) -> str: + """Create a new chat session""" + async with session.post( + f"{self.base_url}/chat/create-chat-session", + headers=self.headers, + json={"persona_id": 0, "description": "Load Test"}, + ) as response: + response.raise_for_status() + data = await response.json() + return data["chat_session_id"] + + async def process_stream( + self, response: aiohttp.ClientResponse + ) -> AsyncGenerator[str, None]: + """Process the SSE stream from the chat response""" + async for chunk in response.content: + chunk_str = chunk.decode() + yield chunk_str + + async def send_message( + self, + session: aiohttp.ClientSession, + chat_session_id: str, + message: str, + parent_message_id: int | None = None, + ) -> ChatMetrics: + """Send a message and measure performance metrics""" + start_time = time.time() + first_doc_time = None + first_answer_time = None + token_count = 0 + + async with session.post( + f"{self.base_url}/chat/send-message", + headers=self.headers, + json={ + "chat_session_id": chat_session_id, + "message": message, + "parent_message_id": parent_message_id, + "prompt_id": None, + "retrieval_options": { + "run_search": "always", + "real_time": True, + }, + "file_descriptors": [], + "search_doc_ids": [], + }, + ) as response: + response.raise_for_status() + + async for chunk in self.process_stream(response): + if "tool_name" in chunk and "run_search" in chunk: + if first_doc_time is None: + first_doc_time = time.time() - start_time + + if "answer_piece" in chunk: + if first_answer_time is None: + first_answer_time = time.time() - start_time + token_count += 1 + + total_time = time.time() - start_time + tokens_per_second = token_count / total_time if total_time > 0 else 0 + + return ChatMetrics( + session_id=UUID(chat_session_id), + total_time=total_time, + first_doc_time=first_doc_time or 0, + first_answer_time=first_answer_time or 0, + tokens_per_second=tokens_per_second, + total_tokens=token_count, + ) + + async def run_chat_session(self) -> None: + """Run a complete chat session with multiple messages""" + async with aiohttp.ClientSession() as session: + try: + chat_session_id = await self.create_chat_session(session) + messages = [ + "Tell me about the key features of the product", + "How does the search functionality work?", + "What are the deployment options?", + "Can you explain the security features?", + "What integrations are available?", + ] + + parent_message_id = None + for i in range(self.messages_per_session): + message = messages[i % len(messages)] + metrics = await self.send_message( + session, chat_session_id, message, parent_message_id + ) + self.metrics.append(metrics) + parent_message_id = metrics.total_tokens # Simplified for example + + except Exception as e: + logger.error(f"Error in chat session: {e}") + + async def run_load_test(self) -> None: + """Run multiple concurrent chat sessions""" + start_time = time.time() + tasks = [self.run_chat_session() for _ in range(self.num_concurrent)] + await asyncio.gather(*tasks) + total_time = time.time() - start_time + + self.print_results(total_time) + + def print_results(self, total_time: float) -> None: + """Print load test results and metrics""" + logger.info("\n=== Load Test Results ===") + logger.info(f"Total Time: {total_time:.2f} seconds") + logger.info(f"Concurrent Sessions: {self.num_concurrent}") + logger.info(f"Messages per Session: {self.messages_per_session}") + logger.info(f"Total Messages: {len(self.metrics)}") + + if self.metrics: + avg_response_time = statistics.mean(m.total_time for m in self.metrics) + avg_first_doc = statistics.mean(m.first_doc_time for m in self.metrics) + avg_first_answer = statistics.mean( + m.first_answer_time for m in self.metrics + ) + avg_tokens_per_sec = statistics.mean( + m.tokens_per_second for m in self.metrics + ) + + logger.info(f"\nAverage Response Time: {avg_response_time:.2f} seconds") + logger.info(f"Average Time to Documents: {avg_first_doc:.2f} seconds") + logger.info(f"Average Time to First Answer: {avg_first_answer:.2f} seconds") + logger.info(f"Average Tokens/Second: {avg_tokens_per_sec:.2f}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Chat Load Testing Tool") + parser.add_argument( + "--url", + type=str, + default="http://localhost:3000/api", + help="Danswer URL", + ) + parser.add_argument( + "--api-key", + type=str, + help="Danswer Basic/Admin Level API key", + ) + parser.add_argument( + "--concurrent", + type=int, + default=10, + help="Number of concurrent chat sessions", + ) + parser.add_argument( + "--messages", + type=int, + default=1, + help="Number of messages per chat session", + ) + + args = parser.parse_args() + + load_tester = ChatLoadTester( + base_url=args.url, + api_key=args.api_key, + num_concurrent=args.concurrent, + messages_per_session=args.messages, + ) + + asyncio.run(load_tester.run_load_test()) + + +if __name__ == "__main__": + main() diff --git a/backend/scripts/orphan_doc_cleanup_script.py b/backend/scripts/orphan_doc_cleanup_script.py new file mode 100644 index 00000000000..4007123ca3c --- /dev/null +++ b/backend/scripts/orphan_doc_cleanup_script.py @@ -0,0 +1,79 @@ +import os +import sys + +from sqlalchemy import text +from sqlalchemy.orm import Session + +# makes it so `PYTHONPATH=.` is not required when running this script +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) + +from danswer.db.engine import get_session_context_manager # noqa: E402 +from danswer.db.document import delete_documents_complete__no_commit # noqa: E402 +from danswer.db.search_settings import get_current_search_settings # noqa: E402 +from danswer.document_index.vespa.index import VespaIndex # noqa: E402 +from danswer.background.celery.tasks.shared.RetryDocumentIndex import ( # noqa: E402 + RetryDocumentIndex, +) + + +def _get_orphaned_document_ids(db_session: Session) -> list[str]: + """Get document IDs that don't have any entries in document_by_connector_credential_pair""" + query = text( + """ + SELECT d.id + FROM document d + LEFT JOIN document_by_connector_credential_pair dbcc ON d.id = dbcc.id + WHERE dbcc.id IS NULL + """ + ) + orphaned_ids = [doc_id[0] for doc_id in db_session.execute(query)] + print(f"Found {len(orphaned_ids)} orphaned documents") + return orphaned_ids + + +def main() -> None: + with get_session_context_manager() as db_session: + # Get orphaned document IDs + orphaned_ids = _get_orphaned_document_ids(db_session) + if not orphaned_ids: + print("No orphaned documents found") + return + + # Setup Vespa index + search_settings = get_current_search_settings(db_session) + index_name = search_settings.index_name + vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None) + retry_index = RetryDocumentIndex(vespa_index) + + # Delete chunks from Vespa first + print("Deleting orphaned document chunks from Vespa") + successfully_vespa_deleted_doc_ids = [] + for doc_id in orphaned_ids: + try: + chunks_deleted = retry_index.delete_single(doc_id) + successfully_vespa_deleted_doc_ids.append(doc_id) + if chunks_deleted > 0: + print(f"Deleted {chunks_deleted} chunks for document {doc_id}") + except Exception as e: + print( + f"Error deleting document {doc_id} in Vespa and will not delete from Postgres: {e}" + ) + + # Delete documents from Postgres + print("Deleting orphaned documents from Postgres") + try: + delete_documents_complete__no_commit( + db_session, successfully_vespa_deleted_doc_ids + ) + db_session.commit() + except Exception as e: + print(f"Error deleting documents from Postgres: {e}") + + print( + f"Successfully cleaned up {len(successfully_vespa_deleted_doc_ids)} orphaned documents" + ) + + +if __name__ == "__main__": + main() diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 2f558629def..0bc34c65d2a 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -61,7 +61,7 @@ # Enable generating persistent log files for local dev environments DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true" # notset, debug, info, notice, warning, error, or critical -LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice") +LOG_LEVEL = os.environ.get("LOG_LEVEL", "info") # Timeout for API-based embedding models # NOTE: does not apply for Google VertexAI, since the python client doesn't @@ -163,47 +163,92 @@ async def async_return_default_schema(*args: Any, **kwargs: Any) -> str: dim=1024, index_name="danswer_chunk_cohere_embed_english_v3_0", ), + SupportedEmbeddingModel( + name="cohere/embed-english-v3.0", + dim=1024, + index_name="danswer_chunk_embed_english_v3_0", + ), SupportedEmbeddingModel( name="cohere/embed-english-light-v3.0", dim=384, index_name="danswer_chunk_cohere_embed_english_light_v3_0", ), + SupportedEmbeddingModel( + name="cohere/embed-english-light-v3.0", + dim=384, + index_name="danswer_chunk_embed_english_light_v3_0", + ), SupportedEmbeddingModel( name="openai/text-embedding-3-large", dim=3072, index_name="danswer_chunk_openai_text_embedding_3_large", ), + SupportedEmbeddingModel( + name="openai/text-embedding-3-large", + dim=3072, + index_name="danswer_chunk_text_embedding_3_large", + ), SupportedEmbeddingModel( name="openai/text-embedding-3-small", dim=1536, index_name="danswer_chunk_openai_text_embedding_3_small", ), + SupportedEmbeddingModel( + name="openai/text-embedding-3-small", + dim=1536, + index_name="danswer_chunk_text_embedding_3_small", + ), SupportedEmbeddingModel( name="google/text-embedding-004", dim=768, index_name="danswer_chunk_google_text_embedding_004", ), + SupportedEmbeddingModel( + name="google/text-embedding-004", + dim=768, + index_name="danswer_chunk_text_embedding_004", + ), SupportedEmbeddingModel( name="google/textembedding-gecko@003", dim=768, index_name="danswer_chunk_google_textembedding_gecko_003", ), + SupportedEmbeddingModel( + name="google/textembedding-gecko@003", + dim=768, + index_name="danswer_chunk_textembedding_gecko_003", + ), SupportedEmbeddingModel( name="voyage/voyage-large-2-instruct", dim=1024, index_name="danswer_chunk_voyage_large_2_instruct", ), + SupportedEmbeddingModel( + name="voyage/voyage-large-2-instruct", + dim=1024, + index_name="danswer_chunk_large_2_instruct", + ), SupportedEmbeddingModel( name="voyage/voyage-light-2-instruct", dim=384, index_name="danswer_chunk_voyage_light_2_instruct", ), + SupportedEmbeddingModel( + name="voyage/voyage-light-2-instruct", + dim=384, + index_name="danswer_chunk_light_2_instruct", + ), # Self-hosted models SupportedEmbeddingModel( name="nomic-ai/nomic-embed-text-v1", dim=768, index_name="danswer_chunk_nomic_ai_nomic_embed_text_v1", ), + SupportedEmbeddingModel( + name="nomic-ai/nomic-embed-text-v1", + dim=768, + index_name="danswer_chunk_nomic_embed_text_v1", + ), SupportedEmbeddingModel( name="intfloat/e5-base-v2", dim=768, diff --git a/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py new file mode 100644 index 00000000000..35d2da61cf1 --- /dev/null +++ b/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py @@ -0,0 +1,39 @@ +import os + +import pytest + +from danswer.connectors.confluence.connector import ConfluenceConnector + + +@pytest.fixture +def confluence_connector() -> ConfluenceConnector: + connector = ConfluenceConnector( + wiki_base="https://danswerai.atlassian.net", + is_cloud=True, + ) + connector.load_credentials( + { + "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], + "confluence_username": os.environ["CONFLUENCE_USER_NAME"], + } + ) + return connector + + +# This should never fail because even if the docs in the cloud change, +# the full doc ids retrieved should always be a subset of the slim doc ids +def test_confluence_connector_permissions( + confluence_connector: ConfluenceConnector, +) -> None: + # Get all doc IDs from the full connector + all_full_doc_ids = set() + for doc_batch in confluence_connector.load_from_state(): + all_full_doc_ids.update([doc.id for doc in doc_batch]) + + # Get all doc IDs from the slim connector + all_slim_doc_ids = set() + for slim_doc_batch in confluence_connector.retrieve_all_slim_documents(): + all_slim_doc_ids.update([doc.id for doc in slim_doc_batch]) + + # The set of full doc IDs should be always be a subset of the slim doc IDs + assert all_full_doc_ids.issubset(all_slim_doc_ids) diff --git a/backend/tests/daily/connectors/slab/test_slab_connector.py b/backend/tests/daily/connectors/slab/test_slab_connector.py new file mode 100644 index 00000000000..d3902cd0a0c --- /dev/null +++ b/backend/tests/daily/connectors/slab/test_slab_connector.py @@ -0,0 +1,88 @@ +import json +import os +import time +from pathlib import Path + +import pytest + +from danswer.configs.constants import DocumentSource +from danswer.connectors.models import Document +from danswer.connectors.slab.connector import SlabConnector + + +def load_test_data(file_name: str = "test_slab_data.json") -> dict[str, str]: + current_dir = Path(__file__).parent + with open(current_dir / file_name, "r") as f: + return json.load(f) + + +@pytest.fixture +def slab_connector() -> SlabConnector: + connector = SlabConnector( + base_url="https://onyx-test.slab.com/", + ) + connector.load_credentials( + { + "slab_bot_token": os.environ["SLAB_BOT_TOKEN"], + } + ) + return connector + + +@pytest.mark.xfail( + reason=( + "Need a test account with a slab subscription to run this test." + "Trial only lasts 14 days." + ) +) +def test_slab_connector_basic(slab_connector: SlabConnector) -> None: + all_docs: list[Document] = [] + target_test_doc_id = "jcp6cohu" + target_test_doc: Document | None = None + for doc_batch in slab_connector.poll_source(0, time.time()): + for doc in doc_batch: + all_docs.append(doc) + if doc.id == target_test_doc_id: + target_test_doc = doc + + assert len(all_docs) == 6 + assert target_test_doc is not None + + desired_test_data = load_test_data() + assert ( + target_test_doc.semantic_identifier == desired_test_data["semantic_identifier"] + ) + assert target_test_doc.source == DocumentSource.SLAB + assert target_test_doc.metadata == {} + assert target_test_doc.primary_owners is None + assert target_test_doc.secondary_owners is None + assert target_test_doc.title is None + assert target_test_doc.from_ingestion_api is False + assert target_test_doc.additional_info is None + + assert len(target_test_doc.sections) == 1 + section = target_test_doc.sections[0] + # Need to replace the weird apostrophe with a normal one + assert section.text.replace("\u2019", "'") == desired_test_data["section_text"] + assert section.link == desired_test_data["link"] + + +@pytest.mark.xfail( + reason=( + "Need a test account with a slab subscription to run this test." + "Trial only lasts 14 days." + ) +) +def test_slab_connector_slim(slab_connector: SlabConnector) -> None: + # Get all doc IDs from the full connector + all_full_doc_ids = set() + for doc_batch in slab_connector.load_from_state(): + all_full_doc_ids.update([doc.id for doc in doc_batch]) + + # Get all doc IDs from the slim connector + all_slim_doc_ids = set() + for slim_doc_batch in slab_connector.retrieve_all_slim_documents(): + all_slim_doc_ids.update([doc.id for doc in slim_doc_batch]) + + # The set of full doc IDs should be always be a subset of the slim doc IDs + assert all_full_doc_ids.issubset(all_slim_doc_ids) diff --git a/backend/tests/daily/connectors/slab/test_slab_data.json b/backend/tests/daily/connectors/slab/test_slab_data.json new file mode 100644 index 00000000000..26c7cf91037 --- /dev/null +++ b/backend/tests/daily/connectors/slab/test_slab_data.json @@ -0,0 +1,5 @@ +{ + "section_text": "Learn about Posts\nWelcome\nThis is a post, where you can edit, share, and collaborate in real time with your team. We'd love to show you how it works!\nReading and editing\nClick the mode button to toggle between read and edit modes. You can only make changes to a post when editing.\nOrganize your posts\nWhen in edit mode, you can add topics to a post, which will keep it organized for the right 👀 to see.\nSmart mentions\nMentions are references to users, posts, topics and third party tools that show details on hover. Paste in a link for automatic conversion.\nLook back in time\nYou are ready to begin writing. You can always bring back this tour in the help menu.\nGreat job!\nYou are ready to begin writing. You can always bring back this tour in the help menu.\n\n", + "link": "https://onyx-test.slab.com/posts/learn-about-posts-jcp6cohu", + "semantic_identifier": "Learn about Posts" +} \ No newline at end of file diff --git a/backend/tests/daily/embedding/test_embeddings.py b/backend/tests/daily/embedding/test_embeddings.py index 10a1dd850f6..7182510214f 100644 --- a/backend/tests/daily/embedding/test_embeddings.py +++ b/backend/tests/daily/embedding/test_embeddings.py @@ -7,6 +7,7 @@ from shared_configs.model_server_models import EmbeddingProvider VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"] +VALID_LONG_SAMPLE = ["hi " * 999] # openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't # seem to be true TOO_LONG_SAMPLE = ["a"] * 2500 @@ -99,3 +100,42 @@ def local_nomic_embedding_model() -> EmbeddingModel: def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None: _run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768) _run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768) + + +@pytest.fixture +def azure_embedding_model() -> EmbeddingModel: + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="text-embedding-3-large", + normalize=True, + query_prefix=None, + passage_prefix=None, + api_key=os.getenv("AZURE_API_KEY"), + provider_type=EmbeddingProvider.AZURE, + api_url=os.getenv("AZURE_API_URL"), + ) + + +# NOTE (chris): this test doesn't work, and I do not know why +# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel): +# """NOTE: this test relies on a very low rate limit for the Azure API + +# this test only being run once in a 1 minute window""" +# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate +# # limits assuming the limit is 1000 tokens per minute +# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# assert len(result) == 1 +# assert len(result[0]) == 1536 + +# # this should fail +# with pytest.raises(ModelServerRateLimitError): +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) + +# # this should succeed, since passage requests retry up to 10 times +# start = time.time() +# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE) +# assert len(result) == 1 +# assert len(result[0]) == 1536 +# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py index b37822d3496..d32e100563b 100644 --- a/backend/tests/integration/common_utils/managers/cc_pair.py +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -240,7 +240,85 @@ def run_once( result.raise_for_status() @staticmethod - def wait_for_indexing( + def wait_for_indexing_inactive( + cc_pair: DATestCCPair, + timeout: float = MAX_DELAY, + user_performing_action: DATestUser | None = None, + ) -> None: + """wait for the number of docs to be indexed on the connector. + This is used to test pausing a connector in the middle of indexing and + terminating that indexing.""" + print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}") + start = time.monotonic() + while True: + fetched_cc_pairs = CCPairManager.get_indexing_statuses( + user_performing_action + ) + for fetched_cc_pair in fetched_cc_pairs: + if fetched_cc_pair.cc_pair_id != cc_pair.id: + continue + + if fetched_cc_pair.in_progress: + continue + + print(f"Indexing is inactive: cc_pair={cc_pair.id}") + return + + elapsed = time.monotonic() - start + if elapsed > timeout: + raise TimeoutError( + f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s" + ) + + print( + f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s" + ) + time.sleep(5) + + @staticmethod + def wait_for_indexing_in_progress( + cc_pair: DATestCCPair, + timeout: float = MAX_DELAY, + num_docs: int = 16, + user_performing_action: DATestUser | None = None, + ) -> None: + """wait for the number of docs to be indexed on the connector. + This is used to test pausing a connector in the middle of indexing and + terminating that indexing.""" + start = time.monotonic() + while True: + fetched_cc_pairs = CCPairManager.get_indexing_statuses( + user_performing_action + ) + for fetched_cc_pair in fetched_cc_pairs: + if fetched_cc_pair.cc_pair_id != cc_pair.id: + continue + + if not fetched_cc_pair.in_progress: + continue + + if fetched_cc_pair.docs_indexed >= num_docs: + print( + "Indexed at least the requested number of docs: " + f"cc_pair={cc_pair.id} " + f"docs_indexed={fetched_cc_pair.docs_indexed} " + f"num_docs={num_docs}" + ) + return + + elapsed = time.monotonic() - start + if elapsed > timeout: + raise TimeoutError( + f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s" + ) + + print( + f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s" + ) + time.sleep(5) + + @staticmethod + def wait_for_indexing_completion( cc_pair: DATestCCPair, after: datetime, timeout: float = MAX_DELAY, diff --git a/backend/tests/integration/common_utils/managers/chat.py b/backend/tests/integration/common_utils/managers/chat.py index 106aa26a791..d8a35b2b31f 100644 --- a/backend/tests/integration/common_utils/managers/chat.py +++ b/backend/tests/integration/common_utils/managers/chat.py @@ -8,8 +8,6 @@ from danswer.file_store.models import FileDescriptor from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import PromptOverride -from danswer.one_shot_answer.models import DirectQARequest -from danswer.one_shot_answer.models import ThreadMessage from danswer.server.query_and_chat.models import ChatSessionCreationRequest from danswer.server.query_and_chat.models import CreateChatMessageRequest from tests.integration.common_utils.constants import API_SERVER_URL @@ -68,6 +66,7 @@ def send_message( prompt_id=prompt_id, search_doc_ids=search_doc_ids or [], retrieval_options=retrieval_options, + rerank_settings=None, # Can be added if needed query_override=query_override, regenerate=regenerate, llm_override=llm_override, @@ -87,30 +86,6 @@ def send_message( return ChatSessionManager.analyze_response(response) - @staticmethod - def get_answer_with_quote( - persona_id: int, - message: str, - user_performing_action: DATestUser | None = None, - ) -> StreamedResponse: - direct_qa_request = DirectQARequest( - messages=[ThreadMessage(message=message)], - prompt_id=None, - persona_id=persona_id, - ) - - response = requests.post( - f"{API_SERVER_URL}/query/stream-answer-with-quote", - json=direct_qa_request.model_dump(), - headers=user_performing_action.headers - if user_performing_action - else GENERAL_HEADERS, - stream=True, - ) - response.raise_for_status() - - return ChatSessionManager.analyze_response(response) - @staticmethod def analyze_response(response: Response) -> StreamedResponse: response_data = [ diff --git a/backend/tests/integration/common_utils/managers/file.py b/backend/tests/integration/common_utils/managers/file.py new file mode 100644 index 00000000000..461874f7ec5 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/file.py @@ -0,0 +1,62 @@ +import mimetypes +from typing import cast +from typing import IO +from typing import List +from typing import Tuple + +import requests + +from danswer.file_store.models import FileDescriptor +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import DATestUser + + +class FileManager: + @staticmethod + def upload_files( + files: List[Tuple[str, IO]], + user_performing_action: DATestUser | None = None, + ) -> Tuple[List[FileDescriptor], str]: + headers = ( + user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS + ) + headers.pop("Content-Type", None) + + files_param = [] + for filename, file_obj in files: + mime_type, _ = mimetypes.guess_type(filename) + if mime_type is None: + mime_type = "application/octet-stream" + files_param.append(("files", (filename, file_obj, mime_type))) + + response = requests.post( + f"{API_SERVER_URL}/chat/file", + files=files_param, + headers=headers, + ) + + if not response.ok: + return ( + cast(List[FileDescriptor], []), + f"Failed to upload files - {response.json().get('detail', 'Unknown error')}", + ) + + response_json = response.json() + return response_json.get("files", cast(List[FileDescriptor], [])), "" + + @staticmethod + def fetch_uploaded_file( + file_id: str, + user_performing_action: DATestUser | None = None, + ) -> bytes: + response = requests.get( + f"{API_SERVER_URL}/chat/file/{file_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return response.content diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py index de2d9db25c1..e5392dfb68b 100644 --- a/backend/tests/integration/common_utils/managers/persona.py +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -42,7 +42,7 @@ def create( "is_public": is_public, "llm_filter_extraction": llm_filter_extraction, "recency_bias": recency_bias, - "prompt_ids": prompt_ids or [], + "prompt_ids": prompt_ids or [0], "document_set_ids": document_set_ids or [], "tool_ids": tool_ids or [], "llm_model_provider_override": llm_model_provider_override, diff --git a/backend/tests/integration/common_utils/managers/tenant.py b/backend/tests/integration/common_utils/managers/tenant.py index fc411018df7..c25a1b2ec6e 100644 --- a/backend/tests/integration/common_utils/managers/tenant.py +++ b/backend/tests/integration/common_utils/managers/tenant.py @@ -69,8 +69,10 @@ def get_all_users( return AllUsersResponse( accepted=[FullUserSnapshot(**user) for user in data["accepted"]], invited=[InvitedUserSnapshot(**user) for user in data["invited"]], + slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]], accepted_pages=data["accepted_pages"], invited_pages=data["invited_pages"], + slack_users_pages=data["slack_users_pages"], ) @staticmethod diff --git a/backend/tests/integration/common_utils/managers/user.py b/backend/tests/integration/common_utils/managers/user.py index 43286c6a716..26cb29cdffb 100644 --- a/backend/tests/integration/common_utils/managers/user.py +++ b/backend/tests/integration/common_utils/managers/user.py @@ -130,8 +130,10 @@ def verify( all_users = AllUsersResponse( accepted=[FullUserSnapshot(**user) for user in data["accepted"]], invited=[InvitedUserSnapshot(**user) for user in data["invited"]], + slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]], accepted_pages=data["accepted_pages"], invited_pages=data["invited_pages"], + slack_users_pages=data["slack_users_pages"], ) for accepted_user in all_users.accepted: if accepted_user.email == user.email and accepted_user.id == user.id: diff --git a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py index 6c0c5908cd1..6396d7ca058 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py +++ b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py @@ -3,6 +3,8 @@ from datetime import timezone from typing import Any +import pytest + from danswer.connectors.models import InputType from danswer.db.enums import AccessType from danswer.server.documents.models import DocumentSource @@ -14,6 +16,7 @@ ) from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestCCPair from tests.integration.common_utils.test_models import DATestConnector from tests.integration.common_utils.test_models import DATestCredential @@ -22,7 +25,7 @@ from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager -# @pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False) +@pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False) def test_slack_permission_sync( reset: None, vespa_client: vespa_fixture, @@ -64,7 +67,6 @@ def test_slack_permission_sync( input_type=InputType.POLL, source=DocumentSource.SLACK, connector_specific_config={ - "workspace": "onyx-test-workspace", "channels": [public_channel["name"], private_channel["name"]], }, access_type=AccessType.SYNC, @@ -77,7 +79,7 @@ def test_slack_permission_sync( access_type=AccessType.SYNC, user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -112,7 +114,7 @@ def test_slack_permission_sync( # Run indexing before = datetime.now(timezone.utc) CCPairManager.run_once(cc_pair, admin_user) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -215,3 +217,123 @@ def test_slack_permission_sync( # Ensure test_user_1 can only see messages from the public channel assert public_message in danswer_doc_message_strings assert private_message not in danswer_doc_message_strings + + +def test_slack_group_permission_sync( + reset: None, + vespa_client: vespa_fixture, + slack_test_setup: tuple[dict[str, Any], dict[str, Any]], +) -> None: + """ + This test ensures that permission sync overrides danswer group access. + """ + public_channel, private_channel = slack_test_setup + + # Creating an admin user (first user created is automatically an admin) + admin_user: DATestUser = UserManager.create( + email="admin@onyx-test.com", + ) + + # Creating a non-admin user + test_user_1: DATestUser = UserManager.create( + email="test_user_1@onyx-test.com", + ) + + # Create a user group and adding the non-admin user to it + user_group = UserGroupManager.create( + name="test_group", + user_ids=[test_user_1.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group], + user_performing_action=admin_user, + ) + + slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"]) + email_id_map = SlackManager.build_slack_user_email_id_map(slack_client) + admin_user_id = email_id_map[admin_user.email] + + LLMProviderManager.create(user_performing_action=admin_user) + + # Add only admin to the private channel + SlackManager.set_channel_members( + slack_client=slack_client, + admin_user_id=admin_user_id, + channel=private_channel, + user_ids=[admin_user_id], + ) + + before = datetime.now(timezone.utc) + credential = CredentialManager.create( + source=DocumentSource.SLACK, + credential_json={ + "slack_bot_token": os.environ["SLACK_BOT_TOKEN"], + }, + user_performing_action=admin_user, + ) + + # Create connector with sync access and assign it to the user group + connector = ConnectorManager.create( + name="Slack", + input_type=InputType.POLL, + source=DocumentSource.SLACK, + connector_specific_config={ + "channels": [private_channel["name"]], + }, + access_type=AccessType.SYNC, + groups=[user_group.id], + user_performing_action=admin_user, + ) + + cc_pair = CCPairManager.create( + credential_id=credential.id, + connector_id=connector.id, + access_type=AccessType.SYNC, + user_performing_action=admin_user, + groups=[user_group.id], + ) + + # Add a test message to the private channel + private_message = "This is a secret message: 987654" + SlackManager.add_message_to_channel( + slack_client=slack_client, + channel=private_channel, + message=private_message, + ) + + # Run indexing + CCPairManager.run_once(cc_pair, admin_user) + CCPairManager.wait_for_indexing_completion( + cc_pair=cc_pair, + after=before, + user_performing_action=admin_user, + ) + + # Run permission sync + CCPairManager.sync( + cc_pair=cc_pair, + user_performing_action=admin_user, + ) + CCPairManager.wait_for_sync( + cc_pair=cc_pair, + after=before, + number_of_updated_docs=1, + user_performing_action=admin_user, + ) + + # Verify admin can see the message + admin_docs = DocumentSearchManager.search_documents( + query="secret message", + user_performing_action=admin_user, + ) + assert private_message in admin_docs + + # Verify test_user_1 cannot see the message despite being in the group + # (Slack permissions should take precedence) + user_1_docs = DocumentSearchManager.search_documents( + query="secret message", + user_performing_action=test_user_1, + ) + assert private_message not in user_1_docs diff --git a/backend/tests/integration/connector_job_tests/slack/test_prune.py b/backend/tests/integration/connector_job_tests/slack/test_prune.py index 2dfc3d0750f..774cf39e2ed 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_prune.py +++ b/backend/tests/integration/connector_job_tests/slack/test_prune.py @@ -61,7 +61,6 @@ def test_slack_prune( input_type=InputType.POLL, source=DocumentSource.SLACK, connector_specific_config={ - "workspace": "onyx-test-workspace", "channels": [public_channel["name"], private_channel["name"]], }, access_type=AccessType.PUBLIC, @@ -74,7 +73,7 @@ def test_slack_prune( access_type=AccessType.SYNC, user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -113,7 +112,7 @@ def test_slack_prune( # Run indexing before = datetime.now(timezone.utc) CCPairManager.run_once(cc_pair, admin_user) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, diff --git a/backend/tests/integration/tests/api_key/test_api_key.py b/backend/tests/integration/tests/api_key/test_api_key.py index bd0618b962d..34023d897a5 100644 --- a/backend/tests/integration/tests/api_key/test_api_key.py +++ b/backend/tests/integration/tests/api_key/test_api_key.py @@ -27,13 +27,6 @@ def test_limited(reset: None) -> None: ) assert response.status_code == 200 - # test basic endpoints - response = requests.get( - f"{API_SERVER_URL}/input_prompt", - headers=api_key.headers, - ) - assert response.status_code == 403 - # test admin endpoints response = requests.get( f"{API_SERVER_URL}/admin/api-key", diff --git a/backend/tests/integration/tests/connector/test_connector_creation.py b/backend/tests/integration/tests/connector/test_connector_creation.py index acfafe9436d..61085c5a5d2 100644 --- a/backend/tests/integration/tests/connector/test_connector_creation.py +++ b/backend/tests/integration/tests/connector/test_connector_creation.py @@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_1, now, timeout=120, user_performing_action=admin_user ) @@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_2, now, timeout=120, user_performing_action=admin_user ) @@ -82,3 +82,48 @@ def test_overlapping_connector_creation(reset: None) -> None: assert info_2 assert info_1.num_docs_indexed == info_2.num_docs_indexed + + +def test_connector_pause_while_indexing(reset: None) -> None: + """Tests that we can pause a connector while indexing is in progress and that + tasks end early or abort as a result. + + TODO: This does not specifically test for soft or hard termination code paths. + Design specific tests for those use cases. + """ + admin_user: DATestUser = UserManager.create(name="admin_user") + + config = { + "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"], + "space": "", + "is_cloud": True, + "page_id": "", + } + + credential = { + "confluence_username": os.environ["CONFLUENCE_USER_NAME"], + "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], + } + + # store the time before we create the connector so that we know after + # when the indexing should have started + datetime.now(timezone.utc) + + # create connector + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.CONFLUENCE, + connector_specific_config=config, + credential_json=credential, + user_performing_action=admin_user, + ) + + CCPairManager.wait_for_indexing_in_progress( + cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user + ) + + CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user) + + CCPairManager.wait_for_indexing_inactive( + cc_pair_1, timeout=60, user_performing_action=admin_user + ) + return diff --git a/backend/tests/integration/tests/personas/test_persona_categories.py b/backend/tests/integration/tests/personas/test_persona_categories.py index fdd0e645814..1ac2d3b3000 100644 --- a/backend/tests/integration/tests/personas/test_persona_categories.py +++ b/backend/tests/integration/tests/personas/test_persona_categories.py @@ -44,6 +44,7 @@ def test_persona_category_management(reset: None) -> None: category=updated_persona_category, user_performing_action=regular_user, ) + assert exc_info.value.response is not None assert exc_info.value.response.status_code == 403 assert PersonaCategoryManager.verify( diff --git a/backend/tests/integration/tests/pruning/test_pruning.py b/backend/tests/integration/tests/pruning/test_pruning.py index 9d9a41c7069..beb1e8efbe9 100644 --- a/backend/tests/integration/tests/pruning/test_pruning.py +++ b/backend/tests/integration/tests/pruning/test_pruning.py @@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_1, now, timeout=60, user_performing_action=admin_user ) diff --git a/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py deleted file mode 100644 index 3eb982ef228..00000000000 --- a/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py +++ /dev/null @@ -1,25 +0,0 @@ -from tests.integration.common_utils.managers.chat import ChatSessionManager -from tests.integration.common_utils.managers.llm_provider import LLMProviderManager -from tests.integration.common_utils.managers.user import UserManager -from tests.integration.common_utils.test_models import DATestUser - - -def test_send_message_simple_with_history(reset: None) -> None: - admin_user: DATestUser = UserManager.create(name="admin_user") - LLMProviderManager.create(user_performing_action=admin_user) - - test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) - - response = ChatSessionManager.get_answer_with_quote( - persona_id=test_chat_session.persona_id, - message="Hello, this is a test.", - user_performing_action=admin_user, - ) - - assert ( - response.tool_name is not None - ), "Tool name should be specified (always search)" - assert ( - response.relevance_summaries is not None - ), "Relevance summaries should be present for all search streams" - assert len(response.full_message) > 0, "Response message should not be empty" diff --git a/backend/tests/regression/answer_quality/api_utils.py b/backend/tests/regression/answer_quality/api_utils.py index 28406c061b8..c37d650788a 100644 --- a/backend/tests/regression/answer_quality/api_utils.py +++ b/backend/tests/regression/answer_quality/api_utils.py @@ -1,16 +1,16 @@ import requests from retry import retry +from danswer.chat.models import ThreadMessage from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType from danswer.connectors.models import InputType +from danswer.context.search.enums import OptionalSearchSetting from danswer.context.search.models import IndexFilters -from danswer.context.search.models import OptionalSearchSetting from danswer.context.search.models import RetrievalDetails from danswer.db.enums import IndexingStatus -from danswer.one_shot_answer.models import DirectQARequest -from danswer.one_shot_answer.models import ThreadMessage from danswer.server.documents.models import ConnectorBase +from ee.danswer.server.query_and_chat.models import OneShotQARequest from tests.regression.answer_quality.cli_utils import get_api_server_host_port GENERAL_HEADERS = {"Content-Type": "application/json"} @@ -37,7 +37,7 @@ def get_answer_from_query( messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)] - new_message_request = DirectQARequest( + new_message_request = OneShotQARequest( messages=messages, prompt_id=0, persona_id=0, @@ -47,12 +47,11 @@ def get_answer_from_query( filters=filters, enable_auto_detect_filters=False, ), - chain_of_thought=False, return_contexts=True, skip_gen_ai_answer_generation=only_retrieve_docs, ) - url = _api_url_builder(env_name, "/query/answer-with-quote/") + url = _api_url_builder(env_name, "/query/answer-with-citation/") headers = { "Content-Type": "application/json", } diff --git a/backend/tests/unit/danswer/llm/answering/conftest.py b/backend/tests/unit/danswer/chat/conftest.py similarity index 92% rename from backend/tests/unit/danswer/llm/answering/conftest.py rename to backend/tests/unit/danswer/chat/conftest.py index a0077b53917..aed94d8fc49 100644 --- a/backend/tests/unit/danswer/llm/answering/conftest.py +++ b/backend/tests/unit/danswer/chat/conftest.py @@ -5,12 +5,12 @@ import pytest from langchain_core.messages import SystemMessage +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import CitationConfig from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.constants import DocumentSource -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import CitationConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLMConfig from danswer.tools.models import ToolResponse from danswer.tools.tool_implementations.search.search_tool import SearchTool @@ -64,6 +64,7 @@ def mock_search_results() -> list[LlmDoc]: updated_at=datetime(2023, 1, 1), link="https://example.com/doc1", source_links={0: "https://example.com/doc1"}, + match_highlights=[], ), LlmDoc( content="Search result 2", @@ -75,6 +76,7 @@ def mock_search_results() -> list[LlmDoc]: updated_at=datetime(2023, 1, 2), link="https://example.com/doc2", source_links={0: "https://example.com/doc2"}, + match_highlights=[], ), ] diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py similarity index 92% rename from backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py rename to backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py index 13e6fd73b5a..178240c7176 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py @@ -5,11 +5,9 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc +from danswer.chat.stream_processing.citation_processing import CitationProcessor +from danswer.chat.stream_processing.utils import DocumentIdOrderMapping from danswer.configs.constants import DocumentSource -from danswer.llm.answering.stream_processing.citation_processing import ( - CitationProcessor, -) -from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping """ @@ -46,6 +44,7 @@ updated_at=datetime.now(), link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None, source_links={0: "https://mintlify.com/docs/settings/broken-links"}, + match_highlights=[], ) for id in range(10) ] @@ -73,8 +72,10 @@ def process_text( processor = CitationProcessor( context_docs=mock_docs, doc_id_to_rank_map=mapping, + display_doc_order_dict=mock_doc_id_to_rank_map, stop_stream=None, ) + result: list[DanswerAnswerPiece | CitationInfo] = [] for token in tokens: result.extend(processor.process_token(token)) @@ -87,6 +88,7 @@ def process_text( final_answer_text += piece.answer_piece or "" elif isinstance(piece, CitationInfo): citations.append(piece) + return final_answer_text, citations @@ -385,6 +387,16 @@ def process_text( "Here is some text[[1]](https://0.com). Some other text", ["doc_0"], ), + # ['To', ' set', ' up', ' D', 'answer', ',', ' if', ' you', ' are', ' running', ' it', ' yourself', ' and', + # ' need', ' access', ' to', ' certain', ' features', ' like', ' auto', '-sync', 'ing', ' document', + # '-level', ' access', ' permissions', ',', ' you', ' should', ' reach', ' out', ' to', ' the', ' D', + # 'answer', ' team', ' to', ' receive', ' access', ' [[', '4', ']].', ''] + ( + "Unique tokens with double brackets and a single token that ends the citation and has characters after it.", + ["... to receive access", " [[", "1", "]].", ""], + "... to receive access [[1]](https://0.com).", + ["doc_0"], + ), ], ) def test_citation_extraction( diff --git a/backend/tests/unit/danswer/chat/stream_processing/test_citation_substitution.py b/backend/tests/unit/danswer/chat/stream_processing/test_citation_substitution.py new file mode 100644 index 00000000000..841d76a3247 --- /dev/null +++ b/backend/tests/unit/danswer/chat/stream_processing/test_citation_substitution.py @@ -0,0 +1,132 @@ +from datetime import datetime + +import pytest + +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.chat.stream_processing.citation_processing import CitationProcessor +from danswer.chat.stream_processing.utils import DocumentIdOrderMapping +from danswer.configs.constants import DocumentSource + + +""" +This module contains tests for the citation extraction functionality in Danswer, +specifically the substitution of the number of document cited in the UI. (The LLM +will see the sources post re-ranking and relevance check, the UI before these steps.) +This module is a derivative of test_citation_processing.py. + +The tests focusses specifically on the substitution of the number of document cited in the UI. + +Key components: +- mock_docs: A list of mock LlmDoc objects used for testing. +- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks. +- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check. +- process_text: A helper function that simulates the citation extraction process. +- test_citation_extraction: A parametrized test function covering various citation scenarios. + +To add new test cases: +1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction. +2. Each tuple should contain: + - A descriptive test name (string) + - Input tokens (list of strings) + - Expected output text (string) + - Expected citations (list of document IDs) +""" + + +mock_docs = [ + LlmDoc( + document_id=f"doc_{int(id/2)}", + content="Document is a doc", + blurb=f"Document #{id}", + semantic_identifier=f"Doc {id}", + source_type=DocumentSource.WEB, + metadata={}, + updated_at=datetime.now(), + link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None, + source_links={0: "https://mintlify.com/docs/settings/broken-links"}, + match_highlights=[], + ) + for id in range(10) +] + +mock_doc_mapping = { + "doc_0": 1, + "doc_1": 2, + "doc_2": 3, + "doc_3": 4, + "doc_4": 5, + "doc_5": 6, +} + +mock_doc_mapping_rerank = { + "doc_0": 2, + "doc_1": 1, + "doc_2": 4, + "doc_3": 3, + "doc_4": 6, + "doc_5": 5, +} + + +@pytest.fixture +def mock_data() -> tuple[list[LlmDoc], dict[str, int]]: + return mock_docs, mock_doc_mapping + + +def process_text( + tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]] +) -> tuple[str, list[CitationInfo]]: + mock_docs, mock_doc_id_to_rank_map = mock_data + mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) + processor = CitationProcessor( + context_docs=mock_docs, + doc_id_to_rank_map=mapping, + display_doc_order_dict=mock_doc_mapping_rerank, + stop_stream=None, + ) + + result: list[DanswerAnswerPiece | CitationInfo] = [] + for token in tokens: + result.extend(processor.process_token(token)) + result.extend(processor.process_token(None)) + + final_answer_text = "" + citations = [] + for piece in result: + if isinstance(piece, DanswerAnswerPiece): + final_answer_text += piece.answer_piece or "" + elif isinstance(piece, CitationInfo): + citations.append(piece) + + return final_answer_text, citations + + +@pytest.mark.parametrize( + "test_name, input_tokens, expected_text, expected_citations", + [ + ( + "Single citation", + ["Gro", "wth! [", "1", "]", "."], + "Growth! [[2]](https://0.com).", + ["doc_0"], + ), + ], +) +def test_citation_substitution( + mock_data: tuple[list[LlmDoc], dict[str, int]], + test_name: str, + input_tokens: list[str], + expected_text: str, + expected_citations: list[str], +) -> None: + final_answer_text, citations = process_text(input_tokens, mock_data) + assert ( + final_answer_text.strip() == expected_text.strip() + ), f"Test '{test_name}' failed: Final answer text does not match expected output." + assert [ + citation.document_id for citation in citations + ] == expected_citations, ( + f"Test '{test_name}' failed: Citations do not match expected output." + ) diff --git a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py b/backend/tests/unit/danswer/chat/stream_processing/test_quotes_processing.py similarity index 97% rename from backend/tests/unit/danswer/direct_qa/test_qa_utils.py rename to backend/tests/unit/danswer/chat/stream_processing/test_quotes_processing.py index bcbd76f4e12..7cb969ab7a6 100644 --- a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py +++ b/backend/tests/unit/danswer/chat/stream_processing/test_quotes_processing.py @@ -2,14 +2,10 @@ import pytest +from danswer.chat.stream_processing.quotes_processing import match_quotes_to_docs +from danswer.chat.stream_processing.quotes_processing import separate_answer_quotes from danswer.configs.constants import DocumentSource from danswer.context.search.models import InferenceChunk -from danswer.llm.answering.stream_processing.quotes_processing import ( - match_quotes_to_docs, -) -from danswer.llm.answering.stream_processing.quotes_processing import ( - separate_answer_quotes, -) def test_passed_in_quotes() -> None: diff --git a/backend/tests/unit/danswer/llm/answering/test_answer.py b/backend/tests/unit/danswer/chat/test_answer.py similarity index 77% rename from backend/tests/unit/danswer/llm/answering/test_answer.py rename to backend/tests/unit/danswer/chat/test_answer.py index bbe1559a4f0..14bbec65434 100644 --- a/backend/tests/unit/danswer/llm/answering/test_answer.py +++ b/backend/tests/unit/danswer/chat/test_answer.py @@ -11,24 +11,21 @@ from langchain_core.messages import ToolCall from langchain_core.messages import ToolCallChunk +from danswer.chat.answer import Answer +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuote -from danswer.chat.models import DanswerQuotes from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.chat.models import StreamStopInfo from danswer.chat.models import StreamStopReason -from danswer.llm.answering.answer import Answer -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.models import QuotesConfig from danswer.llm.interfaces import LLM from danswer.tools.force import ForceUseTool from danswer.tools.models import ToolCallFinalResult from danswer.tools.models import ToolCallKickoff from danswer.tools.models import ToolResponse -from tests.unit.danswer.llm.answering.conftest import DEFAULT_SEARCH_ARGS -from tests.unit.danswer.llm.answering.conftest import QUERY +from tests.unit.danswer.chat.conftest import DEFAULT_SEARCH_ARGS +from tests.unit.danswer.chat.conftest import QUERY @pytest.fixture @@ -284,90 +281,6 @@ def test_answer_with_search_no_tool_calling( mock_search_tool.run.assert_called_once() -def test_answer_with_search_call_quotes_enabled( - answer_instance: Answer, - mock_search_results: list[LlmDoc], - mock_search_tool: MagicMock, -) -> None: - answer_instance.tools = [mock_search_tool] - answer_instance.force_use_tool = ForceUseTool( - force_use=False, tool_name="", args=None - ) - answer_instance.answer_style_config.citation_config = None - answer_instance.answer_style_config.quotes_config = QuotesConfig() - - # Set up the LLM mock to return search results and then an answer - mock_llm = cast(Mock, answer_instance.llm) - - tool_call_chunk = AIMessageChunk(content="") - tool_call_chunk.tool_calls = [ - ToolCall( - id="search", - name="search", - args=DEFAULT_SEARCH_ARGS, - ) - ] - tool_call_chunk.tool_call_chunks = [ - ToolCallChunk( - id="search", - name="search", - args=json.dumps(DEFAULT_SEARCH_ARGS), - index=0, - ) - ] - - # needs to be short due to the "anti-hallucination" check in QuotesProcessor - answer_content = "z" - quote_content = mock_search_results[0].content - mock_llm.stream.side_effect = [ - [tool_call_chunk], - [ - AIMessageChunk( - content=( - '{"answer": "' - + answer_content - + '", "quotes": ["' - + quote_content - + '"]}' - ) - ), - ], - ] - - # Process the output - output = list(answer_instance.processed_streamed_output) - - # Assertions - assert len(output) == 5 - assert output[0] == ToolCallKickoff( - tool_name="search", tool_args=DEFAULT_SEARCH_ARGS - ) - assert output[1] == ToolResponse( - id="final_context_documents", - response=mock_search_results, - ) - assert output[2] == ToolCallFinalResult( - tool_name="search", - tool_args=DEFAULT_SEARCH_ARGS, - tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], - ) - assert output[3] == DanswerAnswerPiece(answer_piece=answer_content) - assert output[4] == DanswerQuotes( - quotes=[ - DanswerQuote( - quote=quote_content, - document_id=mock_search_results[0].document_id, - link=mock_search_results[0].link, - source_type=mock_search_results[0].source_type, - semantic_identifier=mock_search_results[0].semantic_identifier, - blurb=mock_search_results[0].blurb, - ) - ] - ) - - assert answer_instance.llm_answer == answer_content - - def test_is_cancelled(answer_instance: Answer) -> None: # Set up the LLM mock to return multiple chunks mock_llm = Mock() diff --git a/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py b/backend/tests/unit/danswer/chat/test_prune_and_merge.py similarity index 99% rename from backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py rename to backend/tests/unit/danswer/chat/test_prune_and_merge.py index c71d9109007..2741a56526d 100644 --- a/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py +++ b/backend/tests/unit/danswer/chat/test_prune_and_merge.py @@ -1,9 +1,9 @@ import pytest +from danswer.chat.prune_and_merge import _merge_sections from danswer.configs.constants import DocumentSource from danswer.context.search.models import InferenceChunk from danswer.context.search.models import InferenceSection -from danswer.llm.answering.prune_and_merge import _merge_sections # This large test accounts for all of the following: diff --git a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py b/backend/tests/unit/danswer/chat/test_skip_gen_ai.py similarity index 93% rename from backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py rename to backend/tests/unit/danswer/chat/test_skip_gen_ai.py index 7bd4a498bd7..772ec52a6ca 100644 --- a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py +++ b/backend/tests/unit/danswer/chat/test_skip_gen_ai.py @@ -5,10 +5,10 @@ import pytest from pytest_mock import MockerFixture -from danswer.llm.answering.answer import Answer -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PromptConfig -from danswer.one_shot_answer.answer_question import AnswerObjectIterator +from danswer.chat.answer import Answer +from danswer.chat.answer import AnswerStream +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import PromptConfig from danswer.tools.force import ForceUseTool from danswer.tools.tool_implementations.search.search_tool import SearchTool from tests.regression.answer_quality.run_qa import _process_and_write_query_results @@ -60,7 +60,7 @@ def test_skip_gen_ai_answer_generation_flag( skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, ) count = 0 - for _ in cast(AnswerObjectIterator, answer.processed_streamed_output): + for _ in cast(AnswerStream, answer.processed_streamed_output): count += 1 assert count == 3 if skip_gen_ai_answer_generation else 4 if not skip_gen_ai_answer_generation: diff --git a/backend/tests/unit/danswer/indexing/test_indexing_pipeline.py b/backend/tests/unit/danswer/indexing/test_indexing_pipeline.py new file mode 100644 index 00000000000..612535f67ed --- /dev/null +++ b/backend/tests/unit/danswer/indexing/test_indexing_pipeline.py @@ -0,0 +1,120 @@ +from typing import List + +from danswer.configs.app_configs import MAX_DOCUMENT_CHARS +from danswer.connectors.models import Document +from danswer.connectors.models import DocumentSource +from danswer.connectors.models import Section +from danswer.indexing.indexing_pipeline import filter_documents + + +def create_test_document( + doc_id: str = "test_id", + title: str | None = "Test Title", + semantic_id: str = "test_semantic_id", + sections: List[Section] | None = None, +) -> Document: + if sections is None: + sections = [Section(text="Test content", link="test_link")] + return Document( + id=doc_id, + title=title, + semantic_identifier=semantic_id, + sections=sections, + source=DocumentSource.FILE, + metadata={}, + ) + + +def test_filter_documents_empty_title_and_content() -> None: + doc = create_test_document( + title="", semantic_id="", sections=[Section(text="", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 0 + + +def test_filter_documents_empty_title_with_content() -> None: + doc = create_test_document( + title="", sections=[Section(text="Valid content", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert result[0].id == "test_id" + + +def test_filter_documents_empty_content_with_title() -> None: + doc = create_test_document( + title="Valid Title", sections=[Section(text="", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert result[0].id == "test_id" + + +def test_filter_documents_exceeding_max_chars() -> None: + if not MAX_DOCUMENT_CHARS: # Skip if no max chars configured + return + long_text = "a" * (MAX_DOCUMENT_CHARS + 1) + doc = create_test_document(sections=[Section(text=long_text, link="test_link")]) + result = filter_documents([doc]) + assert len(result) == 0 + + +def test_filter_documents_valid_document() -> None: + doc = create_test_document( + title="Valid Title", sections=[Section(text="Valid content", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert result[0].id == "test_id" + assert result[0].title == "Valid Title" + + +def test_filter_documents_whitespace_only() -> None: + doc = create_test_document( + title=" ", semantic_id=" ", sections=[Section(text=" ", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 0 + + +def test_filter_documents_semantic_id_no_title() -> None: + doc = create_test_document( + title=None, + semantic_id="Valid Semantic ID", + sections=[Section(text="Valid content", link="test_link")], + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert result[0].semantic_identifier == "Valid Semantic ID" + + +def test_filter_documents_multiple_sections() -> None: + doc = create_test_document( + sections=[ + Section(text="Content 1", link="test_link"), + Section(text="Content 2", link="test_link"), + Section(text="Content 3", link="test_link"), + ] + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert len(result[0].sections) == 3 + + +def test_filter_documents_multiple_documents() -> None: + docs = [ + create_test_document(doc_id="1", title="Title 1"), + create_test_document( + doc_id="2", title="", sections=[Section(text="", link="test_link")] + ), # Should be filtered + create_test_document(doc_id="3", title="Title 3"), + ] + result = filter_documents(docs) + assert len(result) == 2 + assert {doc.id for doc in result} == {"1", "3"} + + +def test_filter_documents_empty_batch() -> None: + result = filter_documents([]) + assert len(result) == 0 diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_quote_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_quote_processing.py deleted file mode 100644 index 390d838043a..00000000000 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_quote_processing.py +++ /dev/null @@ -1,351 +0,0 @@ -import json -from datetime import datetime - -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuotes -from danswer.chat.models import LlmDoc -from danswer.configs.constants import DocumentSource -from danswer.llm.answering.stream_processing.quotes_processing import ( - QuotesProcessor, -) - -mock_docs = [ - LlmDoc( - document_id=f"doc_{int(id/2)}", - content="Document is a doc", - blurb=f"Document #{id}", - semantic_identifier=f"Doc {id}", - source_type=DocumentSource.WEB, - metadata={}, - updated_at=datetime.now(), - link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None, - source_links={0: "https://mintlify.com/docs/settings/broken-links"}, - ) - for id in range(10) -] - - -def _process_tokens( - processor: QuotesProcessor, tokens: list[str] -) -> tuple[str, list[str]]: - """Process a list of tokens and return the answer and quotes. - - Args: - processor: QuotesProcessor instance - tokens: List of tokens to process - - Returns: - Tuple of (answer_text, list_of_quotes) - """ - answer = "" - quotes: list[str] = [] - - # need to add a None to the end to simulate the end of the stream - for token in tokens + [None]: - for output in processor.process_token(token): - if isinstance(output, DanswerAnswerPiece): - if output.answer_piece: - answer += output.answer_piece - elif isinstance(output, DanswerQuotes): - quotes.extend(q.quote for q in output.quotes) - - return answer, quotes - - -def test_process_model_tokens_answer() -> None: - tokens_with_quotes = [ - "{", - "\n ", - '"answer": "Yes', - ", Danswer allows", - " customized prompts. This", - " feature", - " is currently being", - " developed and implemente", - "d to", - " improve", - " the accuracy", - " of", - " Language", - " Models (", - "LL", - "Ms) for", - " different", - " companies", - ".", - " The custom", - "ized prompts feature", - " woul", - "d allow users to ad", - "d person", - "alized prom", - "pts through", - " an", - " interface or", - " metho", - "d,", - " which would then be used to", - " train", - " the LLM.", - " This enhancement", - " aims to make", - " Danswer more", - " adaptable to", - " different", - " business", - " contexts", - " by", - " tail", - "oring it", - " to the specific language", - " an", - "d terminology", - " used within", - " a", - " company.", - " Additionally", - ",", - " Danswer already", - " supports creating", - " custom AI", - " Assistants with", - " different", - " prom", - "pts and backing", - " knowledge", - " sets", - ",", - " which", - " is", - " a form", - " of prompt", - " customization. However, it", - "'s important to nLogging Details LiteLLM-Success Call: Noneote that some", - " aspects", - " of prompt", - " customization,", - " such as for", - " Sl", - "ack", - "b", - "ots, may", - " still", - " be in", - " development or have", - ' limitations.",', - '\n "quotes": [', - '\n "We', - " woul", - "d like to ad", - "d customized prompts for", - " different", - " companies to improve the accuracy of", - " Language", - " Model", - " (LLM)", - '.",\n "A', - " new", - " feature that", - " allows users to add personalize", - "d prompts.", - " This would involve", - " creating", - " an interface or method for", - " users to input", - " their", - " own", - " prom", - "pts,", - " which would then be used to", - ' train the LLM.",', - '\n "Create', - " custom AI Assistants with", - " different prompts and backing knowledge", - ' sets.",', - '\n "This', - " PR", - " fixes", - " https", - "://github.com/dan", - "swer-ai/dan", - "swer/issues/1", - "584", - " by", - " setting", - " the system", - " default", - " prompt for", - " sl", - "ackbots const", - "rained by", - " ", - "document sets", - ".", - " It", - " probably", - " isn", - "'t ideal", - " -", - " it", - " might", - " be pref", - "erable to be", - " able to select", - " a prompt for", - " the", - " slackbot from", - " the", - " admin", - " panel", - " -", - " but it sol", - "ves the immediate problem", - " of", - " the slack", - " listener", - " cr", - "ashing when", - " configure", - "d this", - ' way."\n ]', - "\n}", - "", - ] - - processor = QuotesProcessor(context_docs=mock_docs) - answer, quotes = _process_tokens(processor, tokens_with_quotes) - - s_json = "".join(tokens_with_quotes) - j = json.loads(s_json) - expected_answer = j["answer"] - assert expected_answer == answer - # NOTE: no quotes, since the docs don't match the quotes - assert len(quotes) == 0 - - -def test_simple_json_answer() -> None: - tokens = [ - "```", - "json", - "\n", - "{", - '"answer": "This is a simple ', - "answer.", - '",\n"', - 'quotes": []', - "\n}", - "\n", - "```", - ] - processor = QuotesProcessor(context_docs=mock_docs) - answer, quotes = _process_tokens(processor, tokens) - - assert "This is a simple answer." == answer - assert len(quotes) == 0 - - -def test_json_answer_with_quotes() -> None: - tokens = [ - "```", - "json", - "\n", - "{", - '"answer": "This ', - "is a ", - "split ", - "answer.", - '",\n"', - 'quotes": []', - "\n}", - "\n", - "```", - ] - processor = QuotesProcessor(context_docs=mock_docs) - answer, quotes = _process_tokens(processor, tokens) - - assert "This is a split answer." == answer - assert len(quotes) == 0 - - -def test_json_answer_with_quotes_one_chunk() -> None: - tokens = ['```json\n{"answer": "z",\n"quotes": ["Document"]\n}\n```'] - processor = QuotesProcessor(context_docs=mock_docs) - answer, quotes = _process_tokens(processor, tokens) - - assert "z" == answer - assert len(quotes) == 1 - assert quotes[0] == "Document" - - -def test_json_answer_split_tokens() -> None: - tokens = [ - "```", - "json", - "\n", - "{", - '\n"', - 'answer": "This ', - "is a ", - "split ", - "answer.", - '",\n"', - 'quotes": []', - "\n}", - "\n", - "```", - ] - processor = QuotesProcessor(context_docs=mock_docs) - answer, quotes = _process_tokens(processor, tokens) - - assert "This is a split answer." == answer - assert len(quotes) == 0 - - -def test_lengthy_prefixed_json_with_quotes() -> None: - tokens = [ - "This is my response in json\n\n", - "```", - "json", - "\n", - "{", - '"answer": "This is a simple ', - "answer.", - '",\n"', - 'quotes": ["Document"]', - "\n}", - "\n", - "```", - ] - processor = QuotesProcessor(context_docs=mock_docs) - answer, quotes = _process_tokens(processor, tokens) - - assert "This is a simple answer." == answer - assert len(quotes) == 1 - assert quotes[0] == "Document" - - -def test_json_with_lengthy_prefix_and_quotes() -> None: - tokens = [ - "*** Based on the provided documents, there does not appear to be any information ", - "directly relevant to answering which documents are my favorite. ", - "The documents seem to be focused on describing the Danswer product ", - "and its features/use cases. Since I do not have personal preferences ", - "for documents, I will provide a general response:\n\n", - "```", - "json", - "\n", - "{", - '"answer": "This is a simple ', - "answer.", - '",\n"', - 'quotes": ["Document"]', - "\n}", - "\n", - "```", - ] - processor = QuotesProcessor(context_docs=mock_docs) - answer, quotes = _process_tokens(processor, tokens) - - assert "This is a simple answer." == answer - assert len(quotes) == 1 - assert quotes[0] == "Document" diff --git a/backend/tests/unit/model_server/test_embedding.py b/backend/tests/unit/model_server/test_embedding.py new file mode 100644 index 00000000000..6781ab27aa0 --- /dev/null +++ b/backend/tests/unit/model_server/test_embedding.py @@ -0,0 +1,198 @@ +import asyncio +import time +from collections.abc import AsyncGenerator +from typing import Any +from typing import List +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from httpx import AsyncClient +from litellm.exceptions import RateLimitError + +from model_server.encoders import CloudEmbedding +from model_server.encoders import embed_text +from model_server.encoders import local_rerank +from model_server.encoders import process_embed_request +from shared_configs.enums import EmbeddingProvider +from shared_configs.enums import EmbedTextType +from shared_configs.model_server_models import EmbedRequest + + +@pytest.fixture +async def mock_http_client() -> AsyncGenerator[AsyncMock, None]: + with patch("httpx.AsyncClient") as mock: + client = AsyncMock(spec=AsyncClient) + mock.return_value = client + client.post = AsyncMock() + async with client as c: + yield c + + +@pytest.fixture +def sample_embeddings() -> List[List[float]]: + return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + +@pytest.mark.asyncio +async def test_cloud_embedding_context_manager() -> None: + async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding: + assert not embedding._closed + assert embedding._closed + + +@pytest.mark.asyncio +async def test_cloud_embedding_explicit_close() -> None: + embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) + assert not embedding._closed + await embedding.aclose() + assert embedding._closed + + +@pytest.mark.asyncio +async def test_openai_embedding( + mock_http_client: AsyncMock, sample_embeddings: List[List[float]] +) -> None: + with patch("openai.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings] + mock_client.embeddings.create = AsyncMock(return_value=mock_response) + + embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) + result = await embedding._embed_openai( + ["test1", "test2"], "text-embedding-ada-002" + ) + + assert result == sample_embeddings + mock_client.embeddings.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_embed_text_cloud_provider() -> None: + with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed: + mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]]) + + result = await embed_text( + texts=["test1", "test2"], + text_type=EmbedTextType.QUERY, + model_name="fake-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key="fake-key", + provider_type=EmbeddingProvider.OPENAI, + prefix=None, + api_url=None, + api_version=None, + ) + + assert result == [[0.1, 0.2], [0.3, 0.4]] + mock_embed.assert_called_once() + + +@pytest.mark.asyncio +async def test_embed_text_local_model() -> None: + with patch("model_server.encoders.get_embedding_model") as mock_get_model: + mock_model = MagicMock() + mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]] + mock_get_model.return_value = mock_model + + result = await embed_text( + texts=["test1", "test2"], + text_type=EmbedTextType.QUERY, + model_name="fake-local-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key=None, + provider_type=None, + prefix=None, + api_url=None, + api_version=None, + ) + + assert result == [[0.1, 0.2], [0.3, 0.4]] + mock_model.encode.assert_called_once() + + +@pytest.mark.asyncio +async def test_local_rerank() -> None: + with patch("model_server.encoders.get_local_reranking_model") as mock_get_model: + mock_model = MagicMock() + mock_array = MagicMock() + mock_array.tolist.return_value = [0.8, 0.6] + mock_model.predict.return_value = mock_array + mock_get_model.return_value = mock_model + + result = await local_rerank( + query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model" + ) + + assert result == [0.8, 0.6] + mock_model.predict.assert_called_once() + + +@pytest.mark.asyncio +async def test_rate_limit_handling() -> None: + with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed: + mock_embed.side_effect = RateLimitError( + "Rate limit exceeded", llm_provider="openai", model="fake-model" + ) + + with pytest.raises(RateLimitError): + await embed_text( + texts=["test"], + text_type=EmbedTextType.QUERY, + model_name="fake-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key="fake-key", + provider_type=EmbeddingProvider.OPENAI, + prefix=None, + api_url=None, + api_version=None, + ) + + +@pytest.mark.asyncio +async def test_concurrent_embeddings() -> None: + def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]: + time.sleep(5) + return [[0.1, 0.2, 0.3]] + + test_req = EmbedRequest( + texts=["test"], + model_name="'nomic-ai/nomic-embed-text-v1'", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key=None, + provider_type=None, + text_type=EmbedTextType.QUERY, + manual_query_prefix=None, + manual_passage_prefix=None, + api_url=None, + api_version=None, + ) + + with patch("model_server.encoders.get_embedding_model") as mock_get_model: + mock_model = MagicMock() + mock_model.encode = mock_encode + mock_get_model.return_value = mock_model + start_time = time.time() + + tasks = [process_embed_request(test_req) for _ in range(5)] + await asyncio.gather(*tasks) + + end_time = time.time() + + # 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread + # However, the developer may still introduce unnecessary blocking above the mock and this test will + # still pass as long as it's less than (7 - 5) / 5 seconds + assert end_time - start_time < 7 diff --git a/ct.yaml b/ct.yaml index f568ef5d52b..cec4478c850 100644 --- a/ct.yaml +++ b/ct.yaml @@ -6,7 +6,7 @@ chart-dirs: # must be kept in sync with Chart.yaml chart-repos: - - vespa=https://danswer-ai.github.io/vespa-helm-charts + - vespa=https://onyx-dot-app.github.io/vespa-helm-charts - postgresql=https://charts.bitnami.com/bitnami helm-extra-args: --debug --timeout 600s diff --git a/deployment/cloud_kubernetes/workers/beat.yaml b/deployment/cloud_kubernetes/workers/beat.yaml index ecd5a121900..cfe7f79cd91 100644 --- a/deployment/cloud_kubernetes/workers/beat.yaml +++ b/deployment/cloud_kubernetes/workers/beat.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-beat - image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10 + image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4 imagePullPolicy: IfNotPresent command: [ diff --git a/deployment/cloud_kubernetes/workers/heavy_worker.yaml b/deployment/cloud_kubernetes/workers/heavy_worker.yaml index 3a4ce1a3805..349ebb4f21e 100644 --- a/deployment/cloud_kubernetes/workers/heavy_worker.yaml +++ b/deployment/cloud_kubernetes/workers/heavy_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-heavy - image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.12 + image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4 imagePullPolicy: IfNotPresent command: [ diff --git a/deployment/cloud_kubernetes/workers/indexing_worker.yaml b/deployment/cloud_kubernetes/workers/indexing_worker.yaml index 36ce0da1400..443bf236e00 100644 --- a/deployment/cloud_kubernetes/workers/indexing_worker.yaml +++ b/deployment/cloud_kubernetes/workers/indexing_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-indexing - image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.12 + image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4 imagePullPolicy: IfNotPresent command: [ diff --git a/deployment/cloud_kubernetes/workers/light_worker.yaml b/deployment/cloud_kubernetes/workers/light_worker.yaml index 171aa284fa7..eaa2ad7b92c 100644 --- a/deployment/cloud_kubernetes/workers/light_worker.yaml +++ b/deployment/cloud_kubernetes/workers/light_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-light - image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.12 + image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4 imagePullPolicy: IfNotPresent command: [ diff --git a/deployment/cloud_kubernetes/workers/primary.yaml b/deployment/cloud_kubernetes/workers/primary.yaml index 3f30eeb0a26..ae365b1516c 100644 --- a/deployment/cloud_kubernetes/workers/primary.yaml +++ b/deployment/cloud_kubernetes/workers/primary.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-primary - image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.12 + image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4 imagePullPolicy: IfNotPresent command: [ diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 10108093096..19991de2d37 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -130,6 +130,7 @@ services: restart: always environment: - ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-} + - JWT_PUBLIC_KEY_URL=${JWT_PUBLIC_KEY_URL:-} # used for JWT authentication of users via API # Gen AI Settings (Needed by DanswerBot) - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} @@ -182,6 +183,13 @@ services: - GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} + - MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-} + - MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-} + # Egnyte OAuth Configs + - EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-} + - EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-} + - EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-} + - EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-} # Celery Configs (defaults are set in the supervisord.conf file. # prefer doing that to have one source of defaults) - CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-} diff --git a/deployment/docker_compose/docker-compose.resources.yml b/deployment/docker_compose/docker-compose.resources.yml new file mode 100644 index 00000000000..513c59c3b29 --- /dev/null +++ b/deployment/docker_compose/docker-compose.resources.yml @@ -0,0 +1,74 @@ +# Docker service resource limits. Most are commented out by default. +# 'background' service has preset (override-able) limits due to variable resource needs. +# Uncomment and set env vars for specific service limits. +# See: https://docs.danswer.dev/deployment/resource-sizing for details. + +services: + background: + deploy: + resources: + limits: + cpus: ${BACKGROUND_CPU_LIMIT:-4} + memory: ${BACKGROUND_MEM_LIMIT:-4g} + # reservations: + # cpus: ${BACKGROUND_CPU_RESERVATION} + # memory: ${BACKGROUND_MEM_RESERVATION} + + # nginx: + # deploy: + # resources: + # limits: + # cpus: ${NGINX_CPU_LIMIT} + # memory: ${NGINX_MEM_LIMIT} + # reservations: + # cpus: ${NGINX_CPU_RESERVATION} + # memory: ${NGINX_MEM_RESERVATION} + # api_server: + # deploy: + # resources: + # limits: + # cpus: ${API_SERVER_CPU_LIMIT} + # memory: ${API_SERVER_MEM_LIMIT} + # reservations: + # cpus: ${API_SERVER_CPU_RESERVATION} + # memory: ${API_SERVER_MEM_RESERVATION} + + # index: + # deploy: + # resources: + # limits: + # cpus: ${VESPA_CPU_LIMIT} + # memory: ${VESPA_MEM_LIMIT} + # reservations: + # cpus: ${VESPA_CPU_RESERVATION} + # memory: ${VESPA_MEM_RESERVATION} + + # inference_model_server: + # deploy: + # resources: + # limits: + # cpus: ${INFERENCE_CPU_LIMIT} + # memory: ${INFERENCE_MEM_LIMIT} + # reservations: + # cpus: ${INFERENCE_CPU_RESERVATION} + # memory: ${INFERENCE_MEM_RESERVATION} + + # indexing_model_server: + # deploy: + # resources: + # limits: + # cpus: ${INDEXING_CPU_LIMIT} + # memory: ${INDEXING_MEM_LIMIT} + # reservations: + # cpus: ${INDEXING_CPU_RESERVATION} + # memory: ${INDEXING_MEM_RESERVATION} + + # relational_db: + # deploy: + # resources: + # limits: + # cpus: ${POSTGRES_CPU_LIMIT} + # memory: ${POSTGRES_MEM_LIMIT} + # reservations: + # cpus: ${POSTGRES_CPU_RESERVATION} + # memory: ${POSTGRES_MEM_RESERVATION} diff --git a/deployment/helm/charts/danswer/Chart.lock b/deployment/helm/charts/danswer/Chart.lock index 26cc24e4494..af26f510eb1 100644 --- a/deployment/helm/charts/danswer/Chart.lock +++ b/deployment/helm/charts/danswer/Chart.lock @@ -3,13 +3,13 @@ dependencies: repository: https://charts.bitnami.com/bitnami version: 14.3.1 - name: vespa - repository: https://danswer-ai.github.io/vespa-helm-charts - version: 0.2.16 + repository: https://onyx-dot-app.github.io/vespa-helm-charts + version: 0.2.18 - name: nginx repository: oci://registry-1.docker.io/bitnamicharts version: 15.14.0 - name: redis repository: https://charts.bitnami.com/bitnami version: 20.1.0 -digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232 -generated: "2024-11-07T09:39:30.17171-08:00" +digest: sha256:5c9eb3d55d5f8e3beb64f26d26f686c8d62755daa10e2e6d87530bdf2fbbf957 +generated: "2024-12-10T10:47:35.812483-08:00" diff --git a/deployment/helm/charts/danswer/Chart.yaml b/deployment/helm/charts/danswer/Chart.yaml index 8cda8e8ba2e..b033122c0fc 100644 --- a/deployment/helm/charts/danswer/Chart.yaml +++ b/deployment/helm/charts/danswer/Chart.yaml @@ -23,8 +23,8 @@ dependencies: repository: https://charts.bitnami.com/bitnami condition: postgresql.enabled - name: vespa - version: 0.2.16 - repository: https://danswer-ai.github.io/vespa-helm-charts + version: 0.2.18 + repository: https://onyx-dot-app.github.io/vespa-helm-charts condition: vespa.enabled - name: nginx version: 15.14.0 diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml index 176e468c110..84bd6747973 100644 --- a/deployment/kubernetes/env-configmap.yaml +++ b/deployment/kubernetes/env-configmap.yaml @@ -61,6 +61,8 @@ data: WEB_CONNECTOR_VALIDATE_URLS: "" GONG_CONNECTOR_START_TIME: "" NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP: "" + MAX_DOCUMENT_CHARS: "" + MAX_FILE_SIZE_BYTES: "" # DanswerBot SlackBot Configs DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: "" DANSWER_BOT_DISPLAY_ERROR_MSGS: "" diff --git a/node_modules/.package-lock.json b/node_modules/.package-lock.json new file mode 100644 index 00000000000..b3aaf2c4dec --- /dev/null +++ b/node_modules/.package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "danswer", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/web/@types/favicon-fetch.d.ts b/web/@types/favicon-fetch.d.ts new file mode 100644 index 00000000000..9b4d38319e7 --- /dev/null +++ b/web/@types/favicon-fetch.d.ts @@ -0,0 +1,9 @@ +declare module "favicon-fetch" { + interface FaviconFetchOptions { + uri: string; + } + + function faviconFetch(options: FaviconFetchOptions): string | null; + + export default faviconFetch; +} diff --git a/web/Dockerfile b/web/Dockerfile index 8093400a7f1..8b91615f359 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -66,6 +66,9 @@ ARG NEXT_PUBLIC_POSTHOG_HOST ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY} ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST} +ARG NEXT_PUBLIC_CLOUD_ENABLED +ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED} + ARG NEXT_PUBLIC_SENTRY_DSN ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN} @@ -138,6 +141,9 @@ ARG NEXT_PUBLIC_POSTHOG_HOST ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY} ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST} +ARG NEXT_PUBLIC_CLOUD_ENABLED +ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED} + ARG NEXT_PUBLIC_SENTRY_DSN ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN} diff --git a/web/package-lock.json b/web/package-lock.json index 8315feea2a5..986c24d972e 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "qa", - "version": "0.2.0-dev", + "version": "1.0.0-dev", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "qa", - "version": "0.2.0-dev", + "version": "1.0.0-dev", "dependencies": { "@dnd-kit/core": "^6.1.0", "@dnd-kit/modifiers": "^7.0.0", @@ -15,11 +15,13 @@ "@headlessui/react": "^2.2.0", "@headlessui/tailwindcss": "^0.2.1", "@phosphor-icons/react": "^2.0.8", - "@radix-ui/react-dialog": "^1.0.5", + "@radix-ui/react-checkbox": "^1.1.2", + "@radix-ui/react-dialog": "^1.1.2", "@radix-ui/react-popover": "^1.1.2", "@radix-ui/react-select": "^2.1.2", "@radix-ui/react-separator": "^1.1.0", "@radix-ui/react-slot": "^1.1.0", + "@radix-ui/react-switch": "^1.1.1", "@radix-ui/react-tabs": "^1.1.1", "@radix-ui/react-tooltip": "^1.1.3", "@sentry/nextjs": "^8.34.0", @@ -35,6 +37,7 @@ "class-variance-authority": "^0.7.0", "clsx": "^2.1.1", "date-fns": "^3.6.0", + "favicon-fetch": "^1.0.0", "formik": "^2.2.9", "js-cookie": "^3.0.5", "lodash": "^4.17.21", @@ -65,6 +68,7 @@ "tailwindcss-animate": "^1.0.7", "typescript": "5.0.3", "uuid": "^9.0.1", + "vaul": "^1.1.1", "yup": "^1.4.0" }, "devDependencies": { @@ -2632,12 +2636,10 @@ "license": "MIT" }, "node_modules/@radix-ui/primitive": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.0.1.tgz", - "integrity": "sha512-yQ8oGX2GVsEYMWGxcovu1uGWPCxV5BFfeeYxqPmuAzUyLT9qmaMXSAhXpb0WrspIeqYzdJpkh2vHModJPgRIaw==", - "dependencies": { - "@babel/runtime": "^7.13.10" - } + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.0.tgz", + "integrity": "sha512-4Z8dn6Upk0qk4P74xBhZ6Hd/w0mPEzOOLxy4xiPXOXqjF7jZS0VAKk7/x/H6FyY2zCkYJqePf1G5KmkmNJ4RBA==", + "license": "MIT" }, "node_modules/@radix-ui/react-arrow": { "version": "1.1.0", @@ -2661,15 +2663,20 @@ } } }, - "node_modules/@radix-ui/react-collection": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.0.tgz", - "integrity": "sha512-GZsZslMJEyo1VKm5L1ZJY8tGDxZNPAoUeQUIbKeJfoi7Q4kmig5AsgLMYYuyYbfjd8fBmFORAIwYAkXMnXZgZw==", + "node_modules/@radix-ui/react-checkbox": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-checkbox/-/react-checkbox-1.1.2.tgz", + "integrity": "sha512-/i0fl686zaJbDQLNKrkCbMyDm6FQMt4jg323k7HuqitoANm9sE23Ql8yOK3Wusk34HSLKDChhMux05FnP6KUkw==", + "license": "MIT", "dependencies": { + "@radix-ui/primitive": "1.1.0", "@radix-ui/react-compose-refs": "1.1.0", - "@radix-ui/react-context": "1.1.0", + "@radix-ui/react-context": "1.1.1", + "@radix-ui/react-presence": "1.1.1", "@radix-ui/react-primitive": "2.0.0", - "@radix-ui/react-slot": "1.1.0" + "@radix-ui/react-use-controllable-state": "1.1.0", + "@radix-ui/react-use-previous": "1.1.0", + "@radix-ui/react-use-size": "1.1.0" }, "peerDependencies": { "@types/react": "*", @@ -2686,10 +2693,11 @@ } } }, - "node_modules/@radix-ui/react-compose-refs": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.0.tgz", - "integrity": "sha512-b4inOtiaOnYf9KWyO3jAeeCG6FeyfY6ldiEPanbUjWd+xIk5wZeHa8yVwmrJ2vderhu/BQvzCrJI0lHd+wIiqw==", + "node_modules/@radix-ui/react-checkbox/node_modules/@radix-ui/react-context": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", + "integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==", + "license": "MIT", "peerDependencies": { "@types/react": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" @@ -2700,259 +2708,21 @@ } } }, - "node_modules/@radix-ui/react-context": { + "node_modules/@radix-ui/react-collection": { "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.0.tgz", - "integrity": "sha512-OKrckBy+sMEgYM/sMmqmErVn0kZqrHPJze+Ql3DzYsDDp0hl0L62nx/2122/Bvps1qz645jlcu2tD9lrRSdf8A==", - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz", - "integrity": "sha512-GjWJX/AUpB703eEBanuBnIWdIXg6NvJFCXcNlSZk4xdszCdhrJgBoUd1cGk67vFO+WdA2pfI/plOpqz/5GUP6Q==", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/primitive": "1.0.1", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-context": "1.0.1", - "@radix-ui/react-dismissable-layer": "1.0.5", - "@radix-ui/react-focus-guards": "1.0.1", - "@radix-ui/react-focus-scope": "1.0.4", - "@radix-ui/react-id": "1.0.1", - "@radix-ui/react-portal": "1.0.4", - "@radix-ui/react-presence": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-slot": "1.0.2", - "@radix-ui/react-use-controllable-state": "1.0.1", - "aria-hidden": "^1.1.1", - "react-remove-scroll": "2.5.5" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-compose-refs": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.0.1.tgz", - "integrity": "sha512-fDSBgd44FKHa1FRMU59qBMPFcl2PZE+2nmqunj+BWFyYYjnhIDWL2ItDs3rrbJDQOtzt5nIebLCQc4QRfz6LJw==", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-context": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.0.1.tgz", - "integrity": "sha512-ebbrdFoYTcuZ0v4wG5tedGnp9tzcV8awzsxYph7gXUyvnNLuTIcCk1q17JEbnVhXAKG9oX3KtchwiMIAYp9NLg==", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-dismissable-layer": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.5.tgz", - "integrity": "sha512-aJeDjQhywg9LBu2t/At58hCvr7pEm0o2Ke1x33B+MhjNmmZ17sy4KImo0KPLgsnc/zN7GPdce8Cnn0SWvwZO7g==", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/primitive": "1.0.1", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1", - "@radix-ui/react-use-escape-keydown": "1.0.3" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-dismissable-layer/node_modules/@radix-ui/react-use-callback-ref": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.1.tgz", - "integrity": "sha512-D94LjX4Sp0xJFVaoQOd3OO9k7tpBYNOXdVhkltUbGv2Qb9OXdrg/CpsjlZv7ia14Sylv398LswWBVVu5nqKzAQ==", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-dismissable-layer/node_modules/@radix-ui/react-use-escape-keydown": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-1.0.3.tgz", - "integrity": "sha512-vyL82j40hcFicA+M4Ex7hVkB9vHgSse1ZWomAqV2Je3RleKGO5iM8KMOEtfoSB0PnIelMd2lATjTGMYqN5ylTg==", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-use-callback-ref": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-focus-guards": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.0.1.tgz", - "integrity": "sha512-Rect2dWbQ8waGzhMavsIbmSVCgYxkXLxxR3ZvCX79JOglzdEy4JXMb98lq4hPxUbLr77nP0UOGf4rcMU+s1pUA==", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-focus-scope": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.0.4.tgz", - "integrity": "sha512-sL04Mgvf+FmyvZeYfNu1EPAaaxD+aw7cYeIB9L9Fvq8+urhltTRaEo5ysKOpHuKPclsZcSUMKlN05x4u+CINpA==", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-focus-scope/node_modules/@radix-ui/react-use-callback-ref": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.1.tgz", - "integrity": "sha512-D94LjX4Sp0xJFVaoQOd3OO9k7tpBYNOXdVhkltUbGv2Qb9OXdrg/CpsjlZv7ia14Sylv398LswWBVVu5nqKzAQ==", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-id": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-id/-/react-id-1.0.1.tgz", - "integrity": "sha512-tI7sT/kqYp8p96yGWY1OAnLHrqDgzHefRBKQ2YAkBS5ja7QLcZ9Z/uY7bEjPUatf8RomoXM8/1sMj1IJaE5UzQ==", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-use-layout-effect": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-id/node_modules/@radix-ui/react-use-layout-effect": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.0.1.tgz", - "integrity": "sha512-v/5RegiJWYdoCvMnITBkNNx6bCj20fiaJnWtRkU18yITptraXjffz5Qbn05uOiQnOvi+dbkznkoaMltz1GnszQ==", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-portal": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.0.4.tgz", - "integrity": "sha512-Qki+C/EuGUVCQTOTD5vzJzJuMUlewbzuKyUy+/iHM2uwGiru9gZeBJtHAPKAEkB5KWGi9mP/CHKcY0wt1aW45Q==", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.0.tgz", + "integrity": "sha512-GZsZslMJEyo1VKm5L1ZJY8tGDxZNPAoUeQUIbKeJfoi7Q4kmig5AsgLMYYuyYbfjd8fBmFORAIwYAkXMnXZgZw==", "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-primitive": "1.0.3" + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-context": "1.1.0", + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-slot": "1.1.0" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { @@ -2963,40 +2733,27 @@ } } }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-presence": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-presence/-/react-presence-1.0.1.tgz", - "integrity": "sha512-UXLW4UAbIY5ZjcvzjfRFo5gxva8QirC9hF7wRE4U5gz+TP0DbRk+//qyuAQ1McDxBt1xNMBTaciFGvEmJvAZCg==", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-use-layout-effect": "1.0.1" - }, + "node_modules/@radix-ui/react-compose-refs": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.0.tgz", + "integrity": "sha512-b4inOtiaOnYf9KWyO3jAeeCG6FeyfY6ldiEPanbUjWd+xIk5wZeHa8yVwmrJ2vderhu/BQvzCrJI0lHd+wIiqw==", "peerDependencies": { "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { "optional": true - }, - "@types/react-dom": { - "optional": true } } }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-presence/node_modules/@radix-ui/react-use-layout-effect": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.0.1.tgz", - "integrity": "sha512-v/5RegiJWYdoCvMnITBkNNx6bCj20fiaJnWtRkU18yITptraXjffz5Qbn05uOiQnOvi+dbkznkoaMltz1GnszQ==", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, + "node_modules/@radix-ui/react-context": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.0.tgz", + "integrity": "sha512-OKrckBy+sMEgYM/sMmqmErVn0kZqrHPJze+Ql3DzYsDDp0hl0L62nx/2122/Bvps1qz645jlcu2tD9lrRSdf8A==", "peerDependencies": { "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { @@ -3004,184 +2761,50 @@ } } }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-primitive": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-1.0.3.tgz", - "integrity": "sha512-yi58uVyoAcK/Nq1inRY56ZSjKypBNKTa/1mcL8qdl6oJeEaDbOldlzrGn7P6Q3Id5d+SYNGc5AJgc4vGhjs5+g==", + "node_modules/@radix-ui/react-dialog": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.2.tgz", + "integrity": "sha512-Yj4dZtqa2o+kG61fzB0H2qUvmwBA2oyQroGLyNtBj1beo1khoQ3q1a2AO8rrQYjd8256CO9+N8L9tvsS+bnIyA==", + "license": "MIT", "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-slot": "1.0.2" + "@radix-ui/primitive": "1.1.0", + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-context": "1.1.1", + "@radix-ui/react-dismissable-layer": "1.1.1", + "@radix-ui/react-focus-guards": "1.1.1", + "@radix-ui/react-focus-scope": "1.1.0", + "@radix-ui/react-id": "1.1.0", + "@radix-ui/react-portal": "1.1.2", + "@radix-ui/react-presence": "1.1.1", + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-slot": "1.1.0", + "@radix-ui/react-use-controllable-state": "1.1.0", + "aria-hidden": "^1.1.1", + "react-remove-scroll": "2.6.0" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { "optional": true }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-slot": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz", - "integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-compose-refs": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-use-controllable-state": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.0.1.tgz", - "integrity": "sha512-Svl5GY5FQeN758fWKrjM6Qb7asvXeiZltlT4U2gVfl8Gx5UAv2sMR0LWo8yhsIZh2oQ0eFdZ59aoOOMV7b47VA==", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-use-callback-ref": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-use-controllable-state/node_modules/@radix-ui/react-use-callback-ref": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.1.tgz", - "integrity": "sha512-D94LjX4Sp0xJFVaoQOd3OO9k7tpBYNOXdVhkltUbGv2Qb9OXdrg/CpsjlZv7ia14Sylv398LswWBVVu5nqKzAQ==", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/react-remove-scroll": { - "version": "2.5.5", - "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.5.5.tgz", - "integrity": "sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==", - "dependencies": { - "react-remove-scroll-bar": "^2.3.3", - "react-style-singleton": "^2.2.1", - "tslib": "^2.1.0", - "use-callback-ref": "^1.3.0", - "use-sidecar": "^1.1.2" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/react-remove-scroll/node_modules/react-remove-scroll-bar": { - "version": "2.3.6", - "resolved": "https://registry.npmjs.org/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.6.tgz", - "integrity": "sha512-DtSYaao4mBmX+HDo5YWYdBWQwYIQQshUV/dVxFxK+KM26Wjwp1gZ6rv6OC3oujI6Bfu6Xyg3TwK533AQutsn/g==", - "dependencies": { - "react-style-singleton": "^2.2.1", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/react-remove-scroll/node_modules/react-style-singleton": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.1.tgz", - "integrity": "sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==", - "dependencies": { - "get-nonce": "^1.0.0", - "invariant": "^2.2.4", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/react-remove-scroll/node_modules/use-callback-ref": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.3.2.tgz", - "integrity": "sha512-elOQwe6Q8gqZgDA8mrh44qRTQqpIHDcZ3hXTLjBe1i4ph8XpNJnO+aQf3NaG+lriLopI4HMx9VjQLfPQ6vhnoA==", - "dependencies": { - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dialog/node_modules/react-remove-scroll/node_modules/use-sidecar": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.2.tgz", - "integrity": "sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==", - "dependencies": { - "detect-node-es": "^1.1.0", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-context": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", + "integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==", + "license": "MIT", "peerDependencies": { - "@types/react": "^16.9.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { @@ -3230,10 +2853,20 @@ } } }, - "node_modules/@radix-ui/react-dismissable-layer/node_modules/@radix-ui/primitive": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.0.tgz", - "integrity": "sha512-4Z8dn6Upk0qk4P74xBhZ6Hd/w0mPEzOOLxy4xiPXOXqjF7jZS0VAKk7/x/H6FyY2zCkYJqePf1G5KmkmNJ4RBA==" + "node_modules/@radix-ui/react-focus-guards": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.1.1.tgz", + "integrity": "sha512-pSIwfrT1a6sIoDASCSpFwOasEwKTZWDw/iBdtnqKO7v6FeOzYJ7U53cPzYFVR3geGGXgVHaH+CdngrrAzqUGxg==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } }, "node_modules/@radix-ui/react-focus-scope": { "version": "1.1.0", @@ -3313,12 +2946,6 @@ } } }, - "node_modules/@radix-ui/react-popover/node_modules/@radix-ui/primitive": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.0.tgz", - "integrity": "sha512-4Z8dn6Upk0qk4P74xBhZ6Hd/w0mPEzOOLxy4xiPXOXqjF7jZS0VAKk7/x/H6FyY2zCkYJqePf1G5KmkmNJ4RBA==", - "license": "MIT" - }, "node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-context": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", @@ -3334,130 +2961,6 @@ } } }, - "node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-focus-guards": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.1.1.tgz", - "integrity": "sha512-pSIwfrT1a6sIoDASCSpFwOasEwKTZWDw/iBdtnqKO7v6FeOzYJ7U53cPzYFVR3geGGXgVHaH+CdngrrAzqUGxg==", - "license": "MIT", - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-popover/node_modules/react-remove-scroll": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.6.0.tgz", - "integrity": "sha512-I2U4JVEsQenxDAKaVa3VZ/JeJZe0/2DxPWL8Tj8yLKctQJQiZM52pn/GWFpSp8dftjM3pSAHVJZscAnC/y+ySQ==", - "license": "MIT", - "dependencies": { - "react-remove-scroll-bar": "^2.3.6", - "react-style-singleton": "^2.2.1", - "tslib": "^2.1.0", - "use-callback-ref": "^1.3.0", - "use-sidecar": "^1.1.2" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-popover/node_modules/react-remove-scroll/node_modules/react-remove-scroll-bar": { - "version": "2.3.6", - "resolved": "https://registry.npmjs.org/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.6.tgz", - "integrity": "sha512-DtSYaao4mBmX+HDo5YWYdBWQwYIQQshUV/dVxFxK+KM26Wjwp1gZ6rv6OC3oujI6Bfu6Xyg3TwK533AQutsn/g==", - "dependencies": { - "react-style-singleton": "^2.2.1", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-popover/node_modules/react-remove-scroll/node_modules/react-style-singleton": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.1.tgz", - "integrity": "sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==", - "dependencies": { - "get-nonce": "^1.0.0", - "invariant": "^2.2.4", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-popover/node_modules/react-remove-scroll/node_modules/use-callback-ref": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.3.2.tgz", - "integrity": "sha512-elOQwe6Q8gqZgDA8mrh44qRTQqpIHDcZ3hXTLjBe1i4ph8XpNJnO+aQf3NaG+lriLopI4HMx9VjQLfPQ6vhnoA==", - "dependencies": { - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-popover/node_modules/react-remove-scroll/node_modules/use-sidecar": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.2.tgz", - "integrity": "sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==", - "dependencies": { - "detect-node-es": "^1.1.0", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.9.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, "node_modules/@radix-ui/react-popper": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.0.tgz", @@ -3587,151 +3090,57 @@ } } }, - "node_modules/@radix-ui/react-roving-focus/node_modules/@radix-ui/primitive": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.0.tgz", - "integrity": "sha512-4Z8dn6Upk0qk4P74xBhZ6Hd/w0mPEzOOLxy4xiPXOXqjF7jZS0VAKk7/x/H6FyY2zCkYJqePf1G5KmkmNJ4RBA==" - }, - "node_modules/@radix-ui/react-select": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.1.2.tgz", - "integrity": "sha512-rZJtWmorC7dFRi0owDmoijm6nSJH1tVw64QGiNIZ9PNLyBDtG+iAq+XGsya052At4BfarzY/Dhv9wrrUr6IMZA==", - "license": "MIT", - "dependencies": { - "@radix-ui/number": "1.1.0", - "@radix-ui/primitive": "1.1.0", - "@radix-ui/react-collection": "1.1.0", - "@radix-ui/react-compose-refs": "1.1.0", - "@radix-ui/react-context": "1.1.1", - "@radix-ui/react-direction": "1.1.0", - "@radix-ui/react-dismissable-layer": "1.1.1", - "@radix-ui/react-focus-guards": "1.1.1", - "@radix-ui/react-focus-scope": "1.1.0", - "@radix-ui/react-id": "1.1.0", - "@radix-ui/react-popper": "1.2.0", - "@radix-ui/react-portal": "1.1.2", - "@radix-ui/react-primitive": "2.0.0", - "@radix-ui/react-slot": "1.1.0", - "@radix-ui/react-use-callback-ref": "1.1.0", - "@radix-ui/react-use-controllable-state": "1.1.0", - "@radix-ui/react-use-layout-effect": "1.1.0", - "@radix-ui/react-use-previous": "1.1.0", - "@radix-ui/react-visually-hidden": "1.1.0", - "aria-hidden": "^1.1.1", - "react-remove-scroll": "2.6.0" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", - "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select/node_modules/@radix-ui/primitive": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.0.tgz", - "integrity": "sha512-4Z8dn6Upk0qk4P74xBhZ6Hd/w0mPEzOOLxy4xiPXOXqjF7jZS0VAKk7/x/H6FyY2zCkYJqePf1G5KmkmNJ4RBA==", - "license": "MIT" - }, - "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-context": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", - "integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==", - "license": "MIT", - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-focus-guards": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.1.1.tgz", - "integrity": "sha512-pSIwfrT1a6sIoDASCSpFwOasEwKTZWDw/iBdtnqKO7v6FeOzYJ7U53cPzYFVR3geGGXgVHaH+CdngrrAzqUGxg==", - "license": "MIT", - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select/node_modules/react-remove-scroll": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.6.0.tgz", - "integrity": "sha512-I2U4JVEsQenxDAKaVa3VZ/JeJZe0/2DxPWL8Tj8yLKctQJQiZM52pn/GWFpSp8dftjM3pSAHVJZscAnC/y+ySQ==", - "license": "MIT", - "dependencies": { - "react-remove-scroll-bar": "^2.3.6", - "react-style-singleton": "^2.2.1", - "tslib": "^2.1.0", - "use-callback-ref": "^1.3.0", - "use-sidecar": "^1.1.2" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select/node_modules/react-remove-scroll/node_modules/react-remove-scroll-bar": { - "version": "2.3.6", - "resolved": "https://registry.npmjs.org/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.6.tgz", - "integrity": "sha512-DtSYaao4mBmX+HDo5YWYdBWQwYIQQshUV/dVxFxK+KM26Wjwp1gZ6rv6OC3oujI6Bfu6Xyg3TwK533AQutsn/g==", + "node_modules/@radix-ui/react-select": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.1.2.tgz", + "integrity": "sha512-rZJtWmorC7dFRi0owDmoijm6nSJH1tVw64QGiNIZ9PNLyBDtG+iAq+XGsya052At4BfarzY/Dhv9wrrUr6IMZA==", + "license": "MIT", "dependencies": { - "react-style-singleton": "^2.2.1", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" + "@radix-ui/number": "1.1.0", + "@radix-ui/primitive": "1.1.0", + "@radix-ui/react-collection": "1.1.0", + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-context": "1.1.1", + "@radix-ui/react-direction": "1.1.0", + "@radix-ui/react-dismissable-layer": "1.1.1", + "@radix-ui/react-focus-guards": "1.1.1", + "@radix-ui/react-focus-scope": "1.1.0", + "@radix-ui/react-id": "1.1.0", + "@radix-ui/react-popper": "1.2.0", + "@radix-ui/react-portal": "1.1.2", + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-slot": "1.1.0", + "@radix-ui/react-use-callback-ref": "1.1.0", + "@radix-ui/react-use-controllable-state": "1.1.0", + "@radix-ui/react-use-layout-effect": "1.1.0", + "@radix-ui/react-use-previous": "1.1.0", + "@radix-ui/react-visually-hidden": "1.1.0", + "aria-hidden": "^1.1.1", + "react-remove-scroll": "2.6.0" }, "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { "optional": true + }, + "@types/react-dom": { + "optional": true } } }, - "node_modules/@radix-ui/react-select/node_modules/react-remove-scroll/node_modules/react-style-singleton": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.1.tgz", - "integrity": "sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==", - "dependencies": { - "get-nonce": "^1.0.0", - "invariant": "^2.2.4", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, + "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-context": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", + "integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==", + "license": "MIT", "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { @@ -3739,40 +3148,40 @@ } } }, - "node_modules/@radix-ui/react-select/node_modules/react-remove-scroll/node_modules/use-callback-ref": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.3.2.tgz", - "integrity": "sha512-elOQwe6Q8gqZgDA8mrh44qRTQqpIHDcZ3hXTLjBe1i4ph8XpNJnO+aQf3NaG+lriLopI4HMx9VjQLfPQ6vhnoA==", + "node_modules/@radix-ui/react-separator": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.0.tgz", + "integrity": "sha512-3uBAs+egzvJBDZAzvb/n4NxxOYpnspmWxO2u5NbZ8Y6FM/NdrGSF9bop3Cf6F6C71z1rTSn8KV0Fo2ZVd79lGA==", + "license": "MIT", "dependencies": { - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" + "@radix-ui/react-primitive": "2.0.0" }, "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { "optional": true + }, + "@types/react-dom": { + "optional": true } } }, - "node_modules/@radix-ui/react-select/node_modules/react-remove-scroll/node_modules/use-sidecar": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.2.tgz", - "integrity": "sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==", + "node_modules/@radix-ui/react-slot": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.0.tgz", + "integrity": "sha512-FUCf5XMfmW4dtYl69pdS4DbxKy8nj4M7SafBgPllysxmdachynNflAdp/gCsnYWNDnge6tI9onzMp5ARYc1KNw==", + "license": "MIT", "dependencies": { - "detect-node-es": "^1.1.0", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" + "@radix-ui/react-compose-refs": "1.1.0" }, "peerDependencies": { - "@types/react": "^16.9.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "peerDependenciesMeta": { "@types/react": { @@ -3780,13 +3189,19 @@ } } }, - "node_modules/@radix-ui/react-separator": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.0.tgz", - "integrity": "sha512-3uBAs+egzvJBDZAzvb/n4NxxOYpnspmWxO2u5NbZ8Y6FM/NdrGSF9bop3Cf6F6C71z1rTSn8KV0Fo2ZVd79lGA==", + "node_modules/@radix-ui/react-switch": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.1.1.tgz", + "integrity": "sha512-diPqDDoBcZPSicYoMWdWx+bCPuTRH4QSp9J+65IvtdS0Kuzt67bI6n32vCj8q6NZmYW/ah+2orOtMwcX5eQwIg==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.0.0" + "@radix-ui/primitive": "1.1.0", + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-context": "1.1.1", + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-use-controllable-state": "1.1.0", + "@radix-ui/react-use-previous": "1.1.0", + "@radix-ui/react-use-size": "1.1.0" }, "peerDependencies": { "@types/react": "*", @@ -3803,14 +3218,11 @@ } } }, - "node_modules/@radix-ui/react-slot": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.0.tgz", - "integrity": "sha512-FUCf5XMfmW4dtYl69pdS4DbxKy8nj4M7SafBgPllysxmdachynNflAdp/gCsnYWNDnge6tI9onzMp5ARYc1KNw==", + "node_modules/@radix-ui/react-switch/node_modules/@radix-ui/react-context": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", + "integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==", "license": "MIT", - "dependencies": { - "@radix-ui/react-compose-refs": "1.1.0" - }, "peerDependencies": { "@types/react": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" @@ -3851,12 +3263,6 @@ } } }, - "node_modules/@radix-ui/react-tabs/node_modules/@radix-ui/primitive": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.0.tgz", - "integrity": "sha512-4Z8dn6Upk0qk4P74xBhZ6Hd/w0mPEzOOLxy4xiPXOXqjF7jZS0VAKk7/x/H6FyY2zCkYJqePf1G5KmkmNJ4RBA==", - "license": "MIT" - }, "node_modules/@radix-ui/react-tabs/node_modules/@radix-ui/react-context": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", @@ -3906,12 +3312,6 @@ } } }, - "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/primitive": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.0.tgz", - "integrity": "sha512-4Z8dn6Upk0qk4P74xBhZ6Hd/w0mPEzOOLxy4xiPXOXqjF7jZS0VAKk7/x/H6FyY2zCkYJqePf1G5KmkmNJ4RBA==", - "license": "MIT" - }, "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-context": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", @@ -7201,7 +6601,8 @@ "node_modules/detect-node-es": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/detect-node-es/-/detect-node-es-1.1.0.tgz", - "integrity": "sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==" + "integrity": "sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==", + "license": "MIT" }, "node_modules/devlop": { "version": "1.1.0", @@ -8184,6 +7585,12 @@ "reusify": "^1.0.4" } }, + "node_modules/favicon-fetch": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/favicon-fetch/-/favicon-fetch-1.0.0.tgz", + "integrity": "sha512-qEbMwsKBebUGo/JpTyeE5aBus5nTsIcYV7qRd5hxGWA3wOrp67aKXBrH3O23tYkNjnOThTyw9TaUrtWwOe3Y1w==", + "license": "MIT" + }, "node_modules/fflate": { "version": "0.4.8", "resolved": "https://registry.npmjs.org/fflate/-/fflate-0.4.8.tgz", @@ -8467,6 +7874,7 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/get-nonce/-/get-nonce-1.0.1.tgz", "integrity": "sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q==", + "license": "MIT", "engines": { "node": ">=6" } @@ -9152,6 +8560,7 @@ "version": "2.2.4", "resolved": "https://registry.npmjs.org/invariant/-/invariant-2.2.4.tgz", "integrity": "sha512-phJfQVBuaJM5raOpJjSfkiD6BpbCE4Ns//LaXl6wGYtUBY83nWS6Rf9tXm2e8VaK60JEjYldbPif/A2B1C2gNA==", + "license": "MIT", "dependencies": { "loose-envify": "^1.0.0" } @@ -14599,6 +14008,53 @@ "react": ">=18" } }, + "node_modules/react-remove-scroll": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.6.0.tgz", + "integrity": "sha512-I2U4JVEsQenxDAKaVa3VZ/JeJZe0/2DxPWL8Tj8yLKctQJQiZM52pn/GWFpSp8dftjM3pSAHVJZscAnC/y+ySQ==", + "license": "MIT", + "dependencies": { + "react-remove-scroll-bar": "^2.3.6", + "react-style-singleton": "^2.2.1", + "tslib": "^2.1.0", + "use-callback-ref": "^1.3.0", + "use-sidecar": "^1.1.2" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/react-remove-scroll-bar": { + "version": "2.3.6", + "resolved": "https://registry.npmjs.org/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.6.tgz", + "integrity": "sha512-DtSYaao4mBmX+HDo5YWYdBWQwYIQQshUV/dVxFxK+KM26Wjwp1gZ6rv6OC3oujI6Bfu6Xyg3TwK533AQutsn/g==", + "license": "MIT", + "dependencies": { + "react-style-singleton": "^2.2.1", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/react-select": { "version": "5.8.0", "resolved": "https://registry.npmjs.org/react-select/-/react-select-5.8.0.tgz", @@ -14678,6 +14134,29 @@ } } }, + "node_modules/react-style-singleton": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.1.tgz", + "integrity": "sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==", + "license": "MIT", + "dependencies": { + "get-nonce": "^1.0.0", + "invariant": "^2.2.4", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/read-cache": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", @@ -16358,6 +15837,49 @@ "dev": true, "license": "MIT" }, + "node_modules/use-callback-ref": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.3.2.tgz", + "integrity": "sha512-elOQwe6Q8gqZgDA8mrh44qRTQqpIHDcZ3hXTLjBe1i4ph8XpNJnO+aQf3NaG+lriLopI4HMx9VjQLfPQ6vhnoA==", + "license": "MIT", + "dependencies": { + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/use-sidecar": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.2.tgz", + "integrity": "sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==", + "license": "MIT", + "dependencies": { + "detect-node-es": "^1.1.0", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.9.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/util": { "version": "0.12.5", "resolved": "https://registry.npmjs.org/util/-/util-0.12.5.tgz", @@ -16396,6 +15918,19 @@ "uuid": "dist/bin/uuid" } }, + "node_modules/vaul": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/vaul/-/vaul-1.1.1.tgz", + "integrity": "sha512-+ejzF6ffQKPcfgS7uOrGn017g39F8SO4yLPXbBhpC7a0H+oPqPna8f1BUfXaz8eU4+pxbQcmjxW+jWBSbxjaFg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-dialog": "^1.1.1" + }, + "peerDependencies": { + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0" + } + }, "node_modules/vfile": { "version": "6.0.1", "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.1.tgz", diff --git a/web/package.json b/web/package.json index c36da3633ef..b8a933a9e00 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "qa", - "version": "0.2.0-dev", + "version": "1.0.0-dev", "version-comment": "version field must be SemVer or chromatic will barf", "private": true, "scripts": { @@ -17,11 +17,13 @@ "@headlessui/react": "^2.2.0", "@headlessui/tailwindcss": "^0.2.1", "@phosphor-icons/react": "^2.0.8", - "@radix-ui/react-dialog": "^1.0.5", + "@radix-ui/react-checkbox": "^1.1.2", + "@radix-ui/react-dialog": "^1.1.2", "@radix-ui/react-popover": "^1.1.2", "@radix-ui/react-select": "^2.1.2", "@radix-ui/react-separator": "^1.1.0", "@radix-ui/react-slot": "^1.1.0", + "@radix-ui/react-switch": "^1.1.1", "@radix-ui/react-tabs": "^1.1.1", "@radix-ui/react-tooltip": "^1.1.3", "@sentry/nextjs": "^8.34.0", @@ -37,6 +39,7 @@ "class-variance-authority": "^0.7.0", "clsx": "^2.1.1", "date-fns": "^3.6.0", + "favicon-fetch": "^1.0.0", "formik": "^2.2.9", "js-cookie": "^3.0.5", "lodash": "^4.17.21", @@ -67,6 +70,7 @@ "tailwindcss-animate": "^1.0.7", "typescript": "5.0.3", "uuid": "^9.0.1", + "vaul": "^1.1.1", "yup": "^1.4.0" }, "devDependencies": { diff --git a/web/playwright.config.ts b/web/playwright.config.ts index 76101a65e1e..ddfb1d476a1 100644 --- a/web/playwright.config.ts +++ b/web/playwright.config.ts @@ -1,6 +1,7 @@ import { defineConfig, devices } from "@playwright/test"; export default defineConfig({ + workers: 1, // temporary change to see if single threaded testing stabilizes the tests testDir: "./tests/e2e", // Folder for test files reporter: "list", // Configure paths for screenshots diff --git a/web/public/Egnyte.png b/web/public/Egnyte.png new file mode 100644 index 00000000000..54eef07dc28 Binary files /dev/null and b/web/public/Egnyte.png differ diff --git a/web/public/Wikipedia.png b/web/public/Wikipedia.png new file mode 100644 index 00000000000..30d9a3bbae0 Binary files /dev/null and b/web/public/Wikipedia.png differ diff --git a/web/public/Wikipedia.svg b/web/public/Wikipedia.svg deleted file mode 100644 index ee4a3caa55f..00000000000 --- a/web/public/Wikipedia.svg +++ /dev/null @@ -1,535 +0,0 @@ - - - Wikipedia logo version 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/web/src/app/admin/api-key/DanswerApiKeyForm.tsx b/web/src/app/admin/api-key/DanswerApiKeyForm.tsx index 80bb84d626f..27d6457d141 100644 --- a/web/src/app/admin/api-key/DanswerApiKeyForm.tsx +++ b/web/src/app/admin/api-key/DanswerApiKeyForm.tsx @@ -82,7 +82,7 @@ export const DanswerApiKeyForm = ({ }} > {({ isSubmitting, values, setFieldValue }) => ( -

+ Choose a memorable name for your API key. This is optional and can be added or changed later! diff --git a/web/src/app/admin/api-key/page.tsx b/web/src/app/admin/api-key/page.tsx index c0d16bc850a..988aa3f60f9 100644 --- a/web/src/app/admin/api-key/page.tsx +++ b/web/src/app/admin/api-key/page.tsx @@ -45,9 +45,6 @@ function NewApiKeyModal({
New API Key -
- -
diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 650c4d199a8..38b81f6b11e 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -24,13 +24,6 @@ import { TextFormField, } from "@/components/admin/connectors/Field"; -import { - Card, - CardHeader, - CardTitle, - CardContent, - CardFooter, -} from "@/components/ui/card"; import { usePopup } from "@/components/admin/connectors/Popup"; import { getDisplayNameForModel, useCategories } from "@/lib/hooks"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; @@ -405,7 +398,7 @@ export function AssistantEditor({ message: `"${assistant.name}" has been added to your list.`, type: "success", }); - router.refresh(); + await refreshAssistants(); } else { setPopup({ message: `"${assistant.name}" could not be added to your list.`, diff --git a/web/src/app/admin/assistants/PersonaTable.tsx b/web/src/app/admin/assistants/PersonaTable.tsx index c5dcfc2690e..a28e210c058 100644 --- a/web/src/app/admin/assistants/PersonaTable.tsx +++ b/web/src/app/admin/assistants/PersonaTable.tsx @@ -41,7 +41,7 @@ function PersonaTypeDisplay({ persona }: { persona: Persona }) { export function PersonasTable() { const router = useRouter(); const { popup, setPopup } = usePopup(); - const { refreshUser, isLoadingUser, isAdmin } = useUser(); + const { refreshUser, isAdmin } = useUser(); const { allAssistants: assistants, refreshAssistants, @@ -90,7 +90,7 @@ export function PersonasTable() { message: `Failed to update persona order - ${await response.text()}`, }); setFinalPersonas(assistants); - router.refresh(); + await refreshAssistants(); return; } @@ -98,10 +98,6 @@ export function PersonasTable() { await refreshUser(); }; - if (isLoadingUser) { - return <>; - } - return (
{popup} @@ -151,7 +147,7 @@ export function PersonasTable() { persona.is_visible ); if (response.ok) { - router.refresh(); + await refreshAssistants(); } else { setPopup({ type: "error", @@ -183,7 +179,7 @@ export function PersonasTable() { onClick={async () => { const response = await deletePersona(persona.id); if (response.ok) { - router.refresh(); + await refreshAssistants(); } else { alert( `Failed to delete persona - ${await response.text()}` diff --git a/web/src/app/admin/bots/SlackBotCreationForm.tsx b/web/src/app/admin/bots/SlackBotCreationForm.tsx index 8a113af05f4..c4918bc6957 100644 --- a/web/src/app/admin/bots/SlackBotCreationForm.tsx +++ b/web/src/app/admin/bots/SlackBotCreationForm.tsx @@ -7,6 +7,7 @@ import { useState } from "react"; import { SlackTokensForm } from "./SlackTokensForm"; import { SourceIcon } from "@/components/SourceIcon"; import { AdminPageTitle } from "@/components/admin/Title"; +import { ValidSources } from "@/lib/types"; export const NewSlackBotForm = ({}: {}) => { const [formValues] = useState({ @@ -21,7 +22,7 @@ export const NewSlackBotForm = ({}: {}) => { return (
} + icon={} title="New Slack Bot" /> diff --git a/web/src/app/admin/bots/SlackBotTable.tsx b/web/src/app/admin/bots/SlackBotTable.tsx index 2329ba44dac..332459e3187 100644 --- a/web/src/app/admin/bots/SlackBotTable.tsx +++ b/web/src/app/admin/bots/SlackBotTable.tsx @@ -96,6 +96,16 @@ export function SlackBotTable({ slackBots }: { slackBots: SlackBot[] }) { ); })} + {slackBots.length === 0 && ( + + + Please add a New Slack Bot to begin chatting with Danswer! + + + )} {slackBots.length > NUM_IN_PAGE && ( diff --git a/web/src/app/admin/bots/SlackBotUpdateForm.tsx b/web/src/app/admin/bots/SlackBotUpdateForm.tsx index 9eec40eeb95..cf90a124c2f 100644 --- a/web/src/app/admin/bots/SlackBotUpdateForm.tsx +++ b/web/src/app/admin/bots/SlackBotUpdateForm.tsx @@ -1,7 +1,7 @@ "use client"; import { usePopup } from "@/components/admin/connectors/Popup"; -import { SlackBot } from "@/lib/types"; +import { SlackBot, ValidSources } from "@/lib/types"; import { useRouter } from "next/navigation"; import { ChevronDown, ChevronRight } from "lucide-react"; import { useState, useEffect, useRef } from "react"; @@ -78,7 +78,7 @@ export const ExistingSlackBotForm = ({
- +
{ return ( - + { + window.location.href = `/admin/bots/${slackBotId}/channels/${slackChannelConfig.id}`; + }} + >
- +
- +
{"#" + slackChannelConfig.channel_config.channel_name}
- + e.stopPropagation()}> {slackChannelConfig.persona && !isPersonaASlackBotPersona(slackChannelConfig.persona) ? ( - + e.stopPropagation()}>
{ + onClick={async (e) => { + e.stopPropagation(); const response = await deleteSlackChannelConfig( slackChannelConfig.id ); diff --git a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx index 9a8caad2ad5..6e22c7b5ea1 100644 --- a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx +++ b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx @@ -81,6 +81,11 @@ export const SlackChannelConfigCreationForm = ({ respond_to_bots: existingSlackChannelConfig?.channel_config?.respond_to_bots || false, + show_continue_in_web_ui: + // If we're updating, we want to keep the existing value + // Otherwise, we want to default to true + existingSlackChannelConfig?.channel_config + ?.show_continue_in_web_ui ?? !isUpdate, enable_auto_filters: existingSlackChannelConfig?.enable_auto_filters || false, respond_member_group_list: @@ -119,6 +124,7 @@ export const SlackChannelConfigCreationForm = ({ questionmark_prefilter_enabled: Yup.boolean().required(), respond_tag_only: Yup.boolean().required(), respond_to_bots: Yup.boolean().required(), + show_continue_in_web_ui: Yup.boolean().required(), enable_auto_filters: Yup.boolean().required(), respond_member_group_list: Yup.array().of(Yup.string()).required(), still_need_help_enabled: Yup.boolean().required(), @@ -282,6 +288,12 @@ export const SlackChannelConfigCreationForm = ({ />
+
} + icon={} title="Edit Slack Channel Config" /> diff --git a/web/src/app/admin/bots/[bot-id]/channels/new/page.tsx b/web/src/app/admin/bots/[bot-id]/channels/new/page.tsx index 85069ccb47f..970f8e15ae8 100644 --- a/web/src/app/admin/bots/[bot-id]/channels/new/page.tsx +++ b/web/src/app/admin/bots/[bot-id]/channels/new/page.tsx @@ -2,7 +2,7 @@ import { AdminPageTitle } from "@/components/admin/Title"; import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm"; import { fetchSS } from "@/lib/utilsSS"; import { ErrorCallout } from "@/components/ErrorCallout"; -import { DocumentSet } from "@/lib/types"; +import { DocumentSet, ValidSources } from "@/lib/types"; import { BackButton } from "@/components/BackButton"; import { fetchAssistantsSS } from "@/lib/assistants/fetchAssistantsSS"; import { @@ -59,7 +59,7 @@ async function NewChannelConfigPage(props: {
} + icon={} title="Configure DanswerBot for Slack Channel" /> diff --git a/web/src/app/admin/bots/[bot-id]/lib.ts b/web/src/app/admin/bots/[bot-id]/lib.ts index f131bfd4f25..1e6bbfe056f 100644 --- a/web/src/app/admin/bots/[bot-id]/lib.ts +++ b/web/src/app/admin/bots/[bot-id]/lib.ts @@ -15,6 +15,7 @@ interface SlackChannelConfigCreationRequest { questionmark_prefilter_enabled: boolean; respond_tag_only: boolean; respond_to_bots: boolean; + show_continue_in_web_ui: boolean; respond_member_group_list: string[]; follow_up_tags?: string[]; usePersona: boolean; @@ -43,6 +44,7 @@ const buildRequestBodyFromCreationRequest = ( channel_name: creationRequest.channel_name, respond_tag_only: creationRequest.respond_tag_only, respond_to_bots: creationRequest.respond_to_bots, + show_continue_in_web_ui: creationRequest.show_continue_in_web_ui, enable_auto_filters: creationRequest.enable_auto_filters, respond_member_group_list: creationRequest.respond_member_group_list, answer_filters: buildFiltersFromCreationRequest(creationRequest), diff --git a/web/src/app/admin/bots/[bot-id]/page.tsx b/web/src/app/admin/bots/[bot-id]/page.tsx index 414e90f2932..f99e877137a 100644 --- a/web/src/app/admin/bots/[bot-id]/page.tsx +++ b/web/src/app/admin/bots/[bot-id]/page.tsx @@ -22,7 +22,6 @@ function SlackBotEditPage({ const unwrappedParams = use(params); const { popup, setPopup } = usePopup(); - console.log("unwrappedParams", unwrappedParams); const { data: slackBot, isLoading: isSlackBotLoading, diff --git a/web/src/app/admin/bots/page.tsx b/web/src/app/admin/bots/page.tsx index 6fc34a3f11b..03e03668f85 100644 --- a/web/src/app/admin/bots/page.tsx +++ b/web/src/app/admin/bots/page.tsx @@ -10,6 +10,7 @@ import Link from "next/link"; import { SourceIcon } from "@/components/SourceIcon"; import { SlackBotTable } from "./SlackBotTable"; import { useSlackBots } from "./[bot-id]/hooks"; +import { ValidSources } from "@/lib/types"; const Main = () => { const { @@ -103,7 +104,7 @@ const Page = () => { return (
} + icon={} title="Slack Bots" /> diff --git a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx index 9011b2cdfbb..e70c5b7e270 100644 --- a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx @@ -275,8 +275,9 @@ export function CustomLLMProviderUpdateForm({ <>
- Additional configurations needed by the model provider. Are - passed to litellm via environment variables. + Additional configurations needed by the model provider. These + are passed to litellm via environment + as arguments into the + `completion` call.
@@ -290,14 +291,14 @@ export function CustomLLMProviderUpdateForm({ ) => ( -
+
{formikProps.values.custom_config_list.map((_, index) => { return (
-
+
@@ -457,6 +458,7 @@ export function CustomLLMProviderUpdateForm({ + )} +
+
+ + ); +} + export function AdvancedConfigDisplay({ pruneFreq, refreshFreq, indexingStart, + onRefreshEdit, + onPruningEdit, }: { pruneFreq: number | null; refreshFreq: number | null; indexingStart: Date | null; + onRefreshEdit: () => void; + onPruningEdit: () => void; }) { const formatRefreshFrequency = (seconds: number | null): string => { if (seconds === null) return "-"; @@ -75,14 +149,21 @@ export function AdvancedConfigDisplay({ <> Advanced Configuration -
    +
      {pruneFreq && (
    • Pruning Frequency - {formatPruneFrequency(pruneFreq)} + + {formatPruneFrequency(pruneFreq)} + + + +
    • )} {refreshFreq && ( @@ -91,7 +172,14 @@ export function AdvancedConfigDisplay({ className="w-full flex justify-between items-center py-2" > Refresh Frequency - {formatRefreshFrequency(refreshFreq)} + + {formatRefreshFrequency(refreshFreq)} + + + + )} {indexingStart && ( @@ -127,15 +215,9 @@ export function ConfigDisplay({ <> Configuration -
        +
          {configEntries.map(([key, value]) => ( -
        • - {key} - {convertObjectToString(value) || "-"} -
        • + ))}
        diff --git a/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx b/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx index 71d26a8eb47..b5b4e7ecbf2 100644 --- a/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx @@ -6,6 +6,8 @@ import { usePopup } from "@/components/admin/connectors/Popup"; import { mutate } from "swr"; import { buildCCPairInfoUrl } from "./lib"; import { setCCPairStatus } from "@/lib/ccPair"; +import { useState } from "react"; +import { LoadingAnimation } from "@/components/Loading"; export function ModifyStatusButtonCluster({ ccPair, @@ -13,44 +15,72 @@ export function ModifyStatusButtonCluster({ ccPair: CCPairFullInfo; }) { const { popup, setPopup } = usePopup(); + const [isUpdating, setIsUpdating] = useState(false); + + const handleStatusChange = async ( + newStatus: ConnectorCredentialPairStatus + ) => { + if (isUpdating) return; // Prevent double-clicks or multiple requests + setIsUpdating(true); + + try { + // Call the backend to update the status + await setCCPairStatus(ccPair.id, newStatus, setPopup); + + // Use mutate to revalidate the status on the backend + await mutate(buildCCPairInfoUrl(ccPair.id)); + } catch (error) { + console.error("Failed to update status", error); + } finally { + // Reset local updating state and button text after mutation + setIsUpdating(false); + } + }; + + // Compute the button text based on current state and backend status + const buttonText = + ccPair.status === ConnectorCredentialPairStatus.PAUSED + ? "Re-Enable" + : "Pause"; + + const tooltip = + ccPair.status === ConnectorCredentialPairStatus.PAUSED + ? "Click to start indexing again!" + : "When paused, the connector's documents will still be visible. However, no new documents will be indexed."; return ( <> {popup} - {ccPair.status === ConnectorCredentialPairStatus.PAUSED ? ( - - ) : ( - - )} + ); } diff --git a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx index af0e2a8f4aa..962339e9fe8 100644 --- a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx @@ -121,7 +121,7 @@ export function ReIndexButton({ {popup}
-
+
)} diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 8e7bac228c7..31566f03437 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -9,9 +9,9 @@ import { AdminPageTitle } from "@/components/admin/Title"; import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib"; import { usePopup } from "@/components/admin/connectors/Popup"; import { useFormContext } from "@/components/context/FormContext"; -import { getSourceDisplayName } from "@/lib/sources"; +import { getSourceDisplayName, getSourceMetadata } from "@/lib/sources"; import { SourceIcon } from "@/components/SourceIcon"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import { deleteCredential, linkCredential } from "@/lib/credential"; import { submitFiles } from "./pages/utils/files"; import { submitGoogleSite } from "./pages/utils/google_site"; @@ -19,7 +19,11 @@ import AdvancedFormPage from "./pages/Advanced"; import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm"; import CreateCredential from "@/components/credentials/actions/CreateCredential"; import ModifyCredential from "@/components/credentials/actions/ModifyCredential"; -import { ConfigurableSources, ValidSources } from "@/lib/types"; +import { + ConfigurableSources, + oauthSupportedSources, + ValidSources, +} from "@/lib/types"; import { Credential, credentialTemplates } from "@/lib/connectors/credentials"; import { ConnectionConfiguration, @@ -43,6 +47,10 @@ import { Formik } from "formik"; import NavigationRow from "./NavigationRow"; import { useRouter } from "next/navigation"; import CardSection from "@/components/admin/CardSection"; +import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils"; +import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants"; +import TemporaryLoadingModal from "@/components/TemporaryLoadingModal"; +import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth"; export interface AdvancedConfig { refreshFreq: number; pruneFreq: number; @@ -110,6 +118,23 @@ export default function AddConnector({ }: { connector: ConfigurableSources; }) { + const [currentPageUrl, setCurrentPageUrl] = useState(null); + const [oauthUrl, setOauthUrl] = useState(null); + const [isAuthorizing, setIsAuthorizing] = useState(false); + const [isAuthorizeVisible, setIsAuthorizeVisible] = useState(false); + useEffect(() => { + if (typeof window !== "undefined") { + setCurrentPageUrl(window.location.href); + } + + if (EE_ENABLED && NEXT_PUBLIC_CLOUD_ENABLED) { + const sourceMetadata = getSourceMetadata(connector); + if (sourceMetadata?.oauthSupported == true) { + setIsAuthorizeVisible(true); + } + } + }, []); + const router = useRouter(); // State for managing credentials and files @@ -135,9 +160,9 @@ export default function AddConnector({ const configuration: ConnectionConfiguration = connectorConfigs[connector]; // Form context and popup management - const { setFormStep, setAlowCreate, formStep, nextFormStep, prevFormStep } = - useFormContext(); + const { setFormStep, setAllowCreate, formStep } = useFormContext(); const { popup, setPopup } = usePopup(); + const [uploading, setUploading] = useState(false); // Hooks for Google Drive and Gmail credentials const { liveGDriveCredential } = useGoogleDriveCredentials(connector); @@ -192,7 +217,7 @@ export default function AddConnector({ const onSwap = async (selectedCredential: Credential) => { setCurrentCredential(selectedCredential); - setAlowCreate(true); + setAllowCreate(true); setPopup({ message: "Swapped credential successfully!", type: "success", @@ -204,6 +229,37 @@ export default function AddConnector({ router.push("/admin/indexing/status?message=connector-created"); }; + const handleAuthorize = async () => { + // authorize button handler + // gets an auth url from the server and directs the user to it in a popup + + if (!currentPageUrl) return; + + setIsAuthorizing(true); + try { + const response = await prepareOAuthAuthorizationRequest( + connector, + currentPageUrl + ); + if (response.url) { + setOauthUrl(response.url); + window.open(response.url, "_blank", "noopener,noreferrer"); + } else { + setPopup({ message: "Failed to fetch OAuth URL", type: "error" }); + } + } catch (error: unknown) { + // Narrow the type of error + if (error instanceof Error) { + setPopup({ message: `Error: ${error.message}`, type: "error" }); + } else { + // Handle non-standard errors + setPopup({ message: "An unknown error occurred", type: "error" }); + } + } finally { + setIsAuthorizing(false); + } + }; + return ( {popup} -
- -
+ {uploading && ( + + )} {!createConnectorToggle && ( - +
+ {/* Button to pop up a form to manually enter credentials */} + + {/* Button to sign in via OAuth */} + {oauthSupportedSources.includes(connector) && + NEXT_PUBLIC_CLOUD_ENABLED && ( + + )} +
)} {/* NOTE: connector will never be google_drive, since the ternary above will diff --git a/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx b/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx new file mode 100644 index 00000000000..30fd9e07478 --- /dev/null +++ b/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx @@ -0,0 +1,111 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { usePathname, useRouter, useSearchParams } from "next/navigation"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { Button } from "@/components/ui/button"; +import Title from "@/components/ui/title"; +import { KeyIcon } from "@/components/icons/icons"; +import { getSourceMetadata, isValidSource } from "@/lib/sources"; +import { ValidSources } from "@/lib/types"; +import CardSection from "@/components/admin/CardSection"; +import { handleOAuthAuthorizationResponse } from "@/lib/oauth_utils"; + +export default function OAuthCallbackPage() { + const router = useRouter(); + const searchParams = useSearchParams(); + + const [statusMessage, setStatusMessage] = useState("Processing..."); + const [statusDetails, setStatusDetails] = useState( + "Please wait while we complete the setup." + ); + const [redirectUrl, setRedirectUrl] = useState(null); + const [isError, setIsError] = useState(false); + const [pageTitle, setPageTitle] = useState( + "Authorize with Third-Party service" + ); + + // Extract query parameters + const code = searchParams.get("code"); + const state = searchParams.get("state"); + + const pathname = usePathname(); + const connector = pathname?.split("/")[3]; + + useEffect(() => { + const handleOAuthCallback = async () => { + if (!code || !state) { + setStatusMessage("Improperly formed OAuth authorization request."); + setStatusDetails( + !code ? "Missing authorization code." : "Missing state parameter." + ); + setIsError(true); + return; + } + + if (!connector || !isValidSource(connector)) { + setStatusMessage( + `The specified connector source type ${connector} does not exist.` + ); + setStatusDetails(`${connector} is not a valid source type.`); + setIsError(true); + return; + } + + const sourceMetadata = getSourceMetadata(connector as ValidSources); + setPageTitle(`Authorize with ${sourceMetadata.displayName}`); + + setStatusMessage("Processing..."); + setStatusDetails("Please wait while we complete authorization."); + setIsError(false); // Ensure no error state during loading + + try { + const response = await handleOAuthAuthorizationResponse(code, state); + + if (!response) { + throw new Error("Empty response from OAuth server."); + } + + setStatusMessage("Success!"); + setStatusDetails( + `Your authorization with ${sourceMetadata.displayName} completed successfully.` + ); + setRedirectUrl(response.redirect_on_success); // Extract the redirect URL + setIsError(false); + } catch (error) { + console.error("OAuth error:", error); + setStatusMessage("Oops, something went wrong!"); + setStatusDetails( + "An error occurred during the OAuth process. Please try again." + ); + setIsError(true); + } + }; + + handleOAuthCallback(); + }, [code, state, connector]); + + return ( +
+ } /> + +
+ +

{statusMessage}

+

{statusDetails}

+ {redirectUrl && !isError && ( +
+

+ Click{" "} + + here + {" "} + to continue. +

+
+ )} +
+
+
+ ); +} diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx index 9ff60ebc223..90e04f1bd63 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx @@ -22,7 +22,7 @@ import { GoogleDriveConfig } from "@/lib/connectors/connectors"; import { useUser } from "@/components/user/UserProvider"; const GDriveMain = ({}: {}) => { - const { isLoadingUser, isAdmin, user } = useUser(); + const { isAdmin, user } = useUser(); const { data: appCredentialData, @@ -63,10 +63,6 @@ const GDriveMain = ({}: {}) => { serviceAccountKeyData || (isServiceAccountKeyError && isServiceAccountKeyError.status === 404); - if (isLoadingUser) { - return <>; - } - if ( (!appCredentialSuccessfullyFetched && isAppCredentialLoading) || (!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) || @@ -108,7 +104,9 @@ const GDriveMain = ({}: {}) => { const googleDriveServiceAccountCredential: | Credential | undefined = credentialsData.find( - (credential) => credential.credential_json?.google_service_account_key + (credential) => + credential.credential_json?.google_service_account_key && + credential.source === "google_drive" ); const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus< diff --git a/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx b/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx index b3120fbefdf..9db5892ebb0 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx @@ -20,7 +20,7 @@ import { GmailConfig } from "@/lib/connectors/connectors"; import { useUser } from "@/components/user/UserProvider"; export const GmailMain = () => { - const { isLoadingUser, isAdmin, user } = useUser(); + const { isAdmin, user } = useUser(); const { data: appCredentialData, @@ -60,10 +60,6 @@ export const GmailMain = () => { serviceAccountKeyData || (isServiceAccountKeyError && isServiceAccountKeyError.status === 404); - if (isLoadingUser) { - return <>; - } - if ( (!appCredentialSuccessfullyFetched && isAppCredentialLoading) || (!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) || diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts index c294de219f2..feb2810178f 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts @@ -2,7 +2,7 @@ import { PopupSpec } from "@/components/admin/connectors/Popup"; import { createConnector, runConnector } from "@/lib/connector"; import { createCredential, linkCredential } from "@/lib/credential"; import { FileConfig } from "@/lib/connectors/connectors"; -import { AccessType } from "@/lib/types"; +import { AccessType, ValidSources } from "@/lib/types"; export const submitFiles = async ( selectedFiles: File[], @@ -34,7 +34,7 @@ export const submitFiles = async ( const [connectorErrorMsg, connector] = await createConnector({ name: "FileConnector-" + Date.now(), - source: "file", + source: ValidSources.File, input_type: "load_state", connector_specific_config: { file_locations: filePaths, @@ -60,7 +60,7 @@ export const submitFiles = async ( const createCredentialResponse = await createCredential({ credential_json: {}, admin_public: true, - source: "file", + source: ValidSources.File, curator_public: true, groups: groups, name, diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts b/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts index e297c4e394a..bb448a09836 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts @@ -2,6 +2,7 @@ import { PopupSpec } from "@/components/admin/connectors/Popup"; import { createConnector, runConnector } from "@/lib/connector"; import { linkCredential } from "@/lib/credential"; import { GoogleSitesConfig } from "@/lib/connectors/connectors"; +import { ValidSources } from "@/lib/types"; export const submitGoogleSite = async ( selectedFiles: File[], @@ -38,7 +39,7 @@ export const submitGoogleSite = async ( const [connectorErrorMsg, connector] = await createConnector({ name: name ? name : `GoogleSitesConnector-${base_url}`, - source: "google_sites", + source: ValidSources.GoogleSites, input_type: "load_state", connector_specific_config: { base_url: base_url, diff --git a/web/src/app/admin/documents/feedback/DocumentFeedbackTable.tsx b/web/src/app/admin/documents/feedback/DocumentFeedbackTable.tsx index 7e02beef689..14cae632b3d 100644 --- a/web/src/app/admin/documents/feedback/DocumentFeedbackTable.tsx +++ b/web/src/app/admin/documents/feedback/DocumentFeedbackTable.tsx @@ -135,7 +135,7 @@ export const DocumentFeedbackTable = ({ /> -
+
{!documentSet.is_up_to_date && ( - +
Cannot update while syncing! Wait for the sync to finish, then diff --git a/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx index 8b4f2955789..4995d9da933 100644 --- a/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx @@ -3,7 +3,7 @@ import { usePopup } from "@/components/admin/connectors/Popup"; import { HealthCheckBanner } from "@/components/health/healthcheck"; import { EmbeddingModelSelection } from "../EmbeddingModelSelectionForm"; -import { useEffect, useMemo, useState } from "react"; +import { useCallback, useEffect, useMemo, useState } from "react"; import Text from "@/components/ui/text"; import { Button } from "@/components/ui/button"; import { ArrowLeft, ArrowRight, WarningCircle } from "@phosphor-icons/react"; @@ -28,7 +28,8 @@ import { Modal } from "@/components/Modal"; import { useRouter } from "next/navigation"; import CardSection from "@/components/admin/CardSection"; -import { CardDescription } from "@/components/ui/card"; +import { combineSearchSettings } from "./utils"; + export default function EmbeddingForm() { const { formStep, nextFormStep, prevFormStep } = useEmbeddingFormContext(); const { popup, setPopup } = usePopup(); @@ -157,6 +158,26 @@ export default function EmbeddingForm() { searchSettings?.multipass_indexing != advancedEmbeddingDetails.multipass_indexing; + const updateSearch = useCallback(async () => { + if (!selectedProvider) { + return false; + } + const searchSettings = combineSearchSettings( + selectedProvider, + advancedEmbeddingDetails, + rerankingDetails, + selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null + ); + + const response = await updateSearchSettings(searchSettings); + if (response.ok) { + return true; + } else { + setPopup({ message: "Failed to update search settings", type: "error" }); + return false; + } + }, [selectedProvider, advancedEmbeddingDetails, rerankingDetails, setPopup]); + const ReIndexingButton = useMemo(() => { const ReIndexingButtonComponent = ({ needsReIndex, @@ -205,7 +226,7 @@ export default function EmbeddingForm() { }; ReIndexingButtonComponent.displayName = "ReIndexingButton"; return ReIndexingButtonComponent; - }, [needsReIndex]); + }, [needsReIndex, updateSearch]); if (!selectedProvider) { return ; @@ -221,24 +242,6 @@ export default function EmbeddingForm() { })); }; - const updateSearch = async () => { - const values: SavedSearchSettings = { - ...rerankingDetails, - ...advancedEmbeddingDetails, - ...selectedProvider, - provider_type: - selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null, - }; - - const response = await updateSearchSettings(values); - if (response.ok) { - return true; - } else { - setPopup({ message: "Failed to update search settings", type: "error" }); - return false; - } - }; - const navigateToEmbeddingPage = (changedResource: string) => { router.push("/admin/configuration/search?message=search-settings"); }; @@ -247,39 +250,35 @@ export default function EmbeddingForm() { if (!selectedProvider) { return; } - let newModel: SavedSearchSettings; + let searchSettings: SavedSearchSettings; - // We use a spread operation to merge properties from multiple objects into a single object. - // Advanced embedding details may update default values. - // Do NOT modify the order unless you are positive the new hierarchy is correct. if (selectedProvider.provider_type != null) { // This is a cloud model - newModel = { - ...selectedProvider, - ...advancedEmbeddingDetails, - ...rerankingDetails, - provider_type: - (selectedProvider.provider_type - ?.toLowerCase() - .split(" ")[0] as EmbeddingProvider) || null, - }; + searchSettings = combineSearchSettings( + selectedProvider, + advancedEmbeddingDetails, + rerankingDetails, + selectedProvider.provider_type + ?.toLowerCase() + .split(" ")[0] as EmbeddingProvider | null + ); } else { // This is a locally hosted model - newModel = { - ...selectedProvider, - ...advancedEmbeddingDetails, - ...rerankingDetails, - provider_type: null, - }; + searchSettings = combineSearchSettings( + selectedProvider, + advancedEmbeddingDetails, + rerankingDetails, + null + ); } - newModel.index_name = null; + searchSettings.index_name = null; const response = await fetch( "/api/search-settings/set-new-search-settings", { method: "POST", - body: JSON.stringify(newModel), + body: JSON.stringify(searchSettings), headers: { "Content-Type": "application/json", }, diff --git a/web/src/app/admin/embeddings/pages/utils.ts b/web/src/app/admin/embeddings/pages/utils.ts index 3d3065b54eb..039b1424286 100644 --- a/web/src/app/admin/embeddings/pages/utils.ts +++ b/web/src/app/admin/embeddings/pages/utils.ts @@ -1,3 +1,16 @@ +import { + CloudEmbeddingProvider, + HostedEmbeddingModel, +} from "@/components/embedding/interfaces"; + +import { + AdvancedSearchConfiguration, + SavedSearchSettings, +} from "../interfaces"; + +import { EmbeddingProvider } from "@/components/embedding/interfaces"; +import { RerankingDetails } from "../interfaces"; + export const deleteSearchSettings = async (search_settings_id: number) => { const response = await fetch(`/api/search-settings/delete-search-settings`, { method: "DELETE", @@ -42,3 +55,20 @@ export const testEmbedding = async ({ return testResponse; }; + +// We use a spread operation to merge properties from multiple objects into a single object. +// Advanced embedding details may update default values. +// Do NOT modify the order unless you are positive the new hierarchy is correct. +export const combineSearchSettings = ( + selectedProvider: CloudEmbeddingProvider | HostedEmbeddingModel, + advancedEmbeddingDetails: AdvancedSearchConfiguration, + rerankingDetails: RerankingDetails, + provider_type: EmbeddingProvider | null +): SavedSearchSettings => { + return { + ...selectedProvider, + ...advancedEmbeddingDetails, + ...rerankingDetails, + provider_type: provider_type, + }; +}; diff --git a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx index 62ce28a870e..0d001e3084a 100644 --- a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx +++ b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx @@ -353,13 +353,9 @@ export function CCPairIndexingStatusTable({ ); }; const toggleSources = () => { - const currentToggledCount = - Object.values(connectorsToggled).filter(Boolean).length; - const shouldToggleOn = currentToggledCount < sortedSources.length / 2; - const connectors = sortedSources.reduce( (acc, source) => { - acc[source] = shouldToggleOn; + acc[source] = shouldExpand; return acc; }, {} as Record @@ -368,6 +364,7 @@ export function CCPairIndexingStatusTable({ setConnectorsToggled(connectors); Cookies.set(TOGGLED_CONNECTORS_COOKIE_NAME, JSON.stringify(connectors)); }; + const shouldExpand = Object.values(connectorsToggled).filter(Boolean).length < sortedSources.length; @@ -384,7 +381,7 @@ export function CCPairIndexingStatusTable({ last_status: "success", connector: { name: "Sample File Connector", - source: "file", + source: ValidSources.File, input_type: "poll", connector_specific_config: { file_locations: ["/path/to/sample/file.txt"], @@ -401,7 +398,7 @@ export function CCPairIndexingStatusTable({ credential: { id: 1, name: "Sample Credential", - source: "file", + source: ValidSources.File, user_id: "1", time_created: "2023-07-01T12:00:00Z", time_updated: "2023-07-01T12:00:00Z", diff --git a/web/src/app/admin/prompt-library/hooks.ts b/web/src/app/admin/prompt-library/hooks.ts deleted file mode 100644 index ccab6b34079..00000000000 --- a/web/src/app/admin/prompt-library/hooks.ts +++ /dev/null @@ -1,46 +0,0 @@ -import useSWR from "swr"; -import { InputPrompt } from "./interfaces"; - -const fetcher = (url: string) => fetch(url).then((res) => res.json()); - -export const useAdminInputPrompts = () => { - const { data, error, mutate } = useSWR( - `/api/admin/input_prompt`, - fetcher - ); - - return { - data, - error, - isLoading: !error && !data, - refreshInputPrompts: mutate, - }; -}; - -export const useInputPrompts = (includePublic: boolean = false) => { - const { data, error, mutate } = useSWR( - `/api/input_prompt${includePublic ? "?include_public=true" : ""}`, - fetcher - ); - - return { - data, - error, - isLoading: !error && !data, - refreshInputPrompts: mutate, - }; -}; - -export const useInputPrompt = (id: number) => { - const { data, error, mutate } = useSWR( - `/api/input_prompt/${id}`, - fetcher - ); - - return { - data, - error, - isLoading: !error && !data, - refreshInputPrompt: mutate, - }; -}; diff --git a/web/src/app/admin/prompt-library/interfaces.ts b/web/src/app/admin/prompt-library/interfaces.ts deleted file mode 100644 index 9143a0ea870..00000000000 --- a/web/src/app/admin/prompt-library/interfaces.ts +++ /dev/null @@ -1,31 +0,0 @@ -export interface InputPrompt { - id: number; - prompt: string; - content: string; - active: boolean; - is_public: string; -} - -export interface EditPromptModalProps { - onClose: () => void; - - promptId: number; - editInputPrompt: ( - promptId: number, - values: CreateInputPromptRequest - ) => Promise; -} -export interface CreateInputPromptRequest { - prompt: string; - content: string; -} - -export interface AddPromptModalProps { - onClose: () => void; - onSubmit: (promptData: CreateInputPromptRequest) => void; -} -export interface PromptData { - id: number; - prompt: string; - content: string; -} diff --git a/web/src/app/admin/prompt-library/modals/AddPromptModal.tsx b/web/src/app/admin/prompt-library/modals/AddPromptModal.tsx deleted file mode 100644 index 1d6ca466367..00000000000 --- a/web/src/app/admin/prompt-library/modals/AddPromptModal.tsx +++ /dev/null @@ -1,69 +0,0 @@ -import React from "react"; -import { Formik, Form } from "formik"; -import * as Yup from "yup"; -import { Button } from "@/components/ui/button"; - -import { BookstackIcon } from "@/components/icons/icons"; -import { AddPromptModalProps } from "../interfaces"; -import { TextFormField } from "@/components/admin/connectors/Field"; -import { Modal } from "@/components/Modal"; - -const AddPromptSchema = Yup.object().shape({ - title: Yup.string().required("Title is required"), - prompt: Yup.string().required("Prompt is required"), -}); - -const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => { - return ( - - { - onSubmit({ - prompt: values.title, - content: values.prompt, - }); - setSubmitting(false); - onClose(); - }} - > - {({ isSubmitting, setFieldValue }) => ( - -

- - Add prompt -

- - - - - - - - )} -
-
- ); -}; - -export default AddPromptModal; diff --git a/web/src/app/admin/prompt-library/modals/EditPromptModal.tsx b/web/src/app/admin/prompt-library/modals/EditPromptModal.tsx deleted file mode 100644 index 873692d851f..00000000000 --- a/web/src/app/admin/prompt-library/modals/EditPromptModal.tsx +++ /dev/null @@ -1,138 +0,0 @@ -import React from "react"; -import { Formik, Form, Field, ErrorMessage } from "formik"; -import * as Yup from "yup"; -import { Modal } from "@/components/Modal"; -import { Textarea } from "@/components/ui/textarea"; -import { Button } from "@/components/ui/button"; -import { useInputPrompt } from "../hooks"; -import { EditPromptModalProps } from "../interfaces"; - -const EditPromptSchema = Yup.object().shape({ - prompt: Yup.string().required("Title is required"), - content: Yup.string().required("Content is required"), - active: Yup.boolean(), -}); - -const EditPromptModal = ({ - onClose, - promptId, - editInputPrompt, -}: EditPromptModalProps) => { - const { - data: promptData, - error, - refreshInputPrompt, - } = useInputPrompt(promptId); - - if (error) - return ( - -

Failed to load prompt data

-
- ); - - if (!promptData) - return ( - -

Loading...

-
- ); - - return ( - - { - editInputPrompt(promptId, values); - refreshInputPrompt(); - }} - > - {({ isSubmitting, values }) => ( -
-

- - - - Edit prompt -

- -
-
- - - -
- -
- - - -
- -
- -
-
- -
- -
-
- )} -
-
- ); -}; - -export default EditPromptModal; diff --git a/web/src/app/admin/prompt-library/page.tsx b/web/src/app/admin/prompt-library/page.tsx deleted file mode 100644 index d7c72ff5fc3..00000000000 --- a/web/src/app/admin/prompt-library/page.tsx +++ /dev/null @@ -1,32 +0,0 @@ -"use client"; - -import { AdminPageTitle } from "@/components/admin/Title"; -import { ClosedBookIcon } from "@/components/icons/icons"; -import { useAdminInputPrompts } from "./hooks"; -import { PromptSection } from "./promptSection"; - -const Page = () => { - const { - data: promptLibrary, - error: promptLibraryError, - isLoading: promptLibraryIsLoading, - refreshInputPrompts: refreshPrompts, - } = useAdminInputPrompts(); - - return ( -
- } - title="Prompt Library" - /> - -
- ); -}; -export default Page; diff --git a/web/src/app/admin/prompt-library/promptLibrary.tsx b/web/src/app/admin/prompt-library/promptLibrary.tsx deleted file mode 100644 index c4f535e9b0d..00000000000 --- a/web/src/app/admin/prompt-library/promptLibrary.tsx +++ /dev/null @@ -1,249 +0,0 @@ -"use client"; - -import { EditIcon, TrashIcon } from "@/components/icons/icons"; -import { PopupSpec } from "@/components/admin/connectors/Popup"; -import { MagnifyingGlass } from "@phosphor-icons/react"; -import { useState } from "react"; -import { - Table, - TableHead, - TableRow, - TableBody, - TableCell, -} from "@/components/ui/table"; -import { FilterDropdown } from "@/components/search/filtering/FilterDropdown"; -import { FiTag } from "react-icons/fi"; -import { PageSelector } from "@/components/PageSelector"; -import { InputPrompt } from "./interfaces"; -import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal"; -import { TableHeader } from "@/components/ui/table"; - -const CategoryBubble = ({ - name, - onDelete, -}: { - name: string; - onDelete?: () => void; -}) => ( - - {name} - {onDelete && ( - - )} - -); - -const NUM_RESULTS_PER_PAGE = 10; - -export const PromptLibraryTable = ({ - promptLibrary, - refresh, - setPopup, - handleEdit, - isPublic, -}: { - promptLibrary: InputPrompt[]; - refresh: () => void; - setPopup: (popup: PopupSpec | null) => void; - handleEdit: (promptId: number) => void; - isPublic: boolean; -}) => { - const [query, setQuery] = useState(""); - const [currentPage, setCurrentPage] = useState(1); - const [selectedStatus, setSelectedStatus] = useState([]); - - const columns = [ - { name: "Prompt", key: "prompt" }, - { name: "Content", key: "content" }, - { name: "Status", key: "status" }, - { name: "", key: "edit" }, - { name: "", key: "delete" }, - ]; - - const filteredPromptLibrary = promptLibrary.filter((item) => { - const cleanedQuery = query.toLowerCase(); - const searchMatch = - item.prompt.toLowerCase().includes(cleanedQuery) || - item.content.toLowerCase().includes(cleanedQuery); - const statusMatch = - selectedStatus.length === 0 || - (selectedStatus.includes("Active") && item.active) || - (selectedStatus.includes("Inactive") && !item.active); - - return searchMatch && statusMatch; - }); - - const totalPages = Math.ceil( - filteredPromptLibrary.length / NUM_RESULTS_PER_PAGE - ); - const startIndex = (currentPage - 1) * NUM_RESULTS_PER_PAGE; - const endIndex = startIndex + NUM_RESULTS_PER_PAGE; - const paginatedPromptLibrary = filteredPromptLibrary.slice( - startIndex, - endIndex - ); - - const handlePageChange = (page: number) => { - setCurrentPage(page); - }; - - const handleDelete = async (id: number) => { - const response = await fetch( - `/api${isPublic ? "/admin" : ""}/input_prompt/${id}`, - { - method: "DELETE", - } - ); - if (!response.ok) { - setPopup({ message: "Failed to delete input prompt", type: "error" }); - } - refresh(); - setConfirmDeletionId(null); - }; - - const handleStatusSelect = (status: string) => { - setSelectedStatus((prev) => { - if (prev.includes(status)) { - return prev.filter((s) => s !== status); - } - return [...prev, status]; - }); - }; - - const [confirmDeletionId, setConfirmDeletionId] = useState( - null - ); - - return ( -
- {confirmDeletionId != null && ( - setConfirmDeletionId(null)} - onSubmit={() => handleDelete(confirmDeletionId)} - entityType="prompt" - entityName={ - paginatedPromptLibrary.find( - (prompt) => prompt.id === confirmDeletionId - )?.prompt ?? "" - } - /> - )} - -
- - { - setQuery(event.target.value); - setCurrentPage(1); - }} - /> -
-
- handleStatusSelect(option.key)} - icon={} - defaultDisplay="All Statuses" - /> -
- {selectedStatus.map((status) => ( - handleStatusSelect(status)} - /> - ))} -
-
-
- - - - {columns.map((column) => ( - {column.name} - ))} - - - - {paginatedPromptLibrary.length > 0 ? ( - paginatedPromptLibrary - .filter((prompt) => !(!isPublic && prompt.is_public)) - .map((item) => ( - - {item.prompt} - - {item.content} - - {item.active ? "Active" : "Inactive"} - - - - - - - - )) - ) : ( - - No matching prompts found... - - )} - -
- {paginatedPromptLibrary.length > 0 && ( -
- -
- )} -
-
- ); -}; diff --git a/web/src/app/admin/prompt-library/promptSection.tsx b/web/src/app/admin/prompt-library/promptSection.tsx deleted file mode 100644 index 015408ef0f0..00000000000 --- a/web/src/app/admin/prompt-library/promptSection.tsx +++ /dev/null @@ -1,150 +0,0 @@ -"use client"; - -import { usePopup } from "@/components/admin/connectors/Popup"; -import { ThreeDotsLoader } from "@/components/Loading"; -import { ErrorCallout } from "@/components/ErrorCallout"; -import { Button } from "@/components/ui/button"; -import { Separator } from "@/components/ui/separator"; -import Text from "@/components/ui/text"; -import { useState } from "react"; -import AddPromptModal from "./modals/AddPromptModal"; -import EditPromptModal from "./modals/EditPromptModal"; -import { PromptLibraryTable } from "./promptLibrary"; -import { CreateInputPromptRequest, InputPrompt } from "./interfaces"; - -export const PromptSection = ({ - promptLibrary, - isLoading, - error, - refreshPrompts, - centering = false, - isPublic, -}: { - promptLibrary: InputPrompt[]; - isLoading: boolean; - error: any; - refreshPrompts: () => void; - centering?: boolean; - isPublic: boolean; -}) => { - const { popup, setPopup } = usePopup(); - const [newPrompt, setNewPrompt] = useState(false); - const [newPromptId, setNewPromptId] = useState(null); - - const createInputPrompt = async ( - promptData: CreateInputPromptRequest - ): Promise => { - const response = await fetch("/api/input_prompt", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ ...promptData, is_public: isPublic }), - }); - - if (!response.ok) { - setPopup({ message: "Failed to create input prompt", type: "error" }); - } - - refreshPrompts(); - return response.json(); - }; - - const editInputPrompt = async ( - promptId: number, - values: CreateInputPromptRequest - ) => { - try { - const response = await fetch(`/api/input_prompt/${promptId}`, { - method: "PATCH", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(values), - }); - - if (!response.ok) { - setPopup({ message: "Failed to update prompt!", type: "error" }); - } - - setNewPromptId(null); - refreshPrompts(); - } catch (err) { - setPopup({ message: `Failed to update prompt: ${err}`, type: "error" }); - } - }; - - if (isLoading) { - return ; - } - - if (error || !promptLibrary) { - return ( - - ); - } - - const handleEdit = (promptId: number) => { - setNewPromptId(promptId); - }; - - return ( -
- {popup} - - {newPrompt && ( - setNewPrompt(false)} - /> - )} - - {newPromptId && ( - setNewPromptId(null)} - /> - )} -
- - Create prompts that can be accessed with the `/` shortcut in - Danswer Chat.{" "} - {isPublic - ? "Prompts created here will be accessible to all users." - : "Prompts created here will be available only to you."} - -
- -
- - - - - -
- -
-
- ); -}; diff --git a/web/src/app/admin/settings/SettingsForm.tsx b/web/src/app/admin/settings/SettingsForm.tsx index f00e4d978b5..5e2eb00335a 100644 --- a/web/src/app/admin/settings/SettingsForm.tsx +++ b/web/src/app/admin/settings/SettingsForm.tsx @@ -175,29 +175,6 @@ export function SettingsForm() { { fieldName, newValue: checked }, ]; - // If we're disabling a page, check if we need to update the default page - if ( - !checked && - (fieldName === "search_page_enabled" || fieldName === "chat_page_enabled") - ) { - const otherPageField = - fieldName === "search_page_enabled" - ? "chat_page_enabled" - : "search_page_enabled"; - const otherPageEnabled = settings && settings[otherPageField]; - - if ( - otherPageEnabled && - settings?.default_page === - (fieldName === "search_page_enabled" ? "search" : "chat") - ) { - updates.push({ - fieldName: "default_page", - newValue: fieldName === "search_page_enabled" ? "chat" : "search", - }); - } - } - updateSettingField(updates); } @@ -218,42 +195,17 @@ export function SettingsForm() { return (
{popup} - Page Visibility + Workspace Settings - handleToggleSettingsField("search_page_enabled", e.target.checked) + handleToggleSettingsField("auto_scroll", e.target.checked) } /> - - handleToggleSettingsField("chat_page_enabled", e.target.checked) - } - /> - - { - value && - updateSettingField([ - { fieldName: "default_page", newValue: value }, - ]); - }} - /> - {isEnterpriseEnabled && ( <> Chat Settings diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index 38959fc8cd2..32ce1d01067 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -5,14 +5,12 @@ export enum GatingType { } export interface Settings { - chat_page_enabled: boolean; - search_page_enabled: boolean; - default_page: "search" | "chat"; maximum_chat_retention_days: number | null; notifications: Notification[]; needs_reindexing: boolean; gpu_enabled: boolean; product_gating: GatingType; + auto_scroll: boolean; } export enum NotificationType { @@ -54,6 +52,7 @@ export interface EnterpriseSettings { custom_popup_header: string | null; custom_popup_content: string | null; enable_consent_screen: boolean | null; + auto_scroll: boolean; } export interface CombinedSettings { diff --git a/web/src/app/admin/users/page.tsx b/web/src/app/admin/users/page.tsx index 8dbc6c308d5..e4ffca94241 100644 --- a/web/src/app/admin/users/page.tsx +++ b/web/src/app/admin/users/page.tsx @@ -1,13 +1,13 @@ "use client"; +import { useEffect, useState } from "react"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; +import { Button } from "@/components/ui/button"; import InvitedUserTable from "@/components/admin/users/InvitedUserTable"; import SignedUpUserTable from "@/components/admin/users/SignedUpUserTable"; import { SearchBar } from "@/components/search/SearchBar"; -import { useState } from "react"; import { FiPlusSquare } from "react-icons/fi"; import { Modal } from "@/components/Modal"; - -import { Button } from "@/components/ui/button"; -import Text from "@/components/ui/text"; import { LoadingAnimation } from "@/components/Loading"; import { AdminPageTitle } from "@/components/admin/Title"; import { usePopup, PopupSpec } from "@/components/admin/connectors/Popup"; @@ -15,42 +15,10 @@ import { UsersIcon } from "@/components/icons/icons"; import { errorHandlingFetcher } from "@/lib/fetcher"; import useSWR, { mutate } from "swr"; import { ErrorCallout } from "@/components/ErrorCallout"; -import { HidableSection } from "@/app/admin/assistants/HidableSection"; import BulkAdd from "@/components/admin/users/BulkAdd"; import { UsersResponse } from "@/lib/users/interfaces"; - -const ValidDomainsDisplay = ({ validDomains }: { validDomains: string[] }) => { - if (!validDomains.length) { - return ( -
- No invited users. Anyone can sign up with a valid email address. To - restrict access you can: -
- (1) Invite users above. Once a user has been invited, only emails that - have explicitly been invited will be able to sign-up. -
-
- (2) Set the{" "} - VALID_EMAIL_DOMAINS{" "} - environment variable to a comma separated list of email domains. This - will restrict access to users with email addresses from these domains. -
-
- ); - } - - return ( -
- No invited users. Anyone with an email address with any of the following - domains can sign up: {validDomains.join(", ")}. -
- To further restrict access you can invite users above. Once a user has - been invited, only emails that have explicitly been invited will be able - to sign-up. -
-
- ); -}; +import SlackUserTable from "@/components/admin/users/SlackUserTable"; +import Text from "@/components/ui/text"; const UsersTables = ({ q, @@ -61,23 +29,48 @@ const UsersTables = ({ }) => { const [invitedPage, setInvitedPage] = useState(1); const [acceptedPage, setAcceptedPage] = useState(1); - const { data, isLoading, mutate, error } = useSWR( - `/api/manage/users?q=${encodeURI(q)}&accepted_page=${ + const [slackUsersPage, setSlackUsersPage] = useState(1); + + const [usersData, setUsersData] = useState( + undefined + ); + const [domainsData, setDomainsData] = useState( + undefined + ); + + const { data, error, mutate } = useSWR( + `/api/manage/users?q=${encodeURIComponent(q)}&accepted_page=${ acceptedPage - 1 - }&invited_page=${invitedPage - 1}`, + }&invited_page=${invitedPage - 1}&slack_users_page=${slackUsersPage - 1}`, errorHandlingFetcher ); - const { - data: validDomains, - isLoading: isLoadingDomains, - error: domainsError, - } = useSWR("/api/manage/admin/valid-domains", errorHandlingFetcher); - if (isLoading || isLoadingDomains) { + const { data: validDomains, error: domainsError } = useSWR( + "/api/manage/admin/valid-domains", + errorHandlingFetcher + ); + + useEffect(() => { + if (data) { + setUsersData(data); + } + }, [data]); + + useEffect(() => { + if (validDomains) { + setDomainsData(validDomains); + } + }, [validDomains]); + + const activeData = data ?? usersData; + const activeDomains = validDomains ?? domainsData; + + // Show loading animation only during the initial data fetch + if (!activeData || !activeDomains) { return ; } - if (error || !data) { + if (error) { return ( !accepted.map((u) => u.email).includes(user.email) + (user) => !accepted.some((u) => u.email === user.email) ); return ( - <> - - {invited.length > 0 ? ( - finalInvited.length > 0 ? ( - - ) : ( -
- To invite additional teammates, use the Invite Users button - above! -
- ) - ) : ( - - )} -
- - + + + Invited Users + Current Users + DanswerBot Users + + + + + + Invited Users + + + {finalInvited.length > 0 ? ( + + ) : ( +

Users that have been invited will show up here

+ )} +
+
+
+ + + + + Current Users + + + {accepted.length > 0 ? ( + + ) : ( +

Users that have an account will show up here

+ )} +
+
+
+ + + + + DanswerBot Users + + + {slack_users.length > 0 ? ( + + ) : ( +

Slack-only users will show up here

+ )} +
+
+
+
); }; @@ -215,6 +257,7 @@ const Page = () => { return (
} /> +
); diff --git a/web/src/app/assistants/mine/WrappedInputPrompts.tsx b/web/src/app/assistants/mine/WrappedInputPrompts.tsx deleted file mode 100644 index e39e695366c..00000000000 --- a/web/src/app/assistants/mine/WrappedInputPrompts.tsx +++ /dev/null @@ -1,51 +0,0 @@ -"use client"; -import SidebarWrapper from "../SidebarWrapper"; -import { ChatSession } from "@/app/chat/interfaces"; -import { Folder } from "@/app/chat/folders/interfaces"; -import { User } from "@/lib/types"; - -import { AssistantsPageTitle } from "../AssistantsPageTitle"; -import { useInputPrompts } from "@/app/admin/prompt-library/hooks"; -import { PromptSection } from "@/app/admin/prompt-library/promptSection"; - -export default function WrappedPrompts({ - chatSessions, - initiallyToggled, - folders, - openedFolders, -}: { - chatSessions: ChatSession[]; - folders: Folder[]; - initiallyToggled: boolean; - openedFolders?: { [key: number]: boolean }; -}) { - const { - data: promptLibrary, - error: promptLibraryError, - isLoading: promptLibraryIsLoading, - refreshInputPrompts: refreshPrompts, - } = useInputPrompts(false); - - return ( - -
- Prompt Gallery - -
-
- ); -} diff --git a/web/src/app/auth/impersonate/page.tsx b/web/src/app/auth/impersonate/page.tsx index 1a2c77d2cdb..bdc9b37fd2f 100644 --- a/web/src/app/auth/impersonate/page.tsx +++ b/web/src/app/auth/impersonate/page.tsx @@ -14,13 +14,9 @@ const ImpersonateSchema = Yup.object().shape({ export default function ImpersonatePage() { const router = useRouter(); - const { user, isLoadingUser, isCloudSuperuser } = useUser(); + const { user, isCloudSuperuser } = useUser(); const { popup, setPopup } = usePopup(); - if (isLoadingUser) { - return null; - } - if (!user) { redirect("/auth/login"); } diff --git a/web/src/app/auth/login/EmailPasswordForm.tsx b/web/src/app/auth/login/EmailPasswordForm.tsx index f89518571e4..75acb11ecfd 100644 --- a/web/src/app/auth/login/EmailPasswordForm.tsx +++ b/web/src/app/auth/login/EmailPasswordForm.tsx @@ -3,25 +3,25 @@ import { TextFormField } from "@/components/admin/connectors/Field"; import { usePopup } from "@/components/admin/connectors/Popup"; import { basicLogin, basicSignup } from "@/lib/user"; +import Cookies from "js-cookie"; import { Button } from "@/components/ui/button"; import { Form, Formik } from "formik"; -import { useRouter } from "next/navigation"; import * as Yup from "yup"; import { requestEmailVerification } from "../lib"; import { useState } from "react"; import { Spinner } from "@/components/Spinner"; - export function EmailPasswordForm({ isSignup = false, shouldVerify, referralSource, + nextUrl, }: { isSignup?: boolean; shouldVerify?: boolean; referralSource?: string; + nextUrl?: string | null; }) { - const router = useRouter(); const { popup, setPopup } = usePopup(); const [isWorking, setIsWorking] = useState(false); @@ -70,12 +70,17 @@ export function EmailPasswordForm({ const loginResponse = await basicLogin(values.email, values.password); if (loginResponse.ok) { - window.justLoggedIn = true; + Cookies.set("JUST_LOGGED_IN", "true", { expires: 1 }); +// window.justLoggedIn = true; if (isSignup && shouldVerify) { await requestEmailVerification(values.email); - router.push("/auth/waiting-on-verification"); + // Use window.location.href to force a full page reload, + // ensuring app re-initializes with the new state (including + // server-side provider values) + window.location.href = "/auth/waiting-on-verification"; } else { - router.push("/"); + // See above comment + window.location.href = nextUrl ? encodeURI(nextUrl) : "/"; } } else { setIsWorking(false); diff --git a/web/src/app/auth/login/page.tsx b/web/src/app/auth/login/page.tsx index 2a320ce3cf5..ca710fd4b56 100644 --- a/web/src/app/auth/login/page.tsx +++ b/web/src/app/auth/login/page.tsx @@ -23,6 +23,9 @@ const Page = async (props: { }) => { const searchParams = await props.searchParams; const autoRedirectDisabled = searchParams?.disableAutoRedirect === "true"; + const nextUrl = Array.isArray(searchParams?.next) + ? searchParams?.next[0] + : searchParams?.next || null; // catch cases where the backend is completely unreachable here // without try / catch, will just raise an exception and the page @@ -39,10 +42,6 @@ const Page = async (props: { console.log(`Some fetch failed for the login page - ${e}`); } - const nextUrl = Array.isArray(searchParams?.next) - ? searchParams?.next[0] - : searchParams?.next || null; - // simply take the user to the home page if Auth is disabled if (authTypeMetadata?.authType === "disabled") { return redirect("/"); @@ -102,12 +101,15 @@ const Page = async (props: { or
- +
Don't have an account?{" "} - + Create an account @@ -122,11 +124,14 @@ const Page = async (props: {
- +
Don't have an account?{" "} - + Create an account diff --git a/web/src/app/auth/signup/page.tsx b/web/src/app/auth/signup/page.tsx index 6f8007d157f..39f4688417f 100644 --- a/web/src/app/auth/signup/page.tsx +++ b/web/src/app/auth/signup/page.tsx @@ -16,7 +16,14 @@ import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; import ReferralSourceSelector from "./ReferralSourceSelector"; import { Separator } from "@/components/ui/separator"; -const Page = async () => { +const Page = async (props: { + searchParams?: Promise<{ [key: string]: string | string[] | undefined }>; +}) => { + const searchParams = await props.searchParams; + const nextUrl = Array.isArray(searchParams?.next) + ? searchParams?.next[0] + : searchParams?.next || null; + // catch cases where the backend is completely unreachable here // without try / catch, will just raise an exception and the page // will not render @@ -88,12 +95,19 @@ const Page = async () => {
Already have an account?{" "} - + Log In diff --git a/web/src/app/chat/ChatBanner.tsx b/web/src/app/chat/ChatBanner.tsx index 59fc8bd32d5..3479b7f1e9e 100644 --- a/web/src/app/chat/ChatBanner.tsx +++ b/web/src/app/chat/ChatBanner.tsx @@ -60,7 +60,11 @@ export function ChatBanner() {
{ + const handleResize = () => { + setScreenSize({ + width: window.innerWidth, + height: window.innerHeight, + }); + }; + + window.addEventListener("resize", handleResize); + return () => window.removeEventListener("resize", handleResize); + }, []); + + return screenSize; + } + + const { height: screenHeight } = useScreenSize(); + + const getContainerHeight = () => { + if (autoScrollEnabled) return undefined; + + if (screenHeight < 600) return "20vh"; + if (screenHeight < 1200) return "30vh"; + return "40vh"; + }; // handle redirect if chat page is disabled // NOTE: this must be done here, in a client component since @@ -149,9 +176,11 @@ export function ChatPage({ // available in server-side components const settings = useContext(SettingsContext); const enterpriseSettings = settings?.enterpriseSettings; - if (settings?.settings?.chat_page_enabled === false) { - router.push("/search"); - } + + const [documentSidebarToggled, setDocumentSidebarToggled] = useState(false); + const [filtersToggled, setFiltersToggled] = useState(false); + + const [userSettingsToggled, setUserSettingsToggled] = useState(false); const { assistants: availableAssistants, finalAssistants } = useAssistants(); @@ -159,14 +188,13 @@ export function ChatPage({ !shouldShowWelcomeModal ); - const { user, isAdmin, isLoadingUser, refreshUser } = useUser(); - + const { user, isAdmin } = useUser(); + const slackChatId = searchParams.get("slackChatId"); const existingChatIdRaw = searchParams.get("chatId"); const [sendOnLoad, setSendOnLoad] = useState( searchParams.get(SEARCH_PARAM_NAMES.SEND_ON_LOAD) ); - const currentPersonaId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID); const modelVersionFromSearchParams = searchParams.get( SEARCH_PARAM_NAMES.STRUCTURED_MODEL ); @@ -252,6 +280,9 @@ export function ChatPage({ const [alternativeAssistant, setAlternativeAssistant] = useState(null); + const [presentingDocument, setPresentingDocument] = + useState(null); + const { visibleAssistants: assistants, recentAssistants, @@ -259,7 +290,7 @@ export function ChatPage({ refreshRecentAssistants, } = useAssistants(); - const liveAssistant = + const liveAssistant: Persona | undefined = alternativeAssistant || selectedAssistant || recentAssistants[0] || @@ -267,8 +298,20 @@ export function ChatPage({ availableAssistants[0]; const noAssistants = liveAssistant == null || liveAssistant == undefined; + + const availableSources = ccPairs.map((ccPair) => ccPair.source); + const [finalAvailableSources, finalAvailableDocumentSets] = + computeAvailableFilters({ + selectedPersona: availableAssistants.find( + (assistant) => assistant.id === liveAssistant?.id + ), + availableSources: availableSources, + availableDocumentSets: documentSets, + }); + // always set the model override for the chat session, when an assistant, llm provider, or user preference exists useEffect(() => { + if (noAssistants) return; const personaDefault = getLLMProviderOverrideForPersona( liveAssistant, llmProviders @@ -355,9 +398,7 @@ export function ChatPage({ textAreaRef.current?.focus(); // only clear things if we're going from one chat session to another - const isChatSessionSwitch = - chatSessionIdRef.current !== null && - existingChatSessionId !== priorChatSessionId; + const isChatSessionSwitch = existingChatSessionId !== priorChatSessionId; if (isChatSessionSwitch) { // de-select documents clearSelectedDocuments(); @@ -370,7 +411,7 @@ export function ChatPage({ // reset LLM overrides (based on chat session!) llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession); - llmOverrideManager.setTemperature(null); + llmOverrideManager.updateTemperature(null); // remove uploaded files setCurrentMessageFiles([]); @@ -403,6 +444,7 @@ export function ChatPage({ } return; } + setIsReady(true); const shouldScrollToBottom = visibleRange.get(existingChatSessionId) === undefined || visibleRange.get(existingChatSessionId)?.end == 0; @@ -428,13 +470,14 @@ export function ChatPage({ loadedSessionId != null) && !currentChatAnswering() ) { - updateCompleteMessageDetail(chatSession.chat_session_id, newMessageMap); - const latestMessageId = newMessageHistory[newMessageHistory.length - 1]?.messageId; + setSelectedMessageForDocDisplay( latestMessageId !== undefined ? latestMessageId : null ); + + updateCompleteMessageDetail(chatSession.chat_session_id, newMessageMap); } setChatSessionSharedStatus(chatSession.shared_status); @@ -446,12 +489,13 @@ export function ChatPage({ } if (shouldScrollToBottom) { - if (!hasPerformedInitialScroll) { + if (!hasPerformedInitialScroll && autoScrollEnabled) { clientScrollToBottom(); - } else if (isChatSessionSwitch) { + } else if (isChatSessionSwitch && autoScrollEnabled) { clientScrollToBottom(true); } } + setIsFetchingChatMessages(false); // if this is a seeded chat, then kick off the AI message generation @@ -468,9 +512,12 @@ export function ChatPage({ }); // force re-name if the chat session doesn't have one if (!chatSession.description) { - await nameChatSession(existingChatSessionId, seededMessage); + await nameChatSession(existingChatSessionId); refreshChatSessions(); } + } else if (newMessageHistory.length === 2 && !chatSession.description) { + await nameChatSession(existingChatSessionId); + refreshChatSessions(); } } @@ -753,7 +800,7 @@ export function ChatPage({ useEffect(() => { async function fetchMaxTokens() { const response = await fetch( - `/api/chat/max-selected-document-tokens?persona_id=${liveAssistant.id}` + `/api/chat/max-selected-document-tokens?persona_id=${liveAssistant?.id}` ); if (response.ok) { const maxTokens = (await response.json()).max_tokens as number; @@ -827,11 +874,13 @@ export function ChatPage({ 0 )}px`; - scrollableDivRef?.current.scrollBy({ - left: 0, - top: Math.max(heightDifference, 0), - behavior: "smooth", - }); + if (autoScrollEnabled) { + scrollableDivRef?.current.scrollBy({ + left: 0, + top: Math.max(heightDifference, 0), + behavior: "smooth", + }); + } } previousHeight.current = newHeight; } @@ -873,11 +922,11 @@ export function ChatPage({ setHasPerformedInitialScroll(true); }, 100); } else { - console.log("All messages are already rendered, scrolling immediately"); // If all messages are already rendered, scroll immediately endDivRef.current.scrollIntoView({ behavior: fast ? "auto" : "smooth", }); + setHasPerformedInitialScroll(true); } }, 50); @@ -925,6 +974,17 @@ export function ChatPage({ } }; + useEffect(() => { + if ( + !personaIncludesRetrieval && + (!selectedDocuments || selectedDocuments.length === 0) && + documentSidebarToggled && + !filtersToggled + ) { + setDocumentSidebarToggled(false); + } + }, [chatSessionIdRef.current]); + useEffect(() => { adjustDocumentSidebarWidth(); // Adjust the width on initial render window.addEventListener("resize", adjustDocumentSidebarWidth); // Add resize event listener @@ -1020,16 +1080,25 @@ export function ChatPage({ updateCanContinue(false, frozenSessionId); if (currentChatState() != "input") { - setPopup({ - message: "Please wait for the response to complete", - type: "error", - }); + if (currentChatState() == "uploading") { + setPopup({ + message: "Please wait for the content to upload", + type: "error", + }); + } else { + setPopup({ + message: "Please wait for the response to complete", + type: "error", + }); + } return; } setAlternativeGeneratingAssistant(alternativeAssistantOverride); + clientScrollToBottom(); + let currChatSessionId: string; const isNewSession = chatSessionIdRef.current === null; const searchParamBasedChatSessionName = @@ -1201,7 +1270,6 @@ export function ChatPage({ if (!packet) { continue; } - if (!initialFetchDetails) { if (!Object.hasOwn(packet, "user_message_id")) { console.error( @@ -1276,7 +1344,7 @@ export function ChatPage({ if (Object.hasOwn(packet, "answer_piece")) { answer += (packet as AnswerPiecePacket).answer_piece; } else if (Object.hasOwn(packet, "top_documents")) { - documents = (packet as DocumentsResponse).top_documents; + documents = (packet as DocumentInfoPacket).top_documents; retrievalType = RetrievalType.Search; if (documents && documents.length > 0) { // point to the latest message (we don't know the messageId yet, which is why @@ -1373,8 +1441,7 @@ export function ChatPage({ type: error ? "error" : "assistant", retrievalType, query: finalMessage?.rephrased_query || query, - documents: - finalMessage?.context_docs?.top_documents || documents, + documents: documents, citations: finalMessage?.citations || {}, files: finalMessage?.files || aiMessageImages || [], toolCall: finalMessage?.tool_call || toolCall, @@ -1428,7 +1495,7 @@ export function ChatPage({ if (!searchParamBasedChatSessionName) { await new Promise((resolve) => setTimeout(resolve, 200)); - await nameChatSession(currChatSessionId, currMessage); + await nameChatSession(currChatSessionId); refreshChatSessions(); } @@ -1498,7 +1565,7 @@ export function ChatPage({ } }; - const handleImageUpload = (acceptedFiles: File[]) => { + const handleImageUpload = async (acceptedFiles: File[]) => { const [_, llmModel] = getFinalLLM( llmProviders, liveAssistant, @@ -1538,8 +1605,9 @@ export function ChatPage({ (file) => !tempFileDescriptors.some((newFile) => newFile.id === file.id) ); }; + updateChatState("uploading", currentSessionId()); - uploadFilesForChat(acceptedFiles).then(([files, error]) => { + await uploadFilesForChat(acceptedFiles).then(([files, error]) => { if (error) { setCurrentMessageFiles((prev) => removeTempFiles(prev)); setPopup({ @@ -1550,15 +1618,16 @@ export function ChatPage({ setCurrentMessageFiles((prev) => [...removeTempFiles(prev), ...files]); } }); + updateChatState("input", currentSessionId()); }; - const [showDocSidebar, setShowDocSidebar] = useState(false); // State to track if sidebar is open + const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open // Used to maintain a "time out" for history sidebar so our existing refs can have time to process change const [untoggled, setUntoggled] = useState(false); const [loadingError, setLoadingError] = useState(null); const explicitlyUntoggle = () => { - setShowDocSidebar(false); + setShowHistorySidebar(false); setUntoggled(true); setTimeout(() => { @@ -1577,7 +1646,7 @@ export function ChatPage({ toggle(); }; const removeToggle = () => { - setShowDocSidebar(false); + setShowHistorySidebar(false); toggle(false); }; @@ -1587,20 +1656,25 @@ export function ChatPage({ useSidebarVisibility({ toggledSidebar, sidebarElementRef, - showDocSidebar, - setShowDocSidebar, + showDocSidebar: showHistorySidebar, + setShowDocSidebar: setShowHistorySidebar, setToggled: removeToggle, mobile: settings?.isMobile, }); + const autoScrollEnabled = + user?.preferences?.auto_scroll == null + ? settings?.enterpriseSettings?.auto_scroll || false + : user?.preferences?.auto_scroll!; + useScrollonStream({ chatState: currentSessionChatState, scrollableDivRef, scrollDist, endDivRef, debounceNumber, - waitForScrollRef, mobile: settings?.isMobile, + enableAutoScroll: autoScrollEnabled, }); // Virtualization + Scrolling related effects and functions @@ -1750,6 +1824,13 @@ export function ChatPage({ liveAssistant ); }); + + useEffect(() => { + if (!retrievalEnabled) { + setDocumentSidebarToggled(false); + } + }, [retrievalEnabled]); + const [stackTraceModalContent, setStackTraceModalContent] = useState< string | null >(null); @@ -1758,7 +1839,41 @@ export function ChatPage({ const [settingsToggled, setSettingsToggled] = useState(false); const currentPersona = alternativeAssistant || liveAssistant; + useEffect(() => { + const handleSlackChatRedirect = async () => { + if (!slackChatId) return; + + // Set isReady to false before starting retrieval to display loading text + setIsReady(false); + + try { + const response = await fetch("/api/chat/seed-chat-session-from-slack", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + chat_session_id: slackChatId, + }), + }); + + if (!response.ok) { + throw new Error("Failed to seed chat from Slack"); + } + + const data = await response.json(); + router.push(data.redirect_url); + } catch (error) { + console.error("Error seeding chat from Slack:", error); + setPopup({ + message: "Failed to load chat from Slack", + type: "error", + }); + } + }; + handleSlackChatRedirect(); + }, [searchParams, router]); useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { if (event.metaKey || event.ctrlKey) { @@ -1789,14 +1904,36 @@ export function ChatPage({ setSharedChatSession(chatSession); }; const [documentSelection, setDocumentSelection] = useState(false); - const toggleDocumentSelectionAspects = () => { - setDocumentSelection((documentSelection) => !documentSelection); - setShowDocSidebar(false); + // const toggleDocumentSelectionAspects = () => { + // setDocumentSelection((documentSelection) => !documentSelection); + // setShowDocSidebar(false); + // }; + + const toggleDocumentSidebar = () => { + if (!documentSidebarToggled) { + setFiltersToggled(false); + setDocumentSidebarToggled(true); + } else if (!filtersToggled) { + setDocumentSidebarToggled(false); + } else { + setFiltersToggled(false); + } + }; + const toggleFilters = () => { + if (!documentSidebarToggled) { + setFiltersToggled(true); + setDocumentSidebarToggled(true); + } else if (filtersToggled) { + setDocumentSidebarToggled(false); + } else { + setFiltersToggled(true); + } }; interface RegenerationRequest { messageId: number; parentMessage: Message; + forceSearch?: boolean; } function createRegenerator(regenerationRequest: RegenerationRequest) { @@ -1806,21 +1943,27 @@ export function ChatPage({ modelOverRide, messageIdToResend: regenerationRequest.parentMessage.messageId, regenerationRequest, + forceSearch: regenerationRequest.forceSearch, }); }; } + if (noAssistants) + return ( + <> + + + + ); return ( <> - {showApiKeyModal && !shouldShowWelcomeModal ? ( + {showApiKeyModal && !shouldShowWelcomeModal && ( setShowApiKeyModal(false)} setPopup={setPopup} /> - ) : ( - noAssistants && )} {/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit. @@ -1828,6 +1971,7 @@ export function ChatPage({ {popup} + {currentFeedback && ( )} - {settingsToggled && ( + {(settingsToggled || userSettingsToggled) && ( setSettingsToggled(false)} + onClose={() => { + setUserSettingsToggled(false); + setSettingsToggled(false); + }} /> )} + {retrievalEnabled && documentSidebarToggled && settings?.isMobile && ( +
+ + { + setDocumentSidebarToggled(false); + }} + selectedMessage={aiMessage} + selectedDocuments={selectedDocuments} + toggleDocumentSelection={toggleDocumentSelection} + clearSelectedDocuments={clearSelectedDocuments} + selectedDocumentTokens={selectedDocumentTokens} + maxTokens={maxTokens} + initialWidth={400} + isOpen={true} + /> + +
+ )} + {deletingChatSession && ( )} + {presentingDocument && ( + setPresentingDocument(null)} + /> + )} + {stackTraceModalContent && ( setStackTraceModalContent(null)} @@ -1928,7 +2110,7 @@ export function ChatPage({ duration-300 ease-in-out ${ - !untoggled && (showDocSidebar || toggledSidebar) + !untoggled && (showHistorySidebar || toggledSidebar) ? "opacity-100 w-[250px] translate-x-0" : "opacity-0 w-[200px] pointer-events-none -translate-x-10" }`} @@ -1942,7 +2124,7 @@ export function ChatPage({ ref={innerSidebarElementRef} toggleSidebar={toggleSidebar} toggled={toggledSidebar && !settings?.isMobile} - backgroundToggled={toggledSidebar || showDocSidebar} + backgroundToggled={toggledSidebar || showHistorySidebar} existingChats={chatSessions} currentChatSession={selectedChatSession} folders={folders} @@ -1954,16 +2136,64 @@ export function ChatPage({
+ {!settings?.isMobile && retrievalEnabled && ( +
+ setDocumentSidebarToggled(false)} + selectedMessage={aiMessage} + selectedDocuments={selectedDocuments} + toggleDocumentSelection={toggleDocumentSelection} + clearSelectedDocuments={clearSelectedDocuments} + selectedDocumentTokens={selectedDocumentTokens} + maxTokens={maxTokens} + initialWidth={400} + isOpen={documentSidebarToggled} + /> +
+ )}
-
+
{liveAssistant && ( setUserSettingsToggled(true)} + liveAssistant={liveAssistant} + onAssistantChange={onAssistantChange} sidebarToggled={toggledSidebar} reset={() => setMessage("")} page="chat" @@ -1974,12 +2204,12 @@ export function ChatPage({ } toggleSidebar={toggleSidebar} currentChatSession={selectedChatSession} + documentSidebarToggled={documentSidebarToggled} + llmOverrideManager={llmOverrideManager} /> )} - {documentSidebarInitialWidth !== undefined && - isReady && - !isLoadingUser ? ( + {documentSidebarInitialWidth !== undefined && isReady ? ( {({ getRootProps }) => (
@@ -1995,7 +2225,7 @@ export function ChatPage({ duration-300 ease-in-out h-full - ${toggledSidebar ? "w-[250px]" : "w-[0px]"} + ${toggledSidebar ? "w-[200px]" : "w-[0px]"} `} >
)} @@ -2005,9 +2235,55 @@ export function ChatPage({ {...getRootProps()} >
+ {liveAssistant && onAssistantChange && ( +
+ {!settings?.isMobile && ( +
+ )} + + + {!settings?.isMobile && ( +
+ )} +
+ )} + {/* ChatBanner is a custom banner that displays a admin-specified message at the top of the chat page. Oly used in the EE version of the app. */} @@ -2015,7 +2291,7 @@ export function ChatPage({ !isFetchingChatMessages && currentSessionChatState == "input" && !loadingError && ( -
+
{ + if ( + !documentSidebarToggled || + (documentSidebarToggled && + selectedMessageForDocDisplay === + message.messageId) + ) { + toggleDocumentSidebar(); + } + setSelectedMessageForDocDisplay( + message.messageId + ); + }} docs={message.documents} currentPersona={liveAssistant} alternativeAssistant={ @@ -2224,7 +2522,6 @@ export function ChatPage({ } messageId={message.messageId} content={message.message} - // content={message.message} files={message.files} query={ messageHistory[i]?.query || undefined @@ -2305,13 +2602,11 @@ export function ChatPage({ previousMessage && previousMessage.messageId ) { - onSubmit({ - messageIdToResend: - previousMessage.messageId, + createRegenerator({ + messageId: message.messageId, + parentMessage: parentMessage!, forceSearch: true, - alternativeAssistantOverride: - currentAlternativeAssistant, - }); + })(llmOverrideManager.llmOverride); } else { setPopup({ type: "error", @@ -2410,6 +2705,15 @@ export function ChatPage({ />
)} + {messageHistory.length > 0 && ( +
+ )} {/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
@@ -2433,13 +2737,21 @@ export function ChatPage({
)} { + clearSelectedDocuments(); + }} + removeFilters={() => { + filterManager.setSelectedSources([]); + filterManager.setSelectedTags([]); + filterManager.setSelectedDocumentSets([]); + setDocumentSidebarToggled(false); + }} showConfigureAPIKey={() => setShowApiKeyModal(true) } chatState={currentSessionChatState} stopGenerating={stopGenerating} openModelSettings={() => setSettingsToggled(true)} - inputPrompts={userInputPrompts} showDocs={() => setDocumentSelection(true)} selectedDocuments={selectedDocuments} // assistant stuff @@ -2455,6 +2767,9 @@ export function ChatPage({ llmOverrideManager={llmOverrideManager} files={currentMessageFiles} setFiles={setCurrentMessageFiles} + toggleFilters={ + retrievalEnabled ? toggleFilters : undefined + } handleFileUpload={handleImageUpload} textAreaRef={textAreaRef} chatSessionId={chatSessionIdRef.current!} @@ -2485,6 +2800,23 @@ export function ChatPage({
+ {!settings?.isMobile && ( +
+ )}
)} @@ -2493,7 +2825,11 @@ export function ChatPage({
@@ -2502,22 +2838,10 @@ export function ChatPage({ )}
- +
+ {/* Right Sidebar - DocumentSidebar */}
- setDocumentSelection(false)} - selectedMessage={aiMessage} - selectedDocuments={selectedDocuments} - toggleDocumentSelection={toggleDocumentSelection} - clearSelectedDocuments={clearSelectedDocuments} - selectedDocumentTokens={selectedDocumentTokens} - maxTokens={maxTokens} - isLoading={isFetchingChatMessages} - isOpen={documentSelection} - /> ); } diff --git a/web/src/app/chat/RegenerateOption.tsx b/web/src/app/chat/RegenerateOption.tsx index dd92bee0dbc..48ac766aadd 100644 --- a/web/src/app/chat/RegenerateOption.tsx +++ b/web/src/app/chat/RegenerateOption.tsx @@ -14,7 +14,6 @@ import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils"; import { useState } from "react"; import { Hoverable } from "@/components/Hoverable"; import { Popover } from "@/components/popover/Popover"; -import { StarFeedback } from "@/components/icons/icons"; import { IconType } from "react-icons"; import { FiRefreshCw } from "react-icons/fi"; diff --git a/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx b/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx index 85ac429c497..0c39f631422 100644 --- a/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx +++ b/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx @@ -1,133 +1,129 @@ -import { HoverPopup } from "@/components/HoverPopup"; import { SourceIcon } from "@/components/SourceIcon"; -import { PopupSpec } from "@/components/admin/connectors/Popup"; import { DanswerDocument } from "@/lib/search/interfaces"; -import { FiInfo, FiRadio } from "react-icons/fi"; +import { FiTag } from "react-icons/fi"; import { DocumentSelector } from "./DocumentSelector"; -import { - DocumentMetadataBlock, - buildDocumentSummaryDisplay, -} from "@/components/search/DocumentDisplay"; -import { InternetSearchIcon } from "@/components/InternetSearchIcon"; +import { buildDocumentSummaryDisplay } from "@/components/search/DocumentDisplay"; +import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBadge"; +import { MetadataBadge } from "@/components/MetadataBadge"; +import { WebResultIcon } from "@/components/WebResultIcon"; +import { Dispatch, SetStateAction } from "react"; +import { ValidSources } from "@/lib/types"; interface DocumentDisplayProps { + closeSidebar: () => void; document: DanswerDocument; - queryEventId: number | null; - isAIPick: boolean; + modal?: boolean; isSelected: boolean; handleSelect: (documentId: string) => void; - setPopup: (popupSpec: PopupSpec | null) => void; tokenLimitReached: boolean; + setPresentingDocument: Dispatch>; +} + +export function DocumentMetadataBlock({ + modal, + document, +}: { + modal?: boolean; + document: DanswerDocument; +}) { + const MAX_METADATA_ITEMS = 3; + const metadataEntries = Object.entries(document.metadata); + + return ( +
+ {document.updated_at && ( + + )} + + {metadataEntries.length > 0 && ( + <> +
+
+ {metadataEntries + .slice(0, MAX_METADATA_ITEMS) + .map(([key, value], index) => ( + + ))} + {metadataEntries.length > MAX_METADATA_ITEMS && ( + ... + )} +
+ + )} +
+ ); } export function ChatDocumentDisplay({ + closeSidebar, document, - queryEventId, - isAIPick, + modal, isSelected, handleSelect, - setPopup, tokenLimitReached, + setPresentingDocument, }: DocumentDisplayProps) { const isInternet = document.is_internet; - // Consider reintroducing null scored docs in the future if (document.score === null) { return null; } + const handleViewFile = async () => { + if (document.source_type == ValidSources.File && setPresentingDocument) { + setPresentingDocument(document); + } else if (document.link) { + window.open(document.link, "_blank"); + } + }; + return ( -
-
- +
+
); diff --git a/web/src/app/chat/documentSidebar/ChatFilters.tsx b/web/src/app/chat/documentSidebar/ChatFilters.tsx new file mode 100644 index 00000000000..a5761e8108a --- /dev/null +++ b/web/src/app/chat/documentSidebar/ChatFilters.tsx @@ -0,0 +1,197 @@ +import { DanswerDocument } from "@/lib/search/interfaces"; +import { ChatDocumentDisplay } from "./ChatDocumentDisplay"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { removeDuplicateDocs } from "@/lib/documentUtils"; +import { Message } from "../interfaces"; +import { + Dispatch, + ForwardedRef, + forwardRef, + SetStateAction, + useEffect, + useState, +} from "react"; +import { FilterManager } from "@/lib/hooks"; +import { CCPairBasicInfo, DocumentSet, Tag } from "@/lib/types"; +import { SourceSelector } from "../shared_chat_search/SearchFilters"; +import { XIcon } from "@/components/icons/icons"; + +interface ChatFiltersProps { + filterManager: FilterManager; + closeSidebar: () => void; + selectedMessage: Message | null; + selectedDocuments: DanswerDocument[] | null; + toggleDocumentSelection: (document: DanswerDocument) => void; + clearSelectedDocuments: () => void; + selectedDocumentTokens: number; + maxTokens: number; + initialWidth: number; + isOpen: boolean; + modal: boolean; + ccPairs: CCPairBasicInfo[]; + tags: Tag[]; + documentSets: DocumentSet[]; + showFilters: boolean; + setPresentingDocument: Dispatch>; +} + +export const ChatFilters = forwardRef( + ( + { + closeSidebar, + modal, + selectedMessage, + selectedDocuments, + filterManager, + toggleDocumentSelection, + clearSelectedDocuments, + selectedDocumentTokens, + maxTokens, + initialWidth, + isOpen, + ccPairs, + tags, + setPresentingDocument, + documentSets, + showFilters, + }, + ref: ForwardedRef + ) => { + const { popup, setPopup } = usePopup(); + const [delayedSelectedDocumentCount, setDelayedSelectedDocumentCount] = + useState(0); + + useEffect(() => { + const timer = setTimeout( + () => { + setDelayedSelectedDocumentCount(selectedDocuments?.length || 0); + }, + selectedDocuments?.length == 0 ? 1000 : 0 + ); + + return () => clearTimeout(timer); + }, [selectedDocuments]); + + const selectedDocumentIds = + selectedDocuments?.map((document) => document.document_id) || []; + + const currentDocuments = selectedMessage?.documents || null; + const dedupedDocuments = removeDuplicateDocs(currentDocuments || []); + + const tokenLimitReached = selectedDocumentTokens > maxTokens - 75; + + const hasSelectedDocuments = selectedDocumentIds.length > 0; + + return ( +
{ + if (e.target === e.currentTarget) { + closeSidebar(); + } + }} + > +
+
+ {popup} +
+

+ {showFilters ? "Filters" : "Sources"} +

+ +
+
+
+ {showFilters ? ( + ccPair.source)} + availableTags={tags} + /> + ) : ( + <> + {dedupedDocuments.length > 0 ? ( + dedupedDocuments.map((document, ind) => ( +
+ { + toggleDocumentSelection( + dedupedDocuments.find( + (doc) => doc.document_id === documentId + )! + ); + }} + tokenLimitReached={tokenLimitReached} + /> +
+ )) + ) : ( +
+ )} + + )} +
+
+ {!showFilters && ( +
+ +
+ )} +
+
+ ); + } +); + +ChatFilters.displayName = "ChatFilters"; diff --git a/web/src/app/chat/documentSidebar/DocumentSelector.tsx b/web/src/app/chat/documentSidebar/DocumentSelector.tsx index 2153ce5bdc7..ac94a410a9d 100644 --- a/web/src/app/chat/documentSidebar/DocumentSelector.tsx +++ b/web/src/app/chat/documentSidebar/DocumentSelector.tsx @@ -12,7 +12,8 @@ export function DocumentSelector({ }) { const [popupDisabled, setPopupDisabled] = useState(false); - function onClick() { + function onClick(e: React.MouseEvent) { + e.stopPropagation(); if (!isDisabled) { setPopupDisabled(true); handleSelect(); diff --git a/web/src/app/chat/documentSidebar/DocumentSidebar.tsx b/web/src/app/chat/documentSidebar/DocumentSidebar.tsx deleted file mode 100644 index 021c2398157..00000000000 --- a/web/src/app/chat/documentSidebar/DocumentSidebar.tsx +++ /dev/null @@ -1,168 +0,0 @@ -import { DanswerDocument } from "@/lib/search/interfaces"; -import Text from "@/components/ui/text"; -import { ChatDocumentDisplay } from "./ChatDocumentDisplay"; -import { usePopup } from "@/components/admin/connectors/Popup"; -import { removeDuplicateDocs } from "@/lib/documentUtils"; -import { Message } from "../interfaces"; -import { ForwardedRef, forwardRef } from "react"; -import { Separator } from "@/components/ui/separator"; - -interface DocumentSidebarProps { - closeSidebar: () => void; - selectedMessage: Message | null; - selectedDocuments: DanswerDocument[] | null; - toggleDocumentSelection: (document: DanswerDocument) => void; - clearSelectedDocuments: () => void; - selectedDocumentTokens: number; - maxTokens: number; - isLoading: boolean; - initialWidth: number; - isOpen: boolean; -} - -export const DocumentSidebar = forwardRef( - ( - { - closeSidebar, - selectedMessage, - selectedDocuments, - toggleDocumentSelection, - clearSelectedDocuments, - selectedDocumentTokens, - maxTokens, - isLoading, - initialWidth, - isOpen, - }, - ref: ForwardedRef - ) => { - const { popup, setPopup } = usePopup(); - - const selectedDocumentIds = - selectedDocuments?.map((document) => document.document_id) || []; - - const currentDocuments = selectedMessage?.documents || null; - const dedupedDocuments = removeDuplicateDocs(currentDocuments || []); - - // NOTE: do not allow selection if less than 75 tokens are left - // this is to prevent the case where they are able to select the doc - // but it basically is unused since it's truncated right at the very - // start of the document (since title + metadata + misc overhead) takes up - // space - const tokenLimitReached = selectedDocumentTokens > maxTokens - 75; - - return ( -
{ - if (e.target === e.currentTarget) { - closeSidebar(); - } - }} - > -
-
- {popup} -
- {dedupedDocuments.length} Document - {dedupedDocuments.length > 1 ? "s" : ""} -

- Select to add to continuous context - - Learn more - -

-
- - - - {currentDocuments ? ( -
- {dedupedDocuments.length > 0 ? ( - dedupedDocuments.map((document, ind) => ( -
- { - toggleDocumentSelection( - dedupedDocuments.find( - (document) => document.document_id === documentId - )! - ); - }} - tokenLimitReached={tokenLimitReached} - /> -
- )) - ) : ( -
- No documents found for the query. -
- )} -
- ) : ( - !isLoading && ( -
- - When you run ask a question, the retrieved documents will - show up here! - -
- ) - )} -
- -
-
- - - -
-
-
- ); - } -); - -DocumentSidebar.displayName = "DocumentSidebar"; diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 9dd3d5274c4..0909d054594 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -1,13 +1,9 @@ import React, { useContext, useEffect, useRef, useState } from "react"; -import { FiPlusCircle, FiPlus, FiInfo, FiX } from "react-icons/fi"; +import { FiPlusCircle, FiPlus, FiInfo, FiX, FiSearch } from "react-icons/fi"; import { ChatInputOption } from "./ChatInputOption"; import { Persona } from "@/app/admin/assistants/interfaces"; -import { InputPrompt } from "@/app/admin/prompt-library/interfaces"; -import { - FilterManager, - getDisplayNameForModel, - LlmOverrideManager, -} from "@/lib/hooks"; + +import { FilterManager, LlmOverrideManager } from "@/lib/hooks"; import { SelectedFilterDisplay } from "./SelectedFilterDisplay"; import { useChatContext } from "@/components/context/ChatContext"; import { getFinalLLM } from "@/lib/llm/utils"; @@ -18,16 +14,11 @@ import { } from "../files/InputBarPreview"; import { AssistantsIconSkeleton, - CpuIconSkeleton, FileIcon, SendIcon, StopGeneratingIcon, } from "@/components/icons/icons"; -import { IconType } from "react-icons"; -import Popup from "../../../components/popup/Popup"; -import { LlmTab } from "../modal/configuration/LlmTab"; -import { AssistantsTab } from "../modal/configuration/AssistantsTab"; -import { DanswerDocument } from "@/lib/search/interfaces"; +import { DanswerDocument, SourceMetadata } from "@/lib/search/interfaces"; import { AssistantIcon } from "@/components/assistants/AssistantIcon"; import { Tooltip, @@ -40,10 +31,49 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; import { ChatState } from "../types"; import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText"; import { useAssistants } from "@/components/context/AssistantsContext"; +import AnimatedToggle from "@/components/search/SearchBar"; +import { Popup } from "@/components/admin/connectors/Popup"; +import { AssistantsTab } from "../modal/configuration/AssistantsTab"; +import { IconType } from "react-icons"; +import { LlmTab } from "../modal/configuration/LlmTab"; +import { XIcon } from "lucide-react"; +import { FilterPills } from "./FilterPills"; +import { Tag } from "@/lib/types"; +import FiltersDisplay from "./FilterDisplay"; const MAX_INPUT_HEIGHT = 200; +interface ChatInputBarProps { + removeFilters: () => void; + removeDocs: () => void; + openModelSettings: () => void; + showDocs: () => void; + showConfigureAPIKey: () => void; + selectedDocuments: DanswerDocument[]; + message: string; + setMessage: (message: string) => void; + stopGenerating: () => void; + onSubmit: () => void; + filterManager: FilterManager; + llmOverrideManager: LlmOverrideManager; + chatState: ChatState; + alternativeAssistant: Persona | null; + // assistants + selectedAssistant: Persona; + setSelectedAssistant: (assistant: Persona) => void; + setAlternativeAssistant: (alternativeAssistant: Persona | null) => void; + + files: FileDescriptor[]; + setFiles: (files: FileDescriptor[]) => void; + handleFileUpload: (files: File[]) => void; + textAreaRef: React.RefObject; + chatSessionId?: string; + toggleFilters?: () => void; +} + export function ChatInputBar({ + removeFilters, + removeDocs, openModelSettings, showDocs, showConfigureAPIKey, @@ -67,30 +97,8 @@ export function ChatInputBar({ textAreaRef, alternativeAssistant, chatSessionId, - inputPrompts, -}: { - showConfigureAPIKey: () => void; - openModelSettings: () => void; - chatState: ChatState; - stopGenerating: () => void; - showDocs: () => void; - selectedDocuments: DanswerDocument[]; - setAlternativeAssistant: (alternativeAssistant: Persona | null) => void; - setSelectedAssistant: (assistant: Persona) => void; - inputPrompts: InputPrompt[]; - message: string; - setMessage: (message: string) => void; - onSubmit: () => void; - filterManager: FilterManager; - llmOverrideManager: LlmOverrideManager; - selectedAssistant: Persona; - alternativeAssistant: Persona | null; - files: FileDescriptor[]; - setFiles: (files: FileDescriptor[]) => void; - handleFileUpload: (files: File[]) => void; - textAreaRef: React.RefObject; - chatSessionId?: string; -}) { + toggleFilters, +}: ChatInputBarProps) { useEffect(() => { const textarea = textAreaRef.current; if (textarea) { @@ -127,7 +135,6 @@ export function ChatInputBar({ const suggestionsRef = useRef(null); const [showSuggestions, setShowSuggestions] = useState(false); - const [showPrompts, setShowPrompts] = useState(false); const interactionsRef = useRef(null); @@ -136,19 +143,6 @@ export function ChatInputBar({ setTabbingIconIndex(0); }; - const hidePrompts = () => { - setTimeout(() => { - setShowPrompts(false); - }, 50); - - setTabbingIconIndex(0); - }; - - const updateInputPrompt = (prompt: InputPrompt) => { - hidePrompts(); - setMessage(`${prompt.content}`); - }; - useEffect(() => { const handleClickOutside = (event: MouseEvent) => { if ( @@ -158,7 +152,6 @@ export function ChatInputBar({ !interactionsRef.current.contains(event.target as Node)) ) { hideSuggestions(); - hidePrompts(); } }; document.addEventListener("mousedown", handleClickOutside); @@ -188,24 +181,10 @@ export function ChatInputBar({ } }; - const handlePromptInput = (text: string) => { - if (!text.startsWith("/")) { - hidePrompts(); - } else { - const promptMatch = text.match(/(?:\s|^)\/(\w*)$/); - if (promptMatch) { - setShowPrompts(true); - } else { - hidePrompts(); - } - } - }; - const handleInputChange = (event: React.ChangeEvent) => { const text = event.target.value; setMessage(text); handleAssistantInput(text); - handlePromptInput(text); }; const assistantTagOptions = assistantOptions.filter((assistant) => @@ -217,49 +196,26 @@ export function ChatInputBar({ ) ); - const filteredPrompts = inputPrompts.filter( - (prompt) => - prompt.active && - prompt.prompt.toLowerCase().startsWith( - message - .slice(message.lastIndexOf("/") + 1) - .split(/\s/)[0] - .toLowerCase() - ) - ); - const [tabbingIconIndex, setTabbingIconIndex] = useState(0); const handleKeyDown = (e: React.KeyboardEvent) => { if ( - ((showSuggestions && assistantTagOptions.length > 0) || showPrompts) && + showSuggestions && + assistantTagOptions.length > 0 && (e.key === "Tab" || e.key == "Enter") ) { e.preventDefault(); - if ( - (tabbingIconIndex == assistantTagOptions.length && showSuggestions) || - (tabbingIconIndex == filteredPrompts.length && showPrompts) - ) { - if (showPrompts) { - window.open("/prompts", "_self"); - } else { - window.open("/assistants/new", "_self"); - } + if (tabbingIconIndex == assistantTagOptions.length && showSuggestions) { + window.open("/assistants/new", "_self"); } else { - if (showPrompts) { - const uppity = - filteredPrompts[tabbingIconIndex >= 0 ? tabbingIconIndex : 0]; - updateInputPrompt(uppity); - } else { - const option = - assistantTagOptions[tabbingIconIndex >= 0 ? tabbingIconIndex : 0]; - - updatedTaggedAssistant(option); - } + const option = + assistantTagOptions[tabbingIconIndex >= 0 ? tabbingIconIndex : 0]; + + updatedTaggedAssistant(option); } } - if (!showPrompts && !showSuggestions) { + if (!showSuggestions) { return; } @@ -267,10 +223,7 @@ export function ChatInputBar({ e.preventDefault(); setTabbingIconIndex((tabbingIconIndex) => - Math.min( - tabbingIconIndex + 1, - showPrompts ? filteredPrompts.length : assistantTagOptions.length - ) + Math.min(tabbingIconIndex + 1, assistantTagOptions.length) ); } else if (e.key === "ArrowUp") { e.preventDefault(); @@ -331,48 +284,9 @@ export function ChatInputBar({
)} - {showPrompts && ( -
-
- {filteredPrompts.map((currentPrompt, index) => ( - - ))} - - - -

Create a new prompt

-
-
-
- )} - -
+ {/*
-
+
*/} @@ -427,18 +341,24 @@ export function ChatInputBar({
)} + {(selectedDocuments.length > 0 || files.length > 0) && (
-
+
{selectedDocuments.length > 0 && ( )} {files.map((file) => ( @@ -515,7 +435,6 @@ export function ChatInputBar({ onKeyDown={(event) => { if ( event.key === "Enter" && - !showPrompts && !showSuggestions && !event.shiftKey && !(event.nativeEvent as any).isComposing @@ -529,72 +448,6 @@ export function ChatInputBar({ suppressContentEditableWarning={true} />
- ( - { - setSelectedAssistant(assistant); - close(); - }} - /> - )} - flexPriority="shrink" - position="top" - mobilePosition="top-right" - > - - - ( - - )} - position="top" - > - - - + {toggleFilters && ( + + )} + {(filterManager.selectedSources.length > 0 || + filterManager.selectedDocumentSets.length > 0 || + filterManager.selectedTags.length > 0 || + filterManager.timeRange) && + toggleFilters && ( + + )}
diff --git a/web/src/app/chat/input/FilterDisplay.tsx b/web/src/app/chat/input/FilterDisplay.tsx new file mode 100644 index 00000000000..cc20266f9bc --- /dev/null +++ b/web/src/app/chat/input/FilterDisplay.tsx @@ -0,0 +1,109 @@ +import React from "react"; +import { XIcon } from "lucide-react"; + +import { FilterPills } from "./FilterPills"; +import { SourceMetadata } from "@/lib/search/interfaces"; +import { FilterManager } from "@/lib/hooks"; +import { Tag } from "@/lib/types"; + +interface FiltersDisplayProps { + filterManager: FilterManager; + toggleFilters: () => void; +} +export default function FiltersDisplay({ + filterManager, + toggleFilters, +}: FiltersDisplayProps) { + return ( +
+ {(() => { + const allFilters = [ + ...filterManager.selectedSources, + ...filterManager.selectedDocumentSets, + ...filterManager.selectedTags, + ...(filterManager.timeRange ? [filterManager.timeRange] : []), + ]; + const filtersToShow = allFilters.slice(0, 2); + const remainingFilters = allFilters.length - 2; + + return ( + <> + {filtersToShow.map((filter, index) => { + if (typeof filter === "object" && "displayName" in filter) { + return ( + + key={index} + item={filter} + itemToString={(source) => source.displayName} + onRemove={(source) => + filterManager.setSelectedSources((prev) => + prev.filter( + (s) => s.internalName !== source.internalName + ) + ) + } + toggleFilters={toggleFilters} + /> + ); + } else if (typeof filter === "string") { + return ( + + key={index} + item={filter} + itemToString={(set) => set} + onRemove={(set) => + filterManager.setSelectedDocumentSets((prev) => + prev.filter((s) => s !== set) + ) + } + toggleFilters={toggleFilters} + /> + ); + } else if ("tag_key" in filter) { + return ( + + key={index} + item={filter} + itemToString={(tag) => `${tag.tag_key}:${tag.tag_value}`} + onRemove={(tag) => + filterManager.setSelectedTags((prev) => + prev.filter( + (t) => + t.tag_key !== tag.tag_key || + t.tag_value !== tag.tag_value + ) + ) + } + toggleFilters={toggleFilters} + /> + ); + } else if ("from" in filter && "to" in filter) { + return ( +
+ + {filter.from.toLocaleDateString()} -{" "} + {filter.to.toLocaleDateString()} + + filterManager.setTimeRange(null)} + size={16} + className="ml-2 text-text-400 hover:text-text-600 cursor-pointer" + /> +
+ ); + } + })} + {remainingFilters > 0 && ( +
+ +{remainingFilters} more +
+ )} + + ); + })()} +
+ ); +} diff --git a/web/src/app/chat/input/FilterPills.tsx b/web/src/app/chat/input/FilterPills.tsx new file mode 100644 index 00000000000..4212eefaa80 --- /dev/null +++ b/web/src/app/chat/input/FilterPills.tsx @@ -0,0 +1,39 @@ +import React from "react"; +import { XIcon } from "lucide-react"; +import { SourceMetadata } from "@/lib/search/interfaces"; +import { Tag } from "@/lib/types"; + +type FilterItem = SourceMetadata | string | Tag; + +interface FilterPillsProps { + item: T; + itemToString: (item: T) => string; + onRemove: (item: T) => void; + toggleFilters?: () => void; +} + +export function FilterPills({ + item, + itemToString, + onRemove, + toggleFilters, +}: FilterPillsProps) { + return ( + + ); +} diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index a64c605a095..2be44379ed5 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -2,6 +2,7 @@ import { AnswerPiecePacket, DanswerDocument, Filters, + DocumentInfoPacket, StreamStopInfo, } from "@/lib/search/interfaces"; import { handleSSEStream } from "@/lib/search/streamingUtils"; @@ -102,6 +103,7 @@ export type PacketType = | ToolCallMetadata | BackendMessage | AnswerPiecePacket + | DocumentInfoPacket | DocumentsResponse | FileChatDisplay | StreamingError @@ -147,7 +149,6 @@ export async function* sendMessage({ }): AsyncGenerator { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; - const body = JSON.stringify({ alternate_assistant_id: alternateAssistantId, chat_session_id: chatSessionId, @@ -203,7 +204,7 @@ export async function* sendMessage({ yield* handleSSEStream(response); } -export async function nameChatSession(chatSessionId: string, message: string) { +export async function nameChatSession(chatSessionId: string) { const response = await fetch("/api/chat/rename-chat-session", { method: "PUT", headers: { @@ -212,7 +213,6 @@ export async function nameChatSession(chatSessionId: string, message: string) { body: JSON.stringify({ chat_session_id: chatSessionId, name: null, - first_message: message, }), }); return response; @@ -263,7 +263,6 @@ export async function renameChatSession( body: JSON.stringify({ chat_session_id: chatSessionId, name: newName, - first_message: null, }), }); return response; @@ -641,14 +640,15 @@ export async function useScrollonStream({ endDivRef, debounceNumber, mobile, + enableAutoScroll, }: { chatState: ChatState; scrollableDivRef: RefObject; - waitForScrollRef: RefObject; scrollDist: MutableRefObject; endDivRef: RefObject; debounceNumber: number; mobile?: boolean; + enableAutoScroll?: boolean; }) { const mobileDistance = 900; // distance that should "engage" the scroll const desktopDistance = 500; // distance that should "engage" the scroll @@ -661,6 +661,10 @@ export async function useScrollonStream({ const previousScroll = useRef(0); useEffect(() => { + if (!enableAutoScroll) { + return; + } + if (chatState != "input" && scrollableDivRef && scrollableDivRef.current) { const newHeight: number = scrollableDivRef.current?.scrollTop!; const heightDifference = newHeight - previousScroll.current; @@ -718,7 +722,7 @@ export async function useScrollonStream({ // scroll on end of stream if within distance useEffect(() => { - if (scrollableDivRef?.current && chatState == "input") { + if (scrollableDivRef?.current && chatState == "input" && enableAutoScroll) { if (scrollDist.current < distance - 50) { scrollableDivRef?.current?.scrollBy({ left: 0, diff --git a/web/src/app/chat/message/MemoizedTextComponents.tsx b/web/src/app/chat/message/MemoizedTextComponents.tsx index 9ab0e28e3ca..efdce4ca86b 100644 --- a/web/src/app/chat/message/MemoizedTextComponents.tsx +++ b/web/src/app/chat/message/MemoizedTextComponents.tsx @@ -1,8 +1,50 @@ import { Citation } from "@/components/search/results/Citation"; +import { WebResultIcon } from "@/components/WebResultIcon"; +import { LoadedDanswerDocument } from "@/lib/search/interfaces"; +import { getSourceMetadata, SOURCE_METADATA_MAP } from "@/lib/sources"; +import { ValidSources } from "@/lib/types"; import React, { memo } from "react"; +import isEqual from "lodash/isEqual"; +import { SlackIcon } from "@/components/icons/icons"; +import { SourceIcon } from "@/components/SourceIcon"; + +export const MemoizedAnchor = memo( + ({ docs, updatePresentingDocument, children }: any) => { + const value = children?.toString(); + if (value?.startsWith("[") && value?.endsWith("]")) { + const match = value.match(/\[(\d+)\]/); + if (match) { + const index = parseInt(match[1], 10) - 1; + const associatedDoc = docs && docs[index]; + + const url = associatedDoc?.link + ? new URL(associatedDoc.link).origin + "/favicon.ico" + : ""; + + const icon = ( + + ); + + return ( + + {children} + + ); + } + } + return ( + + {children} + + ); + } +); export const MemoizedLink = memo((props: any) => { - const { node, ...rest } = props; + const { node, document, updatePresentingDocument, ...rest } = props; const value = rest.children; if (value?.toString().startsWith("*")) { @@ -10,24 +52,39 @@ export const MemoizedLink = memo((props: any) => {
); } else if (value?.toString().startsWith("[")) { - return {rest.children}; - } else { return ( - - rest.href ? window.open(rest.href, "_blank") : undefined - } - className="cursor-pointer text-link hover:text-link-hover" + {rest.children} - + ); } -}); -export const MemoizedParagraph = memo(({ ...props }: any) => { - return

; + return ( + rest.href && window.open(rest.href, "_blank")} + className="cursor-pointer text-link hover:text-link-hover" + > + {rest.children} + + ); }); +export const MemoizedParagraph = memo( + function MemoizedParagraph({ children }: any) { + return

{children}

; + }, + (prevProps, nextProps) => { + const areEqual = isEqual(prevProps.children, nextProps.children); + return areEqual; + } +); + +MemoizedAnchor.displayName = "MemoizedAnchor"; MemoizedLink.displayName = "MemoizedLink"; MemoizedParagraph.displayName = "MemoizedParagraph"; diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index cc4f9c9cac8..75e583cfc59 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -8,14 +8,24 @@ import { FiGlobe, } from "react-icons/fi"; import { FeedbackType } from "../types"; -import React, { useContext, useEffect, useMemo, useRef, useState } from "react"; +import React, { + memo, + ReactNode, + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from "react"; import ReactMarkdown from "react-markdown"; import { DanswerDocument, FilteredDanswerDocument, + LoadedDanswerDocument, } from "@/lib/search/interfaces"; import { SearchSummary } from "./SearchSummary"; -import { SourceIcon } from "@/components/SourceIcon"; + import { SkippedSearch } from "./SkippedSearch"; import remarkGfm from "remark-gfm"; import { CopyButton } from "@/components/CopyButton"; @@ -36,8 +46,6 @@ import "prismjs/themes/prism-tomorrow.css"; import "./custom-code-styles.css"; import { Persona } from "@/app/admin/assistants/interfaces"; import { AssistantIcon } from "@/components/assistants/AssistantIcon"; -import { Citation } from "@/components/search/results/Citation"; -import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay"; import { LikeFeedback, DislikeFeedback } from "@/components/icons/icons"; import { @@ -52,16 +60,18 @@ import { TooltipTrigger, } from "@/components/ui/tooltip"; import { useMouseTracking } from "./hooks"; -import { InternetSearchIcon } from "@/components/InternetSearchIcon"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import GeneratingImageDisplay from "../tools/GeneratingImageDisplay"; import RegenerateOption from "../RegenerateOption"; import { LlmOverride } from "@/lib/hooks"; import { ContinueGenerating } from "./ContinueMessage"; -import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents"; +import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents"; import { extractCodeText } from "./codeUtils"; import ToolResult from "../../../components/tools/ToolResult"; import CsvContent from "../../../components/tools/CSVContent"; +import SourceCard, { + SeeMoreBlock, +} from "@/components/chat_search/sources/SourceCard"; const TOOLS_WITH_CUSTOM_HANDLING = [ SEARCH_TOOL_NAME, @@ -155,6 +165,7 @@ function FileDisplay({ export const AIMessage = ({ regenerate, overriddenModel, + selectedMessageForDocDisplay, continueGenerating, shared, isActive, @@ -162,6 +173,7 @@ export const AIMessage = ({ alternativeAssistant, docs, messageId, + documentSelectionToggled, content, files, selectedDocuments, @@ -178,7 +190,11 @@ export const AIMessage = ({ currentPersona, otherMessagesCanSwitchTo, onMessageSelection, + setPresentingDocument, + index, }: { + index?: number; + selectedMessageForDocDisplay?: number | null; shared?: boolean; isActive?: boolean; continueGenerating?: () => void; @@ -191,6 +207,7 @@ export const AIMessage = ({ currentPersona: Persona; messageId: number | null; content: string | JSX.Element; + documentSelectionToggled?: boolean; files?: FileDescriptor[]; query?: string; citedDocuments?: [string, DanswerDocument][] | null; @@ -204,6 +221,7 @@ export const AIMessage = ({ retrievalDisabled?: boolean; overriddenModel?: string; regenerate?: (modelOverRide: LlmOverride) => Promise; + setPresentingDocument?: (document: DanswerDocument) => void; }) => { const toolCallGenerating = toolCall && !toolCall.tool_result; const processContent = (content: string | JSX.Element) => { @@ -287,18 +305,36 @@ export const AIMessage = ({ }); } + const paragraphCallback = useCallback( + (props: any) => {props.children}, + [] + ); + + const anchorCallback = useCallback( + (props: any) => ( + + {props.children} + + ), + [docs] + ); + const currentMessageInd = messageId ? otherMessagesCanSwitchTo?.indexOf(messageId) : undefined; + const uniqueSources: ValidSources[] = Array.from( new Set((docs || []).map((doc) => doc.source_type)) ).slice(0, 3); const markdownComponents = useMemo( () => ({ - a: MemoizedLink, - p: MemoizedParagraph, - code: ({ node, className, children, ...props }: any) => { + a: anchorCallback, + p: paragraphCallback, + code: ({ node, className, children }: any) => { const codeText = extractCodeText( node, finalContent as string, @@ -312,7 +348,7 @@ export const AIMessage = ({ ); }, }), - [finalContent] + [anchorCallback, paragraphCallback, finalContent] ); const renderedMarkdown = useMemo(() => { @@ -333,12 +369,11 @@ export const AIMessage = ({ onMessageSelection && otherMessagesCanSwitchTo && otherMessagesCanSwitchTo.length > 1; - return (
)} + {docs && docs.length > 0 && ( +
+
+
+ {!settings?.isMobile && + docs.length > 0 && + docs + .slice(0, 2) + .map((doc, ind) => ( + + ))} + +
+
+
+ )} + {content || files ? ( <> @@ -438,81 +505,6 @@ export const AIMessage = ({ ) : isComplete ? null : ( <> )} - {isComplete && docs && docs.length > 0 && ( -
-
-
- {!settings?.isMobile && - filteredDocs.length > 0 && - filteredDocs.slice(0, 2).map((doc, ind) => ( - - ))} -
{ - if (messageId) { - onMessageSelection?.(messageId); - } - toggleDocumentSelection?.(); - }} - key={-1} - className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b" - > -
-

See context

-
- {uniqueSources.map((sourceType, ind) => { - return ( -
- -
- ); - })} -
-
-
- See more -
-
-
-
-
- )}
{handleFeedback && diff --git a/web/src/app/chat/message/SearchSummary.tsx b/web/src/app/chat/message/SearchSummary.tsx index f86212fd290..7349ec6ca35 100644 --- a/web/src/app/chat/message/SearchSummary.tsx +++ b/web/src/app/chat/message/SearchSummary.tsx @@ -41,6 +41,7 @@ export function ShowHideDocsButton({ } export function SearchSummary({ + index, query, hasDocs, finished, @@ -48,6 +49,7 @@ export function SearchSummary({ handleShowRetrieved, handleSearchQueryEdit, }: { + index: number; finished: boolean; query: string; hasDocs: boolean; @@ -98,7 +100,14 @@ export function SearchSummary({ !text-sm !line-clamp-1 !break-all px-0.5`} ref={searchingForRef} > - {finished ? "Searched" : "Searching"} for: {finalQuery} + {finished ? "Searched" : "Searching"} for:{" "} + + {index === 1 + ? finalQuery.length > 50 + ? `${finalQuery.slice(0, 50)}...` + : finalQuery + : finalQuery} +
); diff --git a/web/src/app/chat/message/SkippedSearch.tsx b/web/src/app/chat/message/SkippedSearch.tsx index 05dc8f2d8e4..27a50d4f6f3 100644 --- a/web/src/app/chat/message/SkippedSearch.tsx +++ b/web/src/app/chat/message/SkippedSearch.tsx @@ -1,26 +1,6 @@ import { EmphasizedClickable } from "@/components/BasicClickable"; import { FiBook } from "react-icons/fi"; -function ForceSearchButton({ - messageId, - handleShowRetrieved, -}: { - messageId: number | null; - isCurrentlyShowingRetrieved: boolean; - handleShowRetrieved: (messageId: number | null) => void; -}) { - return ( -
handleShowRetrieved(messageId)} - > - -
Force Search
-
-
- ); -} - export function SkippedSearch({ handleForceSearch, }: { diff --git a/web/src/app/chat/modal/FeedbackModal.tsx b/web/src/app/chat/modal/FeedbackModal.tsx index 39c3253b76a..e050dcc62af 100644 --- a/web/src/app/chat/modal/FeedbackModal.tsx +++ b/web/src/app/chat/modal/FeedbackModal.tsx @@ -5,15 +5,19 @@ import { FeedbackType } from "../types"; import { Modal } from "@/components/Modal"; import { FilledLikeIcon } from "@/components/icons/icons"; -const predefinedPositiveFeedbackOptions = - process.env.NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS?.split(",") || - []; -const predefinedNegativeFeedbackOptions = - process.env.NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS?.split(",") || [ - "Retrieved documents were not relevant", - "AI misread the documents", - "Cited source had incorrect information", - ]; +const predefinedPositiveFeedbackOptions = process.env + .NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS + ? process.env.NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS.split(",") + : []; + +const predefinedNegativeFeedbackOptions = process.env + .NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS + ? process.env.NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS.split(",") + : [ + "Retrieved documents were not relevant", + "AI misread the documents", + "Cited source had incorrect information", + ]; interface FeedbackModalProps { feedbackType: FeedbackType; @@ -49,7 +53,7 @@ export const FeedbackModal = ({ : predefinedNegativeFeedbackOptions; return ( - + <>

diff --git a/web/src/app/chat/modal/SetDefaultModelModal.tsx b/web/src/app/chat/modal/SetDefaultModelModal.tsx index 27696c46916..47cf55f12a7 100644 --- a/web/src/app/chat/modal/SetDefaultModelModal.tsx +++ b/web/src/app/chat/modal/SetDefaultModelModal.tsx @@ -1,4 +1,4 @@ -import { Dispatch, SetStateAction, useEffect, useRef } from "react"; +import { Dispatch, SetStateAction, useContext, useEffect, useRef } from "react"; import { Modal } from "@/components/Modal"; import Text from "@/components/ui/text"; import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks"; @@ -9,6 +9,10 @@ import { setUserDefaultModel } from "@/lib/users/UserSettings"; import { useRouter } from "next/navigation"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import { useUser } from "@/components/user/UserProvider"; +import { Separator } from "@/components/ui/separator"; +import { Switch } from "@/components/ui/switch"; +import { Label } from "@/components/admin/connectors/Field"; +import { SettingsContext } from "@/components/settings/SettingsProvider"; export function SetDefaultModelModal({ setPopup, @@ -23,7 +27,7 @@ export function SetDefaultModelModal({ onClose: () => void; defaultModel: string | null; }) { - const { refreshUser } = useUser(); + const { refreshUser, user, updateUserAutoScroll } = useUser(); const containerRef = useRef(null); const messageRef = useRef(null); @@ -31,6 +35,13 @@ export function SetDefaultModelModal({ const container = containerRef.current; const message = messageRef.current; + const handleEscape = (e: KeyboardEvent) => { + if (e.key === "Escape") { + onClose(); + } + }; + window.addEventListener("keydown", handleEscape); + if (container && message) { const checkScrollable = () => { if (container.scrollHeight > container.clientHeight) { @@ -41,9 +52,14 @@ export function SetDefaultModelModal({ }; checkScrollable(); window.addEventListener("resize", checkScrollable); - return () => window.removeEventListener("resize", checkScrollable); + return () => { + window.removeEventListener("resize", checkScrollable); + window.removeEventListener("keydown", handleEscape); + }; } - }, []); + + return () => window.removeEventListener("keydown", handleEscape); + }, [onClose]); const defaultModelDestructured = defaultModel ? destructureValue(defaultModel) @@ -121,16 +137,41 @@ export function SetDefaultModelModal({ const defaultProvider = llmProviders.find( (llmProvider) => llmProvider.is_default_provider ); + const settings = useContext(SettingsContext); + const autoScroll = settings?.enterpriseSettings?.auto_scroll; + + const checked = + user?.preferences?.auto_scroll === null + ? autoScroll + : user?.preferences?.auto_scroll; return ( <>

- Set Default Model + User settings

+
+
+ { + updateUserAutoScroll(checked); + }} + /> + +
+
+ + + +

+ Default model for assistants +

+ Choose a Large Language Model (LLM) to serve as the default for assistants that don't have a default model assigned. diff --git a/web/src/app/chat/modal/configuration/LlmTab.tsx b/web/src/app/chat/modal/configuration/LlmTab.tsx index 4e51a21933e..46db83e4e0f 100644 --- a/web/src/app/chat/modal/configuration/LlmTab.tsx +++ b/web/src/app/chat/modal/configuration/LlmTab.tsx @@ -35,25 +35,9 @@ export const LlmTab = forwardRef( checkPersonaRequiresImageGeneration(currentAssistant); const { llmProviders } = useChatContext(); - const { setLlmOverride, temperature, setTemperature } = llmOverrideManager; + const { setLlmOverride, temperature, updateTemperature } = + llmOverrideManager; const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false); - const [localTemperature, setLocalTemperature] = useState( - temperature || 0 - ); - const debouncedSetTemperature = useCallback( - (value: number) => { - const debouncedFunction = debounce((value: number) => { - setTemperature(value); - }, 300); - return debouncedFunction(value); - }, - [setTemperature] - ); - - const handleTemperatureChange = (value: number) => { - setLocalTemperature(value); - debouncedSetTemperature(value); - }; return (
@@ -108,26 +92,26 @@ export const LlmTab = forwardRef( - handleTemperatureChange(parseFloat(e.target.value)) + updateTemperature(parseFloat(e.target.value)) } className="w-full p-2 border border-border rounded-md" min="0" max="2" step="0.01" - value={localTemperature} + value={temperature || 0} />
- {localTemperature} + {temperature}
diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index c58345ff5c5..d2b377bb929 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -17,18 +17,17 @@ export default async function Page(props: { const requestCookies = await cookies(); const data = await fetchChatData(searchParams); - if ("redirect" in data) { redirect(data.redirect); } - + const config = await fetchEEASettings(); - + const { disclaimerTitle, disclaimerText } = config; - + const { user, chatSessions, @@ -41,7 +40,7 @@ export default async function Page(props: { openedFolders, defaultAssistantId, shouldShowWelcomeModal, - userInputPrompts, + ccPairs, } = data; return ( @@ -56,12 +55,14 @@ export default async function Page(props: { value={{ chatSessions, availableSources, + ccPairs, + documentSets, + tags, availableDocumentSets: documentSets, availableTags: tags, llmProviders, folders, openedFolders, - userInputPrompts, shouldShowWelcomeModal, defaultAssistantId, }} diff --git a/web/src/app/chat/sessionSidebar/HistorySidebar.tsx b/web/src/app/chat/sessionSidebar/HistorySidebar.tsx index 7342b9c2819..70cc870bbec 100644 --- a/web/src/app/chat/sessionSidebar/HistorySidebar.tsx +++ b/web/src/app/chat/sessionSidebar/HistorySidebar.tsx @@ -101,7 +101,7 @@ export const HistorySidebar = forwardRef( flex-col relative h-screen transition-transform - pt-2`} + `} > ( {page == "chat" && (
( Manage Assistants

- - -

- Manage Prompts -

-
)}
diff --git a/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx b/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx index e9f03f0934d..a19e8fffcf1 100644 --- a/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx +++ b/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx @@ -17,6 +17,8 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader"; import { Persona } from "@/app/admin/assistants/interfaces"; import { Button } from "@/components/ui/button"; +import { DanswerDocument } from "@/lib/search/interfaces"; +import TextView from "@/components/chat_search/TextView"; function BackToDanswerButton() { const router = useRouter(); @@ -41,6 +43,9 @@ export function SharedChatDisplay({ persona: Persona; }) { const [isReady, setIsReady] = useState(false); + const [presentingDocument, setPresentingDocument] = + useState(null); + useEffect(() => { Prism.highlightAll(); setIsReady(true); @@ -63,61 +68,70 @@ export function SharedChatDisplay({ ); return ( -
-
-
-
-
-

- {chatSession.description || - `Chat ${chatSession.chat_session_id}`} -

-

- {humanReadableFormat(chatSession.time_created)} -

+ <> + {presentingDocument && ( + setPresentingDocument(null)} + /> + )} +
+
+
+
+
+

+ {chatSession.description || + `Chat ${chatSession.chat_session_id}`} +

+

+ {humanReadableFormat(chatSession.time_created)} +

- -
- {isReady ? ( -
- {messages.map((message) => { - if (message.type === "user") { - return ( - - ); - } else { - return ( - - ); - } - })} +
- ) : ( -
-
- + {isReady ? ( +
+ {messages.map((message) => { + if (message.type === "user") { + return ( + + ); + } else { + return ( + + ); + } + })}
-
- )} + ) : ( +
+
+ +
+
+ )} +
-
- -
+ +
+ ); } diff --git a/web/src/app/chat/shared_chat_search/Filters.tsx b/web/src/app/chat/shared_chat_search/Filters.tsx new file mode 100644 index 00000000000..b285423b448 --- /dev/null +++ b/web/src/app/chat/shared_chat_search/Filters.tsx @@ -0,0 +1,635 @@ +import React, { useState } from "react"; +import { DocumentSet, Tag, ValidSources } from "@/lib/types"; +import { SourceMetadata } from "@/lib/search/interfaces"; +import { + GearIcon, + InfoIcon, + MinusIcon, + PlusCircleIcon, + PlusIcon, + defaultTailwindCSS, +} from "@/components/icons/icons"; +import { HoverPopup } from "@/components/HoverPopup"; +import { + FiBook, + FiBookmark, + FiFilter, + FiMap, + FiTag, + FiX, +} from "react-icons/fi"; +import { DateRangeSelector } from "@/components/search/DateRangeSelector"; +import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector"; +import { listSourceMetadata } from "@/lib/sources"; +import { SourceIcon } from "@/components/SourceIcon"; +import { TagFilter } from "@/components/search/filtering/TagFilter"; +import { Calendar } from "@/components/ui/calendar"; +import { Popover, PopoverTrigger } from "@/components/ui/popover"; +import { PopoverContent } from "@radix-ui/react-popover"; +import { CalendarIcon } from "lucide-react"; +import { buildDateString, getTimeAgoString } from "@/lib/dateUtils"; +import { Separator } from "@/components/ui/separator"; +import { FilterDropdown } from "@/components/search/filtering/FilterDropdown"; + +const SectionTitle = ({ children }: { children: string }) => ( +
{children}
+); + +export interface SourceSelectorProps { + timeRange: DateRangePickerValue | null; + setTimeRange: React.Dispatch< + React.SetStateAction + >; + showDocSidebar?: boolean; + selectedSources: SourceMetadata[]; + setSelectedSources: React.Dispatch>; + selectedDocumentSets: string[]; + setSelectedDocumentSets: React.Dispatch>; + selectedTags: Tag[]; + setSelectedTags: React.Dispatch>; + availableDocumentSets: DocumentSet[]; + existingSources: ValidSources[]; + availableTags: Tag[]; + toggleFilters: () => void; + filtersUntoggled: boolean; + tagsOnLeft: boolean; +} + +export function SourceSelector({ + timeRange, + setTimeRange, + selectedSources, + setSelectedSources, + selectedDocumentSets, + setSelectedDocumentSets, + selectedTags, + setSelectedTags, + availableDocumentSets, + existingSources, + availableTags, + showDocSidebar, + toggleFilters, + filtersUntoggled, + tagsOnLeft, +}: SourceSelectorProps) { + const handleSelect = (source: SourceMetadata) => { + setSelectedSources((prev: SourceMetadata[]) => { + if ( + prev.map((source) => source.internalName).includes(source.internalName) + ) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + const handleDocumentSetSelect = (documentSetName: string) => { + setSelectedDocumentSets((prev: string[]) => { + if (prev.includes(documentSetName)) { + return prev.filter((s) => s !== documentSetName); + } else { + return [...prev, documentSetName]; + } + }); + }; + + let allSourcesSelected = selectedSources.length > 0; + + const toggleAllSources = () => { + if (allSourcesSelected) { + setSelectedSources([]); + } else { + const allSources = listSourceMetadata().filter((source) => + existingSources.includes(source.internalName) + ); + setSelectedSources(allSources); + } + }; + + return ( +
+ + {!filtersUntoggled && ( + <> + + + +
+
+ Time Range + {true && ( + + )} +
+

+ {getTimeAgoString(timeRange?.from!) || "Select a time range"} +

+
+
+ + { + const initialDate = daterange?.from || new Date(); + const endDate = daterange?.to || new Date(); + setTimeRange({ + from: initialDate, + to: endDate, + selectValue: timeRange?.selectValue || "", + }); + }} + className="rounded-md " + /> + +
+ + {availableTags.length > 0 && ( + <> +
+ Tags +
+ + + )} + + {existingSources.length > 0 && ( +
+
+
+

Sources

+ +
+
+
+ {listSourceMetadata() + .filter((source) => + existingSources.includes(source.internalName) + ) + .map((source) => ( +
source.internalName) + .includes(source.internalName) + ? "bg-hover" + : "hover:bg-hover-light") + } + onClick={() => handleSelect(source)} + > + + + {source.displayName} + +
+ ))} +
+
+ )} + + {availableDocumentSets.length > 0 && ( + <> +
+ Knowledge Sets +
+
+ {availableDocumentSets.map((documentSet) => ( +
+
handleDocumentSetSelect(documentSet.name)} + > + + +
+ } + popupContent={ +
+
Description
+
+ {documentSet.description} +
+
+ } + classNameModifications="-ml-2" + /> + {documentSet.name} +
+
+ ))} +
+ + )} + + )} +
+ ); +} + +export function SelectedBubble({ + children, + onClick, +}: { + children: string | JSX.Element; + onClick: () => void; +}) { + return ( +
+ {children} + +
+ ); +} + +export function HorizontalFilters({ + timeRange, + setTimeRange, + selectedSources, + setSelectedSources, + selectedDocumentSets, + setSelectedDocumentSets, + availableDocumentSets, + existingSources, +}: SourceSelectorProps) { + const handleSourceSelect = (source: SourceMetadata) => { + setSelectedSources((prev: SourceMetadata[]) => { + const prevSourceNames = prev.map((source) => source.internalName); + if (prevSourceNames.includes(source.internalName)) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + const handleDocumentSetSelect = (documentSetName: string) => { + setSelectedDocumentSets((prev: string[]) => { + if (prev.includes(documentSetName)) { + return prev.filter((s) => s !== documentSetName); + } else { + return [...prev, documentSetName]; + } + }); + }; + + const allSources = listSourceMetadata(); + const availableSources = allSources.filter((source) => + existingSources.includes(source.internalName) + ); + + return ( +
+
+
+ +
+ + { + return { + key: source.displayName, + display: ( + <> + + {source.displayName} + + ), + }; + })} + selected={selectedSources.map((source) => source.displayName)} + handleSelect={(option) => + handleSourceSelect( + allSources.find((source) => source.displayName === option.key)! + ) + } + icon={ +
+ +
+ } + defaultDisplay="All Sources" + /> + + { + return { + key: documentSet.name, + display: ( + <> +
+ +
+ {documentSet.name} + + ), + }; + })} + selected={selectedDocumentSets} + handleSelect={(option) => handleDocumentSetSelect(option.key)} + icon={ +
+ +
+ } + defaultDisplay="All Document Sets" + /> +
+ +
+
+ {timeRange && timeRange.selectValue && ( + setTimeRange(null)}> +
{timeRange.selectValue}
+
+ )} + {existingSources.length > 0 && + selectedSources.map((source) => ( + handleSourceSelect(source)} + > + <> + + {source.displayName} + + + ))} + {selectedDocumentSets.length > 0 && + selectedDocumentSets.map((documentSetName) => ( + handleDocumentSetSelect(documentSetName)} + > + <> +
+ +
+ {documentSetName} + +
+ ))} +
+
+
+ ); +} + +export function HorizontalSourceSelector({ + timeRange, + setTimeRange, + selectedSources, + setSelectedSources, + selectedDocumentSets, + setSelectedDocumentSets, + selectedTags, + setSelectedTags, + availableDocumentSets, + existingSources, + availableTags, +}: SourceSelectorProps) { + const handleSourceSelect = (source: SourceMetadata) => { + setSelectedSources((prev: SourceMetadata[]) => { + if (prev.map((s) => s.internalName).includes(source.internalName)) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + const handleDocumentSetSelect = (documentSetName: string) => { + setSelectedDocumentSets((prev: string[]) => { + if (prev.includes(documentSetName)) { + return prev.filter((s) => s !== documentSetName); + } else { + return [...prev, documentSetName]; + } + }); + }; + + const handleTagSelect = (tag: Tag) => { + setSelectedTags((prev: Tag[]) => { + if ( + prev.some( + (t) => t.tag_key === tag.tag_key && t.tag_value === tag.tag_value + ) + ) { + return prev.filter( + (t) => !(t.tag_key === tag.tag_key && t.tag_value === tag.tag_value) + ); + } else { + return [...prev, tag]; + } + }); + }; + + const resetSources = () => { + setSelectedSources([]); + }; + const resetDocuments = () => { + setSelectedDocumentSets([]); + }; + + const resetTags = () => { + setSelectedTags([]); + }; + + return ( +
+ + +
+ + + {timeRange?.from ? getTimeAgoString(timeRange.from) : "Since"} +
+
+ + { + const initialDate = daterange?.from || new Date(); + const endDate = daterange?.to || new Date(); + setTimeRange({ + from: initialDate, + to: endDate, + selectValue: timeRange?.selectValue || "", + }); + }} + className="rounded-md" + /> + +
+ + {existingSources.length > 0 && ( + existingSources.includes(source.internalName)) + .map((source) => ({ + key: source.internalName, + display: ( + <> + + {source.displayName} + + ), + }))} + selected={selectedSources.map((source) => source.internalName)} + handleSelect={(option) => + handleSourceSelect( + listSourceMetadata().find((s) => s.internalName === option.key)! + ) + } + icon={} + defaultDisplay="Sources" + dropdownColor="bg-background-search-filter-dropdown" + width="w-fit ellipsis truncate" + resetValues={resetSources} + dropdownWidth="w-40" + optionClassName="truncate w-full break-all ellipsis" + /> + )} + + {availableDocumentSets.length > 0 && ( + ({ + key: documentSet.name, + display: <>{documentSet.name}, + }))} + selected={selectedDocumentSets} + handleSelect={(option) => handleDocumentSetSelect(option.key)} + icon={} + defaultDisplay="Sets" + resetValues={resetDocuments} + width="w-fit max-w-24 text-ellipsis truncate" + dropdownColor="bg-background-search-filter-dropdown" + dropdownWidth="max-w-36 w-fit" + optionClassName="truncate w-full break-all" + /> + )} + + {availableTags.length > 0 && ( + ({ + key: `${tag.tag_key}=${tag.tag_value}`, + display: ( + + {tag.tag_key} + = + {tag.tag_value} + + ), + }))} + selected={selectedTags.map( + (tag) => `${tag.tag_key}=${tag.tag_value}` + )} + handleSelect={(option) => { + const [tag_key, tag_value] = option.key.split("="); + const selectedTag = availableTags.find( + (tag) => tag.tag_key === tag_key && tag.tag_value === tag_value + ); + if (selectedTag) { + handleTagSelect(selectedTag); + } + }} + icon={} + defaultDisplay="Tags" + resetValues={resetTags} + dropdownColor="bg-background-search-filter-dropdown" + width="w-fit max-w-24 ellipsis truncate" + dropdownWidth="max-w-80 w-fit" + optionClassName="truncate w-full break-all ellipsis" + /> + )} +
+ ); +} diff --git a/web/src/app/chat/shared_chat_search/FixedLogo.tsx b/web/src/app/chat/shared_chat_search/FixedLogo.tsx index 27385c4f1c9..71947b2fd12 100644 --- a/web/src/app/chat/shared_chat_search/FixedLogo.tsx +++ b/web/src/app/chat/shared_chat_search/FixedLogo.tsx @@ -23,10 +23,8 @@ export default function FixedLogo({ return ( <>
@@ -48,7 +46,7 @@ export default function FixedLogo({
-
+
diff --git a/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx b/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx index e8c377dc57f..8a58c639136 100644 --- a/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx +++ b/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx @@ -1,90 +1,7 @@ "use client"; -import React, { ReactNode, useContext, useEffect, useState } from "react"; -import { usePathname, useRouter } from "next/navigation"; -import { ChatIcon, SearchIcon } from "@/components/icons/icons"; -import { SettingsContext } from "@/components/settings/SettingsProvider"; -import KeyboardSymbol from "@/lib/browserUtilities"; - -const ToggleSwitch = () => { - const commandSymbol = KeyboardSymbol(); - const pathname = usePathname(); - const router = useRouter(); - const settings = useContext(SettingsContext); - - const [activeTab, setActiveTab] = useState(() => { - return pathname == "/search" ? "search" : "chat"; - }); - - const [isInitialLoad, setIsInitialLoad] = useState(true); - - useEffect(() => { - const newTab = pathname === "/search" ? "search" : "chat"; - setActiveTab(newTab); - localStorage.setItem("activeTab", newTab); - setIsInitialLoad(false); - }, [pathname]); - - const handleTabChange = (tab: string) => { - setActiveTab(tab); - localStorage.setItem("activeTab", tab); - if (settings?.isMobile && window) { - window.location.href = tab; - } else { - router.push(tab === "search" ? "/search" : "/chat"); - } - }; - - return ( -
-
- - -
- ); -}; +import React, { ReactNode, useEffect, useState } from "react"; +import { useRouter } from "next/navigation"; export default function FunctionalWrapper({ initiallyToggled, @@ -128,12 +45,6 @@ export default function FunctionalWrapper({ window.removeEventListener("keydown", handleKeyDown); }; }, [router]); - const combinedSettings = useContext(SettingsContext); - const settings = combinedSettings?.settings; - const chatBannerPresent = - combinedSettings?.enterpriseSettings?.custom_header_content; - const twoLines = - combinedSettings?.enterpriseSettings?.two_lines_for_chat_header; const [toggledSidebar, setToggledSidebar] = useState(initiallyToggled); @@ -145,24 +56,7 @@ export default function FunctionalWrapper({ return ( <> - {(!settings || - (settings.search_page_enabled && settings.chat_page_enabled)) && ( -
-
-
- -
-
- )} - + {" "}
{content(toggledSidebar, toggle)}
diff --git a/web/src/app/chat/shared_chat_search/SearchFilters.tsx b/web/src/app/chat/shared_chat_search/SearchFilters.tsx new file mode 100644 index 00000000000..3f3e25c0e2b --- /dev/null +++ b/web/src/app/chat/shared_chat_search/SearchFilters.tsx @@ -0,0 +1,294 @@ +import { DocumentSet, Tag, ValidSources } from "@/lib/types"; +import { SourceMetadata } from "@/lib/search/interfaces"; +import { InfoIcon, defaultTailwindCSS } from "@/components/icons/icons"; +import { HoverPopup } from "@/components/HoverPopup"; +import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector"; +import { SourceIcon } from "@/components/SourceIcon"; +import { Checkbox } from "@/components/ui/checkbox"; +import { TagFilter } from "@/components/search/filtering/TagFilter"; +import { CardContent } from "@/components/ui/card"; +import { useEffect } from "react"; +import { useState } from "react"; +import { listSourceMetadata } from "@/lib/sources"; +import { Calendar } from "@/components/ui/calendar"; +import { getDateRangeString } from "@/lib/dateUtils"; +import { Button } from "@/components/ui/button"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { ToolTipDetails } from "@/components/admin/connectors/Field"; + +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { TooltipProvider } from "@radix-ui/react-tooltip"; + +const SectionTitle = ({ + children, + modal, +}: { + children: string; + modal?: boolean; +}) => ( +
+

{children}

+
+); + +export interface SourceSelectorProps { + timeRange: DateRangePickerValue | null; + setTimeRange: React.Dispatch< + React.SetStateAction + >; + showDocSidebar?: boolean; + selectedSources: SourceMetadata[]; + setSelectedSources: React.Dispatch>; + selectedDocumentSets: string[]; + setSelectedDocumentSets: React.Dispatch>; + selectedTags: Tag[]; + setSelectedTags: React.Dispatch>; + availableDocumentSets: DocumentSet[]; + existingSources: ValidSources[]; + availableTags: Tag[]; + filtersUntoggled: boolean; + modal?: boolean; + tagsOnLeft: boolean; +} + +export function SourceSelector({ + timeRange, + filtersUntoggled, + setTimeRange, + selectedSources, + setSelectedSources, + selectedDocumentSets, + setSelectedDocumentSets, + selectedTags, + setSelectedTags, + availableDocumentSets, + existingSources, + modal, + availableTags, +}: SourceSelectorProps) { + const handleSelect = (source: SourceMetadata) => { + setSelectedSources((prev: SourceMetadata[]) => { + if ( + prev.map((source) => source.internalName).includes(source.internalName) + ) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + const handleDocumentSetSelect = (documentSetName: string) => { + setSelectedDocumentSets((prev: string[]) => { + if (prev.includes(documentSetName)) { + return prev.filter((s) => s !== documentSetName); + } else { + return [...prev, documentSetName]; + } + }); + }; + + let allSourcesSelected = selectedSources.length == existingSources.length; + + const toggleAllSources = () => { + if (allSourcesSelected) { + setSelectedSources([]); + } else { + const allSources = listSourceMetadata().filter((source) => + existingSources.includes(source.internalName) + ); + setSelectedSources(allSources); + } + }; + + const [isCalendarOpen, setIsCalendarOpen] = useState(false); + + useEffect(() => { + const handleClickOutside = (event: MouseEvent) => { + const calendar = document.querySelector(".rdp"); + if (calendar && !calendar.contains(event.target as Node)) { + setIsCalendarOpen(false); + } + }; + + document.addEventListener("mousedown", handleClickOutside); + return () => { + document.removeEventListener("mousedown", handleClickOutside); + }; + }, []); + + return ( +
+ {!filtersUntoggled && ( + +
+
+

Time Range

+ {timeRange && ( + + )} +
+ + + + + + { + const today = new Date(); + const initialDate = daterange?.from + ? new Date( + Math.min(daterange.from.getTime(), today.getTime()) + ) + : today; + const endDate = daterange?.to + ? new Date( + Math.min(daterange.to.getTime(), today.getTime()) + ) + : today; + setTimeRange({ + from: initialDate, + to: endDate, + selectValue: timeRange?.selectValue || "", + }); + }} + className="rounded-md" + /> + + +
+ + {availableTags.length > 0 && ( +
+ Tags + +
+ )} + + {existingSources.length > 0 && ( +
+ Sources + +
+ {existingSources.length > 1 && ( +
+ + + +
+ )} + {listSourceMetadata() + .filter((source) => + existingSources.includes(source.internalName) + ) + .map((source) => ( +
handleSelect(source)} + > + s.internalName) + .includes(source.internalName)} + /> + + {source.displayName} +
+ ))} +
+
+ )} + + {availableDocumentSets.length > 0 && ( +
+ Knowledge Sets +
+ {availableDocumentSets.map((documentSet) => ( +
handleDocumentSetSelect(documentSet.name)} + > + + + + + + + +
+
Description
+
+ {documentSet.description} +
+
+
+
+
+ {documentSet.name} +
+ ))} +
+
+ )} +
+ )} +
+ ); +} diff --git a/web/src/app/chat/types.ts b/web/src/app/chat/types.ts index abbe9b3b84a..c430a604c82 100644 --- a/web/src/app/chat/types.ts +++ b/web/src/app/chat/types.ts @@ -1,5 +1,10 @@ export type FeedbackType = "like" | "dislike"; -export type ChatState = "input" | "loading" | "streaming" | "toolBuilding"; +export type ChatState = + | "input" + | "loading" + | "streaming" + | "toolBuilding" + | "uploading"; export interface RegenerationState { regenerating: boolean; finalMessageIndex: number; diff --git a/web/src/app/connector/oauth/callback/[source]/route.tsx b/web/src/app/connector/oauth/callback/[source]/route.tsx new file mode 100644 index 00000000000..cfaea43d629 --- /dev/null +++ b/web/src/app/connector/oauth/callback/[source]/route.tsx @@ -0,0 +1,43 @@ +import { INTERNAL_URL } from "@/lib/constants"; +import { NextRequest, NextResponse } from "next/server"; + +// TODO: deprecate this and just go directly to the backend via /api/... +// For some reason Egnyte doesn't work when using /api, so leaving this as is for now +// If we do try and remove this, make sure we test the Egnyte connector oauth flow +export async function GET(request: NextRequest) { + try { + const backendUrl = new URL(INTERNAL_URL); + // Copy path and query parameters from incoming request + backendUrl.pathname = request.nextUrl.pathname; + backendUrl.search = request.nextUrl.search; + + const response = await fetch(backendUrl, { + method: "GET", + headers: request.headers, + body: request.body, + signal: request.signal, + // @ts-ignore + duplex: "half", + }); + + const responseData = await response.json(); + if (responseData.redirect_url) { + return NextResponse.redirect(responseData.redirect_url); + } + + return new NextResponse(JSON.stringify(responseData), { + status: response.status, + headers: response.headers, + }); + } catch (error: unknown) { + console.error("Proxy error:", error); + return NextResponse.json( + { + message: "Proxy error", + error: + error instanceof Error ? error.message : "An unknown error occurred", + }, + { status: 500 } + ); + } +} diff --git a/web/src/app/ee/admin/groups/[groupId]/GroupDisplay.tsx b/web/src/app/ee/admin/groups/[groupId]/GroupDisplay.tsx index 29b325a42e2..2200eab2f27 100644 --- a/web/src/app/ee/admin/groups/[groupId]/GroupDisplay.tsx +++ b/web/src/app/ee/admin/groups/[groupId]/GroupDisplay.tsx @@ -133,10 +133,7 @@ export const GroupDisplay = ({ const [addConnectorFormVisible, setAddConnectorFormVisible] = useState(false); const [addRateLimitFormVisible, setAddRateLimitFormVisible] = useState(false); - const { isLoadingUser, isAdmin } = useUser(); - if (isLoadingUser) { - return <>; - } + const { isAdmin } = useUser(); const handlePopup = (message: string, type: "success" | "error") => { setPopup({ message, type }); diff --git a/web/src/app/ee/admin/groups/page.tsx b/web/src/app/ee/admin/groups/page.tsx index 8db128ecc7f..1f3abda92a4 100644 --- a/web/src/app/ee/admin/groups/page.tsx +++ b/web/src/app/ee/admin/groups/page.tsx @@ -35,10 +35,7 @@ const Main = () => { error: usersError, } = useUsers(); - const { isLoadingUser, isAdmin } = useUser(); - if (isLoadingUser) { - return <>; - } + const { isAdmin } = useUser(); if (isLoading || isCCPairsLoading || userIsLoading) { return ; diff --git a/web/src/app/ee/admin/performance/lib.ts b/web/src/app/ee/admin/performance/lib.ts index 0837df1dea0..59042a38766 100644 --- a/web/src/app/ee/admin/performance/lib.ts +++ b/web/src/app/ee/admin/performance/lib.ts @@ -97,3 +97,69 @@ export function getDatesList(startDate: Date): string[] { return datesList; } + +export interface PersonaMessageAnalytics { + total_messages: number; + date: string; + persona_id: number; +} + +export interface PersonaSnapshot { + id: number; + name: string; + description: string; + is_visible: boolean; + is_public: boolean; +} + +export const usePersonaMessages = ( + personaId: number | undefined, + timeRange: DateRangePickerValue +) => { + const url = buildApiPath(`/api/analytics/admin/persona/messages`, { + persona_id: personaId?.toString(), + start: convertDateToStartOfDay(timeRange.from)?.toISOString(), + end: convertDateToEndOfDay(timeRange.to)?.toISOString(), + }); + + const { data, error, isLoading } = useSWR( + personaId !== undefined ? url : null, + errorHandlingFetcher + ); + + return { + data, + error, + isLoading, + refreshPersonaMessages: () => mutate(url), + }; +}; + +export interface PersonaUniqueUserAnalytics { + unique_users: number; + date: string; + persona_id: number; +} + +export const usePersonaUniqueUsers = ( + personaId: number | undefined, + timeRange: DateRangePickerValue +) => { + const url = buildApiPath(`/api/analytics/admin/persona/unique-users`, { + persona_id: personaId?.toString(), + start: convertDateToStartOfDay(timeRange.from)?.toISOString(), + end: convertDateToEndOfDay(timeRange.to)?.toISOString(), + }); + + const { data, error, isLoading } = useSWR( + personaId !== undefined ? url : null, + errorHandlingFetcher + ); + + return { + data, + error, + isLoading, + refreshPersonaUniqueUsers: () => mutate(url), + }; +}; diff --git a/web/src/app/ee/admin/performance/usage/PersonaMessagesChart.tsx b/web/src/app/ee/admin/performance/usage/PersonaMessagesChart.tsx new file mode 100644 index 00000000000..593ab6ba4de --- /dev/null +++ b/web/src/app/ee/admin/performance/usage/PersonaMessagesChart.tsx @@ -0,0 +1,231 @@ +import { ThreeDotsLoader } from "@/components/Loading"; +import { X, Search } from "lucide-react"; +import { + getDatesList, + usePersonaMessages, + usePersonaUniqueUsers, +} from "../lib"; +import { useAssistants } from "@/components/context/AssistantsContext"; +import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector"; +import Text from "@/components/ui/text"; +import Title from "@/components/ui/title"; +import CardSection from "@/components/admin/CardSection"; +import { AreaChartDisplay } from "@/components/ui/areaChart"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { useState, useMemo, useEffect } from "react"; + +export function PersonaMessagesChart({ + timeRange, +}: { + timeRange: DateRangePickerValue; +}) { + const [selectedPersonaId, setSelectedPersonaId] = useState< + number | undefined + >(undefined); + const [searchQuery, setSearchQuery] = useState(""); + const [highlightedIndex, setHighlightedIndex] = useState(-1); + const { allAssistants: personaList } = useAssistants(); + + const { + data: personaMessagesData, + isLoading: isPersonaMessagesLoading, + error: personaMessagesError, + } = usePersonaMessages(selectedPersonaId, timeRange); + + const { + data: personaUniqueUsersData, + isLoading: isPersonaUniqueUsersLoading, + error: personaUniqueUsersError, + } = usePersonaUniqueUsers(selectedPersonaId, timeRange); + + const isLoading = isPersonaMessagesLoading || isPersonaUniqueUsersLoading; + const hasError = personaMessagesError || personaUniqueUsersError; + + const filteredPersonaList = useMemo(() => { + if (!personaList) return []; + return personaList.filter((persona) => + persona.name.toLowerCase().includes(searchQuery.toLowerCase()) + ); + }, [personaList, searchQuery]); + + const handleKeyDown = (e: React.KeyboardEvent) => { + e.stopPropagation(); + + switch (e.key) { + case "ArrowDown": + e.preventDefault(); + setHighlightedIndex((prev) => + prev < filteredPersonaList.length - 1 ? prev + 1 : prev + ); + break; + case "ArrowUp": + e.preventDefault(); + setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : prev)); + break; + case "Enter": + if ( + highlightedIndex >= 0 && + highlightedIndex < filteredPersonaList.length + ) { + setSelectedPersonaId(filteredPersonaList[highlightedIndex].id); + setSearchQuery(""); + setHighlightedIndex(-1); + } + break; + case "Escape": + setSearchQuery(""); + setHighlightedIndex(-1); + break; + } + }; + + // Reset highlight when search query changes + useEffect(() => { + setHighlightedIndex(-1); + }, [searchQuery]); + + const chartData = useMemo(() => { + if ( + !personaMessagesData?.length || + !personaUniqueUsersData?.length || + selectedPersonaId === undefined + ) { + return null; + } + + const initialDate = + timeRange.from || + new Date( + Math.min( + ...personaMessagesData.map((entry) => new Date(entry.date).getTime()) + ) + ); + const dateRange = getDatesList(initialDate); + + // Create maps for messages and unique users data + const messagesMap = new Map( + personaMessagesData.map((entry) => [entry.date, entry]) + ); + const uniqueUsersMap = new Map( + personaUniqueUsersData.map((entry) => [entry.date, entry]) + ); + + return dateRange.map((dateStr) => { + const messageData = messagesMap.get(dateStr); + const uniqueUserData = uniqueUsersMap.get(dateStr); + return { + Day: dateStr, + Messages: messageData?.total_messages || 0, + "Unique Users": uniqueUserData?.unique_users || 0, + }; + }); + }, [ + personaMessagesData, + personaUniqueUsersData, + timeRange.from, + selectedPersonaId, + ]); + + let content; + if (isLoading) { + content = ( +
+ +
+ ); + } else if (!personaList || hasError) { + content = ( +
+

Failed to fetch data...

+
+ ); + } else if (selectedPersonaId === undefined) { + content = ( +
+

Select a persona to view analytics

+
+ ); + } else if (!personaMessagesData?.length) { + content = ( +
+

+ No data found for selected persona in the selected time range +

+
+ ); + } else if (chartData) { + content = ( + + ); + } + + const selectedPersona = personaList?.find((p) => p.id === selectedPersonaId); + + return ( + + Persona Analytics +
+ Messages and unique users per day for selected persona +
+ setSearchQuery(e.target.value)} + onClick={(e) => e.stopPropagation()} + onMouseDown={(e) => e.stopPropagation()} + onKeyDown={handleKeyDown} + /> + {searchQuery && ( + { + setSearchQuery(""); + setHighlightedIndex(-1); + }} + /> + )} +
+ {filteredPersonaList.map((persona, index) => ( + setHighlightedIndex(index)} + > + {persona.name} + + ))} + + +
+
+ {content} + + ); +} diff --git a/web/src/app/ee/admin/performance/usage/QueryPerformanceChart.tsx b/web/src/app/ee/admin/performance/usage/QueryPerformanceChart.tsx index ffbf4e8c93e..f9ed3f7986d 100644 --- a/web/src/app/ee/admin/performance/usage/QueryPerformanceChart.tsx +++ b/web/src/app/ee/admin/performance/usage/QueryPerformanceChart.tsx @@ -62,6 +62,7 @@ export function QueryPerformanceChart({ chart = ( { const queryAnalyticsForDate = dateToQueryAnalytics.get(dateStr); const userAnalyticsForDate = dateToUserAnalytics.get(dateStr); diff --git a/web/src/app/ee/admin/performance/usage/page.tsx b/web/src/app/ee/admin/performance/usage/page.tsx index e1fffc323a2..967f16a377e 100644 --- a/web/src/app/ee/admin/performance/usage/page.tsx +++ b/web/src/app/ee/admin/performance/usage/page.tsx @@ -4,6 +4,7 @@ import { DateRangeSelector } from "../DateRangeSelector"; import { DanswerBotChart } from "./DanswerBotChart"; import { FeedbackChart } from "./FeedbackChart"; import { QueryPerformanceChart } from "./QueryPerformanceChart"; +import { PersonaMessagesChart } from "./PersonaMessagesChart"; import { useTimeRange } from "../lib"; import { AdminPageTitle } from "@/components/admin/Title"; import { FiActivity } from "react-icons/fi"; @@ -26,6 +27,7 @@ export default function AnalyticsPage() { + diff --git a/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx b/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx index 475c689441a..cd977d44c1c 100644 --- a/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx +++ b/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx @@ -55,6 +55,7 @@ export function WhitelabelingForm() {
; -}) { - const searchParams = await props.searchParams; - noStore(); - - const data = await fetchChatData(searchParams); - - if ("redirect" in data) { - redirect(data.redirect); - } - - const { chatSessions, folders, openedFolders, toggleSidebar } = data; - - return ( - - ); -} diff --git a/web/src/app/search/WrappedSearch.tsx b/web/src/app/search/WrappedSearch.tsx deleted file mode 100644 index 91dad5d3866..00000000000 --- a/web/src/app/search/WrappedSearch.tsx +++ /dev/null @@ -1,24 +0,0 @@ -"use client"; -import { SearchSection } from "@/components/search/SearchSection"; -import FunctionalWrapper from "../chat/shared_chat_search/FunctionalWrapper"; - -export default function WrappedSearch({ - searchTypeDefault, - initiallyToggled, -}: { - searchTypeDefault: string; - initiallyToggled: boolean; -}) { - return ( - ( - - )} - /> - ); -} diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx deleted file mode 100644 index 008776b2f01..00000000000 --- a/web/src/app/search/page.tsx +++ /dev/null @@ -1,225 +0,0 @@ -import { - AuthTypeMetadata, - getAuthTypeMetadataSS, - getCurrentUserSS, -} from "@/lib/userSS"; -import { redirect } from "next/navigation"; -import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { fetchSS } from "@/lib/utilsSS"; -import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types"; -import { cookies } from "next/headers"; -import { SearchType } from "@/lib/search/interfaces"; -import { Persona } from "../admin/assistants/interfaces"; -import { unstable_noStore as noStore } from "next/cache"; -import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; -import { personaComparator } from "../admin/assistants/lib"; -import { FullEmbeddingModelResponse } from "@/components/embedding/interfaces"; -import { UserDisclaimerModal } from "@/components/search/UserDisclaimerModal"; -import { fetchEEASettings } from "@/lib/eea/fetchEEASettings"; - -import { ChatPopup } from "../chat/ChatPopup"; -import { - FetchAssistantsResponse, - fetchAssistantsSS, -} from "@/lib/assistants/fetchAssistantsSS"; -import { ChatSession } from "../chat/interfaces"; -import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants"; -import { - AGENTIC_SEARCH_TYPE_COOKIE_NAME, - NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN, - DISABLE_LLM_DOC_RELEVANCE, -} from "@/lib/constants"; -import WrappedSearch from "./WrappedSearch"; -import { SearchProvider } from "@/components/context/SearchContext"; -import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs"; -import { LLMProviderDescriptor } from "../admin/configuration/llm/interfaces"; -import { headers } from "next/headers"; -import { - hasCompletedWelcomeFlowSS, - WelcomeModal, -} from "@/components/initialSetup/welcome/WelcomeModalWrapper"; - -export default async function Home(props: { - searchParams: Promise<{ [key: string]: string | string[] | undefined }>; -}) { - const searchParams = await props.searchParams; - // Disable caching so we always get the up to date connector / document set / persona info - // importantly, this prevents users from adding a connector, going back to the main page, - // and then getting hit with a "No Connectors" popup - noStore(); - const requestCookies = await cookies(); - const tasks = [ - getAuthTypeMetadataSS(), - getCurrentUserSS(), - fetchSS("/manage/connector-status"), - fetchSS("/manage/document-set"), - fetchAssistantsSS(), - fetchSS("/query/valid-tags"), - fetchSS("/query/user-searches"), - fetchLLMProvidersSS(), - ]; - - // catch cases where the backend is completely unreachable here - // without try / catch, will just raise an exception and the page - // will not render - let results: ( - | User - | Response - | AuthTypeMetadata - | FullEmbeddingModelResponse - | FetchAssistantsResponse - | LLMProviderDescriptor[] - | null - )[] = [null, null, null, null, null, null, null, null]; - try { - results = await Promise.all(tasks); - } catch (e) { - console.log(`Some fetch failed for the main search page - ${e}`); - } - const authTypeMetadata = results[0] as AuthTypeMetadata | null; - const user = results[1] as User | null; - const ccPairsResponse = results[2] as Response | null; - const documentSetsResponse = results[3] as Response | null; - const [initialAssistantsList, assistantsFetchError] = - results[4] as FetchAssistantsResponse; - const tagsResponse = results[5] as Response | null; - const queryResponse = results[6] as Response | null; - const llmProviders = (results[7] || []) as LLMProviderDescriptor[]; - - const config = await fetchEEASettings(); - - const { - disclaimerTitle, - disclaimerText - } = config; - - const authDisabled = authTypeMetadata?.authType === "disabled"; - - if (!authDisabled && !user) { - const headersList = await headers(); - const fullUrl = headersList.get("x-url") || "/search"; - const searchParamsString = new URLSearchParams( - searchParams as unknown as Record - ).toString(); - const redirectUrl = searchParamsString - ? `${fullUrl}?${searchParamsString}` - : fullUrl; - return redirect(`/auth/login?next=${encodeURIComponent(redirectUrl)}`); - } - - if (user && !user.is_verified && authTypeMetadata?.requiresVerification) { - return redirect("/auth/waiting-on-verification"); - } - - let ccPairs: CCPairBasicInfo[] = []; - if (ccPairsResponse?.ok) { - ccPairs = await ccPairsResponse.json(); - } else { - console.log(`Failed to fetch connectors - ${ccPairsResponse?.status}`); - } - - let documentSets: DocumentSet[] = []; - if (documentSetsResponse?.ok) { - documentSets = await documentSetsResponse.json(); - } else { - console.log( - `Failed to fetch document sets - ${documentSetsResponse?.status}` - ); - } - - let querySessions: ChatSession[] = []; - if (queryResponse?.ok) { - querySessions = (await queryResponse.json()).sessions; - } else { - console.log(`Failed to fetch chat sessions - ${queryResponse?.text()}`); - } - - let assistants: Persona[] = initialAssistantsList; - if (assistantsFetchError) { - console.log(`Failed to fetch assistants - ${assistantsFetchError}`); - } else { - // remove those marked as hidden by an admin - assistants = assistants.filter((assistant) => assistant.is_visible); - // hide personas with no retrieval - assistants = assistants.filter((assistant) => assistant.num_chunks !== 0); - // sort them in priority order - assistants.sort(personaComparator); - } - - let tags: Tag[] = []; - if (tagsResponse?.ok) { - tags = (await tagsResponse.json()).tags; - } else { - console.log(`Failed to fetch tags - ${tagsResponse?.status}`); - } - - // needs to be done in a non-client side component due to nextjs - const storedSearchType = requestCookies.get("searchType")?.value as - | string - | undefined; - const searchTypeDefault: SearchType = - storedSearchType !== undefined && - SearchType.hasOwnProperty(storedSearchType) - ? (storedSearchType as SearchType) - : SearchType.SEMANTIC; // default to semantic - - const hasAnyConnectors = ccPairs.length > 0; - - const shouldShowWelcomeModal = - !llmProviders.length && - !hasCompletedWelcomeFlowSS(requestCookies) && - !hasAnyConnectors && - (!user || user.role === "admin"); - - const shouldDisplayNoSourcesModal = - (!user || user.role === "admin") && - ccPairs.length === 0 && - !shouldShowWelcomeModal; - - const sidebarToggled = requestCookies.get(SIDEBAR_TOGGLED_COOKIE_NAME); - const agenticSearchToggle = requestCookies.get( - AGENTIC_SEARCH_TYPE_COOKIE_NAME - ); - - const toggleSidebar = sidebarToggled - ? sidebarToggled.value.toLocaleLowerCase() == "true" || false - : NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN; - - const agenticSearchEnabled = agenticSearchToggle - ? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false - : false; - - return ( - <> - - - {shouldShowWelcomeModal && ( - - )} - {/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit. - Only used in the EE version of the app. */} - - - - - - - - ); -} diff --git a/web/src/components/Dropdown.tsx b/web/src/components/Dropdown.tsx index 109b0196f0f..affdee4799b 100644 --- a/web/src/components/Dropdown.tsx +++ b/web/src/components/Dropdown.tsx @@ -10,6 +10,8 @@ import { import { ChevronDownIcon } from "./icons/icons"; import { FiCheck, FiChevronDown } from "react-icons/fi"; import { Popover } from "./popover/Popover"; +import { createPortal } from "react-dom"; +import { useDropdownPosition } from "@/lib/dropdown"; export interface Option { name: string; @@ -60,6 +62,7 @@ export function SearchMultiSelectDropdown({ const [isOpen, setIsOpen] = useState(false); const [searchTerm, setSearchTerm] = useState(""); const dropdownRef = useRef(null); + const dropdownMenuRef = useRef(null); const handleSelect = (option: StringOrNumberOption) => { onSelect(option); @@ -75,7 +78,9 @@ export function SearchMultiSelectDropdown({ const handleClickOutside = (event: MouseEvent) => { if ( dropdownRef.current && - !dropdownRef.current.contains(event.target as Node) + !dropdownRef.current.contains(event.target as Node) && + dropdownMenuRef.current && + !dropdownMenuRef.current.contains(event.target as Node) ) { setIsOpen(false); } @@ -87,105 +92,103 @@ export function SearchMultiSelectDropdown({ }; }, []); + useDropdownPosition({ isOpen, dropdownRef, dropdownMenuRef }); + return ( -
+
) => { - if (!searchTerm) { + setSearchTerm(e.target.value); + if (e.target.value) { setIsOpen(true); - } - if (!e.target.value) { + } else { setIsOpen(false); } - setSearchTerm(e.target.value); }} onFocus={() => setIsOpen(true)} className={`inline-flex - justify-between - w-full - px-4 - py-2 - text-sm - bg-background - border - border-border - rounded-md - shadow-sm - `} - onClick={(e) => e.stopPropagation()} + justify-between + w-full + px-4 + py-2 + text-sm + bg-background + border + border-border + rounded-md + shadow-sm + `} />
- {isOpen && ( -
+ {isOpen && + createPortal(
- {filteredOptions.length ? ( - filteredOptions.map((option, index) => - itemComponent ? ( -
{ - setIsOpen(false); - handleSelect(option); - }} - > - {itemComponent({ option })} -
- ) : ( - +
+ {filteredOptions.length ? ( + filteredOptions.map((option, index) => + itemComponent ? ( +
{ + handleSelect(option); + }} + > + {itemComponent({ option })} +
+ ) : ( + + ) ) - ) - ) : ( - - )} -
-
- )} + ) : ( + + )} +
+
, + document.body + )}
); } diff --git a/web/src/components/InternetSearchIcon.tsx b/web/src/components/InternetSearchIcon.tsx deleted file mode 100644 index e21218da9c5..00000000000 --- a/web/src/components/InternetSearchIcon.tsx +++ /dev/null @@ -1,9 +0,0 @@ -export function InternetSearchIcon({ url }: { url: string }) { - return ( - favicon - ); -} diff --git a/web/src/components/IsPublicGroupSelector.tsx b/web/src/components/IsPublicGroupSelector.tsx index 6cb953f5bdf..4db2c36860c 100644 --- a/web/src/components/IsPublicGroupSelector.tsx +++ b/web/src/components/IsPublicGroupSelector.tsx @@ -30,7 +30,7 @@ export const IsPublicGroupSelector = ({ enforceGroupSelection?: boolean; }) => { const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); - const { isAdmin, user, isLoadingUser, isCurator } = useUser(); + const { isAdmin, user, isCurator } = useUser(); const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); const [shouldHideContent, setShouldHideContent] = useState(false); @@ -52,7 +52,7 @@ export const IsPublicGroupSelector = ({ } }, [user, userGroups, isPaidEnterpriseFeaturesEnabled]); - if (isLoadingUser || userGroupsIsLoading) { + if (userGroupsIsLoading) { return
Loading...
; } if (!isPaidEnterpriseFeaturesEnabled) { diff --git a/web/src/components/MetadataBadge.tsx b/web/src/components/MetadataBadge.tsx index cfd94d0a879..f06429a92b1 100644 --- a/web/src/components/MetadataBadge.tsx +++ b/web/src/components/MetadataBadge.tsx @@ -1,9 +1,11 @@ export function MetadataBadge({ icon, value, + flexNone, }: { icon?: React.FC<{ size?: number; className?: string }>; value: string | JSX.Element; + flexNone?: boolean; }) { return (
- {icon && icon({ size: 12, className: "mr-0.5 my-auto" })} + {icon && + icon({ + size: 12, + className: flexNone ? "flex-none" : "mr-0.5 my-auto", + })}
{value}
); diff --git a/web/src/components/Modal.tsx b/web/src/components/Modal.tsx index 4582ed8a558..85871baa6c2 100644 --- a/web/src/components/Modal.tsx +++ b/web/src/components/Modal.tsx @@ -1,11 +1,11 @@ "use client"; import { Separator } from "@/components/ui/separator"; -import { FiX } from "react-icons/fi"; import { IconProps, XIcon } from "./icons/icons"; import { useRef } from "react"; import { isEventWithinRef } from "@/lib/contains"; import ReactDOM from "react-dom"; import { useEffect, useState } from "react"; +import { cn } from "@/lib/utils"; interface ModalProps { icon?: ({ size, className }: IconProps) => JSX.Element; @@ -18,6 +18,8 @@ interface ModalProps { hideDividerForTitle?: boolean; hideCloseButton?: boolean; noPadding?: boolean; + height?: string; + noScroll?: boolean; } export function Modal({ @@ -28,9 +30,11 @@ export function Modal({ width, titleSize, hideDividerForTitle, + height, noPadding, icon, hideCloseButton, + noScroll, }: ModalProps) { const modalRef = useRef(null); const [isMounted, setIsMounted] = useState(false); @@ -56,8 +60,10 @@ export function Modal({ const modalContent = (
{onOutsideClick && !hideCloseButton && (
@@ -83,8 +99,7 @@ export function Modal({
)} - -
+
{title && ( <>
@@ -100,7 +115,14 @@ export function Modal({ {!hideDividerForTitle && } )} -
{children}
+
+ {children} +
diff --git a/web/src/components/SearchResultIcon.tsx b/web/src/components/SearchResultIcon.tsx new file mode 100644 index 00000000000..30571390be5 --- /dev/null +++ b/web/src/components/SearchResultIcon.tsx @@ -0,0 +1,66 @@ +import { useState, useEffect } from "react"; +import faviconFetch from "favicon-fetch"; +import { SourceIcon } from "./SourceIcon"; +import { ValidSources } from "@/lib/types"; + +const CACHE_DURATION = 24 * 60 * 60 * 1000; + +export async function getFaviconUrl(url: string): Promise { + const getCachedFavicon = () => { + const cachedData = localStorage.getItem(`favicon_${url}`); + if (cachedData) { + const { favicon, timestamp } = JSON.parse(cachedData); + if (Date.now() - timestamp < CACHE_DURATION) { + return favicon; + } + } + return null; + }; + + const cachedFavicon = getCachedFavicon(); + if (cachedFavicon) { + return cachedFavicon; + } + + const newFaviconUrl = await faviconFetch({ uri: url }); + if (newFaviconUrl) { + localStorage.setItem( + `favicon_${url}`, + JSON.stringify({ favicon: newFaviconUrl, timestamp: Date.now() }) + ); + return newFaviconUrl; + } + + return null; +} + +export function SearchResultIcon({ url }: { url: string }) { + const [faviconUrl, setFaviconUrl] = useState(null); + + useEffect(() => { + getFaviconUrl(url).then((favicon) => { + if (favicon) { + setFaviconUrl(favicon); + } + }); + }, [url]); + + if (!faviconUrl) { + return ; + } + + return ( +
+ favicon { + e.currentTarget.onerror = null; + }} + /> +
+ ); +} diff --git a/web/src/components/TemporaryLoadingModal.tsx b/web/src/components/TemporaryLoadingModal.tsx new file mode 100644 index 00000000000..6f45aac4691 --- /dev/null +++ b/web/src/components/TemporaryLoadingModal.tsx @@ -0,0 +1,14 @@ +export default function TemporaryLoadingModal({ + content, +}: { + content: string; +}) { + return ( +
+
+
+

{content}

+
+
+ ); +} diff --git a/web/src/components/UserDropdown.tsx b/web/src/components/UserDropdown.tsx index c43938fff02..fe525c20082 100644 --- a/web/src/components/UserDropdown.tsx +++ b/web/src/components/UserDropdown.tsx @@ -10,7 +10,7 @@ import { Popover } from "./popover/Popover"; import { LOGOUT_DISABLED } from "@/lib/constants"; import { SettingsContext } from "./settings/SettingsProvider"; import { FileIcon } from "./icons/icons"; -import { BellIcon, LightSettingsIcon } from "./icons/icons"; +import { BellIcon, LightSettingsIcon, UserIcon } from "./icons/icons"; import { pageType } from "@/app/chat/sessionSidebar/types"; import { NavigationItem, Notification } from "@/app/admin/settings/interfaces"; import DynamicFaIcon, { preloadIcons } from "./icons/DynamicFaIcon"; @@ -57,8 +57,13 @@ const DropdownOption: React.FC = ({ } }; -export function UserDropdown({ page }: { page?: pageType }) { - +export function UserDropdown({ + page, + toggleUserSettings, +}: { + page?: pageType; + toggleUserSettings?: () => void; +}) { const { user, isCurator } = useUser(); const [userInfoVisible, setUserInfoVisible] = useState(false); const userInfoRef = useRef(null); @@ -242,6 +247,13 @@ export function UserDropdown({ page }: { page?: pageType }) { ) )} + {toggleUserSettings && ( + } + label="User Settings" + /> + )} { setUserInfoVisible(true); diff --git a/web/src/components/WebResultIcon.tsx b/web/src/components/WebResultIcon.tsx new file mode 100644 index 00000000000..09475f2a22d --- /dev/null +++ b/web/src/components/WebResultIcon.tsx @@ -0,0 +1,17 @@ +import { ValidSources } from "@/lib/types"; +import { SourceIcon } from "./SourceIcon"; + +export function WebResultIcon({ url }: { url: string }) { + const hostname = new URL(url).hostname; + return hostname == "https://docs.danswer.dev" ? ( + favicon + ) : ( + + ); +} diff --git a/web/src/components/admin/ClientLayout.tsx b/web/src/components/admin/ClientLayout.tsx index b4f6638d990..f4590f7321e 100644 --- a/web/src/components/admin/ClientLayout.tsx +++ b/web/src/components/admin/ClientLayout.tsx @@ -58,7 +58,7 @@ export function ClientLayout({ return (
-
+
- -
Prompt Library
-
- ), - link: "/admin/prompt-library", - }, ] : []), ...(enableEnterprise @@ -449,7 +437,7 @@ export function ClientLayout({ />
-
+
diff --git a/web/src/components/admin/connectors/AccessTypeForm.tsx b/web/src/components/admin/connectors/AccessTypeForm.tsx index 5950cb6fdc0..9868d9bbdbb 100644 --- a/web/src/components/admin/connectors/AccessTypeForm.tsx +++ b/web/src/components/admin/connectors/AccessTypeForm.tsx @@ -1,7 +1,7 @@ import { DefaultDropdown } from "@/components/Dropdown"; import { AccessType, - ValidAutoSyncSources, + ValidAutoSyncSource, ConfigurableSources, validAutoSyncSources, } from "@/lib/types"; @@ -13,8 +13,8 @@ import { useEffect } from "react"; function isValidAutoSyncSource( value: ConfigurableSources -): value is ValidAutoSyncSources { - return validAutoSyncSources.includes(value as ValidAutoSyncSources); +): value is ValidAutoSyncSource { + return validAutoSyncSources.includes(value as ValidAutoSyncSource); } export function AccessTypeForm({ @@ -27,7 +27,7 @@ export function AccessTypeForm({ const isPaidEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled(); const isAutoSyncSupported = isValidAutoSyncSource(connector); - const { isLoadingUser, isAdmin } = useUser(); + const { isAdmin } = useUser(); useEffect(() => { if (!isPaidEnterpriseEnabled) { @@ -49,7 +49,7 @@ export function AccessTypeForm({ name: "Private", value: "private", description: - "Only users who have expliticly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector", + "Only users who have explicitly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector", }, ]; @@ -92,9 +92,7 @@ export function AccessTypeForm({ /> {access_type.value === "sync" && isAutoSyncSupported && ( - + )} )} diff --git a/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx b/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx index 7fbc0b3de1a..3bd16a80432 100644 --- a/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx +++ b/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx @@ -8,7 +8,7 @@ import { UserGroup, UserRole } from "@/lib/types"; import { useUserGroups } from "@/lib/hooks"; import { AccessType, - ValidAutoSyncSources, + ValidAutoSyncSource, ConfigurableSources, validAutoSyncSources, } from "@/lib/types"; @@ -16,8 +16,8 @@ import { useUser } from "@/components/user/UserProvider"; function isValidAutoSyncSource( value: ConfigurableSources -): value is ValidAutoSyncSources { - return validAutoSyncSources.includes(value as ValidAutoSyncSources); +): value is ValidAutoSyncSource { + return validAutoSyncSources.includes(value as ValidAutoSyncSource); } // This should be included for all forms that require groups / public access @@ -34,7 +34,7 @@ export function AccessTypeGroupSelector({ connector: ConfigurableSources; }) { const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); - const { isAdmin, user, isLoadingUser, isCurator } = useUser(); + const { isAdmin, user, isCurator } = useUser(); const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); const [shouldHideContent, setShouldHideContent] = useState(false); const isAutoSyncSupported = isValidAutoSyncSource(connector); @@ -77,7 +77,7 @@ export function AccessTypeGroupSelector({ isPaidEnterpriseFeaturesEnabled, ]); - if (isLoadingUser || userGroupsIsLoading) { + if (userGroupsIsLoading) { return
Loading...
; } if (!isPaidEnterpriseFeaturesEnabled) { diff --git a/web/src/components/admin/connectors/AdminSidebar.tsx b/web/src/components/admin/connectors/AdminSidebar.tsx index 9785b1d8666..ec215a4147f 100644 --- a/web/src/components/admin/connectors/AdminSidebar.tsx +++ b/web/src/components/admin/connectors/AdminSidebar.tsx @@ -6,7 +6,6 @@ import { Logo } from "@/components/EEA_Logo"; import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants"; import { HeaderTitle } from "@/components/header/HeaderTitle"; import { SettingsContext } from "@/components/settings/SettingsProvider"; -import { BackIcon } from "@/components/icons/icons"; import { WarningCircle, WarningDiamond } from "@phosphor-icons/react"; import { Tooltip, @@ -14,6 +13,7 @@ import { TooltipProvider, TooltipTrigger, } from "@/components/ui/tooltip"; +import { CgArrowsExpandUpLeft } from "react-icons/cg"; interface Item { name: string | JSX.Element; @@ -38,47 +38,33 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) { return (