Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced simulated annealing with gonum.optimize. #113

Merged
merged 1 commit into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go/aio/aio.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import (
"github.com/google/zimtohrli/go/audio"
)

const (
DefaultSampleRate = 48000
)

// Fetch calls Recode if path ends with .wav, otherwise Copy.
func Fetch(path string, dir string) (string, error) {
if strings.ToLower(filepath.Ext(path)) == ".wav" {
Expand All @@ -37,7 +41,7 @@ func Fetch(path string, dir string) (string, error) {

// Load loads audio from an ffmpeg-decodable file from a path (which may be a URL).
func Load(path string) (*audio.Audio, error) {
return LoadAtRate(path, 48000)
return LoadAtRate(path, DefaultSampleRate)
}

// LoadAtRate loads audio from an ffmpeg-decodable file from a path (which may be a URL) and returns it at the given sample rate.
Expand Down
87 changes: 41 additions & 46 deletions go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@ import (
"runtime"
"sort"

"github.com/google/zimtohrli/go/aio"
"github.com/google/zimtohrli/go/data"
"github.com/google/zimtohrli/go/goohrli"
"github.com/google/zimtohrli/go/pipe"
"github.com/google/zimtohrli/go/progress"
"github.com/google/zimtohrli/go/worker"
)

const (
sampleRate = 48000
)

func main() {
details := flag.String("details", "", "Glob to directories with databases to show the details of.")
calculate := flag.String("calculate", "", "Glob to directories with databases to calculate metrics for.")
Expand All @@ -44,7 +41,7 @@ func main() {
zimtohrliScoreType := flag.String("zimtohrli_score_type", string(data.Zimtohrli), "Score type name to use when storing Zimtohrli scores in a dataset.")
calculateViSQOL := flag.Bool("calculate_visqol", false, "Whether to calculate ViSQOL scores.")
calculatePipeMetric := flag.String("calculate_pipe", "", "Path to a binary that serves metrics via stdin/stdout pipe. Install some of the via 'install_python_metrics.py'.")
zimtohrliParameters := goohrli.DefaultParameters(48000)
zimtohrliParameters := goohrli.DefaultParameters(aio.DefaultSampleRate)
b, err := json.Marshal(zimtohrliParameters)
if err != nil {
log.Panic(err)
Expand All @@ -55,42 +52,40 @@ func main() {
report := flag.String("report", "", "Glob to directories with databases to generate a report for.")
accuracy := flag.String("accuracy", "", "Glob to directories with databases to provide JND accuracy for.")
mse := flag.String("mse", "", "Glob to directories with databases to provide mean-square-error when predicting MOS or JND for.")
optimizedMSE := flag.String("optimized_mse", "", "Glob to directories with databases to provide mean-square-error when predicting MOS or JND for after having optimized the MOS mapping (as in `-optimize_mapping`).")
optimize := flag.String("optimize", "", "Glob to directories with databases to optimize for.")
optimizeLogfile := flag.String("optimize_logfile", "", "File to write optimization events to.")
optimizeStartStep := flag.Float64("optimize_start_step", 1, "Start step for the simulated annealing.")
optimizeNumSteps := flag.Float64("optimize_num_steps", 1000, "Number of steps for the simulated annealing.")
workers := flag.Int("workers", runtime.NumCPU(), "Number of concurrent workers for tasks.")
failFast := flag.Bool("fail_fast", false, "Whether to panic immediately on any error.")
optimizeMapping := flag.String("optimize_mapping", "", "Glob to directories with databases to optimize the MOS mapping for.")
flag.Parse()

if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mse == "" {
if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" && *mse == "" && *optimizedMSE == "" {
flag.Usage()
os.Exit(1)
}

if err := zimtohrliParameters.Update([]byte(*zimtohrliParametersJSON)); err != nil {
log.Panic(err)
}
if zimtohrliParameters.SampleRate != aio.DefaultSampleRate {
log.Fatalf("Zimtohrli sample rates != %v not supported by this tool, since it loads all data set audio at %vHz.", aio.DefaultSampleRate, aio.DefaultSampleRate)
}

if *optimize != "" {
bundles, err := data.OpenBundles(*optimize)
if err != nil {
log.Fatal(err)
}
optimizeLog := func(ev data.OptimizationEvent) {}
recorder := &data.Recorder{}
if *optimizeLogfile != "" {
f, err := os.OpenFile(*optimizeLogfile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
log.Fatal(err)
}
optimizeLog = func(ev data.OptimizationEvent) {
b, _ := json.Marshal(ev)
f.WriteString(string(b) + "\n")
f.Sync()
}
recorder.Output = f
}
if err = bundles.Optimize(*optimizeStartStep, *optimizeNumSteps, optimizeLog); err != nil {
if err = bundles.Optimize(recorder); err != nil {
log.Fatal(err)
}
}
Expand All @@ -111,7 +106,6 @@ func main() {
if !reflect.DeepEqual(zimtohrliParameters, goohrli.DefaultParameters(zimtohrliParameters.SampleRate)) {
log.Printf("Using %+v", zimtohrliParameters)
}
zimtohrliParameters.SampleRate = sampleRate
z := goohrli.New(zimtohrliParameters)
return z
}
Expand Down Expand Up @@ -176,53 +170,54 @@ func main() {
if err != nil {
log.Fatal(err)
}
for _, bundle := range bundles {
if bundle.IsJND() {
fmt.Printf("Not computing correlation for JND dataset %q\n\n", bundle.Dir)
} else {
corrTable, err := bundle.Correlate()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", bundle.Dir)
fmt.Println(corrTable)
}
corrTable, err := bundles.Correlate()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", *correlate)
fmt.Println(corrTable)
}

if *accuracy != "" {
bundles, err := data.OpenBundles(*accuracy)
if err != nil {
log.Fatal(err)
}
for _, bundle := range bundles {
if bundle.IsJND() {
accuracy, err := bundle.JNDAccuracy()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n", bundle.Dir)
fmt.Println(accuracy)
} else {
fmt.Printf("Not computing accuracy for non-JND dataset %q\n\n", bundle.Dir)
}
result, err := bundles.JNDAccuracy()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", *accuracy)
fmt.Println(result)
}

if *optimizedMSE != "" {
bundles, err := data.OpenBundles(*optimizedMSE)
if err != nil {
log.Fatal(err)
}
res, err := bundles.OptimizedZimtohrliMSE()
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", *optimizedMSE)
fmt.Printf("Error for MOS datasets is `human-MOS - Zimtohrli-predicted-MOS`. Error for JND datasets is `distance from correct side of threshold`.\n\n")
fmt.Printf("MSE after optimizing mapping: %.15f\n\n", res)
}

if *mse != "" {
bundles, err := data.OpenBundles(*mse)
if err != nil {
log.Fatal(err)
}
for _, bundle := range bundles {
z := makeZimtohrli()
mse, err := bundle.ZimtohrliMSE(z)
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n", bundle.Dir)
fmt.Printf("MSE: %.15f\n\n", mse)
z := makeZimtohrli()
res, err := bundles.ZimtohrliMSE(z, true)
if err != nil {
log.Fatal(err)
}
fmt.Printf("## %v\n\n", *mse)
fmt.Print("Error for MOS datasets is `human-MOS - Zimtohrli-predicted-MOS`. Error for JND datasets is `distance from correct side of threshold`.\n\n")
fmt.Printf("MSE: %.15f\n\n", res)
}

if *report != "" {
Expand Down
Loading
Loading