Skip to content

Commit

Permalink
Made JND datasets able to compute MSE.
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Bruse authored and zond committed Jun 18, 2024
1 parent b4ad528 commit 2f7d4d4
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 43 deletions.
24 changes: 10 additions & 14 deletions go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func main() {
leaderboard := flag.String("leaderboard", "", "Glob to directories with databases to compute leaderboard for.")
report := flag.String("report", "", "Glob to directories with databases to generate a report for.")
accuracy := flag.String("accuracy", "", "Glob to directories with databases to provide JND accuracy for.")
mos_mse := flag.String("mos_mse", "", "Glob to directories with databases to provide Zimtohrli-MOS to regular-MOS MSE for.")
mse := flag.String("mse", "", "Glob to directories with databases to provide mean-square-error when predicting MOS or JND for.")
optimize := flag.String("optimize", "", "Glob to directories with databases to optimize for.")
optimizeLogfile := flag.String("optimize_logfile", "", "File to write optimization events to.")
optimizeStartStep := flag.Float64("optimize_start_step", 1, "Start step for the simulated annealing.")
Expand All @@ -64,7 +64,7 @@ func main() {
optimizeMapping := flag.String("optimize_mapping", "", "Glob to directories with databases to optimize the MOS mapping for.")
flag.Parse()

if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mos_mse == "" {
if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mse == "" {
flag.Usage()
os.Exit(1)
}
Expand Down Expand Up @@ -209,23 +209,19 @@ func main() {
}
}

if *mos_mse != "" {
bundles, err := data.OpenBundles(*mos_mse)
if *mse != "" {
bundles, err := data.OpenBundles(*mse)
if err != nil {
log.Fatal(err)
}
for _, bundle := range bundles {
if bundle.IsJND() {
fmt.Printf("Not computing MOS MSE for JND dataset %q\n\n", bundle.Dir)
} else {
z := makeZimtohrli()
mse, err := bundle.ZimtohrliMOSMSE(z)
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n", bundle.Dir)
fmt.Printf("MSE between human MOS and Zimtohrli MOS: %.15f\n", mse)
z := makeZimtohrli()
mse, err := bundle.ZimtohrliMSE(z)
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n", bundle.Dir)
fmt.Printf("MSE: %.15f\n\n", mse)
}
}

Expand Down
94 changes: 65 additions & 29 deletions go/data/study.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,42 +399,76 @@ func (r *ReferenceBundle) JNDAccuracy() (JNDAccuracyScores, error) {
return result, nil
}

// MOSMSE returns the precision when predicting the MOS score.
func (r *ReferenceBundle) ZimtohrliMOSMSE(z *goohrli.Goohrli) (float64, error) {
// ZimtohrliMSE returns the precision when predicting the MOS score or JND difference.
func (r *ReferenceBundle) ZimtohrliMSE(z *goohrli.Goohrli) (float64, error) {
if r.IsJND() {
return 0, fmt.Errorf("cannot compute MOS precision on JND references")
}
if _, found := r.ScoreTypes[MOS]; !found {
return 0, fmt.Errorf("cannot compute MOS precision on a data set without MOS")
}

var mosScaler func(mos float64) float64
if math.Abs(*r.ScoreTypeLimits[MOS][0]-1) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-5) < 0.2 {
mosScaler = func(mos float64) float64 {
return mos
_, threshold, err := r.JNDAccuracyAndThreshold(Zimtohrli)
if err != nil {
return 0, err
}
} else if math.Abs(*r.ScoreTypeLimits[MOS][0]) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-100) < 0.2 {
mosScaler = func(mos float64) float64 {
return 1 + 0.04*mos
sumOfSquares := 0.0
count := 0
for _, ref := range r.References {
for _, dist := range ref.Distortions {
jnd, found := dist.Scores[JND]
if !found {
return 0, fmt.Errorf("%+v doesn't have a JND score", ref)
}
zimt, found := dist.Scores[Zimtohrli]
if !found {
return 0, fmt.Errorf("%+v doesn't have a Zimtohrli score", ref)
}
switch jnd {
case 0:
if zimt >= threshold {
delta := zimt - threshold
sumOfSquares += delta * delta
}
case 1:
if zimt < threshold {
delta := zimt - threshold
sumOfSquares += delta * delta
}
default:
return 0, fmt.Errorf("%+v JND isn't 0 or 1", ref)
}
count++
}
}
return sumOfSquares / float64(count), nil
} else {
return 0, fmt.Errorf("minimum MOS %v and maximum MOS %v are confusing", *r.ScoreTypeLimits[MOS][0], *r.ScoreTypeLimits[MOS][1])
}
var mosScaler func(mos float64) float64
if math.Abs(*r.ScoreTypeLimits[MOS][0]-1) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-5) < 0.2 {
mosScaler = func(mos float64) float64 {
return mos
}
} else if math.Abs(*r.ScoreTypeLimits[MOS][0]) < 0.2 && math.Abs(*r.ScoreTypeLimits[MOS][1]-100) < 0.2 {
mosScaler = func(mos float64) float64 {
return 1 + 0.04*mos
}
} else {
return 0, fmt.Errorf("minimum MOS %v and maximum MOS %v are confusing", *r.ScoreTypeLimits[MOS][0], *r.ScoreTypeLimits[MOS][1])
}

sumOfSquares := 0.0
count := 0
for _, ref := range r.References {
for _, dist := range ref.Distortions {
mos, found := dist.Scores[MOS]
if !found {
return 0, fmt.Errorf("%+v doesn't have a MOS score", ref)
sumOfSquares := 0.0
count := 0
for _, ref := range r.References {
for _, dist := range ref.Distortions {
mos, found := dist.Scores[MOS]
if !found {
return 0, fmt.Errorf("%+v doesn't have a MOS score", ref)
}
zimt, found := dist.Scores[Zimtohrli]
if !found {
return 0, fmt.Errorf("%+v doesn't have a Zimtohrli score", ref)
}
delta := mosScaler(mos) - z.MOSFromZimtohrli(zimt)
sumOfSquares += delta * delta
count++
}
delta := mosScaler(mos) - z.MOSFromZimtohrli(dist.Scores[Zimtohrli])
sumOfSquares += delta * delta
count++
}
return sumOfSquares / float64(count), nil
}
return sumOfSquares / float64(count), nil
}

// Studies is a slice of studies.
Expand Down Expand Up @@ -569,13 +603,15 @@ func (r ReferenceBundles) Split(rng *rand.Rand, split float64) (ReferenceBundles
return left, right
}

// MappingOptimizationResult contains the results of optimizing the MOS mapping.
type MappingOptimizationResult struct {
ParamsBefore []float64
MSEBefore float64
ParamsAfter []float64
MSEAfter float64
}

// OptimizeMOSMapping optimizes the MOS mapping parameters.
func (r ReferenceBundles) OptimizeMapping() (*MappingOptimizationResult, error) {
z := goohrli.New(goohrli.DefaultParameters(48000))
errors := []error{}
Expand All @@ -590,7 +626,7 @@ func (r ReferenceBundles) OptimizeMapping() (*MappingOptimizationResult, error)
count := 0
for _, bundle := range r {
if !bundle.IsJND() {
mse, err := bundle.ZimtohrliMOSMSE(z)
mse, err := bundle.ZimtohrliMSE(z)
if err != nil {
errors = append(errors, err)
}
Expand Down

0 comments on commit 2f7d4d4

Please sign in to comment.