diff --git a/go/bin/score/score.go b/go/bin/score/score.go index 65d7381..3c66d0c 100644 --- a/go/bin/score/score.go +++ b/go/bin/score/score.go @@ -54,7 +54,7 @@ func main() { leaderboard := flag.String("leaderboard", "", "Glob to directories with databases to compute leaderboard for.") report := flag.String("report", "", "Glob to directories with databases to generate a report for.") accuracy := flag.String("accuracy", "", "Glob to directories with databases to provide JND accuracy for.") - mos_mse := flag.String("mos_mse", "", "Glob to directories with databases to provide Zimtohrli-MOS to regular-MOS MSE for.") + mse := flag.String("mse", "", "Glob to directories with databases to provide mean-square-error when predicting MOS or JND for.") optimize := flag.String("optimize", "", "Glob to directories with databases to optimize for.") optimizeLogfile := flag.String("optimize_logfile", "", "File to write optimization events to.") optimizeStartStep := flag.Float64("optimize_start_step", 1, "Start step for the simulated annealing.") @@ -64,7 +64,7 @@ func main() { optimizeMapping := flag.String("optimize_mapping", "", "Glob to directories with databases to optimize the MOS mapping for.") flag.Parse() - if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mos_mse == "" { + if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mse == "" { flag.Usage() os.Exit(1) } @@ -209,23 +209,19 @@ func main() { } } - if *mos_mse != "" { - bundles, err := data.OpenBundles(*mos_mse) + if *mse != "" { + bundles, err := data.OpenBundles(*mse) if err != nil { log.Fatal(err) } for _, bundle := range bundles { - if bundle.IsJND() { - fmt.Printf("Not computing MOS MSE for JND dataset %q\n\n", bundle.Dir) - } else { - z := makeZimtohrli() - mse, err := bundle.ZimtohrliMOSMSE(z) - if err != nil { - log.Fatal(err) - } - fmt.Printf("## %v\n", bundle.Dir) - fmt.Printf("MSE between human MOS and Zimtohrli MOS: %.15f\n", mse) + z := makeZimtohrli() + mse, err := bundle.ZimtohrliMSE(z) + if err != nil { + log.Fatal(err) } + fmt.Printf("## %v\n", bundle.Dir) + fmt.Printf("MSE: %.15f\n\n", mse) } } diff --git a/go/data/study.go b/go/data/study.go index 9fd92a4..39f5c40 100644 --- a/go/data/study.go +++ b/go/data/study.go @@ -399,42 +399,76 @@ func (r *ReferenceBundle) JNDAccuracy() (JNDAccuracyScores, error) { return result, nil } -// MOSMSE returns the precision when predicting the MOS score. -func (r *ReferenceBundle) ZimtohrliMOSMSE(z *goohrli.Goohrli) (float64, error) { +// ZimtohrliMSE returns the precision when predicting the MOS score or JND difference. +func (r *ReferenceBundle) ZimtohrliMSE(z *goohrli.Goohrli) (float64, error) { if r.IsJND() { - return 0, fmt.Errorf("cannot compute MOS precision on JND references") - } - if _, found := r.ScoreTypes[MOS]; !found { - return 0, fmt.Errorf("cannot compute MOS precision on a data set without MOS") - } - - var mosScaler func(mos float64) float64 - if math.Abs(*r.ScoreTypeLimits[MOS][0]-1) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-5) < 0.2 { - mosScaler = func(mos float64) float64 { - return mos + _, threshold, err := r.JNDAccuracyAndThreshold(Zimtohrli) + if err != nil { + return 0, err } - } else if math.Abs(*r.ScoreTypeLimits[MOS][0]) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-100) < 0.2 { - mosScaler = func(mos float64) float64 { - return 1 + 0.04*mos + sumOfSquares := 0.0 + count := 0 + for _, ref := range r.References { + for _, dist := range ref.Distortions { + jnd, found := dist.Scores[JND] + if !found { + return 0, fmt.Errorf("%+v doesn't have a JND score", ref) + } + zimt, found := dist.Scores[Zimtohrli] + if !found { + return 0, fmt.Errorf("%+v doesn't have a Zimtohrli score", ref) + } + switch jnd { + case 0: + if zimt >= threshold { + delta := zimt - threshold + sumOfSquares += delta * delta + } + case 1: + if zimt < threshold { + delta := zimt - threshold + sumOfSquares += delta * delta + } + default: + return 0, fmt.Errorf("%+v JND isn't 0 or 1", ref) + } + count++ + } } + return sumOfSquares / float64(count), nil } else { - return 0, fmt.Errorf("minimum MOS %v and maximum MOS %v are confusing", *r.ScoreTypeLimits[MOS][0], *r.ScoreTypeLimits[MOS][1]) - } + var mosScaler func(mos float64) float64 + if math.Abs(*r.ScoreTypeLimits[MOS][0]-1) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-5) < 0.2 { + mosScaler = func(mos float64) float64 { + return mos + } + } else if math.Abs(*r.ScoreTypeLimits[MOS][0]) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-100) < 0.2 { + mosScaler = func(mos float64) float64 { + return 1 + 0.04*mos + } + } else { + return 0, fmt.Errorf("minimum MOS %v and maximum MOS %v are confusing", *r.ScoreTypeLimits[MOS][0], *r.ScoreTypeLimits[MOS][1]) + } - sumOfSquares := 0.0 - count := 0 - for _, ref := range r.References { - for _, dist := range ref.Distortions { - mos, found := dist.Scores[MOS] - if !found { - return 0, fmt.Errorf("%+v doesn't have a MOS score", ref) + sumOfSquares := 0.0 + count := 0 + for _, ref := range r.References { + for _, dist := range ref.Distortions { + mos, found := dist.Scores[MOS] + if !found { + return 0, fmt.Errorf("%+v doesn't have a MOS score", ref) + } + zimt, found := dist.Scores[Zimtohrli] + if !found { + return 0, fmt.Errorf("%+v doesn't have a Zimtohrli score", ref) + } + delta := mosScaler(mos) - z.MOSFromZimtohrli(zimt) + sumOfSquares += delta * delta + count++ } - delta := mosScaler(mos) - z.MOSFromZimtohrli(dist.Scores[Zimtohrli]) - sumOfSquares += delta * delta - count++ } + return sumOfSquares / float64(count), nil } - return sumOfSquares / float64(count), nil } // Studies is a slice of studies. @@ -569,6 +603,7 @@ func (r ReferenceBundles) Split(rng *rand.Rand, split float64) (ReferenceBundles return left, right } +// MappingOptimizationResult contains the results of optimizing the MOS mapping. type MappingOptimizationResult struct { ParamsBefore []float64 MSEBefore float64 @@ -576,6 +611,7 @@ type MappingOptimizationResult struct { MSEAfter float64 } +// OptimizeMOSMapping optimizes the MOS mapping parameters. func (r ReferenceBundles) OptimizeMapping() (*MappingOptimizationResult, error) { z := goohrli.New(goohrli.DefaultParameters(48000)) errors := []error{} @@ -590,7 +626,7 @@ func (r ReferenceBundles) OptimizeMapping() (*MappingOptimizationResult, error) count := 0 for _, bundle := range r { if !bundle.IsJND() { - mse, err := bundle.ZimtohrliMOSMSE(z) + mse, err := bundle.ZimtohrliMSE(z) if err != nil { errors = append(errors, err) }