Skip to content

Commit

Permalink
Replaced simulated annealing with gonum.optimize.
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Bruse authored and zond committed Jun 19, 2024
1 parent 2f7d4d4 commit 2418507
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 673 deletions.
6 changes: 5 additions & 1 deletion go/aio/aio.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import (
"github.com/google/zimtohrli/go/audio"
)

const (
DefaultSampleRate = 48000
)

// Fetch calls Recode if path ends with .wav, otherwise Copy.
func Fetch(path string, dir string) (string, error) {
if strings.ToLower(filepath.Ext(path)) == ".wav" {
Expand All @@ -37,7 +41,7 @@ func Fetch(path string, dir string) (string, error) {

// Load loads audio from an ffmpeg-decodable file from a path (which may be a URL).
func Load(path string) (*audio.Audio, error) {
return LoadAtRate(path, 48000)
return LoadAtRate(path, DefaultSampleRate)
}

// LoadAtRate loads audio from an ffmpeg-decodable file from a path (which may be a URL) and returns it at the given sample rate.
Expand Down
87 changes: 41 additions & 46 deletions go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@ import (
"runtime"
"sort"

"github.com/google/zimtohrli/go/aio"
"github.com/google/zimtohrli/go/data"
"github.com/google/zimtohrli/go/goohrli"
"github.com/google/zimtohrli/go/pipe"
"github.com/google/zimtohrli/go/progress"
"github.com/google/zimtohrli/go/worker"
)

const (
sampleRate = 48000
)

func main() {
details := flag.String("details", "", "Glob to directories with databases to show the details of.")
calculate := flag.String("calculate", "", "Glob to directories with databases to calculate metrics for.")
Expand All @@ -44,7 +41,7 @@ func main() {
zimtohrliScoreType := flag.String("zimtohrli_score_type", string(data.Zimtohrli), "Score type name to use when storing Zimtohrli scores in a dataset.")
calculateViSQOL := flag.Bool("calculate_visqol", false, "Whether to calculate ViSQOL scores.")
calculatePipeMetric := flag.String("calculate_pipe", "", "Path to a binary that serves metrics via stdin/stdout pipe. Install some of the via 'install_python_metrics.py'.")
zimtohrliParameters := goohrli.DefaultParameters(48000)
zimtohrliParameters := goohrli.DefaultParameters(aio.DefaultSampleRate)
b, err := json.Marshal(zimtohrliParameters)
if err != nil {
log.Panic(err)
Expand All @@ -55,42 +52,40 @@ func main() {
report := flag.String("report", "", "Glob to directories with databases to generate a report for.")
accuracy := flag.String("accuracy", "", "Glob to directories with databases to provide JND accuracy for.")
mse := flag.String("mse", "", "Glob to directories with databases to provide mean-square-error when predicting MOS or JND for.")
optimizedMSE := flag.String("optimized_mse", "", "Glob to directories with databases to provide mean-square-error when predicting MOS or JND for after having optimized the MOS mapping (as in `-optimize_mapping`).")
optimize := flag.String("optimize", "", "Glob to directories with databases to optimize for.")
optimizeLogfile := flag.String("optimize_logfile", "", "File to write optimization events to.")
optimizeStartStep := flag.Float64("optimize_start_step", 1, "Start step for the simulated annealing.")
optimizeNumSteps := flag.Float64("optimize_num_steps", 1000, "Number of steps for the simulated annealing.")
workers := flag.Int("workers", runtime.NumCPU(), "Number of concurrent workers for tasks.")
failFast := flag.Bool("fail_fast", false, "Whether to panic immediately on any error.")
optimizeMapping := flag.String("optimize_mapping", "", "Glob to directories with databases to optimize the MOS mapping for.")
flag.Parse()

if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mse == "" {
if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mse == "" && *optimizedMSE == "" {
flag.Usage()
os.Exit(1)
}

if err := zimtohrliParameters.Update([]byte(*zimtohrliParametersJSON)); err != nil {
log.Panic(err)
}
if zimtohrliParameters.SampleRate != aio.DefaultSampleRate {
log.Fatalf("Zimtohrli sample rates != %v not supported by this tool, since it loads all data set audio at %vHz.", aio.DefaultSampleRate, aio.DefaultSampleRate)
}

if *optimize != "" {
bundles, err := data.OpenBundles(*optimize)
if err != nil {
log.Fatal(err)
}
optimizeLog := func(ev data.OptimizationEvent) {}
recorder := &data.Recorder{}
if *optimizeLogfile != "" {
f, err := os.OpenFile(*optimizeLogfile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
log.Fatal(err)
}
optimizeLog = func(ev data.OptimizationEvent) {
b, _ := json.Marshal(ev)
f.WriteString(string(b) + "\n")
f.Sync()
}
recorder.Output = f
}
if err = bundles.Optimize(*optimizeStartStep, *optimizeNumSteps, optimizeLog); err != nil {
if err = bundles.Optimize(recorder); err != nil {
log.Fatal(err)
}
}
Expand All @@ -111,7 +106,6 @@ func main() {
if !reflect.DeepEqual(zimtohrliParameters, goohrli.DefaultParameters(zimtohrliParameters.SampleRate)) {
log.Printf("Using %+v", zimtohrliParameters)
}
zimtohrliParameters.SampleRate = sampleRate
z := goohrli.New(zimtohrliParameters)
return z
}
Expand Down Expand Up @@ -176,53 +170,54 @@ func main() {
if err != nil {
log.Fatal(err)
}
for _, bundle := range bundles {
if bundle.IsJND() {
fmt.Printf("Not computing correlation for JND dataset %q\n\n", bundle.Dir)
} else {
corrTable, err := bundle.Correlate()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", bundle.Dir)
fmt.Println(corrTable)
}
corrTable, err := bundles.Correlate()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", *correlate)
fmt.Println(corrTable)
}

if *accuracy != "" {
bundles, err := data.OpenBundles(*accuracy)
if err != nil {
log.Fatal(err)
}
for _, bundle := range bundles {
if bundle.IsJND() {
accuracy, err := bundle.JNDAccuracy()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n", bundle.Dir)
fmt.Println(accuracy)
} else {
fmt.Printf("Not computing accuracy for non-JND dataset %q\n\n", bundle.Dir)
}
result, err := bundles.JNDAccuracy()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", *accuracy)
fmt.Println(result)
}

if *optimizedMSE != "" {
bundles, err := data.OpenBundles(*optimizedMSE)
if err != nil {
log.Fatal(err)
}
res, err := bundles.OptimizedZimtohrliMSE()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", *optimizedMSE)
fmt.Printf("Error for MOS datasets is `human-MOS - Zimtohrli-predicted-MOS`. Error for JND datasets is `distance from correct side of threshold`.\n\n")
fmt.Printf("MSE after optimizing mapping: %.15f\n\n", res)
}

if *mse != "" {
bundles, err := data.OpenBundles(*mse)
if err != nil {
log.Fatal(err)
}
for _, bundle := range bundles {
z := makeZimtohrli()
mse, err := bundle.ZimtohrliMSE(z)
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n", bundle.Dir)
fmt.Printf("MSE: %.15f\n\n", mse)
z := makeZimtohrli()
res, err := bundles.ZimtohrliMSE(z, true)
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", *mse)
fmt.Print("Error for MOS datasets is `human-MOS - Zimtohrli-predicted-MOS`. Error for JND datasets is `distance from correct side of threshold`.\n\n")
fmt.Printf("MSE: %.15f\n\n", res)
}

if *report != "" {
Expand Down
Loading

0 comments on commit 2418507

Please sign in to comment.