Skip to content

Commit

Permalink
Merge pull request #162 from shiguredo/feature/stream-error
Browse files Browse the repository at this point in the history
エラー処理の変更
  • Loading branch information
Hexa authored Feb 15, 2024
2 parents 06ddaf8 + 2174edf commit dda1925
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 41 deletions.
8 changes: 8 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@

## develop

- [CHANGE] サービス接続時にエラーになった場合は、Body が空のレスポンスを返すように変更する
- @Hexa
- [CHANGE] サービス接続後にエラーになった場合は、{"type": "error", "reason": string} をクライアントへ送信するように変更する
- @Hexa
- [CHANGE] aws の再接続条件の exception から InternalFailureException を削除する
- @Hexa


## 2023.5.3

- [FIX] VERSION ファイルを tag のバージョンに修正する
Expand Down
9 changes: 9 additions & 0 deletions amazon_transcribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/transcribestreamingservice"
)
Expand Down Expand Up @@ -95,6 +96,14 @@ func (at *AmazonTranscribe) Start(ctx context.Context, r io.Reader) (*transcribe

resp, err := client.StartStreamTranscriptionWithContext(ctx, &input)
if err != nil {
if reqErr, ok := err.(awserr.RequestFailure); ok {
code := reqErr.StatusCode()
message := reqErr.Message()
return nil, &SuzuError{
Code: code,
Message: message,
}
}
return nil, err
}

Expand Down
57 changes: 45 additions & 12 deletions amazon_transcribe_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package suzu
import (
"context"
"encoding/json"
"errors"
"io"

"github.com/aws/aws-sdk-go/service/transcribestreamingservice"
Expand Down Expand Up @@ -43,11 +44,10 @@ type AwsResult struct {
TranscriptionResult
}

func NewAwsResult(err error) AwsResult {
func NewAwsResult() AwsResult {
return AwsResult{
TranscriptionResult: TranscriptionResult{
Type: "aws",
Error: err,
Type: "aws",
},
}
}
Expand All @@ -62,18 +62,26 @@ func (ar *AwsResult) WithIsPartial(isPartial bool) *AwsResult {
return ar
}

func (ar *AwsResult) SetMessage(message string) *AwsResult {
ar.Message = message
return ar
}

func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) {
at := NewAmazonTranscribe(h.Config, h.LanguageCode, int64(h.SampleRate), int64(h.ChannelCount))

oggReader, oggWriter := io.Pipe()
go func() {
defer oggWriter.Close()
if err := opus2ogg(ctx, reader, oggWriter, h.SampleRate, h.ChannelCount, h.Config); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
if !errors.Is(err, io.EOF) {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}

oggWriter.CloseWithError(err)
return
}
Expand All @@ -99,6 +107,13 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader)
case *transcribestreamingservice.TranscriptEvent:
if h.OnResultFunc != nil {
if err := h.OnResultFunc(ctx, w, h.ChannelID, h.ConnectionID, h.LanguageCode, e.Transcript.Results); err != nil {
if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}
w.CloseWithError(err)
return
}
Expand All @@ -111,7 +126,7 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader)
}
}

result := NewAwsResult(nil)
result := NewAwsResult()
if at.Config.AwsResultIsPartial {
result.WithIsPartial(*res.IsPartial)
}
Expand All @@ -123,7 +138,7 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader)
if alt.Transcript != nil {
message = *alt.Transcript
}
result.Message = message
result.SetMessage(message)
if err := encoder.Encode(result); err != nil {
w.CloseWithError(err)
return
Expand All @@ -140,16 +155,34 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader)
if err := stream.Err(); err != nil {
// 復帰が不可能なエラー以外は再接続を試みる
switch err.(type) {
case *transcribestreamingservice.LimitExceededException,
*transcribestreamingservice.InternalFailureException:
case *transcribestreamingservice.LimitExceededException:
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()

// リトライしない設定の場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する
if !*at.Config.Retry {
if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}
}

err = ErrServerDisconnected
default:
// 再接続を想定している以外のエラーの場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する
if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}
}

w.CloseWithError(err)
Expand Down
10 changes: 10 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package suzu

type SuzuError struct {
Code int
Message string
}

