diff --git a/CHANGES.md b/CHANGES.md index 86082e3..14f08bc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,13 @@ ## develop +- [FIX] サービスへの接続が成功してもリトライカウントがリセットされない不具合を修正する + - @Hexa +- [FIX] 解析結果だけでなくエラーメッセージの送信時にもリトライカウントをリセットしていたため、リトライ処理によってカウントがリセットされていた不具合を修正する + - @Hexa +- [FIX] リトライ待ち時にクライアントから切断しようとすると、リトライ待ちで処理がブロックされているため切断までに時間がかかる不具合を修正する + - @Hexa + ### misc ## 2024.6.0 diff --git a/amazon_transcribe.go b/amazon_transcribe.go index ce61cbc..fa8c67d 100644 --- a/amazon_transcribe.go +++ b/amazon_transcribe.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/transcribestreamingservice" + zlog "github.com/rs/zerolog/log" ) type AmazonTranscribe struct { @@ -89,7 +90,7 @@ func NewAmazonTranscribeClient(config Config) *transcribestreamingservice.Transc return transcribestreamingservice.New(sess, cfg) } -func (at *AmazonTranscribe) Start(ctx context.Context, r io.Reader) (*transcribestreamingservice.StartStreamTranscriptionEventStream, error) { +func (at *AmazonTranscribe) Start(ctx context.Context, r io.ReadCloser) (*transcribestreamingservice.StartStreamTranscriptionEventStream, error) { config := at.Config client := NewAmazonTranscribeClient(config) input := NewStartStreamTranscriptionInput(at) @@ -117,9 +118,11 @@ func (at *AmazonTranscribe) Start(ctx context.Context, r io.Reader) (*transcribe stream := resp.GetStream() go func() { + defer r.Close() defer stream.Close() if err := transcribestreamingservice.StreamAudioFromReader(ctx, stream, FrameSize, r); err != nil { + zlog.Error().Err(err).Send() return } }() diff --git a/amazon_transcribe_handler.go b/amazon_transcribe_handler.go index e374d3f..3080874 100644 --- a/amazon_transcribe_handler.go +++ b/amazon_transcribe_handler.go @@ -95,31 +95,19 @@ func (h *AmazonTranscribeHandler) ResetRetryCount() int { return h.RetryCount } -func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) { +func (h *AmazonTranscribeHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) { at := NewAmazonTranscribe(h.Config, h.LanguageCode, int64(h.SampleRate), int64(h.ChannelCount)) - oggReader, oggWriter := io.Pipe() - go func() { - defer oggWriter.Close() - if err := opus2ogg(ctx, reader, oggWriter, h.SampleRate, h.ChannelCount, h.Config); err != nil { - if !errors.Is(err, io.EOF) { - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Send() - } - - oggWriter.CloseWithError(err) - return - } - }() + packetReader := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config) - stream, err := at.Start(ctx, oggReader) + stream, err := at.Start(ctx, packetReader) if err != nil { return nil, err } + // リクエストが成功した時点でリトライカウントをリセットする + h.ResetRetryCount() + r, w := io.Pipe() go func() { @@ -195,33 +183,13 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) switch err.(type) { case *transcribestreamingservice.LimitExceededException, *transcribestreamingservice.InternalFailureException: - // リトライしない設定の場合、または、max_retry を超えた場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する - if (at.Config.MaxRetry < 1) || (at.Config.MaxRetry <= h.GetRetryCount()) { - if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Send() - } - } - - err = ErrServerDisconnected + err = errors.Join(err, ErrServerDisconnected) default: - // 再接続を想定している以外のエラーの場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する - if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Send() - } } w.CloseWithError(err) return } - w.Close() }() diff --git a/handler.go b/handler.go index 17683c2..9586524 100644 --- a/handler.go +++ b/handler.go @@ -3,6 +3,7 @@ package suzu import ( "context" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -103,7 +104,8 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte Msg("CONNECTED") c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - // すぐにヘッダを送信したい場合はここで c.Response().Flush() を実行する + // すぐにヘッダを送信したいので c.Response().Flush() を実行する + c.Response().Flush() ctx := c.Request().Context() // TODO: context.WithCancelCause(ctx) に変更する @@ -115,8 +117,18 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte channelCount := uint16(s.config.ChannelCount) d := time.Duration(s.config.TimeToWaitForOpusPacketMs) * time.Millisecond - r := NewOpusReader(*s.config, d, c.Request().Body) - defer r.Close() + opusReader := NewOpusReader(*s.config, d, c.Request().Body) + defer opusReader.Close() + + var r io.Reader + if s.config.AudioStreamingHeader { + r = readPacketWithHeader(opusReader) + } else { + // ヘッダー処理なし + r = opusReader + } + + opusCh := readOpus(ctx, r) serviceHandler, err := getServiceHandler(serviceType, *s.config, h.SoraChannelID, h.SoraConnectionID, sampleRate, channelCount, languageCode, onResultFunc) if err != nil { @@ -137,7 +149,11 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte Int("retry_count", serviceHandler.GetRetryCount()). Msg("NEW-REQUEST") - reader, err := serviceHandler.Handle(ctx, r) + // リトライ時にこれ以降の処理のみを cancel する + serviceHandlerCtx, cancelServiceHandler := context.WithCancel(ctx) + defer cancelServiceHandler() + + reader, err := serviceHandler.Handle(serviceHandlerCtx, opusCh) if err != nil { zlog.Error(). Err(err). @@ -149,11 +165,33 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte if s.config.MaxRetry > serviceHandler.GetRetryCount() { serviceHandler.UpdateRetryCount() - // 連続のリトライを避けるために少し待つ - time.Sleep(time.Duration(s.config.RetryIntervalMs) * time.Millisecond) - // リトライ対象のエラーのため、クライアントとの接続は切らずにリトライする - continue + retryTimer := time.NewTimer(time.Duration(s.config.RetryIntervalMs) * time.Millisecond) + + retry: + select { + case <-retryTimer.C: + zlog.Debug(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Msg("retry") + cancelServiceHandler() + continue + case _, ok := <-opusCh: + if ok { + // channel が閉じるか、または、リトライのタイマーが発火するまで繰り返す + goto retry + } + retryTimer.Stop() + zlog.Debug(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Msg("retry interrupted") + // リトライする前にクライアントとの接続でエラーが発生した場合は終了する + return fmt.Errorf("%s", "retry interrupted") + } } } // SuzuError の場合はその Status Code を返す @@ -180,14 +218,39 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte Send() return err } else if errors.Is(err, ErrServerDisconnected) { + errs := err.(interface{ Unwrap() []error }).Unwrap() + // 元の err を取得する + err := errs[0] + if s.config.MaxRetry < 1 { // サーバから切断されたが再接続させない設定の場合 zlog.Error(). + Err(ErrServerDisconnected). Err(err). Str("channel_id", h.SoraChannelID). Str("connection_id", h.SoraConnectionID). Send() - return err + + errMessage, err := json.Marshal(NewSuzuErrorResponse(err)) + if err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Send() + return err + } + + if _, err := c.Response().Write(errMessage); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Send() + return err + } + c.Response().Flush() + return ErrServerDisconnected } if s.config.MaxRetry > serviceHandler.GetRetryCount() { @@ -196,7 +259,7 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte serviceHandler.UpdateRetryCount() // TODO: 必要な場合は連続のリトライを避けるために少し待つ処理を追加する - + cancelServiceHandler() break } else { zlog.Error(). @@ -204,18 +267,64 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte Str("channel_id", h.SoraChannelID). Str("connection_id", h.SoraConnectionID). Send() + + errMessage, err := json.Marshal(NewSuzuErrorResponse(err)) + if err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Send() + return err + } + + if _, err := c.Response().Write(errMessage); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Send() + return err + } + c.Response().Flush() + // max_retry を超えた場合は終了 return c.NoContent(http.StatusOK) } } + zlog.Debug(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Send() + + orgErr := err + + errMessage, err := json.Marshal(NewSuzuErrorResponse(err)) + if err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Send() + return err + } + + if _, err := c.Response().Write(errMessage); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.SoraChannelID). + Str("connection_id", h.SoraConnectionID). + Send() + return err + } + c.Response().Flush() + // サーバから切断されたが再度の接続が期待できない場合 - return err + return orgErr } - // 1 度でも接続結果を受け取れた場合はリトライ回数をリセットする - serviceHandler.ResetRetryCount() - // メッセージが空でない場合はクライアントに結果を送信する if n > 0 { if _, err := c.Response().Write(buf[:n]); err != nil { @@ -233,9 +342,7 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte } } -const () - -func readPacketWithHeader(reader io.Reader) (io.Reader, error) { +func readPacketWithHeader(reader io.Reader) io.Reader { r, w := io.Pipe() go func() { @@ -313,57 +420,88 @@ func readPacketWithHeader(reader io.Reader) (io.Reader, error) { } }() - return r, nil + return r } -func opus2ogg(ctx context.Context, opusReader io.Reader, oggWriter io.Writer, sampleRate uint32, channelCount uint16, c Config) error { - o, err := NewWith(oggWriter, sampleRate, channelCount) - if err != nil { - if w, ok := oggWriter.(*io.PipeWriter); ok { - w.CloseWithError(err) - } - return err - } - defer o.Close() +func readOpus(ctx context.Context, reader io.Reader) chan opusChannel { + opusCh := make(chan opusChannel) - var r io.Reader - if c.AudioStreamingHeader { - r, err = readPacketWithHeader(opusReader) - if err != nil { - return err + go func() { + defer close(opusCh) + + for { + select { + case <-ctx.Done(): + opusCh <- opusChannel{ + Error: ctx.Err(), + } + return + default: + buf := make([]byte, FrameSize) + n, err := reader.Read(buf) + if err != nil { + opusCh <- opusChannel{ + Error: err, + } + return + } + + if n > 0 { + opusCh <- opusChannel{ + Payload: buf[:n], + } + + } + } } - } else { - r = opusReader - } + }() - for { - buf := make([]byte, FrameSize) - n, err := r.Read(buf) + return opusCh +} + +func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, channelCount uint16, c Config) io.ReadCloser { + oggReader, oggWriter := io.Pipe() + + go func() { + o, err := NewWith(oggWriter, sampleRate, channelCount) if err != nil { - if w, ok := oggWriter.(*io.PipeWriter); ok { - w.CloseWithError(err) - } - return err + oggWriter.CloseWithError(err) + return } + defer o.Close() - if n > 0 { - opus := codecs.OpusPacket{} - _, err := opus.Unmarshal(buf[:n]) - if err != nil { - if w, ok := oggWriter.(*io.PipeWriter); ok { - w.CloseWithError(err) + for { + select { + case <-ctx.Done(): + oggWriter.CloseWithError(ctx.Err()) + return + case opus, ok := <-opusCh: + if !ok { + oggWriter.CloseWithError(io.EOF) + return } - return err - } - if err := o.Write(&opus); err != nil { - if w, ok := oggWriter.(*io.PipeWriter); ok { - w.CloseWithError(err) + if err := opus.Error; err != nil { + oggWriter.CloseWithError(err) + return + } + + opusPacket := codecs.OpusPacket{} + _, err := opusPacket.Unmarshal(opus.Payload) + if err != nil { + oggWriter.CloseWithError(err) + return + } + + if err := o.Write(&opusPacket); err != nil { + oggWriter.CloseWithError(err) + return } - return err } } - } + }() + + return oggReader } type opusRequest struct { @@ -465,3 +603,41 @@ func silentPacket(audioStreamingHeader bool) []byte { return packet } + +type opusChannel struct { + Payload []byte + Error error +} + +func opusChannelToIOReadCloser(ctx context.Context, ch chan opusChannel) io.ReadCloser { + r, w := io.Pipe() + + go func() { + defer w.Close() + + for { + select { + case <-ctx.Done(): + w.CloseWithError(ctx.Err()) + return + case opus, ok := <-ch: + if !ok { + w.CloseWithError(io.EOF) + return + } + + if err := opus.Error; err != nil { + w.CloseWithError(err) + return + } + + if _, err := w.Write(opus.Payload); err != nil { + w.CloseWithError(err) + return + } + } + } + }() + + return r +} diff --git a/handler_test.go b/handler_test.go index b7192ad..ce3d365 100644 --- a/handler_test.go +++ b/handler_test.go @@ -34,7 +34,7 @@ func TestOpusPacketReader(t *testing.T) { } t.Run("success", func(t *testing.T) { - d := time.Duration(100) * time.Millisecond + d := time.Duration(3000) * time.Millisecond r := readDumpFile(t, "testdata/000.jsonl", 0) defer r.Close() @@ -47,12 +47,12 @@ func TestOpusPacketReader(t *testing.T) { assert.ErrorIs(t, err, io.EOF) break } - assert.Equal(t, buf[:n], []byte{0, 0, 0}) + assert.Equal(t, []byte{0, 0, 0}, buf[:n]) } }) t.Run("read error", func(t *testing.T) { - d := time.Duration(100) * time.Millisecond + d := time.Duration(3000) * time.Millisecond errPacketRead := errors.New("packet read error") r := NewErrReadCloser(errPacketRead) @@ -66,12 +66,12 @@ func TestOpusPacketReader(t *testing.T) { assert.ErrorIs(t, err, errPacketRead) break } - assert.Equal(t, buf[:n], []byte{255, 255, 254}) + assert.Equal(t, []byte{255, 255, 254}, buf[:n]) } }) t.Run("closed reader", func(t *testing.T) { - d := time.Duration(100) * time.Millisecond + d := time.Duration(3000) * time.Millisecond r := readDumpFile(t, "testdata/dump.jsonl", 0) r.Close() @@ -81,17 +81,17 @@ func TestOpusPacketReader(t *testing.T) { buf := make([]byte, FrameSize) _, err := reader.Read(buf) if err != nil { - assert.ErrorIs(t, err, io.ErrClosedPipe) + assert.ErrorIs(t, io.ErrClosedPipe, err) break } } }) t.Run("close reader", func(t *testing.T) { - d := time.Duration(100) * time.Millisecond + d := time.Duration(3000) * time.Millisecond r := readDumpFile(t, "testdata/dump.jsonl", 0) go func() { - time.Sleep(100 * time.Millisecond) + time.Sleep(3000 * time.Millisecond) r.Close() }() @@ -101,7 +101,7 @@ func TestOpusPacketReader(t *testing.T) { buf := make([]byte, FrameSize) _, err := reader.Read(buf) if err != nil { - assert.ErrorIs(t, err, io.EOF) + assert.ErrorIs(t, io.EOF, err) break } } @@ -288,8 +288,7 @@ func TestReadPacketWithHeader(t *testing.T) { } }() - r, err := readPacketWithHeader(reader) - assert.NoError(t, err) + r := readPacketWithHeader(reader) i := 0 for { diff --git a/packet_dump_handler.go b/packet_dump_handler.go index 1ed7e94..e35e855 100644 --- a/packet_dump_handler.go +++ b/packet_dump_handler.go @@ -67,7 +67,7 @@ func (h *PacketDumpHandler) ResetRetryCount() int { return h.RetryCount } -func (h *PacketDumpHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) { +func (h *PacketDumpHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) { c := h.Config filename := c.DumpFile channelID := h.ChannelID @@ -75,6 +75,8 @@ func (h *PacketDumpHandler) Handle(ctx context.Context, reader io.Reader) (*io.P r, w := io.Pipe() + reader := opusChannelToIOReadCloser(ctx, opusCh) + go func() { f, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { diff --git a/service_handler.go b/service_handler.go index 0a8b722..05a3cf2 100644 --- a/service_handler.go +++ b/service_handler.go @@ -15,7 +15,7 @@ var ( ) type serviceHandlerInterface interface { - Handle(context.Context, io.Reader) (*io.PipeReader, error) + Handle(context.Context, chan opusChannel) (*io.PipeReader, error) UpdateRetryCount() int GetRetryCount() int ResetRetryCount() int diff --git a/speech_to_text_handler.go b/speech_to_text_handler.go index 5c85fd6..7d0649f 100644 --- a/speech_to_text_handler.go +++ b/speech_to_text_handler.go @@ -91,30 +91,18 @@ func (h *SpeechToTextHandler) ResetRetryCount() int { return h.RetryCount } -func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) { +func (h *SpeechToTextHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) { stt := NewSpeechToText(h.Config, h.LanguageCode, int32(h.SampleRate), int32(h.ChannelCount)) - oggReader, oggWriter := io.Pipe() - go func() { - defer oggWriter.Close() - if err := opus2ogg(ctx, reader, oggWriter, h.SampleRate, h.ChannelCount, h.Config); err != nil { - if !errors.Is(err, io.EOF) { - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Send() - } - oggWriter.CloseWithError(err) - return - } - }() + packetReader := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config) - stream, err := stt.Start(ctx, oggReader) + stream, err := stt.Start(ctx, packetReader) if err != nil { return nil, err } + h.ResetRetryCount() + r, w := io.Pipe() go func() { @@ -147,53 +135,29 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io w.CloseWithError(err) return } + if status := resp.Error; status != nil { // 音声の長さの上限値に達した場合 + err := fmt.Errorf("%s", status.GetMessage()) code := codes.Code(status.GetCode()) - if code == codes.OutOfRange || - code == codes.InvalidArgument || - code == codes.ResourceExhausted { - - err := fmt.Errorf(status.GetMessage()) - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Int32("code", status.GetCode()). - Send() - - // リトライしない設定の場合、または、max_retry を超えた場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する - if (stt.Config.MaxRetry < 1) || (stt.Config.MaxRetry <= h.GetRetryCount()) { - if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Send() - } - } - - w.CloseWithError(ErrServerDisconnected) - return - } - errMessage := status.GetMessage() zlog.Error(). + Err(err). Str("channel_id", h.ChannelID). Str("connection_id", h.ConnectionID). Int32("code", status.GetCode()). - Msg(errMessage) + Send() - err := fmt.Errorf(errMessage) - if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Send() + if code == codes.OutOfRange || + code == codes.InvalidArgument || + code == codes.ResourceExhausted { + + err := errors.Join(err, ErrServerDisconnected) + w.CloseWithError(err) + return } - w.Close() + w.CloseWithError(err) return } diff --git a/test_handler.go b/test_handler.go index e264439..005ea90 100644 --- a/test_handler.go +++ b/test_handler.go @@ -73,9 +73,11 @@ func (h *TestHandler) ResetRetryCount() int { return h.RetryCount } -func (h *TestHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) { +func (h *TestHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) { r, w := io.Pipe() + reader := opusChannelToIOReadCloser(ctx, opusCh) + go func() { encoder := json.NewEncoder(w)