From 415b07afc819aa1e0b277e70c41fa595ac00525a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Tue, 26 Nov 2024 17:28:18 +0100 Subject: [PATCH] deepgram: add base_url parameter (#1137) Co-authored-by: Michael Louis Co-authored-by: milo157 <43028253+milo157@users.noreply.github.com> --- .changeset/olive-kangaroos-yawn.md | 5 +++++ .../livekit/plugins/deepgram/stt.py | 22 ++++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) create mode 100644 .changeset/olive-kangaroos-yawn.md diff --git a/.changeset/olive-kangaroos-yawn.md b/.changeset/olive-kangaroos-yawn.md new file mode 100644 index 000000000..553940297 --- /dev/null +++ b/.changeset/olive-kangaroos-yawn.md @@ -0,0 +1,5 @@ +--- +"livekit-plugins-deepgram": patch +--- + +Added support for custom deepgram base url diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 4c4d46cc5..2295f1e74 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -42,7 +42,6 @@ from .models import DeepgramLanguages, DeepgramModels BASE_URL = "https://api.deepgram.com/v1/listen" -BASE_URL_WS = "wss://api.deepgram.com/v1/listen" # This is the magic number during testing that we use to determine if a frame is loud enough @@ -126,6 +125,7 @@ def __init__( profanity_filter: bool = False, api_key: str | None = None, http_session: aiohttp.ClientSession | None = None, + base_url: str = BASE_URL, energy_filter: AudioEnergyFilter | bool = False, ) -> None: """ @@ -140,6 +140,7 @@ def __init__( streaming=True, interim_results=interim_results ) ) + self._base_url = base_url api_key = api_key or os.environ.get("DEEPGRAM_API_KEY") if api_key is None: @@ -209,7 +210,7 @@ async def _recognize_impl( try: async with self._ensure_session().post( - url=_to_deepgram_url(recognize_config), + url=_to_deepgram_url(recognize_config, self._base_url, websocket=False), data=rtc.combine_audio_frames(buffer).to_wav_bytes(), headers={ "Authorization": f"Token {self._api_key}", @@ -251,6 +252,7 @@ def stream( opts=config, api_key=self._api_key, http_session=self._ensure_session(), + base_url=self._base_url, ) def _sanitize_options(self, *, language: str | None = None) -> STTOptions: @@ -276,6 +278,7 @@ def __init__( conn_options: APIConnectOptions, api_key: str, http_session: aiohttp.ClientSession, + base_url: str, ) -> None: super().__init__( stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate @@ -287,6 +290,7 @@ def __init__( self._opts = opts self._api_key = api_key self._session = http_session + self._base_url = base_url self._speaking = False self._audio_duration_collector = PeriodicCollector( callback=self._on_audio_duration_report, @@ -419,7 +423,9 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse): ws = await asyncio.wait_for( self._session.ws_connect( - _to_deepgram_url(live_config, websocket=True), + _to_deepgram_url( + live_config, base_url=self._base_url, websocket=True + ), headers={"Authorization": f"Token {self._api_key}"}, ), self._conn_options.timeout, @@ -566,7 +572,7 @@ def prerecorded_transcription_to_speech_event( ) -def _to_deepgram_url(opts: dict, *, websocket: bool = False) -> str: +def _to_deepgram_url(opts: dict, base_url: str, *, websocket: bool) -> str: if opts.get("keywords"): # convert keywords to a list of "keyword:intensifier" opts["keywords"] = [ @@ -575,5 +581,11 @@ def _to_deepgram_url(opts: dict, *, websocket: bool = False) -> str: # lowercase bools opts = {k: str(v).lower() if isinstance(v, bool) else v for k, v in opts.items()} - base_url = BASE_URL_WS if websocket else BASE_URL + + if websocket and base_url.startswith("http"): + base_url = base_url.replace("http", "ws", 1) + + elif not websocket and base_url.startswith("ws"): + base_url = base_url.replace("ws", "http", 1) + return f"{base_url}?{urlencode(opts, doseq=True)}"