func (e *SuzuError) Error() string {
return e.Message
}
18 changes: 15 additions & 3 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@ var (

type TranscriptionResult struct {
Message string `json:"message,omitempty"`
Error error `json:"error,omitempty"`
Reason string `json:"reason,omitempty"`
Type string `json:"type"`
}

func NewSuzuErrorResponse(err error) TranscriptionResult {
return TranscriptionResult{
Type: "error",
Reason: err.Error(),
}
}

func getServiceHandler(serviceType string, config Config, channelID, connectionID string, sampleRate uint32, channelCount uint16, languageCode string, onResultFunc any) (serviceHandlerInterface, error) {
newHandlerFunc, err := NewServiceHandlerFuncs.get(serviceType)
if err != nil {
Expand Down Expand Up @@ -135,8 +142,13 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
Str("channel_id", h.SoraChannelID).
Str("connection_id", h.SoraConnectionID).
Send()
// TODO: エラー内容で status code を変更する
return echo.NewHTTPError(http.StatusInternalServerError)
if err, ok := err.(*SuzuError); ok {
// SuzuError の場合はその Status Code を返す
return c.NoContent(err.Code)
}

// SuzuError 以外の場合は 500 を返す
return echo.NewHTTPError(http.StatusInternalServerError, err)
}
defer reader.Close()

Expand Down
18 changes: 15 additions & 3 deletions speech_to_text.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,29 @@ func (stt SpeechToText) Start(ctx context.Context, r io.Reader) (speechpb.Speech

client, err := speech.NewClient(ctx, opts...)
if err != nil {
return nil, err
return nil, &SuzuError{
// TODO: 適切な StatusCode に変更する
Code: 500,
Message: err.Error(),
}
}
stream, err := client.StreamingRecognize(ctx)
if err != nil {
return nil, err
return nil, &SuzuError{
// TODO: 適切な StatusCode に変更する
Code: 500,
Message: err.Error(),
}
}

if err := stream.Send(&speechpb.StreamingRecognizeRequest{
StreamingRequest: streamingRecognitionConfig,
}); err != nil {
return nil, err
return nil, &SuzuError{
// TODO: 適切な StatusCode に変更する
Code: 500,
Message: err.Error(),
}
}

go func() {
Expand Down
75 changes: 60 additions & 15 deletions speech_to_text_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package suzu
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"

Expand Down Expand Up @@ -45,11 +47,10 @@ type GcpResult struct {
TranscriptionResult
}

func NewGcpResult(err error) GcpResult {
func NewGcpResult() GcpResult {
return GcpResult{
TranscriptionResult: TranscriptionResult{
Type: "gcp",
Error: err,
Type: "gcp",
},
}
}
Expand All @@ -64,18 +65,25 @@ func (gr *GcpResult) WithStability(stability float32) *GcpResult {
return gr
}

func (gr *GcpResult) SetMessage(message string) *GcpResult {
gr.Message = message
return gr
}

func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) {
stt := NewSpeechToText(h.Config, h.LanguageCode, int32(h.SampleRate), int32(h.ChannelCount))

oggReader, oggWriter := io.Pipe()
go func() {
defer oggWriter.Close()
if err := opus2ogg(ctx, reader, oggWriter, h.SampleRate, h.ChannelCount, h.Config); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
if !errors.Is(err, io.EOF) {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}
oggWriter.CloseWithError(err)
return
}
Expand Down Expand Up @@ -107,39 +115,76 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io
return
}

if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}

w.CloseWithError(err)
return
}
if status := resp.Error; err != nil {
if status := resp.Error; status != nil {
// 音声の長さの上限値に達した場合
code := codes.Code(status.GetCode())
if code == codes.OutOfRange ||
code == codes.InvalidArgument ||
code == codes.ResourceExhausted {

err := fmt.Errorf(status.GetMessage())
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Int32("code", status.GetCode()).
Msg(status.GetMessage())
err := ErrServerDisconnected
Send()

// リトライしない設定の場合はクライアントにエラーを返し、再度接続するかはクライアント側で判断する
if !*stt.Config.Retry {
if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}
}

w.CloseWithError(err)
w.CloseWithError(ErrServerDisconnected)
return
}

errMessage := status.GetMessage()
zlog.Error().
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Int32("code", status.GetCode()).
Msg(status.GetMessage())
Msg(errMessage)

err := fmt.Errorf(errMessage)
if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}

w.Close()
return
}

if h.OnResultFunc != nil {
if err := h.OnResultFunc(ctx, w, h.ChannelID, h.ConnectionID, h.LanguageCode, resp.Results); err != nil {
if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}
w.CloseWithError(err)
return
}
Expand All @@ -151,7 +196,7 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io
}
}

result := NewGcpResult(nil)
result := NewGcpResult()
if stt.Config.GcpResultIsFinal {
result.WithIsFinal(res.IsFinal)
}
Expand All @@ -173,7 +218,7 @@ func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io
}
}
transcript := alternative.Transcript
result.Message = transcript
result.SetMessage(transcript)
if err := encoder.Encode(result); err != nil {
w.CloseWithError(err)
return
Expand Down
Loading

0 comments on commit dda1925

Please sign in to comment.