diff --git a/CHANGES.md b/CHANGES.md index 55c1066..3609c21 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,14 @@ ## develop +- [CHANGE] サービス接続時にエラーになった場合は、Body が空のレスポンスを返すように変更する + - @Hexa +- [CHANGE] サービス接続後にエラーになった場合は、{"type": "error", "reason": string} をクライアントへ送信するように変更する + - @Hexa +- [CHANGE] aws の再接続条件の exception から InternalFailureException を削除する + - @Hexa + + ## 2023.5.3 - [FIX] VERSION ファイルを tag のバージョンに修正する diff --git a/amazon_transcribe.go b/amazon_transcribe.go index 9501238..0b0e4da 100644 --- a/amazon_transcribe.go +++ b/amazon_transcribe.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/transcribestreamingservice" ) @@ -95,6 +96,14 @@ func (at *AmazonTranscribe) Start(ctx context.Context, r io.Reader) (*transcribe resp, err := client.StartStreamTranscriptionWithContext(ctx, &input) if err != nil { + if reqErr, ok := err.(awserr.RequestFailure); ok { + code := reqErr.StatusCode() + message := reqErr.Message() + return nil, &SuzuError{ + Code: code, + Message: message, + } + } return nil, err } diff --git a/amazon_transcribe_handler.go b/amazon_transcribe_handler.go index 2a0201e..32214b2 100644 --- a/amazon_transcribe_handler.go +++ b/amazon_transcribe_handler.go @@ -3,6 +3,7 @@ package suzu import ( "context" "encoding/json" + "errors" "io" "github.com/aws/aws-sdk-go/service/transcribestreamingservice" @@ -43,11 +44,10 @@ type AwsResult struct { TranscriptionResult } -func NewAwsResult(err error) AwsResult { +func NewAwsResult() AwsResult { return AwsResult{ TranscriptionResult: TranscriptionResult{ - Type: "aws", - Error: err, + Type: "aws", }, } } @@ -62,6 +62,11 @@ func (ar *AwsResult) WithIsPartial(isPartial bool) *AwsResult { return ar } +func (ar *AwsResult) SetMessage(message string) *AwsResult { + ar.Message = message + return ar +} + func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) { at := NewAmazonTranscribe(h.Config, h.LanguageCode, int64(h.SampleRate), int64(h.ChannelCount)) @@ -69,11 +74,14 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) go func() { defer oggWriter.Close() if err := opus2ogg(ctx, reader, oggWriter, h.SampleRate, h.ChannelCount, h.Config); err != nil { - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Send() + if !errors.Is(err, io.EOF) { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } + oggWriter.CloseWithError(err) return } @@ -99,6 +107,13 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) case *transcribestreamingservice.TranscriptEvent: if h.OnResultFunc != nil { if err := h.OnResultFunc(ctx, w, h.ChannelID, h.ConnectionID, h.LanguageCode, e.Transcript.Results); err != nil { + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } w.CloseWithError(err) return } @@ -111,7 +126,7 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) } } - result := NewAwsResult(nil) + result := NewAwsResult() if at.Config.AwsResultIsPartial { result.WithIsPartial(*res.IsPartial) } @@ -123,7 +138,7 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) if alt.Transcript != nil { message = *alt.Transcript } - result.Message = message + result.SetMessage(message) if err := encoder.Encode(result); err != nil { w.CloseWithError(err) return @@ -140,16 +155,34 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) if err := stream.Err(); err != nil { // 復帰が不可能なエラー以外は再接続を試みる switch err.(type) { - case *transcribestreamingservice.LimitExceededException, - *transcribestreamingservice.InternalFailureException: + case *transcribestreamingservice.LimitExceededException: zlog.Error(). Err(err). Str("channel_id", h.ChannelID). Str("connection_id", h.ConnectionID). Send() + // リトライしない設定の場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する + if !*at.Config.Retry { + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } + } + err = ErrServerDisconnected default: + // 再接続を想定している以外のエラーの場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } } w.CloseWithError(err) diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..34a3551 --- /dev/null +++ b/errors.go @@ -0,0 +1,10 @@ +package suzu + +type SuzuError struct { + Code int + Message string +} + +func (e *SuzuError) Error() string { + return e.Message +} diff --git a/handler.go b/handler.go index 1251641..c348a42 100644 --- a/handler.go +++ b/handler.go @@ -26,10 +26,17 @@ var ( type TranscriptionResult struct { Message string `json:"message,omitempty"` - Error error `json:"error,omitempty"` + Reason string `json:"reason,omitempty"` Type string `json:"type"` } +func NewSuzuErrorResponse(err error) TranscriptionResult { + return TranscriptionResult{ + Type: "error", + Reason: err.Error(), + } +} + func getServiceHandler(serviceType string, config Config, channelID, connectionID string, sampleRate uint32, channelCount uint16, languageCode string, onResultFunc any) (serviceHandlerInterface, error) { newHandlerFunc, err := NewServiceHandlerFuncs.get(serviceType) if err != nil { @@ -135,8 +142,13 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte Str("channel_id", h.SoraChannelID). Str("connection_id", h.SoraConnectionID). Send() - // TODO: エラー内容で status code を変更する - return echo.NewHTTPError(http.StatusInternalServerError) + if err, ok := err.(*SuzuError); ok { + // SuzuError の場合はその Status Code を返す + return c.NoContent(err.Code) + } + + // SuzuError 以外の場合は 500 を返す + return echo.NewHTTPError(http.StatusInternalServerError, err) } defer reader.Close() diff --git a/speech_to_text.go b/speech_to_text.go index 295e820..3dcafd6 100644 --- a/speech_to_text.go +++ b/speech_to_text.go @@ -43,17 +43,29 @@ func (stt SpeechToText) Start(ctx context.Context, r io.Reader) (speechpb.Speech client, err := speech.NewClient(ctx, opts...) if err != nil { - return nil, err + return nil, &SuzuError{ + // TODO: 適切な StatusCode に変更する + Code: 500, + Message: err.Error(), + } } stream, err := client.StreamingRecognize(ctx) if err != nil { - return nil, err + return nil, &SuzuError{ + // TODO: 適切な StatusCode に変更する + Code: 500, + Message: err.Error(), + } } if err := stream.Send(&speechpb.StreamingRecognizeRequest{ StreamingRequest: streamingRecognitionConfig, }); err != nil { - return nil, err + return nil, &SuzuError{ + // TODO: 適切な StatusCode に変更する + Code: 500, + Message: err.Error(), + } } go func() { diff --git a/speech_to_text_handler.go b/speech_to_text_handler.go index 0f28878..e49dea0 100644 --- a/speech_to_text_handler.go +++ b/speech_to_text_handler.go @@ -3,6 +3,8 @@ package suzu import ( "context" "encoding/json" + "errors" + "fmt" "io" "strings" @@ -45,11 +47,10 @@ type GcpResult struct { TranscriptionResult } -func NewGcpResult(err error) GcpResult { +func NewGcpResult() GcpResult { return GcpResult{ TranscriptionResult: TranscriptionResult{ - Type: "gcp", - Error: err, + Type: "gcp", }, } } @@ -64,6 +65,11 @@ func (gr *GcpResult) WithStability(stability float32) *GcpResult { return gr } +func (gr *GcpResult) SetMessage(message string) *GcpResult { + gr.Message = message + return gr +} + func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) { stt := NewSpeechToText(h.Config, h.LanguageCode, int32(h.SampleRate), int32(h.ChannelCount)) @@ -71,11 +77,13 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io go func() { defer oggWriter.Close() if err := opus2ogg(ctx, reader, oggWriter, h.SampleRate, h.ChannelCount, h.Config); err != nil { - zlog.Error(). - Err(err). - Str("channel_id", h.ChannelID). - Str("connection_id", h.ConnectionID). - Send() + if !errors.Is(err, io.EOF) { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } oggWriter.CloseWithError(err) return } @@ -107,32 +115,62 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io return } + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } + w.CloseWithError(err) return } - if status := resp.Error; err != nil { + if status := resp.Error; status != nil { // 音声の長さの上限値に達した場合 code := codes.Code(status.GetCode()) if code == codes.OutOfRange || code == codes.InvalidArgument || code == codes.ResourceExhausted { + err := fmt.Errorf(status.GetMessage()) zlog.Error(). Err(err). Str("channel_id", h.ChannelID). Str("connection_id", h.ConnectionID). Int32("code", status.GetCode()). - Msg(status.GetMessage()) - err := ErrServerDisconnected + Send() + + // リトライしない設定の場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する + if !*stt.Config.Retry { + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } + } - w.CloseWithError(err) + w.CloseWithError(ErrServerDisconnected) return } + + errMessage := status.GetMessage() zlog.Error(). Str("channel_id", h.ChannelID). Str("connection_id", h.ConnectionID). Int32("code", status.GetCode()). - Msg(status.GetMessage()) + Msg(errMessage) + + err := fmt.Errorf(errMessage) + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } w.Close() return @@ -140,6 +178,13 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io if h.OnResultFunc != nil { if err := h.OnResultFunc(ctx, w, h.ChannelID, h.ConnectionID, h.LanguageCode, resp.Results); err != nil { + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } w.CloseWithError(err) return } @@ -151,7 +196,7 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io } } - result := NewGcpResult(nil) + result := NewGcpResult() if stt.Config.GcpResultIsFinal { result.WithIsFinal(res.IsFinal) } @@ -173,7 +218,7 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io } } transcript := alternative.Transcript - result.Message = transcript + result.SetMessage(transcript) if err := encoder.Encode(result); err != nil { w.CloseWithError(err) return diff --git a/test_handler.go b/test_handler.go index 0113fcc..a8c3912 100644 --- a/test_handler.go +++ b/test_handler.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" "io" + + zlog "github.com/rs/zerolog/log" ) func init() { @@ -40,12 +42,13 @@ type TestResult struct { TranscriptionResult } -func TestErrorResult(err error) TestResult { +func NewTestResult(channelID, message string) TestResult { return TestResult{ TranscriptionResult: TranscriptionResult{ - Type: "test", - Error: err, + Type: "test", + Message: message, }, + ChannelID: &channelID, } } @@ -59,18 +62,33 @@ func (h *TestHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeRea buf := make([]byte, FrameSize) n, err := reader.Read(buf) if err != nil { + if err != io.EOF { + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } + } w.CloseWithError(err) return } if n > 0 { - var result TestResult - result.Type = "test" - result.Message = fmt.Sprintf("n: %d", n) - result.ChannelID = &[]string{"ch_0"}[0] + message := fmt.Sprintf("n: %d", n) + channelID := &[]string{"ch_0"}[0] + result := NewTestResult(*channelID, message) if h.OnResultFunc != nil { if err := h.OnResultFunc(ctx, w, h.ChannelID, h.ConnectionID, h.LanguageCode, result); err != nil { + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } w.CloseWithError(err) return } diff --git a/test_handler_test.go b/test_handler_test.go index 43fe17a..57541ff 100644 --- a/test_handler_test.go +++ b/test_handler_test.go @@ -125,7 +125,7 @@ func TestSpeechHandler(t *testing.T) { lastMessage = result.Message } // TODO: テストデータは固定のため、すべてのメッセージを確認する - assert.Equal(t, lastMessage, "n: 3") + assert.Equal(t, "n: 3", lastMessage) } }) @@ -359,4 +359,55 @@ func TestSpeechHandler(t *testing.T) { }) + t.Run("stream error", func(t *testing.T) { + r := readDumpFile(t, "testdata/dump.jsonl", 0) + defer r.Close() + + e := echo.New() + req := httptest.NewRequest("POST", path, r) + req.Header.Set("sora-audio-streaming-language-code", "ja-JP") + req.Proto = "HTTP/2.0" + req.ProtoMajor = 2 + req.ProtoMinor = 0 + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := s.createSpeechHandler(serviceType, func(ctx context.Context, w io.WriteCloser, chnanelID, connectionID, languageCode string, results any) error { + go func() { + defer w.Close() + + encoder := json.NewEncoder(w) + if err := encoder.Encode(NewSuzuErrorResponse(fmt.Errorf("STREAM-ERROR"))); err != nil { + return + } + }() + + return nil + }) + err := h(c) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + + delim := []byte("\n")[0] + for { + line, err := rec.Body.ReadBytes(delim) + if err != nil { + assert.ErrorIs(t, err, io.EOF) + break + } + + var result TranscriptionResult + if err := json.Unmarshal(line, &result); err != nil { + assert.ErrorIs(t, err, io.EOF) + } + + assert.Equal(t, "error", result.Type) + if assert.NotEmpty(t, result.Reason) { + assert.Equal(t, "STREAM-ERROR", result.Reason) + assert.Empty(t, result.Message) + } + } + } + + }) }