From bbd4656aecc0a9f613d16ee04939c06d33ea4bad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torbj=C3=B6rn=20Einarson?= Date: Tue, 23 Jan 2024 17:40:56 +0100 Subject: [PATCH] fix: add forgotten hevc/sei files --- hevc/sei.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ hevc/sei_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 hevc/sei.go create mode 100644 hevc/sei_test.go diff --git a/hevc/sei.go b/hevc/sei.go new file mode 100644 index 00000000..aa217755 --- /dev/null +++ b/hevc/sei.go @@ -0,0 +1,74 @@ +package hevc + +import ( + "bytes" + "errors" + "fmt" + + "github.com/Eyevinn/mp4ff/sei" +) + +var ( + ErrNotSEINalu = errors.New("not an SEI NAL unit") +) + +// ParseSEINalu - parse SEI NAL unit (incl header) and return messages given SPS. +// Returns sei.ErrRbspTrailingBitsMissing if the NALU is missing the trailing bits. +func ParseSEINalu(nalu []byte, sps *SPS) ([]sei.SEIMessage, error) { + switch GetNaluType(nalu[0]) { + case NALU_SEI_PREFIX, NALU_SEI_SUFFIX: + default: + return nil, ErrNotSEINalu + } + seiBytes := nalu[2:] // Skip NALU header + buf := bytes.NewReader(seiBytes) + seiDatas, err := sei.ExtractSEIData(buf) + missingRbspTrailingBits := false + if err != nil { + if errors.Is(err, sei.ErrRbspTrailingBitsMissing) { + missingRbspTrailingBits = true + } else { + return nil, fmt.Errorf("extracting SEI data: %w", err) + } + } + + seiMsgs := make([]sei.SEIMessage, 0, len(seiDatas)) + var seiMsg sei.SEIMessage + for _, seiData := range seiDatas { + switch { + case seiData.Type() == sei.SEIPicTimingType && sps != nil && sps.VUI != nil: + htp := fillHEVCPicTimingParams(sps) + seiMsg, err = sei.DecodePicTimingHevcSEI(&seiData, htp) + default: + seiMsg, err = sei.DecodeSEIMessage(&seiData, sei.HEVC) + } + if err != nil { + return nil, fmt.Errorf("sei decode: %w", err) + } + seiMsgs = append(seiMsgs, seiMsg) + } + if missingRbspTrailingBits { + return seiMsgs, sei.ErrRbspTrailingBitsMissing + } + return seiMsgs, nil +} + +func fillHEVCPicTimingParams(sps *SPS) sei.HEVCPicTimingParams { + hpt := sei.HEVCPicTimingParams{} + if sps.VUI == nil { + return hpt + } + hpt.FrameFieldInfoPresentFlag = sps.VUI.FrameFieldInfoPresentFlag + hrd := sps.VUI.HrdParameters + if hrd == nil { + return hpt + } + hpt.CpbDpbDelaysPresentFlag = hrd.CpbDpbDelaysPresentFlag() + hpt.SubPicHrdParamsPresentFlag = hrd.SubPicHrdParamsPresentFlag + hpt.SubPicCpbParamsInPicTimingSeiFlag = hrd.SubPicCpbParamsInPicTimingSeiFlag + hpt.AuCbpRemovalDelayLengthMinus1 = hrd.AuCpbRemovalDelayLengthMinus1 + hpt.DpbOutputDelayLengthMinus1 = hrd.DpbOutputDelayLengthMinus1 + hpt.DpbOutputDelayDuLengthMinus1 = hrd.DpbOutputDelayDuLengthMinus1 + hpt.DuCpbRemovalDelayIncrementLengthMinus1 = hrd.DuCpbRemovalDelayIncrementLengthMinus1 + return hpt +} diff --git a/hevc/sei_test.go b/hevc/sei_test.go new file mode 100644 index 00000000..720f9cde --- /dev/null +++ b/hevc/sei_test.go @@ -0,0 +1,69 @@ +package hevc + +import ( + "encoding/hex" + "testing" + + "github.com/Eyevinn/mp4ff/sei" +) + +func TestSEIParsing(t *testing.T) { + testCases := []struct { + desc string + spsNALUHex string + seiNALUHex string + expectedMsgs []sei.SEIMessage + expectedFrameField *sei.HEVCFrameFieldInfo + expectedErr error + }{ + { + desc: "Test SEI HEVC pic_timing with SPS", + spsNALUHex: "420101014000000300400000030000030078a003c080221f7a3ee46c1bdf4f60280d00000303e80000c350601def7e00028b1c001443c8", + seiNALUHex: "4e0101071000001a0000030180", + expectedMsgs: []sei.SEIMessage{&sei.PicTimingHevcSEI{}}, + expectedFrameField: &sei.HEVCFrameFieldInfo{ + PicStruct: 1, + SourceScanType: 0, + DuplicateFlag: false, + }, + expectedErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + spsBytes, err := hex.DecodeString(tc.spsNALUHex) + if err != nil { + t.Error(err) + } + sps, err := ParseSPSNALUnit(spsBytes) + if err != nil { + t.Fatalf("ParseSPSNALU failed: %v", err) + } + seiBytes, err := hex.DecodeString(tc.seiNALUHex) + if err != nil { + t.Error(err) + } + msgs, err := ParseSEINalu(seiBytes, sps) + if err != tc.expectedErr { + t.Fatalf("expected err %q got : %v", tc.expectedErr, err) + } + if len(msgs) != len(tc.expectedMsgs) { + t.Fatalf("Expected %d messages, got %d", len(tc.expectedMsgs), len(msgs)) + } + for i, msg := range msgs { + msgType := msg.Type() + if msgType != tc.expectedMsgs[i].Type() { + t.Errorf("Expected message type %d, got %d", tc.expectedMsgs[i].Type(), msg.Type()) + } + if (msg.Type() == sei.SEIPicTimingType) && tc.expectedFrameField != nil { + picTimeSEI := msg.(*sei.PicTimingHevcSEI) + gotFrameField := picTimeSEI.FrameFieldInfo + if *gotFrameField != *tc.expectedFrameField { + t.Errorf("Expected framefield %+v, got %+v", tc.expectedFrameField, gotFrameField) + } + } + } + }) + } +}