Skip to content

Commit

Permalink
- Read 時の処理を channel に置き換えて処理を中断できるようにする
Browse files Browse the repository at this point in the history
- opus から ogg への変換処理を共通化する
  • Loading branch information
Hexa committed Nov 28, 2024
1 parent 853817a commit fe3b43f
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 74 deletions.
22 changes: 2 additions & 20 deletions amazon_transcribe_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,10 @@ func (h *AmazonTranscribeHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) {
func (h *AmazonTranscribeHandler) Handle(ctx context.Context, packetReader io.Reader) (*io.PipeReader, error) {
at := NewAmazonTranscribe(h.Config, h.LanguageCode, int64(h.SampleRate), int64(h.ChannelCount))

oggReader, oggWriter := io.Pipe()
go func() {
defer oggWriter.Close()
if err := opus2ogg(ctx, reader, oggWriter, h.SampleRate, h.ChannelCount, h.Config); err != nil {
if !errors.Is(err, io.EOF) {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}

oggWriter.CloseWithError(err)
return
}
}()

stream, err := at.Start(ctx, oggReader)
stream, err := at.Start(ctx, packetReader)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -205,7 +188,6 @@ func (h *AmazonTranscribeHandler) Handle(ctx context.Context, reader io.Reader)
w.CloseWithError(err)
return
}

w.Close()
}()

Expand Down
163 changes: 127 additions & 36 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
if s.config.AudioStreamingHeader {
r = readPacketWithHeader(opusReader)
} else {
// ヘッダー処理なし
r = opusReader
}

opusCh := readOpus(ctx, r)

serviceHandler, err := getServiceHandler(serviceType, *s.config, h.SoraChannelID, h.SoraConnectionID, sampleRate, channelCount, languageCode, onResultFunc)
if err != nil {
zlog.Error().
Expand All @@ -149,7 +152,11 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
serviceHandlerCtx, cancelServiceHandler := context.WithCancel(ctx)
defer cancelServiceHandler()

reader, err := serviceHandler.Handle(serviceHandlerCtx, r)
oggCh := opus2ogg2(serviceHandlerCtx, opusCh, sampleRate, channelCount, *s.config)

packetReader := readOgg(serviceHandlerCtx, oggCh)

reader, err := serviceHandler.Handle(serviceHandlerCtx, packetReader)
if err != nil {
zlog.Error().
Err(err).
Expand All @@ -161,25 +168,22 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
if s.config.MaxRetry > serviceHandler.GetRetryCount() {
serviceHandler.UpdateRetryCount()

// 切断検知のために、クライアントから送られてくるパケットは受信し続ける
packetDiscardCtx, cancelPacketDiscard := context.WithCancel(serviceHandlerCtx)
defer cancelPacketDiscard()

errCh := make(chan error)
go discardPacket(packetDiscardCtx, r, errCh)

// 連続のリトライを避けるために少し待つ
retryTimer := time.NewTimer(time.Duration(s.config.RetryIntervalMs) * time.Millisecond)

retry:
select {
case <-retryTimer.C:
retryTimer.Stop()
cancelPacketDiscard()
// リトライ対象のエラーのため、クライアントとの接続は切らずにリトライする
continue
case err := <-errCh:
case _, ok := <-oggCh:
if ok {
// エラー、または、リトライのタイマーが発火するま繰り返す
goto retry
}
retryTimer.Stop()
// リトライする前にクライアントとの接続でエラーが発生した場合は終了する
return err
return fmt.Errorf("retry error")
}
}
}
Expand Down Expand Up @@ -331,24 +335,6 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
}
}

func discardPacket(ctx context.Context, r io.Reader, errCh chan error) {
defer close(errCh)

// サービス側には接続していないため、パケットは破棄する
buf := make([]byte, HeaderLength+MaxPayloadLength)
for {
select {
case <-ctx.Done():
return
default:
if _, err := r.Read(buf); err != nil {
errCh <- err
return
}
}
}
}

func readPacketWithHeader(reader io.Reader) io.Reader {
r, w := io.Pipe()

Expand Down Expand Up @@ -430,6 +416,117 @@ func readPacketWithHeader(reader io.Reader) io.Reader {
return r
}

func readOpus(ctx context.Context, reader io.Reader) chan []byte {
opusCh := make(chan []byte)

go func() {
defer close(opusCh)

for {
select {
case <-ctx.Done():
return
default:
buf := make([]byte, FrameSize)
n, err := reader.Read(buf)
if err != nil {
return
}

if n > 0 {
opusCh <- buf[:n]
}
}
}
}()

return opusCh
}

func readOgg(ctx context.Context, oggCh chan []byte) io.Reader {
pr, pw := io.Pipe()

go func() {
defer pw.Close()
for {
select {
case <-ctx.Done():
pw.CloseWithError(ctx.Err())
return
case buf, ok := <-oggCh:
if !ok {
pw.CloseWithError(fmt.Errorf("channel closed"))
return
}

if _, err := pw.Write(buf); err != nil {
pw.CloseWithError(err)
return
}
}
}
}()

return pr
}

func opus2ogg2(ctx context.Context, opusCh chan []byte, sampleRate uint32, channelCount uint16, c Config) chan []byte {
oggReader, oggWriter := io.Pipe()
oggCh := make(chan []byte)

go func() {
defer close(oggCh)

for {
buf := make([]byte, FrameSize)
n, err := oggReader.Read(buf)
if err != nil {
oggWriter.CloseWithError(err)
return
}
if n > 0 {
oggCh <- buf[:n]
}
}
}()

go func() {
o, err := NewWith(oggWriter, sampleRate, channelCount)
if err != nil {
oggWriter.CloseWithError(err)
return
}
defer o.Close()

for {
select {
case <-ctx.Done():
oggWriter.CloseWithError(ctx.Err())
return
case buf, ok := <-opusCh:
if !ok {
oggWriter.CloseWithError(fmt.Errorf("channel closed"))
return
}

opus := codecs.OpusPacket{}
_, err := opus.Unmarshal(buf)
if err != nil {
oggWriter.CloseWithError(err)
return
}

if err := o.Write(&opus); err != nil {
oggWriter.CloseWithError(err)
return
}
}
}
}()

return oggCh
}

func opus2ogg(ctx context.Context, opusReader io.Reader, oggWriter io.Writer, sampleRate uint32, channelCount uint16, c Config) error {

Check failure on line 530 in handler.go

View workflow job for this annotation

GitHub Actions / ci

func opus2ogg is unused (U1000)
o, err := NewWith(oggWriter, sampleRate, channelCount)
if err != nil {
Expand Down Expand Up @@ -470,12 +567,6 @@ func opus2ogg(ctx context.Context, opusReader io.Reader, oggWriter io.Writer, sa
return nil
}

if !ok {
if w, ok := oggWriter.(*io.PipeWriter); ok {
w.CloseWithError(err)
}
}

opus := codecs.OpusPacket{}
_, err := opus.Unmarshal(buf)
if err != nil {
Expand Down
20 changes: 2 additions & 18 deletions speech_to_text_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,26 +91,10 @@ func (h *SpeechToTextHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *SpeechToTextHandler) Handle(ctx context.Context, reader io.Reader) (*io.PipeReader, error) {
func (h *SpeechToTextHandler) Handle(ctx context.Context, packetReader io.Reader) (*io.PipeReader, error) {
stt := NewSpeechToText(h.Config, h.LanguageCode, int32(h.SampleRate), int32(h.ChannelCount))

oggReader, oggWriter := io.Pipe()
go func() {
defer oggWriter.Close()
if err := opus2ogg(ctx, reader, oggWriter, h.SampleRate, h.ChannelCount, h.Config); err != nil {
if !errors.Is(err, io.EOF) {
zlog.Error().
Err(err).
Str("channel_id", h.ChannelID).
Str("connection_id", h.ConnectionID).
Send()
}
oggWriter.CloseWithError(err)
return
}
}()

stream, err := stt.Start(ctx, oggReader)
stream, err := stt.Start(ctx, packetReader)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit fe3b43f

Please sign in to comment.